summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.github/ISSUE_TEMPLATE/bug_report.md2
-rw-r--r--.github/workflows/build.yml7
-rw-r--r--.github/workflows/go.yml1
-rw-r--r--.github/workflows/issue_reviver.yml2
-rw-r--r--Makefile37
-rw-r--r--g3doc/architecture_guide/performance.md4
-rw-r--r--g3doc/architecture_guide/resources.md4
-rw-r--r--g3doc/architecture_guide/security.md2
-rw-r--r--g3doc/user_guide/FAQ.md29
-rw-r--r--g3doc/user_guide/install.md184
-rw-r--r--g3doc/user_guide/quick_start/docker.md27
-rw-r--r--g3doc/user_guide/tutorials/BUILD13
-rw-r--r--g3doc/user_guide/tutorials/cni.md4
-rw-r--r--g3doc/user_guide/tutorials/docker-compose.md100
-rw-r--r--g3doc/user_guide/tutorials/docker.md8
-rw-r--r--g3doc/user_guide/tutorials/kubernetes.md186
-rw-r--r--images/Makefile8
-rw-r--r--images/benchmarks/httpd/Dockerfile2
-rw-r--r--images/benchmarks/nginx/Dockerfile11
-rw-r--r--images/benchmarks/nginx/nginx.conf19
-rw-r--r--images/benchmarks/nginx/nginx_gofer.conf19
-rw-r--r--images/jekyll/Dockerfile7
-rwxr-xr-ximages/jekyll/build.sh (renamed from scripts/fuse_tests.sh)8
-rw-r--r--images/packetdrill/Dockerfile2
-rw-r--r--images/packetimpact/Dockerfile10
-rw-r--r--pkg/abi/linux/BUILD4
-rw-r--r--pkg/abi/linux/aio.go3
-rw-r--r--pkg/abi/linux/bpf.go1
-rw-r--r--pkg/abi/linux/capability.go4
-rw-r--r--pkg/abi/linux/dev.go6
-rw-r--r--pkg/abi/linux/fcntl.go4
-rw-r--r--pkg/abi/linux/fuse.go710
-rw-r--r--pkg/abi/linux/ioctl.go29
-rw-r--r--pkg/abi/linux/ipc.go2
-rw-r--r--pkg/abi/linux/linux.go9
-rw-r--r--pkg/abi/linux/netfilter.go43
-rw-r--r--pkg/abi/linux/netfilter_ipv6.go35
-rw-r--r--pkg/abi/linux/netlink.go2
-rw-r--r--pkg/abi/linux/poll.go2
-rw-r--r--pkg/abi/linux/rusage.go2
-rw-r--r--pkg/abi/linux/seccomp.go23
-rw-r--r--pkg/abi/linux/sem.go4
-rw-r--r--pkg/abi/linux/shm.go6
-rw-r--r--pkg/abi/linux/signal.go2
-rw-r--r--pkg/abi/linux/socket.go13
-rw-r--r--pkg/abi/linux/time.go16
-rw-r--r--pkg/abi/linux/tty.go7
-rw-r--r--pkg/abi/linux/utsname.go2
-rw-r--r--pkg/amutex/BUILD5
-rw-r--r--pkg/amutex/amutex.go30
-rw-r--r--pkg/bpf/decoder.go13
-rw-r--r--pkg/bpf/decoder_test.go4
-rw-r--r--pkg/bpf/program_builder.go23
-rw-r--r--pkg/bpf/program_builder_test.go42
-rw-r--r--pkg/buffer/BUILD7
-rw-r--r--pkg/buffer/buffer.go35
-rw-r--r--pkg/buffer/pool.go83
-rw-r--r--pkg/buffer/pool_test.go51
-rw-r--r--pkg/buffer/safemem.go4
-rw-r--r--pkg/buffer/safemem_test.go2
-rw-r--r--pkg/buffer/view.go13
-rw-r--r--pkg/buffer/view_test.go77
-rw-r--r--pkg/context/BUILD1
-rw-r--r--pkg/context/context.go57
-rw-r--r--pkg/coverage/coverage.go39
-rw-r--r--pkg/fd/fd.go42
-rw-r--r--pkg/lisafs/README.md363
-rw-r--r--pkg/marshal/BUILD (renamed from tools/go_marshal/marshal/BUILD)4
-rw-r--r--pkg/marshal/marshal.go (renamed from tools/go_marshal/marshal/marshal.go)21
-rw-r--r--pkg/marshal/marshal_impl_util.go (renamed from tools/go_marshal/marshal/marshal_impl_util.go)8
-rw-r--r--pkg/marshal/primitive/BUILD (renamed from tools/go_marshal/primitive/BUILD)3
-rw-r--r--pkg/marshal/primitive/primitive.go (renamed from tools/go_marshal/primitive/primitive.go)164
-rw-r--r--pkg/merkletree/merkletree.go25
-rw-r--r--pkg/merkletree/merkletree_test.go42
-rw-r--r--pkg/p9/client_file.go38
-rw-r--r--pkg/p9/file.go24
-rw-r--r--pkg/p9/handlers.go31
-rw-r--r--pkg/p9/messages.go60
-rw-r--r--pkg/p9/messages_test.go24
-rw-r--r--pkg/p9/p9.go162
-rw-r--r--pkg/p9/p9test/client_test.go23
-rw-r--r--pkg/p9/version.go8
-rw-r--r--pkg/refs/refcounter.go4
-rw-r--r--pkg/safemem/BUILD4
-rw-r--r--pkg/seccomp/seccomp.go177
-rw-r--r--pkg/seccomp/seccomp_rules.go75
-rw-r--r--pkg/seccomp/seccomp_test.go603
-rw-r--r--pkg/seccomp/seccomp_test_victim.go2
-rw-r--r--pkg/sentry/arch/BUILD4
-rw-r--r--pkg/sentry/arch/arch.go7
-rw-r--r--pkg/sentry/arch/arch_amd64.go12
-rw-r--r--pkg/sentry/arch/arch_arm64.go12
-rw-r--r--pkg/sentry/arch/signal_act.go2
-rw-r--r--pkg/sentry/arch/signal_amd64.go27
-rw-r--r--pkg/sentry/arch/signal_arm64.go30
-rw-r--r--pkg/sentry/arch/signal_stack.go2
-rw-r--r--pkg/sentry/arch/stack.go179
-rw-r--r--pkg/sentry/arch/stack_unsafe.go69
-rw-r--r--pkg/sentry/control/BUILD1
-rw-r--r--pkg/sentry/control/proc.go28
-rw-r--r--pkg/sentry/devices/memdev/BUILD5
-rw-r--r--pkg/sentry/devices/memdev/full.go4
-rw-r--r--pkg/sentry/devices/memdev/null.go4
-rw-r--r--pkg/sentry/devices/memdev/random.go4
-rw-r--r--pkg/sentry/devices/memdev/zero.go28
-rw-r--r--pkg/sentry/devices/ttydev/ttydev.go2
-rw-r--r--pkg/sentry/fdimport/BUILD1
-rw-r--r--pkg/sentry/fdimport/fdimport.go22
-rw-r--r--pkg/sentry/fs/g3doc/fuse.md99
-rw-r--r--pkg/sentry/fs/host/BUILD2
-rw-r--r--pkg/sentry/fs/host/socket_unsafe.go4
-rw-r--r--pkg/sentry/fs/host/tty.go4
-rw-r--r--pkg/sentry/fs/inode.go2
-rw-r--r--pkg/sentry/fs/inode_overlay.go11
-rw-r--r--pkg/sentry/fs/proc/BUILD1
-rw-r--r--pkg/sentry/fs/proc/sys_net.go120
-rw-r--r--pkg/sentry/fs/proc/sys_net_state.go15
-rw-r--r--pkg/sentry/fs/proc/sys_net_test.go73
-rw-r--r--pkg/sentry/fs/proc/task.go46
-rw-r--r--pkg/sentry/fs/tty/BUILD3
-rw-r--r--pkg/sentry/fs/tty/dir.go46
-rw-r--r--pkg/sentry/fs/tty/fs.go4
-rw-r--r--pkg/sentry/fs/tty/line_discipline.go55
-rw-r--r--pkg/sentry/fs/tty/master.go37
-rw-r--r--pkg/sentry/fs/tty/queue.go14
-rw-r--r--pkg/sentry/fs/tty/replica.go (renamed from pkg/sentry/fs/tty/slave.go)88
-rw-r--r--pkg/sentry/fs/tty/terminal.go39
-rw-r--r--pkg/sentry/fs/tty/tty_test.go4
-rw-r--r--pkg/sentry/fsimpl/devpts/BUILD4
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go52
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts_test.go4
-rw-r--r--pkg/sentry/fsimpl/devpts/line_discipline.go55
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go47
-rw-r--r--pkg/sentry/fsimpl/devpts/queue.go14
-rw-r--r--pkg/sentry/fsimpl/devpts/replica.go204
-rw-r--r--pkg/sentry/fsimpl/devpts/slave.go198
-rw-r--r--pkg/sentry/fsimpl/devpts/terminal.go37
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs.go6
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go2
-rw-r--r--pkg/sentry/fsimpl/eventfd/eventfd.go10
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD4
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_file.go32
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_test.go46
-rw-r--r--pkg/sentry/fsimpl/ext/dentry.go2
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go11
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/BUILD3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_32.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_64.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent.go3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_new.go4
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_old.go4
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/disklayout.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent.go12
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent_test.go9
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode.go3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_new.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_old.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_32.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_64.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_old.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_test.go9
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/test_utils.go6
-rw-r--r--pkg/sentry/fsimpl/ext/ext.go2
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/extent_file.go7
-rw-r--r--pkg/sentry/fsimpl/ext/extent_test.go19
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go22
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go2
-rw-r--r--pkg/sentry/fsimpl/ext/regular_file.go6
-rw-r--r--pkg/sentry/fsimpl/ext/symlink.go6
-rw-r--r--pkg/sentry/fsimpl/ext/utils.go8
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD18
-rw-r--r--pkg/sentry/fsimpl/fuse/connection.go365
-rw-r--r--pkg/sentry/fsimpl/fuse/connection_control.go (renamed from pkg/sentry/fsimpl/fuse/init.go)171
-rw-r--r--pkg/sentry/fsimpl/fuse/connection_test.go117
-rw-r--r--pkg/sentry/fsimpl/fuse/dev.go208
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go105
-rw-r--r--pkg/sentry/fsimpl/fuse/directory.go105
-rw-r--r--pkg/sentry/fsimpl/fuse/file.go133
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go573
-rw-r--r--pkg/sentry/fsimpl/fuse/read_write.go242
-rw-r--r--pkg/sentry/fsimpl/fuse/regular_file.go230
-rw-r--r--pkg/sentry/fsimpl/fuse/request_response.go229
-rw-r--r--pkg/sentry/fsimpl/fuse/utils_test.go132
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go3
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go48
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go181
-rw-r--r--pkg/sentry/fsimpl/gofer/handle.go2
-rw-r--r--pkg/sentry/fsimpl/gofer/p9file.go7
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go34
-rw-r--r--pkg/sentry/fsimpl/gofer/socket.go4
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go31
-rw-r--r--pkg/sentry/fsimpl/host/BUILD2
-rw-r--r--pkg/sentry/fsimpl/host/host.go89
-rw-r--r--pkg/sentry/fsimpl/host/mmap.go6
-rw-r--r--pkg/sentry/fsimpl/host/socket_unsafe.go4
-rw-r--r--pkg/sentry/fsimpl/host/tty.go14
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD1
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go10
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go15
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go276
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go117
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go98
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go30
-rw-r--r--pkg/sentry/fsimpl/kernfs/symlink.go6
-rw-r--r--pkg/sentry/fsimpl/kernfs/synthetic_directory.go102
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go166
-rw-r--r--pkg/sentry/fsimpl/overlay/directory.go14
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go452
-rw-r--r--pkg/sentry/fsimpl/overlay/non_directory.go118
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go176
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go8
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD1
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go7
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go23
-rw-r--r--pkg/sentry/fsimpl/proc/task.go17
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go50
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go138
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go7
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go27
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go24
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go69
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys_test.go71
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go3
-rw-r--r--pkg/sentry/fsimpl/signalfd/signalfd.go10
-rw-r--r--pkg/sentry/fsimpl/sockfs/sockfs.go11
-rw-r--r--pkg/sentry/fsimpl/sys/kcov.go7
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go13
-rw-r--r--pkg/sentry/fsimpl/sys/sys_test.go2
-rw-r--r--pkg/sentry/fsimpl/timerfd/timerfd.go8
-rw-r--r--pkg/sentry/fsimpl/tmpfs/benchmark_test.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/device_file.go1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go50
-rw-r--r--pkg/sentry/fsimpl/tmpfs/named_pipe.go1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/pipe_test.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go80
-rw-r--r--pkg/sentry/fsimpl/tmpfs/socket_file.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/symlink.go1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go143
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs_test.go2
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD25
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go264
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go291
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go349
-rw-r--r--pkg/sentry/inet/BUILD1
-rw-r--r--pkg/sentry/inet/inet.go11
-rw-r--r--pkg/sentry/inet/test_stack.go17
-rw-r--r--pkg/sentry/kernel/BUILD8
-rw-r--r--pkg/sentry/kernel/auth/BUILD1
-rw-r--r--pkg/sentry/kernel/auth/id.go4
-rw-r--r--pkg/sentry/kernel/fd_table.go162
-rw-r--r--pkg/sentry/kernel/fd_table_test.go8
-rw-r--r--pkg/sentry/kernel/fd_table_unsafe.go61
-rw-r--r--pkg/sentry/kernel/kcov.go4
-rw-r--r--pkg/sentry/kernel/kernel.go103
-rw-r--r--pkg/sentry/kernel/pipe/BUILD1
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go3
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go12
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go5
-rw-r--r--pkg/sentry/kernel/ptrace.go23
-rw-r--r--pkg/sentry/kernel/ptrace_amd64.go2
-rw-r--r--pkg/sentry/kernel/syscalls.go10
-rw-r--r--pkg/sentry/kernel/task_clone.go6
-rw-r--r--pkg/sentry/kernel/task_context.go6
-rw-r--r--pkg/sentry/kernel/task_exit.go3
-rw-r--r--pkg/sentry/kernel/task_futex.go7
-rw-r--r--pkg/sentry/kernel/task_run.go2
-rw-r--r--pkg/sentry/kernel/task_signals.go6
-rw-r--r--pkg/sentry/kernel/task_syscall.go7
-rw-r--r--pkg/sentry/kernel/task_usermem.go43
-rw-r--r--pkg/sentry/kernel/threads.go2
-rw-r--r--pkg/sentry/loader/loader.go10
-rw-r--r--pkg/sentry/mm/mm.go2
-rw-r--r--pkg/sentry/mm/mm_test.go3
-rw-r--r--pkg/sentry/mm/special_mappable.go9
-rw-r--r--pkg/sentry/mm/syscalls.go13
-rw-r--r--pkg/sentry/platform/kvm/BUILD10
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.s7
-rw-r--r--pkg/sentry/platform/kvm/kvm.go10
-rw-r--r--pkg/sentry/platform/kvm/machine.go21
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go41
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go13
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go4
-rw-r--r--pkg/sentry/platform/ptrace/BUILD1
-rw-r--r--pkg/sentry/platform/ptrace/filters.go9
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go5
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go2
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go10
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go61
-rw-r--r--pkg/sentry/platform/ring0/defs_amd64.go39
-rw-r--r--pkg/sentry/platform/ring0/defs_arm64.go3
-rw-r--r--pkg/sentry/platform/ring0/entry_amd64.go7
-rw-r--r--pkg/sentry/platform/ring0/entry_amd64.s193
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s52
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/BUILD5
-rw-r--r--pkg/sentry/platform/ring0/kernel.go22
-rw-r--r--pkg/sentry/platform/ring0/kernel_amd64.go64
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go9
-rw-r--r--pkg/sentry/platform/ring0/lib_amd64.go12
-rw-r--r--pkg/sentry/platform/ring0/lib_amd64.s47
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.go9
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.s14
-rw-r--r--pkg/sentry/platform/ring0/offsets_amd64.go11
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go1
-rw-r--r--pkg/sentry/platform/ring0/x86.go2
-rw-r--r--pkg/sentry/socket/BUILD2
-rw-r--r--pkg/sentry/socket/hostinet/BUILD7
-rw-r--r--pkg/sentry/socket/hostinet/socket.go4
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go13
-rw-r--r--pkg/sentry/socket/hostinet/stack.go30
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go72
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go23
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go23
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go26
-rw-r--r--pkg/sentry/socket/netfilter/targets.go472
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go32
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go32
-rw-r--r--pkg/sentry/socket/netlink/BUILD4
-rw-r--r--pkg/sentry/socket/netlink/provider_vfs2.go1
-rw-r--r--pkg/sentry/socket/netlink/socket.go4
-rw-r--r--pkg/sentry/socket/netlink/socket_vfs2.go7
-rw-r--r--pkg/sentry/socket/netstack/BUILD4
-rw-r--r--pkg/sentry/socket/netstack/netstack.go95
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go27
-rw-r--r--pkg/sentry/socket/netstack/stack.go43
-rw-r--r--pkg/sentry/socket/socket.go2
-rw-r--r--pkg/sentry/socket/unix/BUILD18
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go14
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go4
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go23
-rw-r--r--pkg/sentry/socket/unix/unix.go60
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go47
-rw-r--r--pkg/sentry/state/state.go6
-rw-r--r--pkg/sentry/strace/BUILD1
-rw-r--r--pkg/sentry/strace/epoll.go4
-rw-r--r--pkg/sentry/strace/socket.go23
-rw-r--r--pkg/sentry/strace/strace.go37
-rw-r--r--pkg/sentry/syscalls/linux/BUILD5
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go33
-rw-r--r--pkg/sentry/syscalls/linux/sys_capability.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go38
-rw-r--r--pkg/sentry/syscalls/linux/sys_futex.go14
-rw-r--r--pkg/sentry/syscalls/linux/sys_getdents.go31
-rw-r--r--pkg/sentry/syscalls/linux/sys_identity.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go11
-rw-r--r--pkg/sentry/syscalls/linux/sys_pipe.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_poll.go12
-rw-r--r--pkg/sentry/syscalls/linux/sys_prctl.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_rlimit.go19
-rw-r--r--pkg/sentry/syscalls/linux/sys_rusage.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_sched.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_seccomp.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go11
-rw-r--r--pkg/sentry/syscalls/linux/sys_shm.go11
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go67
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go11
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_sysinfo.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go15
-rw-r--r--pkg/sentry/syscalls/linux/sys_time.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_timer.go101
-rw-r--r--pkg/sentry/syscalls/linux/sys_timerfd.go6
-rw-r--r--pkg/sentry/syscalls/linux/sys_tls_amd64.go13
-rw-r--r--pkg/sentry/syscalls/linux/sys_utsname.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/aio.go27
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go49
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go10
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go13
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mmap.go12
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/pipe.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/poll.go24
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go67
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go106
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/timerfd.go6
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go16
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/xattr.go32
-rw-r--r--pkg/sentry/vfs/anonfs.go21
-rw-r--r--pkg/sentry/vfs/dentry.go2
-rw-r--r--pkg/sentry/vfs/device.go3
-rw-r--r--pkg/sentry/vfs/epoll.go25
-rw-r--r--pkg/sentry/vfs/file_description.go63
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go52
-rw-r--r--pkg/sentry/vfs/filesystem.go30
-rw-r--r--pkg/sentry/vfs/filesystem_type.go9
-rw-r--r--pkg/sentry/vfs/genericfstree/genericfstree.go2
-rw-r--r--pkg/sentry/vfs/inotify.go2
-rw-r--r--pkg/sentry/vfs/lock.go2
-rw-r--r--pkg/sentry/vfs/memxattr/xattr.go16
-rw-r--r--pkg/sentry/vfs/mount.go85
-rw-r--r--pkg/sentry/vfs/mount_test.go26
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go11
-rw-r--r--pkg/sentry/vfs/options.go52
-rw-r--r--pkg/sentry/vfs/permissions.go40
-rw-r--r--pkg/sentry/vfs/resolving_path.go5
-rw-r--r--pkg/sentry/vfs/vfs.go66
-rw-r--r--pkg/state/types.go14
-rw-r--r--pkg/sync/BUILD1
-rw-r--r--pkg/sync/seqatomic_unsafe.go40
-rw-r--r--pkg/sync/seqcount.go30
-rw-r--r--pkg/sync/spin_unsafe.go24
-rw-r--r--pkg/syserror/syserror.go1
-rw-r--r--pkg/syserror/syserror_test.go20
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go4
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go4
-rw-r--r--pkg/tcpip/buffer/BUILD4
-rw-r--r--pkg/tcpip/checker/checker.go230
-rw-r--r--pkg/tcpip/faketime/BUILD24
-rw-r--r--pkg/tcpip/faketime/faketime.go (renamed from pkg/tcpip/stack/fake_time_test.go)157
-rw-r--r--pkg/tcpip/faketime/faketime_test.go95
-rw-r--r--pkg/tcpip/header/icmpv4.go41
-rw-r--r--pkg/tcpip/header/icmpv6.go49
-rw-r--r--pkg/tcpip/header/ipv4.go15
-rw-r--r--pkg/tcpip/header/ipv6.go9
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go113
-rw-r--r--pkg/tcpip/header/parse/BUILD15
-rw-r--r--pkg/tcpip/header/parse/parse.go168
-rw-r--r--pkg/tcpip/header/udp.go5
-rw-r--r--pkg/tcpip/link/rawfile/BUILD13
-rw-r--r--pkg/tcpip/link/rawfile/errors.go8
-rw-r--r--pkg/tcpip/link/rawfile/errors_test.go53
-rw-r--r--pkg/tcpip/link/sniffer/BUILD1
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go60
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go105
-rw-r--r--pkg/tcpip/network/arp/arp_test.go6
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD1
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go48
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go136
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go12
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go4
-rw-r--r--pkg/tcpip/network/ip_test.go552
-rw-r--r--pkg/tcpip/network/ipv4/BUILD5
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go188
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go407
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go1012
-rw-r--r--pkg/tcpip/network/ipv6/BUILD5
-rw-r--r--pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go (renamed from pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go)2
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go229
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go68
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go1063
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go644
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go (renamed from pkg/tcpip/stack/ndp.go)576
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go105
-rw-r--r--pkg/tcpip/network/testutil/BUILD20
-rw-r--r--pkg/tcpip/network/testutil/testutil.go144
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go4
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go6
-rw-r--r--pkg/tcpip/stack/BUILD8
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go758
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go77
-rw-r--r--pkg/tcpip/stack/conntrack.go46
-rw-r--r--pkg/tcpip/stack/forwarder_test.go59
-rw-r--r--pkg/tcpip/stack/iptables.go146
-rw-r--r--pkg/tcpip/stack/iptables_targets.go147
-rw-r--r--pkg/tcpip/stack/iptables_types.go108
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go9
-rw-r--r--pkg/tcpip/stack/ndp_test.go831
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go67
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go20
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go67
-rw-r--r--pkg/tcpip/stack/nic.go1369
-rw-r--r--pkg/tcpip/stack/nic_test.go146
-rw-r--r--pkg/tcpip/stack/nud_test.go20
-rw-r--r--pkg/tcpip/stack/packet_buffer.go44
-rw-r--r--pkg/tcpip/stack/registration.go335
-rw-r--r--pkg/tcpip/stack/route.go153
-rw-r--r--pkg/tcpip/stack/stack.go415
-rw-r--r--pkg/tcpip/stack/stack_test.go354
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go24
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go8
-rw-r--r--pkg/tcpip/stack/transport_test.go114
-rw-r--r--pkg/tcpip/tcpip.go155
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go129
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go138
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go26
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go33
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go43
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go31
-rw-r--r--pkg/tcpip/transport/tcp/BUILD2
-rw-r--r--pkg/tcpip/transport/tcp/connect.go22
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go51
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go230
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go8
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go229
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go101
-rw-r--r--pkg/tcpip/transport/tcp/segment.go45
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go52
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go15
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go1247
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go29
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go158
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go64
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go2
-rw-r--r--pkg/tcpip/transport/udp/protocol.go145
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go116
-rw-r--r--pkg/test/dockerutil/container.go12
-rw-r--r--pkg/test/testutil/testutil.go4
-rw-r--r--pkg/usermem/usermem.go46
-rw-r--r--pkg/usermem/usermem_test.go18
-rw-r--r--runsc/boot/BUILD5
-rw-r--r--runsc/boot/controller.go12
-rw-r--r--runsc/boot/filter/config.go491
-rw-r--r--runsc/boot/filter/config_amd64.go41
-rw-r--r--runsc/boot/filter/config_arm64.go25
-rw-r--r--runsc/boot/filter/config_profile.go6
-rw-r--r--runsc/boot/fs.go27
-rw-r--r--runsc/boot/loader.go151
-rw-r--r--runsc/boot/loader_test.go13
-rw-r--r--runsc/boot/vfs.go189
-rw-r--r--runsc/cmd/boot.go2
-rw-r--r--runsc/cmd/checkpoint.go2
-rw-r--r--runsc/cmd/create.go2
-rw-r--r--runsc/cmd/exec.go4
-rw-r--r--runsc/cmd/gofer.go21
-rw-r--r--runsc/cmd/restore.go2
-rw-r--r--runsc/cmd/run.go2
-rw-r--r--runsc/config/config.go6
-rw-r--r--runsc/config/config_test.go87
-rw-r--r--runsc/config/flags.go37
-rw-r--r--runsc/console/console.go18
-rw-r--r--runsc/container/console_test.go10
-rw-r--r--runsc/container/container.go3
-rw-r--r--runsc/container/container_test.go219
-rw-r--r--runsc/container/multi_container_test.go32
-rw-r--r--runsc/container/shared_volume_test.go12
-rw-r--r--runsc/fsgofer/filter/config.go169
-rw-r--r--runsc/fsgofer/filter/config_amd64.go37
-rw-r--r--runsc/fsgofer/filter/config_arm64.go21
-rw-r--r--runsc/fsgofer/filter/extra_filters_race.go1
-rw-r--r--runsc/fsgofer/fsgofer.go5
-rw-r--r--runsc/fsgofer/fsgofer_test.go23
-rw-r--r--runsc/sandbox/sandbox.go4
-rw-r--r--runsc/specutils/BUILD1
-rw-r--r--runsc/specutils/seccomp/BUILD34
-rw-r--r--runsc/specutils/seccomp/audit_amd64.go25
-rw-r--r--runsc/specutils/seccomp/audit_arm64.go25
-rw-r--r--runsc/specutils/seccomp/seccomp.go229
-rw-r--r--runsc/specutils/seccomp/seccomp_test.go414
-rw-r--r--runsc/specutils/specutils.go26
-rwxr-xr-xscripts/common.sh86
-rwxr-xr-xscripts/common_build.sh116
-rwxr-xr-xscripts/dev.sh75
-rwxr-xr-xscripts/do_tests.sh27
-rwxr-xr-xscripts/docker_tests.sh25
-rwxr-xr-xscripts/go.sh45
-rwxr-xr-xscripts/hostnet_tests.sh23
-rwxr-xr-xscripts/iptables_tests.sh26
-rwxr-xr-xscripts/kvm_tests.sh30
-rwxr-xr-xscripts/make_tests.sh20
-rwxr-xr-xscripts/overlay_tests.sh23
-rwxr-xr-xscripts/packetdrill_tests.sh23
-rwxr-xr-xscripts/packetimpact_tests.sh23
-rwxr-xr-xscripts/root_tests.sh25
-rwxr-xr-xscripts/runtime_tests.sh29
-rwxr-xr-xscripts/simple_tests.sh20
-rwxr-xr-xscripts/swgso_tests.sh23
-rwxr-xr-xscripts/syscall_kvm_tests.sh20
-rwxr-xr-xscripts/syscall_tests.sh20
-rw-r--r--test/README.md4
-rw-r--r--test/benchmarks/base/size_test.go1
-rw-r--r--test/benchmarks/base/startup_test.go3
-rw-r--r--test/benchmarks/network/BUILD11
-rw-r--r--test/benchmarks/network/httpd_test.go110
-rw-r--r--test/benchmarks/network/nginx_test.go134
-rw-r--r--test/benchmarks/network/static_server.go87
-rw-r--r--test/benchmarks/tcp/tcp_proxy.go23
-rw-r--r--test/e2e/integration_test.go44
-rw-r--r--test/fuse/BUILD66
-rw-r--r--test/fuse/README.md207
-rw-r--r--test/fuse/linux/BUILD198
-rw-r--r--test/fuse/linux/create_test.cc128
-rw-r--r--test/fuse/linux/fuse_base.cc447
-rw-r--r--test/fuse/linux/fuse_base.h240
-rw-r--r--test/fuse/linux/fuse_fd_util.cc61
-rw-r--r--test/fuse/linux/fuse_fd_util.h48
-rw-r--r--test/fuse/linux/mkdir_test.cc88
-rw-r--r--test/fuse/linux/mknod_test.cc107
-rw-r--r--test/fuse/linux/open_test.cc128
-rw-r--r--test/fuse/linux/read_test.cc390
-rw-r--r--test/fuse/linux/readdir_test.cc193
-rw-r--r--test/fuse/linux/readlink_test.cc85
-rw-r--r--test/fuse/linux/release_test.cc74
-rw-r--r--test/fuse/linux/rmdir_test.cc100
-rw-r--r--test/fuse/linux/setstat_test.cc338
-rw-r--r--test/fuse/linux/stat_test.cc226
-rw-r--r--test/fuse/linux/symlink_test.cc88
-rw-r--r--test/fuse/linux/unlink_test.cc107
-rw-r--r--test/fuse/linux/write_test.cc303
-rw-r--r--test/image/image_test.go64
-rw-r--r--test/iptables/iptables_test.go9
-rw-r--r--test/packetimpact/dut/BUILD10
-rw-r--r--test/packetimpact/dut/posix_server.cc7
-rw-r--r--test/packetimpact/runner/BUILD26
-rw-r--r--test/packetimpact/runner/defs.bzl5
-rw-r--r--test/packetimpact/runner/dut.go442
-rw-r--r--test/packetimpact/runner/packetimpact_test.go359
-rw-r--r--test/packetimpact/testbench/connections.go2
-rw-r--r--test/packetimpact/testbench/rawsockets.go2
-rw-r--r--test/packetimpact/testbench/testbench.go10
-rw-r--r--test/packetimpact/tests/BUILD36
-rw-r--r--test/packetimpact/tests/tcp_close_wait_ack_test.go109
-rw-r--r--test/packetimpact/tests/tcp_linger_test.go17
-rw-r--r--test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go133
-rw-r--r--test/packetimpact/tests/tcp_rcv_buf_space_test.go80
-rw-r--r--test/packetimpact/tests/tcp_timewait_reset_test.go68
-rw-r--r--test/packetimpact/tests/tcp_unacc_seq_ack_test.go234
-rw-r--r--test/perf/linux/BUILD16
-rw-r--r--test/perf/linux/getdents_benchmark.cc2
-rw-r--r--test/perf/linux/open_read_close_benchmark.cc61
-rw-r--r--test/root/root.go2
-rw-r--r--test/runtimes/BUILD10
-rw-r--r--test/runtimes/README.md62
-rw-r--r--test/runtimes/exclude/go1.12.csv (renamed from test/runtimes/exclude_go1.12.csv)0
-rw-r--r--test/runtimes/exclude/java11.csv (renamed from test/runtimes/exclude_java11.csv)2
-rw-r--r--test/runtimes/exclude/nodejs12.4.0.csv (renamed from test/runtimes/exclude_nodejs12.4.0.csv)6
-rw-r--r--test/runtimes/exclude/php7.3.6.csv (renamed from test/runtimes/exclude_php7.3.6.csv)1
-rw-r--r--test/runtimes/exclude/python3.7.3.csv (renamed from test/runtimes/exclude_python3.7.3.csv)0
-rw-r--r--test/runtimes/proctor/BUILD24
-rw-r--r--test/runtimes/proctor/lib/BUILD24
-rw-r--r--test/runtimes/proctor/lib/go.go (renamed from test/runtimes/proctor/go.go)4
-rw-r--r--test/runtimes/proctor/lib/java.go (renamed from test/runtimes/proctor/java.go)2
-rw-r--r--test/runtimes/proctor/lib/lib.go (renamed from test/runtimes/proctor/proctor.go)78
-rw-r--r--test/runtimes/proctor/lib/lib_test.go (renamed from test/runtimes/proctor/proctor_test.go)6
-rw-r--r--test/runtimes/proctor/lib/nodejs.go (renamed from test/runtimes/proctor/nodejs.go)4
-rw-r--r--test/runtimes/proctor/lib/php.go (renamed from test/runtimes/proctor/php.go)4
-rw-r--r--test/runtimes/proctor/lib/python.go (renamed from test/runtimes/proctor/python.go)2
-rw-r--r--test/runtimes/proctor/main.go85
-rw-r--r--test/runtimes/runner/BUILD15
-rw-r--r--test/runtimes/runner/lib/BUILD22
-rw-r--r--test/runtimes/runner/lib/exclude_test.go (renamed from test/runtimes/runner/exclude_test.go)6
-rw-r--r--test/runtimes/runner/lib/lib.go185
-rw-r--r--test/runtimes/runner/main.go169
-rw-r--r--test/syscalls/BUILD2
-rw-r--r--test/syscalls/linux/BUILD4
-rw-r--r--test/syscalls/linux/fallocate.cc6
-rw-r--r--test/syscalls/linux/inotify.cc126
-rw-r--r--test/syscalls/linux/ip6tables.cc30
-rw-r--r--test/syscalls/linux/iptables.cc26
-rw-r--r--test/syscalls/linux/kcov.cc5
-rw-r--r--test/syscalls/linux/mkdir.cc7
-rw-r--r--test/syscalls/linux/mknod.cc17
-rw-r--r--test/syscalls/linux/mount.cc22
-rw-r--r--test/syscalls/linux/open_create.cc116
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc21
-rw-r--r--test/syscalls/linux/prctl.cc2
-rw-r--r--test/syscalls/linux/proc.cc84
-rw-r--r--test/syscalls/linux/proc_net.cc41
-rw-r--r--test/syscalls/linux/pty.cc243
-rw-r--r--test/syscalls/linux/pty_root.cc4
-rw-r--r--test/syscalls/linux/raw_socket_icmp.cc22
-rw-r--r--test/syscalls/linux/rename.cc3
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc38
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc12
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.cc151
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc153
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.cc87
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.h4
-rw-r--r--test/syscalls/linux/socket_test_util.cc14
-rw-r--r--test/syscalls/linux/socket_test_util.h4
-rw-r--r--test/syscalls/linux/socket_unix_stream.cc18
-rw-r--r--test/syscalls/linux/stat.cc28
-rw-r--r--test/syscalls/linux/statfs.cc7
-rw-r--r--test/syscalls/linux/symlink.cc10
-rw-r--r--test/syscalls/linux/tcp_socket.cc30
-rw-r--r--test/syscalls/linux/truncate.cc20
-rw-r--r--test/syscalls/linux/udp_socket.cc30
-rw-r--r--test/syscalls/linux/unlink.cc14
-rw-r--r--test/syscalls/linux/vdso_clock_gettime.cc69
-rw-r--r--test/syscalls/linux/xattr.cc62
-rw-r--r--test/util/BUILD7
-rw-r--r--test/util/fs_util.cc14
-rw-r--r--test/util/fs_util.h7
-rw-r--r--test/util/fuse_util.cc63
-rw-r--r--test/util/fuse_util.h75
-rw-r--r--test/util/pty_util.cc10
-rw-r--r--test/util/pty_util.h8
-rw-r--r--tools/bazel.mk7
-rw-r--r--tools/defs.bzl5
-rw-r--r--tools/github/BUILD15
-rw-r--r--tools/github/main.go162
-rw-r--r--tools/github/nogo/BUILD16
-rw-r--r--tools/github/nogo/nogo.go126
-rw-r--r--tools/github/reviver/BUILD27
-rw-r--r--tools/github/reviver/github.go (renamed from tools/issue_reviver/github/github.go)38
-rw-r--r--tools/github/reviver/github_test.go (renamed from tools/issue_reviver/github/github_test.go)2
-rw-r--r--tools/github/reviver/reviver.go (renamed from tools/issue_reviver/reviver/reviver.go)0
-rw-r--r--tools/github/reviver/reviver_test.go (renamed from tools/issue_reviver/reviver/reviver_test.go)0
-rw-r--r--tools/go_generics/go_merge/main.go6
-rw-r--r--tools/go_generics/imports.go10
-rw-r--r--tools/go_marshal/README.md30
-rw-r--r--tools/go_marshal/defs.bzl11
-rw-r--r--tools/go_marshal/gomarshal/generator.go85
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go4
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go12
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go20
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_struct.go51
-rw-r--r--tools/go_marshal/gomarshal/util.go20
-rw-r--r--tools/go_marshal/main.go11
-rw-r--r--tools/go_marshal/test/BUILD2
-rw-r--r--tools/go_marshal/test/escape/BUILD2
-rw-r--r--tools/go_marshal/test/escape/escape.go28
-rw-r--r--tools/go_marshal/test/marshal_test.go116
-rw-r--r--tools/go_marshal/test/test.go24
-rw-r--r--tools/issue_reviver/BUILD13
-rw-r--r--tools/issue_reviver/github/BUILD25
-rw-r--r--tools/issue_reviver/main.go100
-rw-r--r--tools/issue_reviver/reviver/BUILD18
-rwxr-xr-xtools/make_apt.sh19
-rw-r--r--tools/nogo/build.go3
-rw-r--r--tools/nogo/matchers.go24
-rw-r--r--tools/nogo/nogo.go45
-rw-r--r--tools/nogo/util/BUILD9
-rw-r--r--tools/nogo/util/util.go85
-rw-r--r--website/BUILD1
-rw-r--r--website/_config.yml3
-rw-r--r--website/_sass/front.scss2
-rw-r--r--website/assets/images/2020-09-18-containing-a-real-vulnerability-figure1.pngbin0 -> 48602 bytes
-rw-r--r--website/assets/images/background_1080p.jpgbin0 -> 344285 bytes
-rw-r--r--website/blog/2019-11-18-security-basics.md4
-rw-r--r--website/blog/2020-09-18-containing-a-real-vulnerability.md223
-rw-r--r--website/blog/BUILD10
-rw-r--r--website/css/main.scss15
-rw-r--r--website/index.md2
735 files changed, 31180 insertions, 12697 deletions
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
index 49a1ba697..50d187633 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.md
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -23,7 +23,7 @@ reproduced with software that is publicly available.
Please include the following details of your environment:
-* `runsc -v`
+* `runsc -version`
* `docker version` or `docker info` (if available)
* `kubectl version` and `kubectl get nodes` (if using Kubernetes)
* `uname -a`
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index cf782a580..e28e46352 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -3,9 +3,11 @@ on:
push:
branches:
- master
+ - feature/**
pull_request:
branches:
- master
+ - feature/**
jobs:
default:
@@ -19,3 +21,8 @@ jobs:
restore-keys: |
${{ runner.os }}-bazel-
- run: make
+ - run: make build OPTIONS="--build_tag_filters nogo" TARGETS="//..."
+ - run: make run TARGETS="//tools/github" ARGS="-path=bazel-bin/ nogo"
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ GITHUB_REPOSITORY: ${{ github.repository }}
diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
index b51c22158..3a6a592d1 100644
--- a/.github/workflows/go.yml
+++ b/.github/workflows/go.yml
@@ -6,6 +6,7 @@ on:
pull_request:
branches:
- master
+ - feature/**
jobs:
generate:
diff --git a/.github/workflows/issue_reviver.yml b/.github/workflows/issue_reviver.yml
index 2b399a3f2..c53185620 100644
--- a/.github/workflows/issue_reviver.yml
+++ b/.github/workflows/issue_reviver.yml
@@ -9,7 +9,7 @@ jobs:
steps:
- uses: actions/checkout@v2
if: github.repository == 'google/gvisor'
- - run: make run TARGETS="//tools/issue_reviver"
+ - run: make run TARGETS="//tools/github" ARGS="revive"
if: github.repository == 'google/gvisor'
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
diff --git a/Makefile b/Makefile
index 43a243c90..e36a3baed 100644
--- a/Makefile
+++ b/Makefile
@@ -121,18 +121,22 @@ smoke-tests: ## Runs a simple smoke test after build runsc.
@$(call submake,run DOCKER_PRIVILEGED="" ARGS="--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true")
.PHONY: smoke-tests
+fuse-tests:
+ @$(call submake,test OPTIONS="--test_tag_filters fuse" TARGETS="test/fuse/...")
+.PHONY: fuse-tests
+
unit-tests: ## Local package unit tests in pkg/..., runsc/, tools/.., etc.
@$(call submake,test TARGETS="pkg/... runsc/... tools/...")
+.PHONY: unit-tests
tests: ## Runs all unit tests and syscall tests.
tests: unit-tests
@$(call submake,test TARGETS="test/syscalls/...")
.PHONY: tests
-
integration-tests: ## Run all standard integration tests.
integration-tests: docker-tests overlay-tests hostnet-tests swgso-tests
-integration-tests: do-tests kvm-tests root-tests containerd-tests
+integration-tests: do-tests kvm-tests containerd-test-1.3.4
.PHONY: integration-tests
network-tests: ## Run all networking integration tests.
@@ -157,6 +161,10 @@ syscall-tests: syscall-ptrace-tests syscall-kvm-tests syscall-native-tests
@$(call submake,install-test-runtime)
@$(call submake,test-runtime OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*")
+%-runtime-tests_vfs2: load-runtimes_%
+ @$(call submake,install-test-runtime RUNTIME="vfs2" ARGS="--vfs2")
+ @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*")
+
do-tests: runsc
@$(call submake,run TARGETS="//runsc" ARGS="--rootless do true")
@$(call submake,run TARGETS="//runsc" ARGS="--rootless -network=none do true")
@@ -182,6 +190,7 @@ swgso-tests: load-basic-images
@$(call submake,install-test-runtime RUNTIME="swgso" ARGS="--software-gso=true --gso=false")
@$(call submake,test-runtime RUNTIME="swgso" TARGETS="$(INTEGRATION_TARGETS)")
.PHONY: swgso-tests
+
hostnet-tests: load-basic-images
@$(call submake,install-test-runtime RUNTIME="hostnet" ARGS="--network=host")
@$(call submake,test-runtime RUNTIME="hostnet" OPTIONS="--test_arg=-checkpoint=false" TARGETS="$(INTEGRATION_TARGETS)")
@@ -196,6 +205,8 @@ kvm-tests: load-basic-images
.PHONY: kvm-tests
iptables-tests: load-iptables
+ @sudo modprobe iptable_filter
+ @sudo modprobe ip6table_filter
@$(call submake,test-runtime RUNTIME="runc" TARGETS="//test/iptables:iptables_test")
@$(call submake,install-test-runtime RUNTIME="iptables" ARGS="--net-raw")
@$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test")
@@ -207,21 +218,18 @@ packetdrill-tests: load-packetdrill
.PHONY: packetdrill-tests
packetimpact-tests: load-packetimpact
- @sudo modprobe iptable_filter ip6table_filter
+ @sudo modprobe iptable_filter
+ @sudo modprobe ip6table_filter
@$(call submake,install-test-runtime RUNTIME="packetimpact")
@$(call submake,test-runtime OPTIONS="--jobs=HOST_CPUS*3 --local_test_jobs=HOST_CPUS*3" RUNTIME="packetimpact" TARGETS="$(shell $(MAKE) query TARGETS='attr(tags, packetimpact, tests(//...))')")
.PHONY: packetimpact-tests
-root-tests: load-basic-images
- @$(call submake,install-test-runtime)
- @$(call submake,sudo TARGETS="//test/root:root_test" ARGS="-test.v")
-.PHONY: root-tests
-
# Specific containerd version tests.
-containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-basic_resolv load-basic_httpd install-test-runtime
+containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-basic_resolv load-basic_httpd load-basic_ubuntu
+ @$(call submake,install-test-runtime RUNTIME="root")
@CONTAINERD_VERSION=$* $(MAKE) sudo TARGETS="tools/installers:containerd"
@$(MAKE) sudo TARGETS="tools/installers:shim"
- @$(MAKE) sudo TARGETS="test/root:root_test" ARGS="-test.v"
+ @$(MAKE) sudo TARGETS="test/root:root_test" ARGS="--runtime=root -test.v"
# Note that we can't run containerd-test-1.1.8 tests here.
#
@@ -294,8 +302,8 @@ $(RELEASE_KEY):
echo Name-Email: test@example.com >> $$C && \
echo Expire-Date: 0 >> $$C && \
echo %commit >> $$C && \
- gpg --batch $(GPG_TEST_OPTIONS) --passphrase '' --no-default-keyring --keyring $$T --no-tty --gen-key $$C && \
- gpg --batch $(GPG_TEST_OPTIONS) --export-secret-keys --no-default-keyring --keyring $$T --secret-keyring $$T > $@; \
+ gpg --batch $(GPG_TEST_OPTIONS) --passphrase '' --no-default-keyring --secret-keyring $$T --no-tty --gen-key $$C && \
+ gpg --batch $(GPG_TEST_OPTIONS) --export-secret-keys --no-default-keyring --secret-keyring $$T > $@; \
rc=$$?; rm -f $$T $$C; exit $$rc
release: $(RELEASE_KEY) ## Builds a release.
@@ -371,3 +379,8 @@ configure: ## Configures a single runtime. Requires sudo. Typically called from
test-runtime: ## A convenient wrapper around test that provides the runtime argument. Target must still be provided.
@$(call submake,test OPTIONS="$(OPTIONS) --test_arg=--runtime=$(RUNTIME)")
.PHONY: test-runtime
+
+nogo: ## Surfaces all nogo findings.
+ @$(call submake,build OPTIONS="--build_tag_filters nogo" TARGETS="//...")
+ @$(call submake,run TARGETS="//tools/github" ARGS="-path=$(BUILD_ROOT) -dry-run nogo")
+.PHONY: nogo
diff --git a/g3doc/architecture_guide/performance.md b/g3doc/architecture_guide/performance.md
index 39dbb0045..b981f0c01 100644
--- a/g3doc/architecture_guide/performance.md
+++ b/g3doc/architecture_guide/performance.md
@@ -30,7 +30,7 @@ is distinct from **structural costs**. Improvements here are ongoing and driven
by the workloads that matter to gVisor users and contributors.
This page provides a guide for understanding baseline performance, and calls out
-distint **structural costs** and **implementation costs**, highlighting where
+distinct **structural costs** and **implementation costs**, highlighting where
improvements are possible and not possible.
While we include a variety of workloads here, it’s worth emphasizing that gVisor
@@ -211,7 +211,7 @@ url="/performance/applications.csv" title="perf.py http.(node|ruby)
The above figure shows the result of simple `node` and `ruby` web services that
render a template upon receiving a request. Because these synthetic benchmarks
-do minimal work per request, must like the `redis` case, they suffer from high
+do minimal work per request, much like the `redis` case, they suffer from high
overheads. In practice, the more work an application does the smaller the impact
of **structural costs** become.
diff --git a/g3doc/architecture_guide/resources.md b/g3doc/architecture_guide/resources.md
index 1dec37bd1..fc997d40c 100644
--- a/g3doc/architecture_guide/resources.md
+++ b/g3doc/architecture_guide/resources.md
@@ -19,12 +19,12 @@ sandboxed process:
Much like a Virtual Machine (VM), a gVisor sandbox appears as an opaque process
on the system. Processes within the sandbox do not manifest as processes on the
-host system, and process-level interactions within the sandbox requires entering
+host system, and process-level interactions within the sandbox require entering
the sandbox (e.g. via a [Docker exec][exec]).
## Networking
-The sandbox attaches a network endpoint to the system, but runs it's own network
+The sandbox attaches a network endpoint to the system, but runs its own network
stack. All network resources, other than packets in flight on the host, exist
only inside the sandbox, bound by relevant resource limits.
diff --git a/g3doc/architecture_guide/security.md b/g3doc/architecture_guide/security.md
index b99b86332..9363d834c 100644
--- a/g3doc/architecture_guide/security.md
+++ b/g3doc/architecture_guide/security.md
@@ -104,7 +104,7 @@ interactions with a guest operating system and a set of virtualized hardware
devices. These hardware devices are then implemented via the host System API by
a Virtual Machine Monitor (VMM). The Sentry similarly prevents direct
interactions by providing its own implementation of the System API that the
-application must interact with. Applications are not able to to directly craft
+application must interact with. Applications are not able to directly craft
specific arguments or flags for the host System API, or interact directly with
host primitives.
diff --git a/g3doc/user_guide/FAQ.md b/g3doc/user_guide/FAQ.md
index 89df65e99..514fe3918 100644
--- a/g3doc/user_guide/FAQ.md
+++ b/g3doc/user_guide/FAQ.md
@@ -74,11 +74,10 @@ directories.
### I'm getting an error like: `panic: unable to attach: operation not permitted` or `fork/exec /proc/self/exe: invalid argument: unknown` {#runsc-perms}
-Make sure that permissions and the owner is correct on the `runsc` binary.
+Make sure that permissions is correct on the `runsc` binary.
```bash
-sudo chown root:root /usr/local/bin/runsc
-sudo chmod 0755 /usr/local/bin/runsc
+sudo chmod a+rx /usr/local/bin/runsc
```
### I'm getting an error like `mount submount "/etc/hostname": creating mount with source ".../hostname": input/output error: unknown.` {#memlock}
@@ -96,6 +95,30 @@ containerd.
See [issue #1765](https://gvisor.dev/issue/1765) for more details.
+### I'm getting an error like `RuntimeHandler "runsc" not supported` {#runtime-handler}
+
+This error indicates that the Kubernetes CRI runtime was not set up to handle
+`runsc` as a runtime handler. Please ensure that containerd configuration has
+been created properly and containerd has been restarted. See the
+[containerd quick start](containerd/quick_start.md) for more details.
+
+If you have ensured that containerd has been set up properly and you used
+kubeadm to create your cluster please check if Docker is also installed on that
+system. Kubeadm prefers using Docker if both Docker and containerd are
+installed.
+
+Please recreate your cluster and set the `--cni-socket` option on kubeadm
+commands. For example:
+
+```bash
+kubeadm init --cni-socket=/var/run/containerd/containerd.sock` ...
+```
+
+To fix an existing cluster edit the `/var/lib/kubelet/kubeadm-flags.env` file
+and set the `--container-runtime` flag to `remote` and set the
+`--container-runtime-endpoint` flag to point to the containerd socket. e.g.
+`/var/run/containerd/containerd.sock`.
+
### My container cannot resolve another container's name when using Docker user defined bridge {#docker-bridge}
This is normally indicated by errors like `bad address 'container-name'` when
diff --git a/g3doc/user_guide/install.md b/g3doc/user_guide/install.md
index 9afdd264d..abb9e8582 100644
--- a/g3doc/user_guide/install.md
+++ b/g3doc/user_guide/install.md
@@ -5,6 +5,68 @@
> Note: gVisor supports only x86\_64 and requires Linux 4.14.77+
> ([older Linux](./networking.md#gso)).
+## Install latest release {#install-latest}
+
+To download and install the latest release manually follow these steps:
+
+```bash
+(
+ set -e
+ URL=https://storage.googleapis.com/gvisor/releases/release/latest
+ wget ${URL}/runsc ${URL}/runsc.sha512
+ sha512sum -c runsc.sha512
+ rm -f runsc.sha512
+ sudo mv runsc /usr/local/bin
+ sudo chmod a+rx /usr/local/bin/runsc
+)
+```
+
+To install gVisor with Docker, run the following commands:
+
+```bash
+/usr/local/bin/runsc install
+sudo systemctl restart docker
+docker run --rm --runtime=runsc hello-world
+```
+
+For more details about using gVisor with Docker, see
+[Docker Quick Start](./quick_start/docker.md)
+
+Note: It is important to copy `runsc` to a location that is readable and
+executable to all users, since `runsc` executes itself as user `nobody` to avoid
+unnecessary privileges. The `/usr/local/bin` directory is a good place to put
+the `runsc` binary.
+
+## Install from an `apt` repository
+
+First, appropriate dependencies must be installed to allow `apt` to install
+packages via https:
+
+```bash
+sudo apt-get update && \
+sudo apt-get install -y \
+ apt-transport-https \
+ ca-certificates \
+ curl \
+ gnupg-agent \
+ software-properties-common
+```
+
+Next, the configure the key used to sign archives and the repository:
+
+```bash
+curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add -
+sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main"
+```
+
+Now the runsc package can be installed:
+
+```bash
+sudo apt-get update && sudo apt-get install -y runsc
+```
+
+If you have Docker installed, it will be automatically configured.
+
## Versions
The `runsc` binaries and repositories are available in multiple versions and
@@ -21,12 +83,16 @@ Binaries are available for every commit on the `master` branch, and are
available at the following URL:
`https://storage.googleapis.com/gvisor/releases/master/latest/runsc`
+`https://storage.googleapis.com/gvisor/releases/master/latest/runsc.sha512`
-Checksums for the release binary are at:
+You can use this link with the steps described in
+[Install latest release](#install-latest).
-`https://storage.googleapis.com/gvisor/releases/master/latest/runsc.sha512`
+For `apt` installation, use the `master` to configure the repository:
-For `apt` installation, use the `master` as the `${DIST}` below.
+```bash
+sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases master main"
+```
### Nightly
@@ -34,18 +100,22 @@ Nightly releases are built most nights from the master branch, and are available
at the following URL:
`https://storage.googleapis.com/gvisor/releases/nightly/latest/runsc`
-
-Checksums for the release binary are at:
-
`https://storage.googleapis.com/gvisor/releases/nightly/latest/runsc.sha512`
+You can use this link with the steps described in
+[Install latest release](#install-latest).
+
Specific nightly releases can be found at:
`https://storage.googleapis.com/gvisor/releases/nightly/${yyyy-mm-dd}/runsc`
Note that a release may not be available for every day.
-For `apt` installation, use the `nightly` as the `${DIST}` below.
+For `apt` installation, use the `nightly` to configure the repository:
+
+```bash
+sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases nightly main"
+```
### Latest release
@@ -53,105 +123,47 @@ The latest official release is available at the following URL:
`https://storage.googleapis.com/gvisor/releases/release/latest`
-For `apt` installation, use the `release` as the `${DIST}` below.
-
-### Specific release
-
-A given release release is available at the following URL:
-
-`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}`
-
-See the [releases][releases] page for information about specific releases.
-
-For `apt` installation of a specific release, which may include point updates,
-use the date of the release, e.g. `${yyyymmdd}`, as the `${DIST}` below.
-
-> Note: only newer releases may be available as `apt` repositories.
-
-### Point release
-
-A given point release is available at the following URL:
-
-`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}.${rc}`
-
-Note that `apt` installation of a specific point release is not supported.
-
-## Install from an `apt` repository
+You can use this link with the steps described in
+[Install latest release](#install-latest).
-First, appropriate dependencies must be installed to allow `apt` to install
-packages via https:
+For `apt` installation, use the `release` to configure the repository:
```bash
-sudo apt-get update && \
-sudo apt-get install -y \
- apt-transport-https \
- ca-certificates \
- curl \
- gnupg-agent \
- software-properties-common
+sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main"
```
-Next, the key used to sign archives should be added to your `apt` keychain:
-
-```bash
-curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add -
-```
+### Specific release
-Based on the release type, you will need to substitute `${DIST}` below, using
-one of:
+A given release release is available at the following URL:
-* `master`: For HEAD.
-* `nightly`: For nightly releases.
-* `release`: For the latest release.
-* `${yyyymmdd}`: For a specific releases (see above).
+`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}`
-The repository for the release you wish to install should be added:
+You can use this link with the steps described in
+[Install latest release](#install-latest).
-```bash
-sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases ${DIST} main"
-```
+See the [releases](https://github.com/google/gvisor/releases) page for
+information about specific releases.
-For example, to install the latest official release, you can use:
+For `apt` installation of a specific release, which may include point updates,
+use the date of the release for repository, e.g. `${yyyymmdd}`.
```bash
-sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main"
+sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases yyyymmdd main"
```
-Now the runsc package can be installed:
-
-```bash
-sudo apt-get update && sudo apt-get install -y runsc
-```
+> Note: only newer releases may be available as `apt` repositories.
-If you have Docker installed, it will be automatically configured.
+### Point release
-## Install directly
+A given point release is available at the following URL:
-The binary URLs provided above can be used to install directly. For example, the
-latest nightly binary can be downloaded, validated, and placed in an appropriate
-location by running:
+`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}.${rc}`
-```bash
-(
- set -e
- URL=https://storage.googleapis.com/gvisor/releases/nightly/latest
- wget ${URL}/runsc
- wget ${URL}/runsc.sha512
- sha512sum -c runsc.sha512
- rm -f runsc.sha512
- sudo mv runsc /usr/local/bin
- sudo chown root:root /usr/local/bin/runsc
- sudo chmod 0755 /usr/local/bin/runsc
-)
-```
+You can use this link with the steps described in
+[Install latest release](#install-latest).
-**It is important to copy this binary to a location that is accessible to all
-users, and ensure it is executable by all users**, since `runsc` executes itself
-as user `nobody` to avoid unnecessary privileges. The `/usr/local/bin` directory
-is a good place to put the `runsc` binary.
+Note that `apt` installation of a specific point release is not supported.
After installation, try out `runsc` by following the
[Docker Quick Start](./quick_start/docker.md) or
[OCI Quick Start](./quick_start/oci.md).
-
-[releases]: https://github.com/google/gvisor/releases
diff --git a/g3doc/user_guide/quick_start/docker.md b/g3doc/user_guide/quick_start/docker.md
index 6ad594ecc..ee842e453 100644
--- a/g3doc/user_guide/quick_start/docker.md
+++ b/g3doc/user_guide/quick_start/docker.md
@@ -22,18 +22,6 @@ named "runsc" by default.
sudo runsc install
```
-You may also wish to install a runtime entry for debugging. The `runsc install`
-command can accept options that will be passed to the runtime when it is invoked
-by Docker.
-
-```bash
-sudo runsc install --runtime runsc-debug -- \
- --debug \
- --debug-log=/tmp/runsc-debug.log \
- --strace \
- --log-packets
-```
-
You must restart the Docker daemon after installing the runtime. Typically this
is done via `systemd`:
@@ -85,6 +73,21 @@ $ docker run --runtime=runsc -it ubuntu dmesg
Note that this is easily replicated by an attacker so applications should never
use `dmesg` to verify the runtime in a security sensitive context.
+## Options
+
+You may also wish to install a runtime entry with different options. The `runsc
+install` command can accept flags that will be passed to the runtime when it is
+invoked by Docker. For example, to install a runtime with debugging enabled, run
+the following:
+
+```bash
+sudo runsc install --runtime runsc-debug -- \
+ --debug \
+ --debug-log=/tmp/runsc-debug.log \
+ --strace \
+ --log-packets
+```
+
Next, look at the different options available for gVisor: [platform][platforms],
[network][networking], [filesystem][filesystem].
diff --git a/g3doc/user_guide/tutorials/BUILD b/g3doc/user_guide/tutorials/BUILD
index 405026a33..f405349b3 100644
--- a/g3doc/user_guide/tutorials/BUILD
+++ b/g3doc/user_guide/tutorials/BUILD
@@ -15,6 +15,15 @@ doc(
)
doc(
+ name = "docker_compose",
+ src = "docker-compose.md",
+ category = "User Guide",
+ permalink = "/docs/tutorials/docker-compose/",
+ subcategory = "Tutorials",
+ weight = "20",
+)
+
+doc(
name = "kubernetes",
src = "kubernetes.md",
category = "User Guide",
@@ -24,7 +33,7 @@ doc(
],
permalink = "/docs/tutorials/kubernetes/",
subcategory = "Tutorials",
- weight = "20",
+ weight = "30",
)
doc(
@@ -33,5 +42,5 @@ doc(
category = "User Guide",
permalink = "/docs/tutorials/cni/",
subcategory = "Tutorials",
- weight = "30",
+ weight = "40",
)
diff --git a/g3doc/user_guide/tutorials/cni.md b/g3doc/user_guide/tutorials/cni.md
index ce2fd09a8..a3507c25b 100644
--- a/g3doc/user_guide/tutorials/cni.md
+++ b/g3doc/user_guide/tutorials/cni.md
@@ -47,7 +47,7 @@ sudo mkdir -p /etc/cni/net.d
sudo sh -c 'cat > /etc/cni/net.d/10-bridge.conf << EOF
{
- "cniVersion": "0.4.0",
+ "cniVersion": "0.3.1",
"name": "mynet",
"type": "bridge",
"bridge": "cni0",
@@ -65,7 +65,7 @@ EOF'
sudo sh -c 'cat > /etc/cni/net.d/99-loopback.conf << EOF
{
- "cniVersion": "0.4.0",
+ "cniVersion": "0.3.1",
"name": "lo",
"type": "loopback"
}
diff --git a/g3doc/user_guide/tutorials/docker-compose.md b/g3doc/user_guide/tutorials/docker-compose.md
new file mode 100644
index 000000000..3284231f8
--- /dev/null
+++ b/g3doc/user_guide/tutorials/docker-compose.md
@@ -0,0 +1,100 @@
+# Wordpress with Docker Compose
+
+This page shows you how to deploy a sample [WordPress][wordpress] site using
+[Docker Compose][docker-compose].
+
+### Before you begin
+
+[Follow these instructions][docker-install] to install runsc with Docker. This
+document assumes that Docker and Docker Compose are installed and the runtime
+name chosen for gVisor is `runsc`.
+
+### Configuration
+
+We'll start by creating the `docker-compose.yaml` file to specify our services.
+We will specify two services, a `wordpress` service for the Wordpress Apache
+server, and a `db` service for MySQL. We will configure Wordpress to connect to
+MySQL via the `db` service host name.
+
+> **Note:** Docker Compose uses it's own network by default and allows services
+> to communicate using their service name. Docker Compose does this by setting
+> up a DNS server at IP address 127.0.0.11 and configuring containers to use it
+> via [resolv.conf][resolv.conf]. This IP is not addressable inside a gVisor
+> sandbox so it's important that we set the DNS IP address to the alternative
+> `8.8.8.8` and use a network that allows routing to it. See
+> [Networking in Compose][compose-networking] for more details.
+
+> **Note:** The `runtime` field was removed from services in the 3.x version of
+> the API in versions of docker-compose < 1.27.0. You will need to write your
+> `docker-compose.yaml` file using the 2.x format or use docker-compose >=
+> 1.27.0. See this [issue](https://github.com/docker/compose/issues/6239) for
+> more details.
+
+```yaml
+version: '2.3'
+
+services:
+ db:
+ image: mysql:5.7
+ volumes:
+ - db_data:/var/lib/mysql
+ restart: always
+ environment:
+ MYSQL_ROOT_PASSWORD: somewordpress
+ MYSQL_DATABASE: wordpress
+ MYSQL_USER: wordpress
+ MYSQL_PASSWORD: wordpress
+ # All services must be on the same network to communicate.
+ network_mode: "bridge"
+
+ wordpress:
+ depends_on:
+ - db
+ # When using the "bridge" network specify links.
+ links:
+ - db
+ image: wordpress:latest
+ ports:
+ - "8080:80"
+ restart: always
+ environment:
+ WORDPRESS_DB_HOST: db:3306
+ WORDPRESS_DB_USER: wordpress
+ WORDPRESS_DB_PASSWORD: wordpress
+ WORDPRESS_DB_NAME: wordpress
+ # Specify the dns address if needed.
+ dns:
+ - 8.8.8.8
+ # All services must be on the same network to communicate.
+ network_mode: "bridge"
+ # Specify the runtime used by Docker. Must be set up in
+ # /etc/docker/daemon.json.
+ runtime: "runsc"
+
+volumes:
+ db_data: {}
+```
+
+Once you have a `docker-compose.yaml` in the current directory you can start the
+containers:
+
+```bash
+docker-compose up
+```
+
+Once the containers have started you can access wordpress at
+http://localhost:8080.
+
+Congrats! You now how a working wordpress site up and running using Docker
+Compose.
+
+### What's next
+
+Learn how to deploy [WordPress with Kubernetes][wordpress-k8s].
+
+[docker-compose]: https://docs.docker.com/compose/
+[docker-install]: ../quick_start/docker.md
+[wordpress]: https://wordpress.com/
+[resolv.conf]: https://man7.org/linux/man-pages/man5/resolv.conf.5.html
+[wordpress-k8s]: kubernetes.md
+[compose-networking]: https://docs.docker.com/compose/networking/
diff --git a/g3doc/user_guide/tutorials/docker.md b/g3doc/user_guide/tutorials/docker.md
index 705560038..9ca01da2a 100644
--- a/g3doc/user_guide/tutorials/docker.md
+++ b/g3doc/user_guide/tutorials/docker.md
@@ -60,9 +60,11 @@ Congratulations! You have just deployed a WordPress site using Docker.
### What's next
-[Learn how to deploy WordPress with Kubernetes][wordpress-k8s].
+Learn how to deploy WordPress with [Kubernetes][wordpress-k8s] or
+[Docker Compose][wordpress-compose].
[docker]: https://www.docker.com/
-[docker-install]: /docs/user_guide/quick_start/docker/
+[docker-install]: ../quick_start/docker.md
[wordpress]: https://wordpress.com/
-[wordpress-k8s]: /docs/tutorials/kubernetes/
+[wordpress-k8s]: kubernetes.md
+[wordpress-compose]: docker-compose.md
diff --git a/g3doc/user_guide/tutorials/kubernetes.md b/g3doc/user_guide/tutorials/kubernetes.md
index d2a94b1b7..1ec6e71e9 100644
--- a/g3doc/user_guide/tutorials/kubernetes.md
+++ b/g3doc/user_guide/tutorials/kubernetes.md
@@ -23,12 +23,12 @@ gcloud beta container node-pools create sandbox-pool --cluster=${CLUSTER_NAME} -
If you prefer to use the console, select your cluster and select the **ADD NODE
POOL** button:
-![+ ADD NODE POOL](./node-pool-button.png)
+![+ ADD NODE POOL](node-pool-button.png)
Then select the **Image type** with **Containerd** and select **Enable sandbox
with gVisor** option. Select other options as you like:
-![+ NODE POOL](./add-node-pool.png)
+![+ NODE POOL](add-node-pool.png)
### Check that gVisor is enabled
@@ -57,47 +57,149 @@ curl -LO https://k8s.io/examples/application/wordpress/mysql-deployment.yaml
Add a **spec.template.spec.runtimeClassName** set to **gvisor** to both files,
as shown below:
-**wordpress-deployment.yaml:** ```yaml apiVersion: v1 kind: Service metadata:
-name: wordpress labels: app: wordpress spec: ports: - port: 80 selector: app:
-wordpress tier: frontend
-
-## type: LoadBalancer
-
-apiVersion: v1 kind: PersistentVolumeClaim metadata: name: wp-pv-claim labels:
-app: wordpress spec: accessModes: - ReadWriteOnce resources: requests:
-
-## storage: 20Gi
-
-apiVersion: apps/v1 kind: Deployment metadata: name: wordpress labels: app:
-wordpress spec: selector: matchLabels: app: wordpress tier: frontend strategy:
-type: Recreate template: metadata: labels: app: wordpress tier: frontend spec:
-runtimeClassName: gvisor # ADD THIS LINE containers: - image:
-wordpress:4.8-apache name: wordpress env: - name: WORDPRESS_DB_HOST value:
-wordpress-mysql - name: WORDPRESS_DB_PASSWORD valueFrom: secretKeyRef: name:
-mysql-pass key: password ports: - containerPort: 80 name: wordpress
-volumeMounts: - name: wordpress-persistent-storage mountPath: /var/www/html
-volumes: - name: wordpress-persistent-storage persistentVolumeClaim: claimName:
-wp-pv-claim ```
-
-**mysql-deployment.yaml:** ```yaml apiVersion: v1 kind: Service metadata: name:
-wordpress-mysql labels: app: wordpress spec: ports: - port: 3306 selector: app:
-wordpress tier: mysql
-
-## clusterIP: None
-
-apiVersion: v1 kind: PersistentVolumeClaim metadata: name: mysql-pv-claim
-labels: app: wordpress spec: accessModes: - ReadWriteOnce resources: requests:
-
-## storage: 20Gi
+**wordpress-deployment.yaml:**
+
+```yaml
+apiVersion: v1
+kind: Service
+metadata:
+ name: wordpress
+ labels:
+ app: wordpress
+spec:
+ ports:
+ - port: 80
+ selector:
+ app: wordpress
+ tier: frontend
+ type: LoadBalancer
+---
+apiVersion: v1
+kind: PersistentVolumeClaim
+metadata:
+ name: wp-pv-claim
+ labels:
+ app: wordpress
+spec:
+ accessModes:
+ - ReadWriteOnce
+ resources:
+ requests:
+ storage: 20Gi
+---
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+ name: wordpress
+ labels:
+ app: wordpress
+spec:
+ selector:
+ matchLabels:
+ app: wordpress
+ tier: frontend
+ strategy:
+ type: Recreate
+ template:
+ metadata:
+ labels:
+ app: wordpress
+ tier: frontend
+ spec:
+ runtimeClassName: gvisor # ADD THIS LINE
+ containers:
+ - image: wordpress:4.8-apache
+ name: wordpress
+ env:
+ - name: WORDPRESS_DB_HOST
+ value: wordpress-mysql
+ - name: WORDPRESS_DB_PASSWORD
+ valueFrom:
+ secretKeyRef:
+ name: mysql-pass
+ key: password
+ ports:
+ - containerPort: 80
+ name: wordpress
+ volumeMounts:
+ - name: wordpress-persistent-storage
+ mountPath: /var/www/html
+ volumes:
+ - name: wordpress-persistent-storage
+ persistentVolumeClaim:
+ claimName: wp-pv-claim
+```
-apiVersion: apps/v1 kind: Deployment metadata: name: wordpress-mysql labels:
-app: wordpress spec: selector: matchLabels: app: wordpress tier: mysql strategy:
-type: Recreate template: metadata: labels: app: wordpress tier: mysql spec:
-runtimeClassName: gvisor # ADD THIS LINE containers: - image: mysql:5.6 name:
-mysql env: - name: MYSQL_ROOT_PASSWORD valueFrom: secretKeyRef: name: mysql-pass
-key: password ports: - containerPort: 3306 name: mysql volumeMounts: - name:
-mysql-persistent-storage mountPath: /var/lib/mysql volumes: - name:
-mysql-persistent-storage persistentVolumeClaim: claimName: mysql-pv-claim ```
+**mysql-deployment.yaml:**
+
+```yaml
+apiVersion: v1
+kind: Service
+metadata:
+ name: wordpress-mysql
+ labels:
+ app: wordpress
+spec:
+ ports:
+ - port: 3306
+ selector:
+ app: wordpress
+ tier: mysql
+ clusterIP: None
+---
+apiVersion: v1
+kind: PersistentVolumeClaim
+metadata:
+ name: mysql-pv-claim
+ labels:
+ app: wordpress
+spec:
+ accessModes:
+ - ReadWriteOnce
+ resources:
+ requests:
+ storage: 20Gi
+---
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+ name: wordpress-mysql
+ labels:
+ app: wordpress
+spec:
+ selector:
+ matchLabels:
+ app: wordpress
+ tier: mysql
+ strategy:
+ type: Recreate
+ template:
+ metadata:
+ labels:
+ app: wordpress
+ tier: mysql
+ spec:
+ runtimeClassName: gvisor # ADD THIS LINE
+ containers:
+ - image: mysql:5.6
+ name: mysql
+ env:
+ - name: MYSQL_ROOT_PASSWORD
+ valueFrom:
+ secretKeyRef:
+ name: mysql-pass
+ key: password
+ ports:
+ - containerPort: 3306
+ name: mysql
+ volumeMounts:
+ - name: mysql-persistent-storage
+ mountPath: /var/lib/mysql
+ volumes:
+ - name: mysql-persistent-storage
+ persistentVolumeClaim:
+ claimName: mysql-pv-claim
+```
Note that apart from `runtimeClassName: gvisor`, nothing else about the
Deployment has is changed.
diff --git a/images/Makefile b/images/Makefile
index 278dec02f..d4b6524ba 100644
--- a/images/Makefile
+++ b/images/Makefile
@@ -59,9 +59,9 @@ local_image = $(LOCAL_IMAGE_PREFIX)/$(subst _,/,$(1))
# we need to explicitly repull the base layer in order to ensure that the
# architecture is correct. Note that we use the term "rebuild" here to avoid
# conflicting with the bazel "build" terminology, which is used elsewhere.
-rebuild-%: FROM=$(shell grep FROM $(call path,$*)/Dockerfile } cut -d' ' -f2)
+rebuild-%: FROM=$(shell grep FROM $(call path,$*)/Dockerfile | cut -d' ' -f2)
rebuild-%: register-cross
- $(foreach IMAGE,$(FROM),docker $(DOCKER_PLATFORM_ARGS) $(IMAGE); &&) true
+ $(foreach IMAGE,$(FROM),docker pull $(DOCKER_PLATFORM_ARGS) $(IMAGE) &&) \
T=$$(mktemp -d) && cp -a $(call path,$*)/* $$T && \
docker build $(DOCKER_PLATFORM_ARGS) -t $(call remote_image,$*) $$T && \
rm -rf $$T
@@ -73,10 +73,10 @@ pull-%:
docker pull $(DOCKER_PLATFORM_ARGS) $(call remote_image,$*)
# load will either pull the "remote" or build it locally. This is the preferred
-# entrypoint, as it should never file. The local tag should always be set after
+# entrypoint, as it should never fail. The local tag should always be set after
# this returns (either by the pull or the build).
load-%:
- docker inspect $(call remote_image,$*) >/dev/null 2>&1 || $(MAKE) pull-$* || $(MAKE) rebuild-$*
+ $(MAKE) pull-$* || $(MAKE) rebuild-$*
docker tag $(call remote_image,$*) $(call local_image,$*)
# push pushes the remote image, after either pulling (to validate that the tag
diff --git a/images/benchmarks/httpd/Dockerfile b/images/benchmarks/httpd/Dockerfile
index b72406012..e95538a40 100644
--- a/images/benchmarks/httpd/Dockerfile
+++ b/images/benchmarks/httpd/Dockerfile
@@ -8,7 +8,7 @@ RUN set -x \
# Generate a bunch of relevant files.
RUN mkdir -p /local && \
- for size in 1 10 100 1000 1024 10240; do \
+ for size in 1 10 100 1024 10240; do \
dd if=/dev/zero of=/local/latin${size}k.txt count=${size} bs=1024; \
done
diff --git a/images/benchmarks/nginx/Dockerfile b/images/benchmarks/nginx/Dockerfile
index b64eb52ae..c8e3330d0 100644
--- a/images/benchmarks/nginx/Dockerfile
+++ b/images/benchmarks/nginx/Dockerfile
@@ -1 +1,12 @@
FROM nginx:1.15.10
+
+# Generate a bunch of relevant files.
+RUN mkdir -p /local && \
+ for size in 1 10 100 1024 10240; do \
+ dd if=/dev/zero of=/local/latin${size}k.txt count=${size} bs=1024; \
+ done
+
+RUN touch /local/index.html
+
+COPY ./nginx.conf /etc/nginx/nginx.conf
+COPY ./nginx_gofer.conf /etc/nginx/nginx_gofer.conf
diff --git a/images/benchmarks/nginx/nginx.conf b/images/benchmarks/nginx/nginx.conf
new file mode 100644
index 000000000..2c43c0cda
--- /dev/null
+++ b/images/benchmarks/nginx/nginx.conf
@@ -0,0 +1,19 @@
+user nginx;
+worker_processes 1;
+daemon off;
+
+error_log /var/log/nginx/error.log warn;
+pid /var/run/nginx.pid;
+
+events {
+ worker_connections 1024;
+}
+
+
+http {
+ server {
+ location / {
+ root /tmp/html;
+ }
+ }
+}
diff --git a/images/benchmarks/nginx/nginx_gofer.conf b/images/benchmarks/nginx/nginx_gofer.conf
new file mode 100644
index 000000000..dbba2a575
--- /dev/null
+++ b/images/benchmarks/nginx/nginx_gofer.conf
@@ -0,0 +1,19 @@
+user nginx;
+worker_processes 1;
+daemon off;
+
+error_log /var/log/nginx/error.log warn;
+pid /var/run/nginx.pid;
+
+events {
+ worker_connections 1024;
+}
+
+
+http {
+ server {
+ location / {
+ root /local;
+ }
+ }
+}
diff --git a/images/jekyll/Dockerfile b/images/jekyll/Dockerfile
index ba039ba15..ae19f3bfc 100644
--- a/images/jekyll/Dockerfile
+++ b/images/jekyll/Dockerfile
@@ -1,5 +1,6 @@
FROM jekyll/jekyll:4.0.0
USER root
+
RUN gem install \
html-proofer:3.10.2 \
nokogiri:1.10.1 \
@@ -10,5 +11,9 @@ RUN gem install \
jekyll-relative-links:0.6.1 \
jekyll-feed:0.13.0 \
jekyll-sitemap:1.4.0
+
+# checks.rb is used with html-proofer for presubmit checks.
COPY checks.rb /checks.rb
-CMD ["/usr/gem/gems/jekyll-4.0.0/exe/jekyll", "build", "-t", "-s", "/input", "-d", "/output"]
+
+COPY build.sh /build.sh
+CMD ["/build.sh"]
diff --git a/scripts/fuse_tests.sh b/images/jekyll/build.sh
index bbaaa99fc..010972ea6 100755
--- a/scripts/fuse_tests.sh
+++ b/images/jekyll/build.sh
@@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-source $(dirname $0)/common.sh
+set -euxo pipefail
-# Run all vfs2_fuse system call tests.
-test --test_tag_filters=fuse //test/fuse/...
+# Generate the syntax highlighting css file.
+/usr/gem/bin/rougify style github >/input/_sass/syntax.css
+# Build website including pages irrespective of date.
+/usr/gem/bin/jekyll build --future -t -s /input -d /output
diff --git a/images/packetdrill/Dockerfile b/images/packetdrill/Dockerfile
index 01296dbaf..b4cd73006 100644
--- a/images/packetdrill/Dockerfile
+++ b/images/packetdrill/Dockerfile
@@ -1,8 +1,8 @@
FROM ubuntu:bionic
RUN apt-get update && apt-get install -y net-tools git iptables iputils-ping \
netcat tcpdump jq tar bison flex make
+# Pick up updated git.
RUN hash -r
RUN git clone --depth 1 --branch packetdrill-v2.0 \
https://github.com/google/packetdrill.git
RUN cd packetdrill/gtests/net/packetdrill && ./configure && make
-CMD /bin/bash
diff --git a/images/packetimpact/Dockerfile b/images/packetimpact/Dockerfile
index 87aa99ef2..906d5cdd6 100644
--- a/images/packetimpact/Dockerfile
+++ b/images/packetimpact/Dockerfile
@@ -1,4 +1,4 @@
-FROM ubuntu:bionic
+FROM ubuntu:focal
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
# iptables to disable OS native packet processing.
iptables \
@@ -11,6 +11,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
# tshark to log verbose packet sniffing.
tshark \
# killall for cleanup.
- psmisc
-RUN hash -r
-CMD /bin/bash
+ psmisc \
+ # qemu-system-x86 to emulate fuchsia.
+ qemu-system-x86 \
+ # sha1sum to generate entropy.
+ libdigest-sha-perl
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index b5c5cc20b..cdcaa8c73 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -74,9 +74,9 @@ go_library(
"//pkg/abi",
"//pkg/binary",
"//pkg/bits",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/usermem",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/abi/linux/aio.go b/pkg/abi/linux/aio.go
index 86ee3f8b5..5fc099892 100644
--- a/pkg/abi/linux/aio.go
+++ b/pkg/abi/linux/aio.go
@@ -42,6 +42,8 @@ const (
//
// The priority field is currently ignored in the implementation below. Also
// note that the IOCB_FLAG_RESFD feature is not supported.
+//
+// +marshal
type IOCallback struct {
Data uint64
Key uint32
@@ -64,6 +66,7 @@ type IOCallback struct {
// IOEvent describes an I/O result.
//
+// +marshal
// +stateify savable
type IOEvent struct {
Data uint64
diff --git a/pkg/abi/linux/bpf.go b/pkg/abi/linux/bpf.go
index aa3d3ce70..9422fcf69 100644
--- a/pkg/abi/linux/bpf.go
+++ b/pkg/abi/linux/bpf.go
@@ -16,6 +16,7 @@ package linux
// BPFInstruction is a raw BPF virtual machine instruction.
//
+// +marshal slice:BPFInstructionSlice
// +stateify savable
type BPFInstruction struct {
// OpCode is the operation to execute.
diff --git a/pkg/abi/linux/capability.go b/pkg/abi/linux/capability.go
index 965f74663..afd16cc27 100644
--- a/pkg/abi/linux/capability.go
+++ b/pkg/abi/linux/capability.go
@@ -177,12 +177,16 @@ const (
)
// CapUserHeader is equivalent to Linux's cap_user_header_t.
+//
+// +marshal
type CapUserHeader struct {
Version uint32
Pid int32
}
// CapUserData is equivalent to Linux's cap_user_data_t.
+//
+// +marshal slice:CapUserDataSlice
type CapUserData struct {
Effective uint32
Permitted uint32
diff --git a/pkg/abi/linux/dev.go b/pkg/abi/linux/dev.go
index 192e2093b..7771650b3 100644
--- a/pkg/abi/linux/dev.go
+++ b/pkg/abi/linux/dev.go
@@ -54,9 +54,9 @@ const (
// Unix98 PTY masters.
UNIX98_PTY_MASTER_MAJOR = 128
- // UNIX98_PTY_SLAVE_MAJOR is the initial major device number for
- // Unix98 PTY slaves.
- UNIX98_PTY_SLAVE_MAJOR = 136
+ // UNIX98_PTY_REPLICA_MAJOR is the initial major device number for
+ // Unix98 PTY replicas.
+ UNIX98_PTY_REPLICA_MAJOR = 136
)
// Minor device numbers for TTYAUX_MAJOR.
diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go
index 9242e80a5..cc3571fad 100644
--- a/pkg/abi/linux/fcntl.go
+++ b/pkg/abi/linux/fcntl.go
@@ -45,6 +45,8 @@ const (
)
// Flock is the lock structure for F_SETLK.
+//
+// +marshal
type Flock struct {
Type int16
Whence int16
@@ -63,6 +65,8 @@ const (
)
// FOwnerEx is the owner structure for F_SETOWN_EX and F_GETOWN_EX.
+//
+// +marshal
type FOwnerEx struct {
Type int32
PID int32
diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go
index 7e30483ee..d91c97a64 100644
--- a/pkg/abi/linux/fuse.go
+++ b/pkg/abi/linux/fuse.go
@@ -14,12 +14,20 @@
package linux
+import (
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
// +marshal
type FUSEOpcode uint32
// +marshal
type FUSEOpID uint64
+// FUSE_ROOT_ID is the id of root inode.
+const FUSE_ROOT_ID = 1
+
// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h.
const (
FUSE_LOOKUP FUSEOpcode = 1
@@ -116,61 +124,28 @@ type FUSEHeaderOut struct {
Unique FUSEOpID
}
-// FUSEWriteIn is the header written by a daemon when it makes a
-// write request to the FUSE filesystem.
-//
-// +marshal
-type FUSEWriteIn struct {
- // Fh specifies the file handle that is being written to.
- Fh uint64
-
- // Offset is the offset of the write.
- Offset uint64
-
- // Size is the size of data being written.
- Size uint32
-
- // WriteFlags is the flags used during the write.
- WriteFlags uint32
-
- // LockOwner is the ID of the lock owner.
- LockOwner uint64
-
- // Flags is the flags for the request.
- Flags uint32
-
- _ uint32
-}
-
// FUSE_INIT flags, consistent with the ones in include/uapi/linux/fuse.h.
+// Our taget version is 7.23 but we have few implemented in advance.
const (
- FUSE_ASYNC_READ = 1 << 0
- FUSE_POSIX_LOCKS = 1 << 1
- FUSE_FILE_OPS = 1 << 2
- FUSE_ATOMIC_O_TRUNC = 1 << 3
- FUSE_EXPORT_SUPPORT = 1 << 4
- FUSE_BIG_WRITES = 1 << 5
- FUSE_DONT_MASK = 1 << 6
- FUSE_SPLICE_WRITE = 1 << 7
- FUSE_SPLICE_MOVE = 1 << 8
- FUSE_SPLICE_READ = 1 << 9
- FUSE_FLOCK_LOCKS = 1 << 10
- FUSE_HAS_IOCTL_DIR = 1 << 11
- FUSE_AUTO_INVAL_DATA = 1 << 12
- FUSE_DO_READDIRPLUS = 1 << 13
- FUSE_READDIRPLUS_AUTO = 1 << 14
- FUSE_ASYNC_DIO = 1 << 15
- FUSE_WRITEBACK_CACHE = 1 << 16
- FUSE_NO_OPEN_SUPPORT = 1 << 17
- FUSE_PARALLEL_DIROPS = 1 << 18
- FUSE_HANDLE_KILLPRIV = 1 << 19
- FUSE_POSIX_ACL = 1 << 20
- FUSE_ABORT_ERROR = 1 << 21
- FUSE_MAX_PAGES = 1 << 22
- FUSE_CACHE_SYMLINKS = 1 << 23
- FUSE_NO_OPENDIR_SUPPORT = 1 << 24
- FUSE_EXPLICIT_INVAL_DATA = 1 << 25
- FUSE_MAP_ALIGNMENT = 1 << 26
+ FUSE_ASYNC_READ = 1 << 0
+ FUSE_POSIX_LOCKS = 1 << 1
+ FUSE_FILE_OPS = 1 << 2
+ FUSE_ATOMIC_O_TRUNC = 1 << 3
+ FUSE_EXPORT_SUPPORT = 1 << 4
+ FUSE_BIG_WRITES = 1 << 5
+ FUSE_DONT_MASK = 1 << 6
+ FUSE_SPLICE_WRITE = 1 << 7
+ FUSE_SPLICE_MOVE = 1 << 8
+ FUSE_SPLICE_READ = 1 << 9
+ FUSE_FLOCK_LOCKS = 1 << 10
+ FUSE_HAS_IOCTL_DIR = 1 << 11
+ FUSE_AUTO_INVAL_DATA = 1 << 12
+ FUSE_DO_READDIRPLUS = 1 << 13
+ FUSE_READDIRPLUS_AUTO = 1 << 14
+ FUSE_ASYNC_DIO = 1 << 15
+ FUSE_WRITEBACK_CACHE = 1 << 16
+ FUSE_NO_OPEN_SUPPORT = 1 << 17
+ FUSE_MAX_PAGES = 1 << 22 // From FUSE 7.28
)
// currently supported FUSE protocol version numbers.
@@ -179,6 +154,13 @@ const (
FUSE_KERNEL_MINOR_VERSION = 31
)
+// Constants relevant to FUSE operations.
+const (
+ FUSE_NAME_MAX = 1024
+ FUSE_PAGE_SIZE = 4096
+ FUSE_DIRENT_ALIGN = 8
+)
+
// FUSEInitIn is the request sent by the kernel to the daemon,
// to negotiate the version and flags.
//
@@ -199,7 +181,7 @@ type FUSEInitIn struct {
}
// FUSEInitOut is the reply sent by the daemon to the kernel
-// for FUSEInitIn.
+// for FUSEInitIn. We target FUSE 7.23; this struct supports 7.28.
//
// +marshal
type FUSEInitOut struct {
@@ -240,13 +222,16 @@ type FUSEInitOut struct {
// if the value from daemon is too large.
MaxPages uint16
- // MapAlignment is an unknown field and not used by this package at this moment.
- // Use as a placeholder to be consistent with the FUSE protocol.
- MapAlignment uint16
+ _ uint16
_ [8]uint32
}
+// FUSE_GETATTR_FH is currently the only flag of FUSEGetAttrIn.GetAttrFlags.
+// If it is set, the file handle (FUSEGetAttrIn.Fh) is used to indicate the
+// object instead of the node id attribute in the request header.
+const FUSE_GETATTR_FH = (1 << 0)
+
// FUSEGetAttrIn is the request sent by the kernel to the daemon,
// to get the attribute of a inode.
//
@@ -267,22 +252,52 @@ type FUSEGetAttrIn struct {
//
// +marshal
type FUSEAttr struct {
- Ino uint64
- Size uint64
- Blocks uint64
- Atime uint64
- Mtime uint64
- Ctime uint64
+ // Ino is the inode number of this file.
+ Ino uint64
+
+ // Size is the size of this file.
+ Size uint64
+
+ // Blocks is the number of the 512B blocks allocated by this file.
+ Blocks uint64
+
+ // Atime is the time of last access.
+ Atime uint64
+
+ // Mtime is the time of last modification.
+ Mtime uint64
+
+ // Ctime is the time of last status change.
+ Ctime uint64
+
+ // AtimeNsec is the nano second part of Atime.
AtimeNsec uint32
+
+ // MtimeNsec is the nano second part of Mtime.
MtimeNsec uint32
+
+ // CtimeNsec is the nano second part of Ctime.
CtimeNsec uint32
- Mode uint32
- Nlink uint32
- UID uint32
- GID uint32
- Rdev uint32
- BlkSize uint32
- _ uint32
+
+ // Mode contains the file type and mode.
+ Mode uint32
+
+ // Nlink is the number of the hard links.
+ Nlink uint32
+
+ // UID is user ID of the owner.
+ UID uint32
+
+ // GID is group ID of the owner.
+ GID uint32
+
+ // Rdev is the device ID if this is a special file.
+ Rdev uint32
+
+ // BlkSize is the block size for filesystem I/O.
+ BlkSize uint32
+
+ _ uint32
}
// FUSEGetAttrOut is the reply sent by the daemon to the kernel
@@ -301,3 +316,558 @@ type FUSEGetAttrOut struct {
// Attr contains the metadata returned from the FUSE server
Attr FUSEAttr
}
+
+// FUSEEntryOut is the reply sent by the daemon to the kernel
+// for FUSE_MKNOD, FUSE_MKDIR, FUSE_SYMLINK, FUSE_LINK and
+// FUSE_LOOKUP.
+//
+// +marshal
+type FUSEEntryOut struct {
+ // NodeID is the ID for current inode.
+ NodeID uint64
+
+ // Generation is the generation number of inode.
+ // Used to identify an inode that have different ID at different time.
+ Generation uint64
+
+ // EntryValid indicates timeout for an entry.
+ EntryValid uint64
+
+ // AttrValid indicates timeout for an entry's attributes.
+ AttrValid uint64
+
+ // EntryValidNsec indicates timeout for an entry in nanosecond.
+ EntryValidNSec uint32
+
+ // AttrValidNsec indicates timeout for an entry's attributes in nanosecond.
+ AttrValidNSec uint32
+
+ // Attr contains the attributes of an entry.
+ Attr FUSEAttr
+}
+
+// FUSELookupIn is the request sent by the kernel to the daemon
+// to look up a file name.
+//
+// Dynamically-sized objects cannot be marshalled.
+type FUSELookupIn struct {
+ marshal.StubMarshallable
+
+ // Name is a file name to be looked up.
+ Name string
+}
+
+// MarshalBytes serializes r.name to the dst buffer.
+func (r *FUSELookupIn) MarshalBytes(buf []byte) {
+ copy(buf, r.Name)
+}
+
+// SizeBytes is the size of the memory representation of FUSELookupIn.
+// 1 extra byte for null-terminated string.
+func (r *FUSELookupIn) SizeBytes() int {
+ return len(r.Name) + 1
+}
+
+// MAX_NON_LFS indicates the maximum offset without large file support.
+const MAX_NON_LFS = ((1 << 31) - 1)
+
+// flags returned by OPEN request.
+const (
+ // FOPEN_DIRECT_IO indicates bypassing page cache for this opened file.
+ FOPEN_DIRECT_IO = 1 << 0
+ // FOPEN_KEEP_CACHE avoids invalidate of data cache on open.
+ FOPEN_KEEP_CACHE = 1 << 1
+ // FOPEN_NONSEEKABLE indicates the file cannot be seeked.
+ FOPEN_NONSEEKABLE = 1 << 2
+)
+
+// FUSEOpenIn is the request sent by the kernel to the daemon,
+// to negotiate flags and get file handle.
+//
+// +marshal
+type FUSEOpenIn struct {
+ // Flags of this open request.
+ Flags uint32
+
+ _ uint32
+}
+
+// FUSEOpenOut is the reply sent by the daemon to the kernel
+// for FUSEOpenIn.
+//
+// +marshal
+type FUSEOpenOut struct {
+ // Fh is the file handler for opened file.
+ Fh uint64
+
+ // OpenFlag for the opened file.
+ OpenFlag uint32
+
+ _ uint32
+}
+
+// FUSE_READ flags, consistent with the ones in include/uapi/linux/fuse.h.
+const (
+ FUSE_READ_LOCKOWNER = 1 << 1
+)
+
+// FUSEReadIn is the request sent by the kernel to the daemon
+// for FUSE_READ.
+//
+// +marshal
+type FUSEReadIn struct {
+ // Fh is the file handle in userspace.
+ Fh uint64
+
+ // Offset is the read offset.
+ Offset uint64
+
+ // Size is the number of bytes to read.
+ Size uint32
+
+ // ReadFlags for this FUSE_READ request.
+ // Currently only contains FUSE_READ_LOCKOWNER.
+ ReadFlags uint32
+
+ // LockOwner is the id of the lock owner if there is one.
+ LockOwner uint64
+
+ // Flags for the underlying file.
+ Flags uint32
+
+ _ uint32
+}
+
+// FUSEWriteIn is the first part of the payload of the
+// request sent by the kernel to the daemon
+// for FUSE_WRITE (struct for FUSE version >= 7.9).
+//
+// The second part of the payload is the
+// binary bytes of the data to be written.
+//
+// +marshal
+type FUSEWriteIn struct {
+ // Fh is the file handle in userspace.
+ Fh uint64
+
+ // Offset is the write offset.
+ Offset uint64
+
+ // Size is the number of bytes to write.
+ Size uint32
+
+ // ReadFlags for this FUSE_WRITE request.
+ WriteFlags uint32
+
+ // LockOwner is the id of the lock owner if there is one.
+ LockOwner uint64
+
+ // Flags for the underlying file.
+ Flags uint32
+
+ _ uint32
+}
+
+// FUSEWriteOut is the payload of the reply sent by the daemon to the kernel
+// for a FUSE_WRITE request.
+//
+// +marshal
+type FUSEWriteOut struct {
+ // Size is the number of bytes written.
+ Size uint32
+
+ _ uint32
+}
+
+// FUSEReleaseIn is the request sent by the kernel to the daemon
+// when there is no more reference to a file.
+//
+// +marshal
+type FUSEReleaseIn struct {
+ // Fh is the file handler for the file to be released.
+ Fh uint64
+
+ // Flags of the file.
+ Flags uint32
+
+ // ReleaseFlags of this release request.
+ ReleaseFlags uint32
+
+ // LockOwner is the id of the lock owner if there is one.
+ LockOwner uint64
+}
+
+// FUSECreateMeta contains all the static fields of FUSECreateIn,
+// which is used for FUSE_CREATE.
+//
+// +marshal
+type FUSECreateMeta struct {
+ // Flags of the creating file.
+ Flags uint32
+
+ // Mode is the mode of the creating file.
+ Mode uint32
+
+ // Umask is the current file mode creation mask.
+ Umask uint32
+ _ uint32
+}
+
+// FUSECreateIn contains all the arguments sent by the kernel to the daemon, to
+// atomically create and open a new regular file.
+//
+// Dynamically-sized objects cannot be marshalled.
+type FUSECreateIn struct {
+ marshal.StubMarshallable
+
+ // CreateMeta contains mode, rdev and umash field for FUSE_MKNODS.
+ CreateMeta FUSECreateMeta
+
+ // Name is the name of the node to create.
+ Name string
+}
+
+// MarshalBytes serializes r.CreateMeta and r.Name to the dst buffer.
+func (r *FUSECreateIn) MarshalBytes(buf []byte) {
+ r.CreateMeta.MarshalBytes(buf[:r.CreateMeta.SizeBytes()])
+ copy(buf[r.CreateMeta.SizeBytes():], r.Name)
+}
+
+// SizeBytes is the size of the memory representation of FUSECreateIn.
+// 1 extra byte for null-terminated string.
+func (r *FUSECreateIn) SizeBytes() int {
+ return r.CreateMeta.SizeBytes() + len(r.Name) + 1
+}
+
+// FUSEMknodMeta contains all the static fields of FUSEMknodIn,
+// which is used for FUSE_MKNOD.
+//
+// +marshal
+type FUSEMknodMeta struct {
+ // Mode of the inode to create.
+ Mode uint32
+
+ // Rdev encodes device major and minor information.
+ Rdev uint32
+
+ // Umask is the current file mode creation mask.
+ Umask uint32
+
+ _ uint32
+}
+
+// FUSEMknodIn contains all the arguments sent by the kernel
+// to the daemon, to create a new file node.
+//
+// Dynamically-sized objects cannot be marshalled.
+type FUSEMknodIn struct {
+ marshal.StubMarshallable
+
+ // MknodMeta contains mode, rdev and umash field for FUSE_MKNODS.
+ MknodMeta FUSEMknodMeta
+
+ // Name is the name of the node to create.
+ Name string
+}
+
+// MarshalBytes serializes r.MknodMeta and r.Name to the dst buffer.
+func (r *FUSEMknodIn) MarshalBytes(buf []byte) {
+ r.MknodMeta.MarshalBytes(buf[:r.MknodMeta.SizeBytes()])
+ copy(buf[r.MknodMeta.SizeBytes():], r.Name)
+}
+
+// SizeBytes is the size of the memory representation of FUSEMknodIn.
+// 1 extra byte for null-terminated string.
+func (r *FUSEMknodIn) SizeBytes() int {
+ return r.MknodMeta.SizeBytes() + len(r.Name) + 1
+}
+
+// FUSESymLinkIn is the request sent by the kernel to the daemon,
+// to create a symbolic link.
+//
+// Dynamically-sized objects cannot be marshalled.
+type FUSESymLinkIn struct {
+ marshal.StubMarshallable
+
+ // Name of symlink to create.
+ Name string
+
+ // Target of the symlink.
+ Target string
+}
+
+// MarshalBytes serializes r.Name and r.Target to the dst buffer.
+// Left null-termination at end of r.Name and r.Target.
+func (r *FUSESymLinkIn) MarshalBytes(buf []byte) {
+ copy(buf, r.Name)
+ copy(buf[len(r.Name)+1:], r.Target)
+}
+
+// SizeBytes is the size of the memory representation of FUSESymLinkIn.
+// 2 extra bytes for null-terminated string.
+func (r *FUSESymLinkIn) SizeBytes() int {
+ return len(r.Name) + len(r.Target) + 2
+}
+
+// FUSEEmptyIn is used by operations without request body.
+type FUSEEmptyIn struct{ marshal.StubMarshallable }
+
+// MarshalBytes do nothing for marshal.
+func (r *FUSEEmptyIn) MarshalBytes(buf []byte) {}
+
+// SizeBytes is 0 for empty request.
+func (r *FUSEEmptyIn) SizeBytes() int {
+ return 0
+}
+
+// FUSEMkdirMeta contains all the static fields of FUSEMkdirIn,
+// which is used for FUSE_MKDIR.
+//
+// +marshal
+type FUSEMkdirMeta struct {
+ // Mode of the directory of create.
+ Mode uint32
+
+ // Umask is the user file creation mask.
+ Umask uint32
+}
+
+// FUSEMkdirIn contains all the arguments sent by the kernel
+// to the daemon, to create a new directory.
+//
+// Dynamically-sized objects cannot be marshalled.
+type FUSEMkdirIn struct {
+ marshal.StubMarshallable
+
+ // MkdirMeta contains Mode and Umask of the directory to create.
+ MkdirMeta FUSEMkdirMeta
+
+ // Name of the directory to create.
+ Name string
+}
+
+// MarshalBytes serializes r.MkdirMeta and r.Name to the dst buffer.
+func (r *FUSEMkdirIn) MarshalBytes(buf []byte) {
+ r.MkdirMeta.MarshalBytes(buf[:r.MkdirMeta.SizeBytes()])
+ copy(buf[r.MkdirMeta.SizeBytes():], r.Name)
+}
+
+// SizeBytes is the size of the memory representation of FUSEMkdirIn.
+// 1 extra byte for null-terminated Name string.
+func (r *FUSEMkdirIn) SizeBytes() int {
+ return r.MkdirMeta.SizeBytes() + len(r.Name) + 1
+}
+
+// FUSERmDirIn is the request sent by the kernel to the daemon
+// when trying to remove a directory.
+//
+// Dynamically-sized objects cannot be marshalled.
+type FUSERmDirIn struct {
+ marshal.StubMarshallable
+
+ // Name is a directory name to be removed.
+ Name string
+}
+
+// MarshalBytes serializes r.name to the dst buffer.
+func (r *FUSERmDirIn) MarshalBytes(buf []byte) {
+ copy(buf, r.Name)
+}
+
+// SizeBytes is the size of the memory representation of FUSERmDirIn.
+func (r *FUSERmDirIn) SizeBytes() int {
+ return len(r.Name) + 1
+}
+
+// FUSEDirents is a list of Dirents received from the FUSE daemon server.
+// It is used for FUSE_READDIR.
+//
+// Dynamically-sized objects cannot be marshalled.
+type FUSEDirents struct {
+ marshal.StubMarshallable
+
+ Dirents []*FUSEDirent
+}
+
+// FUSEDirent is a Dirent received from the FUSE daemon server.
+// It is used for FUSE_READDIR.
+//
+// Dynamically-sized objects cannot be marshalled.
+type FUSEDirent struct {
+ marshal.StubMarshallable
+
+ // Meta contains all the static fields of FUSEDirent.
+ Meta FUSEDirentMeta
+
+ // Name is the filename of the dirent.
+ Name string
+}
+
+// FUSEDirentMeta contains all the static fields of FUSEDirent.
+// It is used for FUSE_READDIR.
+//
+// +marshal
+type FUSEDirentMeta struct {
+ // Inode of the dirent.
+ Ino uint64
+
+ // Offset of the dirent.
+ Off uint64
+
+ // NameLen is the length of the dirent name.
+ NameLen uint32
+
+ // Type of the dirent.
+ Type uint32
+}
+
+// SizeBytes is the size of the memory representation of FUSEDirents.
+func (r *FUSEDirents) SizeBytes() int {
+ var sizeBytes int
+ for _, dirent := range r.Dirents {
+ sizeBytes += dirent.SizeBytes()
+ }
+
+ return sizeBytes
+}
+
+// UnmarshalBytes deserializes FUSEDirents from the src buffer.
+func (r *FUSEDirents) UnmarshalBytes(src []byte) {
+ for {
+ if len(src) <= (*FUSEDirentMeta)(nil).SizeBytes() {
+ break
+ }
+
+ // Its unclear how many dirents there are in src. Each dirent is dynamically
+ // sized and so we can't make assumptions about how many dirents we can allocate.
+ if r.Dirents == nil {
+ r.Dirents = make([]*FUSEDirent, 0)
+ }
+
+ // We have to allocate a struct for each dirent - there must be a better way
+ // to do this. Linux allocates 1 page to store all the dirents and then
+ // simply reads them from the page.
+ var dirent FUSEDirent
+ dirent.UnmarshalBytes(src)
+ r.Dirents = append(r.Dirents, &dirent)
+
+ src = src[dirent.SizeBytes():]
+ }
+}
+
+// SizeBytes is the size of the memory representation of FUSEDirent.
+func (r *FUSEDirent) SizeBytes() int {
+ dataSize := r.Meta.SizeBytes() + len(r.Name)
+
+ // Each Dirent must be padded such that its size is a multiple
+ // of FUSE_DIRENT_ALIGN. Similar to the fuse dirent alignment
+ // in linux/fuse.h.
+ return (dataSize + (FUSE_DIRENT_ALIGN - 1)) & ^(FUSE_DIRENT_ALIGN - 1)
+}
+
+// UnmarshalBytes deserializes FUSEDirent from the src buffer.
+func (r *FUSEDirent) UnmarshalBytes(src []byte) {
+ r.Meta.UnmarshalBytes(src)
+ src = src[r.Meta.SizeBytes():]
+
+ if r.Meta.NameLen > FUSE_NAME_MAX {
+ // The name is too long and therefore invalid. We don't
+ // need to unmarshal the name since it'll be thrown away.
+ return
+ }
+
+ buf := make([]byte, r.Meta.NameLen)
+ name := primitive.ByteSlice(buf)
+ name.UnmarshalBytes(src[:r.Meta.NameLen])
+ r.Name = string(name)
+}
+
+// FATTR_* consts are the attribute flags defined in include/uapi/linux/fuse.h.
+// These should be or-ed together for setattr to know what has been changed.
+const (
+ FATTR_MODE = (1 << 0)
+ FATTR_UID = (1 << 1)
+ FATTR_GID = (1 << 2)
+ FATTR_SIZE = (1 << 3)
+ FATTR_ATIME = (1 << 4)
+ FATTR_MTIME = (1 << 5)
+ FATTR_FH = (1 << 6)
+ FATTR_ATIME_NOW = (1 << 7)
+ FATTR_MTIME_NOW = (1 << 8)
+ FATTR_LOCKOWNER = (1 << 9)
+ FATTR_CTIME = (1 << 10)
+)
+
+// FUSESetAttrIn is the request sent by the kernel to the daemon,
+// to set the attribute(s) of a file.
+//
+// +marshal
+type FUSESetAttrIn struct {
+ // Valid indicates which attributes are modified by this request.
+ Valid uint32
+
+ _ uint32
+
+ // Fh is used to identify the file if FATTR_FH is set in Valid.
+ Fh uint64
+
+ // Size is the size that the request wants to change to.
+ Size uint64
+
+ // LockOwner is the owner of the lock that the request wants to change to.
+ LockOwner uint64
+
+ // Atime is the access time that the request wants to change to.
+ Atime uint64
+
+ // Mtime is the modification time that the request wants to change to.
+ Mtime uint64
+
+ // Ctime is the status change time that the request wants to change to.
+ Ctime uint64
+
+ // AtimeNsec is the nano second part of Atime.
+ AtimeNsec uint32
+
+ // MtimeNsec is the nano second part of Mtime.
+ MtimeNsec uint32
+
+ // CtimeNsec is the nano second part of Ctime.
+ CtimeNsec uint32
+
+ // Mode is the file mode that the request wants to change to.
+ Mode uint32
+
+ _ uint32
+
+ // UID is the user ID of the owner that the request wants to change to.
+ UID uint32
+
+ // GID is the group ID of the owner that the request wants to change to.
+ GID uint32
+
+ _ uint32
+}
+
+// FUSEUnlinkIn is the request sent by the kernel to the daemon
+// when trying to unlink a node.
+//
+// Dynamically-sized objects cannot be marshalled.
+type FUSEUnlinkIn struct {
+ marshal.StubMarshallable
+
+ // Name of the node to unlink.
+ Name string
+}
+
+// MarshalBytes serializes r.name to the dst buffer, which should
+// have size len(r.Name) + 1 and last byte set to 0.
+func (r *FUSEUnlinkIn) MarshalBytes(buf []byte) {
+ copy(buf, r.Name)
+}
+
+// SizeBytes is the size of the memory representation of FUSEUnlinkIn.
+// 1 extra byte for null-terminated Name string.
+func (r *FUSEUnlinkIn) SizeBytes() int {
+ return len(r.Name) + 1
+}
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index d6dbedc3e..7df02dd6d 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -113,6 +113,35 @@ const (
_IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS
)
+// Constants from uapi/linux/fs.h.
+const (
+ FS_IOC_GETFLAGS = 2148034049
+ FS_VERITY_FL = 1048576
+)
+
+// Constants from uapi/linux/fsverity.h.
+const (
+ FS_IOC_ENABLE_VERITY = 1082156677
+ FS_IOC_MEASURE_VERITY = 3221513862
+)
+
+// DigestMetadata is a helper struct for VerityDigest.
+//
+// +marshal
+type DigestMetadata struct {
+ DigestAlgorithm uint16
+ DigestSize uint16
+}
+
+// SizeOfDigestMetadata is the size of struct DigestMetadata.
+const SizeOfDigestMetadata = 4
+
+// VerityDigest is struct from uapi/linux/fsverity.h.
+type VerityDigest struct {
+ Metadata DigestMetadata
+ Digest []byte
+}
+
// IOC outputs the result of _IOC macro in asm-generic/ioctl.h.
func IOC(dir, typ, nr, size uint32) uint32 {
return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT
diff --git a/pkg/abi/linux/ipc.go b/pkg/abi/linux/ipc.go
index 22acd2d43..c6e65df62 100644
--- a/pkg/abi/linux/ipc.go
+++ b/pkg/abi/linux/ipc.go
@@ -37,6 +37,8 @@ const IPC_PRIVATE = 0
// features like 32-bit UIDs.
// IPCPerm is equivalent to struct ipc64_perm.
+//
+// +marshal
type IPCPerm struct {
Key uint32
UID uint32
diff --git a/pkg/abi/linux/linux.go b/pkg/abi/linux/linux.go
index 281acdbde..3b4abece1 100644
--- a/pkg/abi/linux/linux.go
+++ b/pkg/abi/linux/linux.go
@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package linux contains the constants and types needed to interface with a Linux kernel.
+// Package linux contains the constants and types needed to interface with a
+// Linux kernel.
package linux
// NumSoftIRQ is the number of software IRQs, exposed via /proc/stat.
@@ -21,6 +22,8 @@ package linux
const NumSoftIRQ = 10
// Sysinfo is the structure provided by sysinfo on linux versions > 2.3.48.
+//
+// +marshal
type Sysinfo struct {
Uptime int64
Loads [3]uint64
@@ -34,6 +37,6 @@ type Sysinfo struct {
_ [6]byte // Pad Procs to 64bits.
TotalHigh uint64
FreeHigh uint64
- Unit uint32
- /* The _f field in the glibc version of Sysinfo has size 0 on AMD64 */
+ Unit uint32 `marshal:"unaligned"` // Struct ends mid-64-bit-word.
+ // The _f field in the glibc version of Sysinfo has size 0 on AMD64.
}
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 91e35366f..b521144d9 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -17,9 +17,9 @@ package linux
import (
"io"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// This file contains structures required to support netfilter, specifically
@@ -265,6 +265,18 @@ type KernelXTEntryMatch struct {
Data []byte
}
+// XTGetRevision corresponds to xt_get_revision in
+// include/uapi/linux/netfilter/x_tables.h
+//
+// +marshal
+type XTGetRevision struct {
+ Name ExtensionName
+ Revision uint8
+}
+
+// SizeOfXTGetRevision is the size of an XTGetRevision.
+const SizeOfXTGetRevision = 30
+
// XTEntryTarget holds a target for a rule. For example, it can specify that
// packets matching the rule should DROP, ACCEPT, or use an extension target.
// iptables-extension(8) has a list of possible targets.
@@ -285,6 +297,13 @@ type XTEntryTarget struct {
// SizeOfXTEntryTarget is the size of an XTEntryTarget.
const SizeOfXTEntryTarget = 32
+// KernelXTEntryTarget is identical to XTEntryTarget, but contains a
+// variable-length Data field.
+type KernelXTEntryTarget struct {
+ XTEntryTarget
+ Data []byte
+}
+
// XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE,
// RETURN, or jump. It corresponds to struct xt_standard_target in
// include/uapi/linux/netfilter/x_tables.h.
@@ -450,9 +469,9 @@ func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) {
}
// CopyIn implements marshal.Marshallable.CopyIn.
-func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
- length, err := task.CopyInBytes(addr, buf) // escapes: okay.
+func (ke *KernelIPTGetEntries) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
+ buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
+ length, err := cc.CopyInBytes(addr, buf) // escapes: okay.
// Unmarshal unconditionally. If we had a short copy-in, this results in a
// partially unmarshalled struct.
ke.UnmarshalBytes(buf) // escapes: fallback.
@@ -460,21 +479,21 @@ func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int
}
// CopyOut implements marshal.Marshallable.CopyOut.
-func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+func (ke *KernelIPTGetEntries) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
// Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
// back to MarshalBytes.
- return task.CopyOutBytes(addr, ke.marshalAll(task))
+ return cc.CopyOutBytes(addr, ke.marshalAll(cc))
}
// CopyOutN implements marshal.Marshallable.CopyOutN.
-func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+func (ke *KernelIPTGetEntries) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {
// Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
// back to MarshalBytes.
- return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit])
+ return cc.CopyOutBytes(addr, ke.marshalAll(cc)[:limit])
}
-func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte {
- buf := task.CopyScratchBuffer(ke.SizeBytes())
+func (ke *KernelIPTGetEntries) marshalAll(cc marshal.CopyContext) []byte {
+ buf := cc.CopyScratchBuffer(ke.SizeBytes())
ke.MarshalBytes(buf)
return buf
}
@@ -510,6 +529,8 @@ type IPTReplace struct {
const SizeOfIPTReplace = 96
// ExtensionName holds the name of a netfilter extension.
+//
+// +marshal
type ExtensionName [XT_EXTENSION_MAXNAMELEN]byte
// String implements fmt.Stringer.
diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go
index f6117024c..6d31eb5e3 100644
--- a/pkg/abi/linux/netfilter_ipv6.go
+++ b/pkg/abi/linux/netfilter_ipv6.go
@@ -17,9 +17,9 @@ package linux
import (
"io"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// This file contains structures required to support IPv6 netfilter and
@@ -128,9 +128,9 @@ func (ke *KernelIP6TGetEntries) UnmarshalUnsafe(src []byte) {
}
// CopyIn implements marshal.Marshallable.CopyIn.
-func (ke *KernelIP6TGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
- length, err := task.CopyInBytes(addr, buf) // escapes: okay.
+func (ke *KernelIP6TGetEntries) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
+ buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
+ length, err := cc.CopyInBytes(addr, buf) // escapes: okay.
// Unmarshal unconditionally. If we had a short copy-in, this results
// in a partially unmarshalled struct.
ke.UnmarshalBytes(buf) // escapes: fallback.
@@ -138,21 +138,21 @@ func (ke *KernelIP6TGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (in
}
// CopyOut implements marshal.Marshallable.CopyOut.
-func (ke *KernelIP6TGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+func (ke *KernelIP6TGetEntries) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
// Type KernelIP6TGetEntries doesn't have a packed layout in memory,
// fall back to MarshalBytes.
- return task.CopyOutBytes(addr, ke.marshalAll(task))
+ return cc.CopyOutBytes(addr, ke.marshalAll(cc))
}
// CopyOutN implements marshal.Marshallable.CopyOutN.
-func (ke *KernelIP6TGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+func (ke *KernelIP6TGetEntries) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {
// Type KernelIP6TGetEntries doesn't have a packed layout in memory, fall
// back to MarshalBytes.
- return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit])
+ return cc.CopyOutBytes(addr, ke.marshalAll(cc)[:limit])
}
-func (ke *KernelIP6TGetEntries) marshalAll(task marshal.Task) []byte {
- buf := task.CopyScratchBuffer(ke.SizeBytes())
+func (ke *KernelIP6TGetEntries) marshalAll(cc marshal.CopyContext) []byte {
+ buf := cc.CopyScratchBuffer(ke.SizeBytes())
ke.MarshalBytes(buf)
return buf
}
@@ -321,3 +321,16 @@ const (
// Enable all flags.
IP6T_INV_MASK = 0x7F
)
+
+// NFNATRange corresponds to struct nf_nat_range in
+// include/uapi/linux/netfilter/nf_nat.h.
+type NFNATRange struct {
+ Flags uint32
+ MinAddr Inet6Addr
+ MaxAddr Inet6Addr
+ MinProto uint16 // Network byte order.
+ MaxProto uint16 // Network byte order.
+}
+
+// SizeOfNFNATRange is the size of NFNATRange.
+const SizeOfNFNATRange = 40
diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go
index 0ba086c76..b41f94a69 100644
--- a/pkg/abi/linux/netlink.go
+++ b/pkg/abi/linux/netlink.go
@@ -40,6 +40,8 @@ const (
)
// SockAddrNetlink is struct sockaddr_nl, from uapi/linux/netlink.h.
+//
+// +marshal
type SockAddrNetlink struct {
Family uint16
_ uint16
diff --git a/pkg/abi/linux/poll.go b/pkg/abi/linux/poll.go
index c04d26e4c..3443a5768 100644
--- a/pkg/abi/linux/poll.go
+++ b/pkg/abi/linux/poll.go
@@ -15,6 +15,8 @@
package linux
// PollFD is struct pollfd, used by poll(2)/ppoll(2), from uapi/asm-generic/poll.h.
+//
+// +marshal slice:PollFDSlice
type PollFD struct {
FD int32
Events int16
diff --git a/pkg/abi/linux/rusage.go b/pkg/abi/linux/rusage.go
index d8302dc85..e29d0ac7e 100644
--- a/pkg/abi/linux/rusage.go
+++ b/pkg/abi/linux/rusage.go
@@ -26,6 +26,8 @@ const (
)
// Rusage represents the Linux struct rusage.
+//
+// +marshal
type Rusage struct {
UTime Timeval
STime Timeval
diff --git a/pkg/abi/linux/seccomp.go b/pkg/abi/linux/seccomp.go
index d0607e256..b07cafe12 100644
--- a/pkg/abi/linux/seccomp.go
+++ b/pkg/abi/linux/seccomp.go
@@ -34,11 +34,11 @@ type BPFAction uint32
const (
SECCOMP_RET_KILL_PROCESS BPFAction = 0x80000000
- SECCOMP_RET_KILL_THREAD = 0x00000000
- SECCOMP_RET_TRAP = 0x00030000
- SECCOMP_RET_ERRNO = 0x00050000
- SECCOMP_RET_TRACE = 0x7ff00000
- SECCOMP_RET_ALLOW = 0x7fff0000
+ SECCOMP_RET_KILL_THREAD BPFAction = 0x00000000
+ SECCOMP_RET_TRAP BPFAction = 0x00030000
+ SECCOMP_RET_ERRNO BPFAction = 0x00050000
+ SECCOMP_RET_TRACE BPFAction = 0x7ff00000
+ SECCOMP_RET_ALLOW BPFAction = 0x7fff0000
)
func (a BPFAction) String() string {
@@ -64,6 +64,19 @@ func (a BPFAction) Data() uint16 {
return uint16(a & SECCOMP_RET_DATA)
}
+// WithReturnCode sets the lower 16 bits of the SECCOMP_RET_ERRNO or
+// SECCOMP_RET_TRACE actions to the provided return code, overwriting the previous
+// action, and returns a new BPFAction. If not SECCOMP_RET_ERRNO or
+// SECCOMP_RET_TRACE then this panics.
+func (a BPFAction) WithReturnCode(code uint16) BPFAction {
+ // mask out the previous return value
+ baseAction := a & SECCOMP_RET_ACTION_FULL
+ if baseAction == SECCOMP_RET_ERRNO || baseAction == SECCOMP_RET_TRACE {
+ return BPFAction(uint32(baseAction) | uint32(code))
+ }
+ panic("WithReturnCode only valid for SECCOMP_RET_ERRNO and SECCOMP_RET_TRACE")
+}
+
// SockFprog is sock_fprog taken from <linux/filter.h>.
type SockFprog struct {
Len uint16
diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go
index de422c519..487a626cc 100644
--- a/pkg/abi/linux/sem.go
+++ b/pkg/abi/linux/sem.go
@@ -35,6 +35,8 @@ const (
const SEM_UNDO = 0x1000
// SemidDS is equivalent to struct semid64_ds.
+//
+// +marshal
type SemidDS struct {
SemPerm IPCPerm
SemOTime TimeT
@@ -45,6 +47,8 @@ type SemidDS struct {
}
// Sembuf is equivalent to struct sembuf.
+//
+// +marshal slice:SembufSlice
type Sembuf struct {
SemNum uint16
SemOp int16
diff --git a/pkg/abi/linux/shm.go b/pkg/abi/linux/shm.go
index e45aadb10..274b1e847 100644
--- a/pkg/abi/linux/shm.go
+++ b/pkg/abi/linux/shm.go
@@ -51,6 +51,8 @@ const (
// ShmidDS is equivalent to struct shmid64_ds. Source:
// include/uapi/asm-generic/shmbuf.h
+//
+// +marshal
type ShmidDS struct {
ShmPerm IPCPerm
ShmSegsz uint64
@@ -66,6 +68,8 @@ type ShmidDS struct {
}
// ShmParams is equivalent to struct shminfo. Source: include/uapi/linux/shm.h
+//
+// +marshal
type ShmParams struct {
ShmMax uint64
ShmMin uint64
@@ -75,6 +79,8 @@ type ShmParams struct {
}
// ShmInfo is equivalent to struct shm_info. Source: include/uapi/linux/shm.h
+//
+// +marshal
type ShmInfo struct {
UsedIDs int32 // Number of currently existing segments.
_ [4]byte
diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go
index 1c330e763..6ca57ffbb 100644
--- a/pkg/abi/linux/signal.go
+++ b/pkg/abi/linux/signal.go
@@ -214,6 +214,8 @@ const (
)
// Sigevent represents struct sigevent.
+//
+// +marshal
type Sigevent struct {
Value uint64 // union sigval {int, void*}
Signo int32
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index e37c8727d..d156d41e4 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -14,7 +14,10 @@
package linux
-import "gvisor.dev/gvisor/pkg/binary"
+import (
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal"
+)
// Address families, from linux/socket.h.
const (
@@ -265,6 +268,8 @@ type InetMulticastRequestWithNIC struct {
type Inet6Addr [16]byte
// SockAddrInet6 is struct sockaddr_in6, from uapi/linux/in6.h.
+//
+// +marshal
type SockAddrInet6 struct {
Family uint16
Port uint16
@@ -274,6 +279,8 @@ type SockAddrInet6 struct {
}
// SockAddrLink is a struct sockaddr_ll, from uapi/linux/if_packet.h.
+//
+// +marshal
type SockAddrLink struct {
Family uint16
Protocol uint16
@@ -290,6 +297,8 @@ type SockAddrLink struct {
const UnixPathMax = 108
// SockAddrUnix is struct sockaddr_un, from uapi/linux/un.h.
+//
+// +marshal
type SockAddrUnix struct {
Family uint16
Path [UnixPathMax]int8
@@ -299,6 +308,8 @@ type SockAddrUnix struct {
// equivalent to struct sockaddr. SockAddr ensures that a well-defined set of
// types can be used as socket addresses.
type SockAddr interface {
+ marshal.Marshallable
+
// implementsSockAddr exists purely to allow a type to indicate that they
// implement this interface. This method is a no-op and shouldn't be called.
implementsSockAddr()
diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go
index e6860ed49..206f5af7e 100644
--- a/pkg/abi/linux/time.go
+++ b/pkg/abi/linux/time.go
@@ -93,6 +93,8 @@ const (
const maxSecInDuration = math.MaxInt64 / int64(time.Second)
// TimeT represents time_t in <time.h>. It represents time in seconds.
+//
+// +marshal
type TimeT int64
// NsecToTimeT translates nanoseconds to TimeT (seconds).
@@ -102,7 +104,7 @@ func NsecToTimeT(nsec int64) TimeT {
// Timespec represents struct timespec in <time.h>.
//
-// +marshal
+// +marshal slice:TimespecSlice
type Timespec struct {
Sec int64
Nsec int64
@@ -158,7 +160,7 @@ const SizeOfTimeval = 16
// Timeval represents struct timeval in <time.h>.
//
-// +marshal
+// +marshal slice:TimevalSlice
type Timeval struct {
Sec int64
Usec int64
@@ -196,6 +198,8 @@ func DurationToTimeval(dur time.Duration) Timeval {
}
// Itimerspec represents struct itimerspec in <time.h>.
+//
+// +marshal
type Itimerspec struct {
Interval Timespec
Value Timespec
@@ -206,12 +210,16 @@ type Itimerspec struct {
// struct timeval it_interval; /* next value */
// struct timeval it_value; /* current value */
// };
+//
+// +marshal
type ItimerVal struct {
Interval Timeval
Value Timeval
}
// ClockT represents type clock_t.
+//
+// +marshal
type ClockT int64
// ClockTFromDuration converts time.Duration to clock_t.
@@ -220,6 +228,8 @@ func ClockTFromDuration(d time.Duration) ClockT {
}
// Tms represents struct tms, used by times(2).
+//
+// +marshal
type Tms struct {
UTime ClockT
STime ClockT
@@ -229,6 +239,8 @@ type Tms struct {
// TimerID represents type timer_t, which identifies a POSIX per-process
// interval timer.
+//
+// +marshal
type TimerID int32
// StatxTimestamp represents struct statx_timestamp.
diff --git a/pkg/abi/linux/tty.go b/pkg/abi/linux/tty.go
index e640969a6..47e65d9fb 100644
--- a/pkg/abi/linux/tty.go
+++ b/pkg/abi/linux/tty.go
@@ -325,9 +325,9 @@ var MasterTermios = KernelTermios{
OutputSpeed: 38400,
}
-// DefaultSlaveTermios is the default terminal configuration of the slave end
-// of a Unix98 pseudoterminal.
-var DefaultSlaveTermios = KernelTermios{
+// DefaultReplicaTermios is the default terminal configuration of the replica
+// end of a Unix98 pseudoterminal.
+var DefaultReplicaTermios = KernelTermios{
InputFlags: ICRNL | IXON,
OutputFlags: OPOST | ONLCR,
ControlFlags: B38400 | CS8 | CREAD,
@@ -341,6 +341,7 @@ var DefaultSlaveTermios = KernelTermios{
// include/uapi/asm-generic/termios.h.
//
// +stateify savable
+// +marshal
type WindowSize struct {
Rows uint16
Cols uint16
diff --git a/pkg/abi/linux/utsname.go b/pkg/abi/linux/utsname.go
index 60f220a67..cb7c95437 100644
--- a/pkg/abi/linux/utsname.go
+++ b/pkg/abi/linux/utsname.go
@@ -26,6 +26,8 @@ const (
)
// UtsName represents struct utsname, the struct returned by uname(2).
+//
+// +marshal
type UtsName struct {
Sysname [UTSLen + 1]byte
Nodename [UTSLen + 1]byte
diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD
index ffc918846..bd3a5cce9 100644
--- a/pkg/amutex/BUILD
+++ b/pkg/amutex/BUILD
@@ -6,7 +6,10 @@ go_library(
name = "amutex",
srcs = ["amutex.go"],
visibility = ["//:sandbox"],
- deps = ["//pkg/syserror"],
+ deps = [
+ "//pkg/context",
+ "//pkg/syserror",
+ ],
)
go_test(
diff --git a/pkg/amutex/amutex.go b/pkg/amutex/amutex.go
index a078a31db..d7acc1d9f 100644
--- a/pkg/amutex/amutex.go
+++ b/pkg/amutex/amutex.go
@@ -19,41 +19,17 @@ package amutex
import (
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/syserror"
)
// Sleeper must be implemented by users of the abortable mutex to allow for
// cancellation of waits.
-type Sleeper interface {
- // SleepStart is called by the AbortableMutex.Lock() function when the
- // mutex is contended and the goroutine is about to sleep.
- //
- // A channel can be returned that causes the sleep to be canceled if
- // it's readable. If no cancellation is desired, nil can be returned.
- SleepStart() <-chan struct{}
-
- // SleepFinish is called by AbortableMutex.Lock() once a contended mutex
- // is acquired or the wait is aborted.
- SleepFinish(success bool)
-
- // Interrupted returns true if the wait is aborted.
- Interrupted() bool
-}
+type Sleeper = context.ChannelSleeper
// NoopSleeper is a stateless no-op implementation of Sleeper for anonymous
// embedding in other types that do not support cancelation.
-type NoopSleeper struct{}
-
-// SleepStart implements Sleeper.SleepStart.
-func (NoopSleeper) SleepStart() <-chan struct{} {
- return nil
-}
-
-// SleepFinish implements Sleeper.SleepFinish.
-func (NoopSleeper) SleepFinish(success bool) {}
-
-// Interrupted implements Sleeper.Interrupted.
-func (NoopSleeper) Interrupted() bool { return false }
+type NoopSleeper = context.Context
// Block blocks until either receiving from ch succeeds (in which case it
// returns nil) or sleeper is interrupted (in which case it returns
diff --git a/pkg/bpf/decoder.go b/pkg/bpf/decoder.go
index c8ee0c3b1..069d0395d 100644
--- a/pkg/bpf/decoder.go
+++ b/pkg/bpf/decoder.go
@@ -21,10 +21,15 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
)
-// DecodeProgram translates an array of BPF instructions into text format.
-func DecodeProgram(program []linux.BPFInstruction) (string, error) {
+// DecodeProgram translates a compiled BPF program into text format.
+func DecodeProgram(p Program) (string, error) {
+ return DecodeInstructions(p.instructions)
+}
+
+// DecodeInstructions translates an array of BPF instructions into text format.
+func DecodeInstructions(instns []linux.BPFInstruction) (string, error) {
var ret bytes.Buffer
- for line, s := range program {
+ for line, s := range instns {
ret.WriteString(fmt.Sprintf("%v: ", line))
if err := decode(s, line, &ret); err != nil {
return "", err
@@ -34,7 +39,7 @@ func DecodeProgram(program []linux.BPFInstruction) (string, error) {
return ret.String(), nil
}
-// Decode translates BPF instruction into text format.
+// Decode translates a single BPF instruction into text format.
func Decode(inst linux.BPFInstruction) (string, error) {
var ret bytes.Buffer
err := decode(inst, -1, &ret)
diff --git a/pkg/bpf/decoder_test.go b/pkg/bpf/decoder_test.go
index 6a023f0c0..bb971ce21 100644
--- a/pkg/bpf/decoder_test.go
+++ b/pkg/bpf/decoder_test.go
@@ -93,7 +93,7 @@ func TestDecode(t *testing.T) {
}
}
-func TestDecodeProgram(t *testing.T) {
+func TestDecodeInstructions(t *testing.T) {
for _, test := range []struct {
name string
program []linux.BPFInstruction
@@ -126,7 +126,7 @@ func TestDecodeProgram(t *testing.T) {
program: []linux.BPFInstruction{Stmt(Ld+Abs+W, 10), Stmt(Ld+Len+Mem, 0)},
fail: true},
} {
- got, err := DecodeProgram(test.program)
+ got, err := DecodeInstructions(test.program)
if test.fail {
if err == nil {
t.Errorf("%s: Decode(...) failed, expected: 'error', got: %q", test.name, got)
diff --git a/pkg/bpf/program_builder.go b/pkg/bpf/program_builder.go
index 7992044d0..caaf99c83 100644
--- a/pkg/bpf/program_builder.go
+++ b/pkg/bpf/program_builder.go
@@ -32,13 +32,21 @@ type ProgramBuilder struct {
// Maps label names to label objects.
labels map[string]*label
+ // unusableLabels are labels that are added before being referenced in a
+ // jump. Any labels added this way cannot be referenced later in order to
+ // avoid backwards references.
+ unusableLabels map[string]bool
+
// Array of BPF instructions that makes up the program.
instructions []linux.BPFInstruction
}
// NewProgramBuilder creates a new ProgramBuilder instance.
func NewProgramBuilder() *ProgramBuilder {
- return &ProgramBuilder{labels: map[string]*label{}}
+ return &ProgramBuilder{
+ labels: map[string]*label{},
+ unusableLabels: map[string]bool{},
+ }
}
// label contains information to resolve a label to an offset.
@@ -108,9 +116,12 @@ func (b *ProgramBuilder) AddJumpLabels(code uint16, k uint32, jtLabel, jfLabel s
func (b *ProgramBuilder) AddLabel(name string) error {
l, ok := b.labels[name]
if !ok {
- // This is done to catch jump backwards cases, but it's not strictly wrong
- // to have unused labels.
- return fmt.Errorf("Adding a label that hasn't been used is not allowed: %v", name)
+ if _, ok = b.unusableLabels[name]; ok {
+ return fmt.Errorf("label %q already set", name)
+ }
+ // Mark the label as unusable. This is done to catch backwards jumps.
+ b.unusableLabels[name] = true
+ return nil
}
if l.target != -1 {
return fmt.Errorf("label %q target already set: %v", name, l.target)
@@ -141,6 +152,10 @@ func (b *ProgramBuilder) addLabelSource(labelName string, t jmpType) {
func (b *ProgramBuilder) resolveLabels() error {
for key, v := range b.labels {
+ if _, ok := b.unusableLabels[key]; ok {
+ return fmt.Errorf("backwards reference detected for label: %q", key)
+ }
+
if v.target == -1 {
return fmt.Errorf("label target not set: %v", key)
}
diff --git a/pkg/bpf/program_builder_test.go b/pkg/bpf/program_builder_test.go
index 92ca5f4c3..37f684f25 100644
--- a/pkg/bpf/program_builder_test.go
+++ b/pkg/bpf/program_builder_test.go
@@ -26,16 +26,16 @@ func validate(p *ProgramBuilder, expected []linux.BPFInstruction) error {
if err != nil {
return fmt.Errorf("Instructions() failed: %v", err)
}
- got, err := DecodeProgram(instructions)
+ got, err := DecodeInstructions(instructions)
if err != nil {
- return fmt.Errorf("DecodeProgram('instructions') failed: %v", err)
+ return fmt.Errorf("DecodeInstructions('instructions') failed: %v", err)
}
- expectedDecoded, err := DecodeProgram(expected)
+ expectedDecoded, err := DecodeInstructions(expected)
if err != nil {
- return fmt.Errorf("DecodeProgram('expected') failed: %v", err)
+ return fmt.Errorf("DecodeInstructions('expected') failed: %v", err)
}
if got != expectedDecoded {
- return fmt.Errorf("DecodeProgram() failed, expected: %q, got: %q", expectedDecoded, got)
+ return fmt.Errorf("DecodeInstructions() failed, expected: %q, got: %q", expectedDecoded, got)
}
return nil
}
@@ -124,10 +124,38 @@ func TestProgramBuilderLabelWithNoInstruction(t *testing.T) {
}
}
+// TestProgramBuilderUnusedLabel tests that adding an unused label doesn't
+// cause program generation to fail.
func TestProgramBuilderUnusedLabel(t *testing.T) {
p := NewProgramBuilder()
- if err := p.AddLabel("unused"); err == nil {
- t.Errorf("AddLabel(unused) should have failed")
+ p.AddStmt(Ld+Abs+W, 10)
+ p.AddJump(Jmp+Ja, 10, 0, 0)
+
+ expected := []linux.BPFInstruction{
+ Stmt(Ld+Abs+W, 10),
+ Jump(Jmp+Ja, 10, 0, 0),
+ }
+
+ if err := p.AddLabel("unused"); err != nil {
+ t.Errorf("AddLabel(unused) should have succeeded")
+ }
+
+ if err := validate(p, expected); err != nil {
+ t.Errorf("Validate() failed: %v", err)
+ }
+}
+
+// TestProgramBuilderBackwardsReference tests that including a backwards
+// reference to a label in a program causes a failure.
+func TestProgramBuilderBackwardsReference(t *testing.T) {
+ p := NewProgramBuilder()
+ if err := p.AddLabel("bw_label"); err != nil {
+ t.Errorf("failed to add label")
+ }
+ p.AddStmt(Ld+Abs+W, 10)
+ p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "bw_label", 0)
+ if _, err := p.Instructions(); err == nil {
+ t.Errorf("Instructions() should have failed")
}
}
diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD
index b03d46d18..1186f788e 100644
--- a/pkg/buffer/BUILD
+++ b/pkg/buffer/BUILD
@@ -20,6 +20,7 @@ go_library(
srcs = [
"buffer.go",
"buffer_list.go",
+ "pool.go",
"safemem.go",
"view.go",
"view_unsafe.go",
@@ -37,9 +38,13 @@ go_test(
name = "buffer_test",
size = "small",
srcs = [
+ "pool_test.go",
"safemem_test.go",
"view_test.go",
],
library = ":buffer",
- deps = ["//pkg/safemem"],
+ deps = [
+ "//pkg/safemem",
+ "//pkg/state",
+ ],
)
diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go
index c6d089fd9..311808ae9 100644
--- a/pkg/buffer/buffer.go
+++ b/pkg/buffer/buffer.go
@@ -14,36 +14,26 @@
// Package buffer provides the implementation of a buffer view.
//
-// A view is an flexible buffer, backed by a pool, supporting the safecopy
-// operations natively as well as the ability to grow via either prepend or
-// append, as well as shrink.
+// A view is an flexible buffer, supporting the safecopy operations natively as
+// well as the ability to grow via either prepend or append, as well as shrink.
package buffer
-import (
- "sync"
-)
-
-const bufferSize = 8144 // See below.
-
// buffer encapsulates a queueable byte buffer.
//
-// Note that the total size is slightly less than two pages. This is done
-// intentionally to ensure that the buffer object aligns with runtime
-// internals. We have no hard size or alignment requirements. This two page
-// size will effectively minimize internal fragmentation, but still have a
-// large enough chunk to limit excessive segmentation.
-//
// +stateify savable
type buffer struct {
- data [bufferSize]byte
+ data []byte
read int
write int
bufferEntry
}
-// reset resets internal data.
-//
-// This must be called before returning the buffer to the pool.
+// init performs in-place initialization for zero value.
+func (b *buffer) init(size int) {
+ b.data = make([]byte, size)
+}
+
+// Reset resets read and write locations, effectively emptying the buffer.
func (b *buffer) Reset() {
b.read = 0
b.write = 0
@@ -85,10 +75,3 @@ func (b *buffer) WriteMove(n int) {
func (b *buffer) WriteSlice() []byte {
return b.data[b.write:]
}
-
-// bufferPool is a pool for buffers.
-var bufferPool = sync.Pool{
- New: func() interface{} {
- return new(buffer)
- },
-}
diff --git a/pkg/buffer/pool.go b/pkg/buffer/pool.go
new file mode 100644
index 000000000..7ad6132ab
--- /dev/null
+++ b/pkg/buffer/pool.go
@@ -0,0 +1,83 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+const (
+ // embeddedCount is the number of buffer structures embedded in the pool. It
+ // is also the number for overflow allocations.
+ embeddedCount = 8
+
+ // defaultBufferSize is the default size for each underlying storage buffer.
+ //
+ // It is slightly less than two pages. This is done intentionally to ensure
+ // that the buffer object aligns with runtime internals. This two page size
+ // will effectively minimize internal fragmentation, but still have a large
+ // enough chunk to limit excessive segmentation.
+ defaultBufferSize = 8144
+)
+
+// pool allocates buffer.
+//
+// It contains an embedded buffer storage for fast path when the number of
+// buffers needed is small.
+//
+// +stateify savable
+type pool struct {
+ bufferSize int
+ avail []buffer `state:"nosave"`
+ embeddedStorage [embeddedCount]buffer `state:"wait"`
+}
+
+// get gets a new buffer from p.
+func (p *pool) get() *buffer {
+ if p.avail == nil {
+ p.avail = p.embeddedStorage[:]
+ }
+ if len(p.avail) == 0 {
+ p.avail = make([]buffer, embeddedCount)
+ }
+ if p.bufferSize <= 0 {
+ p.bufferSize = defaultBufferSize
+ }
+ buf := &p.avail[0]
+ buf.init(p.bufferSize)
+ p.avail = p.avail[1:]
+ return buf
+}
+
+// put releases buf.
+func (p *pool) put(buf *buffer) {
+ // Remove reference to the underlying storage, allowing it to be garbage
+ // collected.
+ buf.data = nil
+}
+
+// setBufferSize sets the size of underlying storage buffer for future
+// allocations. It can be called at any time.
+func (p *pool) setBufferSize(size int) {
+ p.bufferSize = size
+}
+
+// afterLoad is invoked by stateify.
+func (p *pool) afterLoad() {
+ // S/R does not save subslice into embeddedStorage correctly. Restore
+ // available portion of embeddedStorage manually. Restore as nil if none used.
+ for i := len(p.embeddedStorage); i > 0; i-- {
+ if p.embeddedStorage[i-1].data != nil {
+ p.avail = p.embeddedStorage[i:]
+ break
+ }
+ }
+}
diff --git a/pkg/buffer/pool_test.go b/pkg/buffer/pool_test.go
new file mode 100644
index 000000000..8584bac89
--- /dev/null
+++ b/pkg/buffer/pool_test.go
@@ -0,0 +1,51 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "testing"
+)
+
+func TestGetDefaultBufferSize(t *testing.T) {
+ var p pool
+ for i := 0; i < embeddedCount*2; i++ {
+ buf := p.get()
+ if got, want := len(buf.data), defaultBufferSize; got != want {
+ t.Errorf("#%d len(buf.data) = %d, want %d", i, got, want)
+ }
+ }
+}
+
+func TestGetCustomBufferSize(t *testing.T) {
+ const size = 100
+
+ var p pool
+ p.setBufferSize(size)
+ for i := 0; i < embeddedCount*2; i++ {
+ buf := p.get()
+ if got, want := len(buf.data), size; got != want {
+ t.Errorf("#%d len(buf.data) = %d, want %d", i, got, want)
+ }
+ }
+}
+
+func TestPut(t *testing.T) {
+ var p pool
+ buf := p.get()
+ p.put(buf)
+ if buf.data != nil {
+ t.Errorf("buf.data = %x, want nil", buf.data)
+ }
+}
diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go
index b789e56e9..8b42575b4 100644
--- a/pkg/buffer/safemem.go
+++ b/pkg/buffer/safemem.go
@@ -44,7 +44,7 @@ func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, e
// Need at least one buffer.
firstBuf := v.data.Back()
if firstBuf == nil {
- firstBuf = bufferPool.Get().(*buffer)
+ firstBuf = v.pool.get()
v.data.PushBack(firstBuf)
}
@@ -56,7 +56,7 @@ func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, e
count -= l
blocks = append(blocks, firstBuf.WriteBlock())
for count > 0 {
- emptyBuf := bufferPool.Get().(*buffer)
+ emptyBuf := v.pool.get()
v.data.PushBack(emptyBuf)
block := emptyBuf.WriteBlock().TakeFirst64(count)
count -= uint64(block.Len())
diff --git a/pkg/buffer/safemem_test.go b/pkg/buffer/safemem_test.go
index 47f357e0c..721cc5934 100644
--- a/pkg/buffer/safemem_test.go
+++ b/pkg/buffer/safemem_test.go
@@ -23,6 +23,8 @@ import (
)
func TestSafemem(t *testing.T) {
+ const bufferSize = defaultBufferSize
+
testCases := []struct {
name string
input string
diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go
index e6901eadb..00652d675 100644
--- a/pkg/buffer/view.go
+++ b/pkg/buffer/view.go
@@ -27,6 +27,7 @@ import (
type View struct {
data bufferList
size int64
+ pool pool
}
// TrimFront removes the first count bytes from the buffer.
@@ -81,7 +82,7 @@ func (v *View) advanceRead(count int64) {
buf = buf.Next() // Iterate.
v.data.Remove(oldBuf)
oldBuf.Reset()
- bufferPool.Put(oldBuf)
+ v.pool.put(oldBuf)
// Update counts.
count -= sz
@@ -118,7 +119,7 @@ func (v *View) Truncate(length int64) {
// Drop the buffer completely; see above.
v.data.Remove(buf)
buf.Reset()
- bufferPool.Put(buf)
+ v.pool.put(buf)
v.size -= sz
}
}
@@ -137,7 +138,7 @@ func (v *View) Grow(length int64, zero bool) {
// Is there some space in the last buffer?
if buf == nil || buf.Full() {
- buf = bufferPool.Get().(*buffer)
+ buf = v.pool.get()
v.data.PushBack(buf)
}
@@ -181,7 +182,7 @@ func (v *View) Prepend(data []byte) {
for len(data) > 0 {
// Do we need an empty buffer?
- buf := bufferPool.Get().(*buffer)
+ buf := v.pool.get()
v.data.PushFront(buf)
// The buffer is empty; copy last chunk.
@@ -211,7 +212,7 @@ func (v *View) Append(data []byte) {
// Ensure there's a buffer with space.
if buf == nil || buf.Full() {
- buf = bufferPool.Get().(*buffer)
+ buf = v.pool.get()
v.data.PushBack(buf)
}
@@ -297,7 +298,7 @@ func (v *View) WriteFromReader(r io.Reader, count int64) (int64, error) {
// Ensure we have an empty buffer.
if buf == nil || buf.Full() {
- buf = bufferPool.Get().(*buffer)
+ buf = v.pool.get()
v.data.PushBack(buf)
}
diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go
index 3db1bc6ee..839af0223 100644
--- a/pkg/buffer/view_test.go
+++ b/pkg/buffer/view_test.go
@@ -16,11 +16,16 @@ package buffer
import (
"bytes"
+ "context"
"io"
"strings"
"testing"
+
+ "gvisor.dev/gvisor/pkg/state"
)
+const bufferSize = defaultBufferSize
+
func fillAppend(v *View, data []byte) {
v.Append(data)
}
@@ -50,6 +55,30 @@ var fillFuncs = map[string]func(*View, []byte){
"writeFromReaderEnd": fillWriteFromReaderEnd,
}
+func BenchmarkReadAt(b *testing.B) {
+ b.ReportAllocs()
+ var v View
+ v.Append(make([]byte, 100))
+
+ buf := make([]byte, 10)
+ for i := 0; i < b.N; i++ {
+ v.ReadAt(buf, 0)
+ }
+}
+
+func BenchmarkWriteRead(b *testing.B) {
+ b.ReportAllocs()
+ var v View
+ sz := 1000
+ wbuf := make([]byte, sz)
+ rbuf := bytes.NewBuffer(make([]byte, sz))
+ for i := 0; i < b.N; i++ {
+ v.Append(wbuf)
+ rbuf.Reset()
+ v.ReadToWriter(rbuf, int64(sz))
+ }
+}
+
func testReadAt(t *testing.T, v *View, offset int64, n int, wantStr string, wantErr error) {
t.Helper()
d := make([]byte, n)
@@ -465,3 +494,51 @@ func TestView(t *testing.T) {
}
}
}
+
+func doSaveAndLoad(t *testing.T, toSave, toLoad *View) {
+ t.Helper()
+ var buf bytes.Buffer
+ ctx := context.Background()
+ if _, err := state.Save(ctx, &buf, toSave); err != nil {
+ t.Fatal("state.Save:", err)
+ }
+ if _, err := state.Load(ctx, bytes.NewReader(buf.Bytes()), toLoad); err != nil {
+ t.Fatal("state.Load:", err)
+ }
+}
+
+func TestSaveRestoreViewEmpty(t *testing.T) {
+ var toSave View
+ var v View
+ doSaveAndLoad(t, &toSave, &v)
+
+ if got := v.pool.avail; got != nil {
+ t.Errorf("pool is not in zero state: v.pool.avail = %v, want nil", got)
+ }
+ if got := v.Flatten(); len(got) != 0 {
+ t.Errorf("v.Flatten() = %x, want []", got)
+ }
+}
+
+func TestSaveRestoreView(t *testing.T) {
+ // Create data that fits 2.5 slots.
+ data := bytes.Join([][]byte{
+ bytes.Repeat([]byte{1, 2}, defaultBufferSize),
+ bytes.Repeat([]byte{3}, defaultBufferSize/2),
+ }, nil)
+
+ var toSave View
+ toSave.Append(data)
+
+ var v View
+ doSaveAndLoad(t, &toSave, &v)
+
+ // Next available slot at index 3; 0-2 slot are used.
+ i := 3
+ if got, want := &v.pool.avail[0], &v.pool.embeddedStorage[i]; got != want {
+ t.Errorf("next available buffer points to %p, want %p (&v.pool.embeddedStorage[%d])", got, want, i)
+ }
+ if got := v.Flatten(); !bytes.Equal(got, data) {
+ t.Errorf("v.Flatten() = %x, want %x", got, data)
+ }
+}
diff --git a/pkg/context/BUILD b/pkg/context/BUILD
index 239f31149..f33e23bf7 100644
--- a/pkg/context/BUILD
+++ b/pkg/context/BUILD
@@ -7,7 +7,6 @@ go_library(
srcs = ["context.go"],
visibility = ["//:sandbox"],
deps = [
- "//pkg/amutex",
"//pkg/log",
],
)
diff --git a/pkg/context/context.go b/pkg/context/context.go
index 5319b6d8d..2613bc752 100644
--- a/pkg/context/context.go
+++ b/pkg/context/context.go
@@ -26,7 +26,6 @@ import (
"context"
"time"
- "gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/log"
)
@@ -68,9 +67,10 @@ func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) {
// In both cases, values extracted from the Context should be used instead.
type Context interface {
log.Logger
- amutex.Sleeper
context.Context
+ ChannelSleeper
+
// UninterruptibleSleepStart indicates the beginning of an uninterruptible
// sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate
// is true and the Context represents a Task, the Task's AddressSpace is
@@ -85,29 +85,60 @@ type Context interface {
UninterruptibleSleepFinish(activate bool)
}
-// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep
-// methods for anonymous embedding in other types that do not implement sleeps.
-type NoopSleeper struct {
- amutex.NoopSleeper
+// A ChannelSleeper represents a goroutine that may sleep interruptibly, where
+// interruption is indicated by a channel becoming readable.
+type ChannelSleeper interface {
+ // SleepStart is called before going to sleep interruptibly. If SleepStart
+ // returns a non-nil channel and that channel becomes ready for receiving
+ // while the goroutine is sleeping, the goroutine should be woken, and
+ // SleepFinish(false) should be called. Otherwise, SleepFinish(true) should
+ // be called after the goroutine stops sleeping.
+ SleepStart() <-chan struct{}
+
+ // SleepFinish is called after an interruptibly-sleeping goroutine stops
+ // sleeping, as documented by SleepStart.
+ SleepFinish(success bool)
+
+ // Interrupted returns true if the channel returned by SleepStart is
+ // ready for receiving.
+ Interrupted() bool
+}
+
+// NoopSleeper is a noop implementation of ChannelSleeper and
+// Context.UninterruptibleSleep* methods for anonymous embedding in other types
+// that do not implement special behavior around sleeps.
+type NoopSleeper struct{}
+
+// SleepStart implements ChannelSleeper.SleepStart.
+func (NoopSleeper) SleepStart() <-chan struct{} {
+ return nil
+}
+
+// SleepFinish implements ChannelSleeper.SleepFinish.
+func (NoopSleeper) SleepFinish(success bool) {}
+
+// Interrupted implements ChannelSleeper.Interrupted.
+func (NoopSleeper) Interrupted() bool {
+ return false
}
-// UninterruptibleSleepStart does nothing.
-func (NoopSleeper) UninterruptibleSleepStart(bool) {}
+// UninterruptibleSleepStart implements Context.UninterruptibleSleepStart.
+func (NoopSleeper) UninterruptibleSleepStart(deactivate bool) {}
-// UninterruptibleSleepFinish does nothing.
-func (NoopSleeper) UninterruptibleSleepFinish(bool) {}
+// UninterruptibleSleepFinish implements Context.UninterruptibleSleepFinish.
+func (NoopSleeper) UninterruptibleSleepFinish(activate bool) {}
-// Deadline returns zero values, meaning no deadline.
+// Deadline implements context.Context.Deadline.
func (NoopSleeper) Deadline() (time.Time, bool) {
return time.Time{}, false
}
-// Done returns nil.
+// Done implements context.Context.Done.
func (NoopSleeper) Done() <-chan struct{} {
return nil
}
-// Err returns nil.
+// Err returns context.Context.Err.
func (NoopSleeper) Err() error {
return nil
}
diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go
index 6831adcce..a4f4b2c5e 100644
--- a/pkg/coverage/coverage.go
+++ b/pkg/coverage/coverage.go
@@ -100,12 +100,9 @@ var coveragePool = sync.Pool{
// instrumentation_filter.
//
// Note that we "consume", i.e. clear, coverdata when this function is run, to
-// ensure that each event is only reported once.
-//
-// TODO(b/160639712): evaluate whether it is ok to reset the global coverage
-// data every time this function is run. We could technically have each thread
-// store a local snapshot against which we compare the most recent coverdata so
-// that separate threads do not affect each other's view of the data.
+// ensure that each event is only reported once. Due to the limitations of Go
+// coverage tools, we reset the global coverage data every time this function is
+// run.
func ConsumeCoverageData(w io.Writer) int {
once.Do(initCoverageData)
@@ -117,23 +114,23 @@ func ConsumeCoverageData(w io.Writer) int {
for fileIndex, file := range globalData.files {
counters := coverdata.Cover.Counters[file]
for index := 0; index < len(counters); index++ {
- val := atomic.SwapUint32(&counters[index], 0)
- if val != 0 {
- // Calculate the synthetic PC.
- pc := globalData.syntheticPCs[fileIndex][index]
-
- usermem.ByteOrder.PutUint64(pcBuffer[:], pc)
- n, err := w.Write(pcBuffer[:])
- if err != nil {
- if err == io.EOF {
- // Simply stop writing if we encounter EOF; it's ok if we attempted to
- // write more than we can hold.
- return total + n
- }
- panic(fmt.Sprintf("Internal error writing PCs to kcov area: %v", err))
+ if atomic.LoadUint32(&counters[index]) == 0 {
+ continue
+ }
+ // Non-zero coverage data found; consume it and report as a PC.
+ atomic.StoreUint32(&counters[index], 0)
+ pc := globalData.syntheticPCs[fileIndex][index]
+ usermem.ByteOrder.PutUint64(pcBuffer[:], pc)
+ n, err := w.Write(pcBuffer[:])
+ if err != nil {
+ if err == io.EOF {
+ // Simply stop writing if we encounter EOF; it's ok if we attempted to
+ // write more than we can hold.
+ return total + n
}
- total += n
+ panic(fmt.Sprintf("Internal error writing PCs to kcov area: %v", err))
}
+ total += n
}
}
diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go
index 83bcfe220..cc6b0cdf1 100644
--- a/pkg/fd/fd.go
+++ b/pkg/fd/fd.go
@@ -49,7 +49,7 @@ func fixCount(n int, err error) (int, error) {
// Read implements io.Reader.
func (r *ReadWriter) Read(b []byte) (int, error) {
- c, err := fixCount(syscall.Read(int(atomic.LoadInt64(&r.fd)), b))
+ c, err := fixCount(syscall.Read(r.FD(), b))
if c == 0 && len(b) > 0 && err == nil {
return 0, io.EOF
}
@@ -62,7 +62,7 @@ func (r *ReadWriter) Read(b []byte) (int, error) {
func (r *ReadWriter) ReadAt(b []byte, off int64) (c int, err error) {
for len(b) > 0 {
var m int
- m, err = fixCount(syscall.Pread(int(atomic.LoadInt64(&r.fd)), b, off))
+ m, err = fixCount(syscall.Pread(r.FD(), b, off))
if m == 0 && err == nil {
return c, io.EOF
}
@@ -82,7 +82,7 @@ func (r *ReadWriter) Write(b []byte) (int, error) {
var n, remaining int
for remaining = len(b); remaining > 0; {
woff := len(b) - remaining
- n, err = syscall.Write(int(atomic.LoadInt64(&r.fd)), b[woff:])
+ n, err = syscall.Write(r.FD(), b[woff:])
if n > 0 {
// syscall.Write wrote some bytes. This is the common case.
@@ -110,7 +110,7 @@ func (r *ReadWriter) Write(b []byte) (int, error) {
func (r *ReadWriter) WriteAt(b []byte, off int64) (c int, err error) {
for len(b) > 0 {
var m int
- m, err = fixCount(syscall.Pwrite(int(atomic.LoadInt64(&r.fd)), b, off))
+ m, err = fixCount(syscall.Pwrite(r.FD(), b, off))
if err != nil {
break
}
@@ -121,6 +121,16 @@ func (r *ReadWriter) WriteAt(b []byte, off int64) (c int, err error) {
return
}
+// FD returns the owned file descriptor. Ownership remains unchanged.
+func (r *ReadWriter) FD() int {
+ return int(atomic.LoadInt64(&r.fd))
+}
+
+// String implements Stringer.String().
+func (r *ReadWriter) String() string {
+ return fmt.Sprintf("FD: %d", r.FD())
+}
+
// FD owns a host file descriptor.
//
// It is similar to os.File, with a few important distinctions:
@@ -167,6 +177,23 @@ func NewFromFile(file *os.File) (*FD, error) {
return New(fd), nil
}
+// NewFromFiles creates new FDs for each file in the slice.
+func NewFromFiles(files []*os.File) ([]*FD, error) {
+ rv := make([]*FD, 0, len(files))
+ for _, f := range files {
+ new, err := NewFromFile(f)
+ if err != nil {
+ // Cleanup on error.
+ for _, fd := range rv {
+ fd.Close()
+ }
+ return nil, err
+ }
+ rv = append(rv, new)
+ }
+ return rv, nil
+}
+
// Open is equivalent to open(2).
func Open(path string, openmode int, perm uint32) (*FD, error) {
f, err := syscall.Open(path, openmode|syscall.O_LARGEFILE, perm)
@@ -204,11 +231,6 @@ func (f *FD) Release() int {
return int(atomic.SwapInt64(&f.fd, -1))
}
-// FD returns the file descriptor owned by FD. FD retains ownership.
-func (f *FD) FD() int {
- return int(atomic.LoadInt64(&f.fd))
-}
-
// File converts the FD to an os.File.
//
// FD does not transfer ownership of the file descriptor (it will be
@@ -219,7 +241,7 @@ func (f *FD) FD() int {
// This operation is somewhat expensive, so care should be taken to minimize
// its use.
func (f *FD) File() (*os.File, error) {
- fd, err := syscall.Dup(int(atomic.LoadInt64(&f.fd)))
+ fd, err := syscall.Dup(f.FD())
if err != nil {
return nil, err
}
diff --git a/pkg/lisafs/README.md b/pkg/lisafs/README.md
new file mode 100644
index 000000000..51d0d40e5
--- /dev/null
+++ b/pkg/lisafs/README.md
@@ -0,0 +1,363 @@
+# Replacing 9P
+
+## Background
+
+The Linux filesystem model consists of the following key aspects (modulo mounts,
+which are outside the scope of this discussion):
+
+- A `struct inode` represents a "filesystem object", such as a directory or a
+ regular file. "Filesystem object" is most precisely defined by the practical
+ properties of an inode, such as an immutable type (regular file, directory,
+ symbolic link, etc.) and its independence from the path originally used to
+ obtain it.
+
+- A `struct dentry` represents a node in a filesystem tree. Semantically, each
+ dentry is immutably associated with an inode representing the filesystem
+ object at that position. (Linux implements optimizations involving reuse of
+ unreferenced dentries, which allows their associated inodes to change, but
+ this is outside the scope of this discussion.)
+
+- A `struct file` represents an open file description (hereafter FD) and is
+ needed to perform I/O. Each FD is immutably associated with the dentry
+ through which it was opened.
+
+The current gVisor virtual filesystem implementation (hereafter VFS1) closely
+imitates the Linux design:
+
+- `struct inode` => `fs.Inode`
+
+- `struct dentry` => `fs.Dirent`
+
+- `struct file` => `fs.File`
+
+gVisor accesses most external filesystems through a variant of the 9P2000.L
+protocol, including extensions for performance (`walkgetattr`) and for features
+not supported by vanilla 9P2000.L (`flushf`, `lconnect`). The 9P protocol family
+is inode-based; 9P fids represent a file (equivalently "file system object"),
+and the protocol is structured around alternatively obtaining fids to represent
+files (with `walk` and, in gVisor, `walkgetattr`) and performing operations on
+those fids.
+
+In the sections below, a **shared** filesystem is a filesystem that is *mutably*
+accessible by multiple concurrent clients, such that a **non-shared** filesystem
+is a filesystem that is either read-only or accessible by only a single client.
+
+## Problems
+
+### Serialization of Path Component RPCs
+
+Broadly speaking, VFS1 traverses each path component in a pathname, alternating
+between verifying that each traversed dentry represents an inode that represents
+a searchable directory and moving to the next dentry in the path.
+
+In the context of a remote filesystem, the structure of this traversal means
+that - modulo caching - a path involving N components requires at least N-1
+*sequential* RPCs to obtain metadata for intermediate directories, incurring
+significant latency. (In vanilla 9P2000.L, 2(N-1) RPCs are required: N-1 `walk`
+and N-1 `getattr`. We added the `walkgetattr` RPC to reduce this overhead.) On
+non-shared filesystems, this overhead is primarily significant during
+application startup; caching mitigates much of this overhead at steady state. On
+shared filesystems, where correct caching requires revalidation (requiring RPCs
+for each revalidated directory anyway), this overhead is consistently ruinous.
+
+### Inefficient RPCs
+
+9P is not exceptionally economical with RPCs in general. In addition to the
+issue described above:
+
+- Opening an existing file in 9P involves at least 2 RPCs: `walk` to produce
+ an unopened fid representing the file, and `lopen` to open the fid.
+
+- Creating a file also involves at least 2 RPCs: `walk` to produce an unopened
+ fid representing the parent directory, and `lcreate` to create the file and
+ convert the fid to an open fid representing the created file. In practice,
+ both the Linux and gVisor 9P clients expect to have an unopened fid for the
+ created file (necessitating an additional `walk`), as well as attributes for
+ the created file (necessitating an additional `getattr`), for a total of 4
+ RPCs. (In a shared filesystem, where whether a file already exists can
+ change between RPCs, a correct implementation of `open(O_CREAT)` would have
+ to alternate between these two paths (plus `clunk`ing the temporary fid
+ between alternations, since the nature of the `fid` differs between the two
+ paths). Neither Linux nor gVisor implement the required alternation, so
+ `open(O_CREAT)` without `O_EXCL` can spuriously fail with `EEXIST` on both.)
+
+- Closing (`clunk`ing) a fid requires an RPC. VFS1 issues this RPC
+ asynchronously in an attempt to reduce critical path latency, but scheduling
+ overhead makes this not clearly advantageous in practice.
+
+- `read` and `readdir` can return partial reads without a way to indicate EOF,
+ necessitating an additional final read to detect EOF.
+
+- Operations that affect filesystem state do not consistently return updated
+ filesystem state. In gVisor, the client implementation attempts to handle
+ this by tracking what it thinks updated state "should" be; this is complex,
+ and especially brittle for timestamps (which are often not arbitrarily
+ settable). In Linux, the client implemtation invalidates cached metadata
+ whenever it performs such an operation, and reloads it when a dentry
+ corresponding to an inode with no valid cached metadata is revalidated; this
+ is simple, but necessitates an additional `getattr`.
+
+### Dentry/Inode Ambiguity
+
+As noted above, 9P's documentation tends to imply that unopened fids represent
+an inode. In practice, most filesystem APIs present very limited interfaces for
+working with inodes at best, such that the interpretation of unopened fids
+varies:
+
+- Linux's 9P client associates unopened fids with (dentry, uid) pairs. When
+ caching is enabled, it also associates each inode with the first fid opened
+ writably that references that inode, in order to support page cache
+ writeback.
+
+- gVisor's 9P client associates unopened fids with inodes, and also caches
+ opened fids in inodes in a manner similar to Linux.
+
+- The runsc fsgofer associates unopened fids with both "dentries" (host
+ filesystem paths) and "inodes" (host file descriptors); which is used
+ depends on the operation invoked on the fid.
+
+For non-shared filesystems, this confusion has resulted in correctness issues
+that are (in gVisor) currently handled by a number of coarse-grained locks that
+serialize renames with all other filesystem operations. For shared filesystems,
+this means inconsistent behavior in the presence of concurrent mutation.
+
+## Design
+
+Almost all Linux filesystem syscalls describe filesystem resources in one of two
+ways:
+
+- Path-based: A filesystem position is described by a combination of a
+ starting position and a sequence of path components relative to that
+ position, where the starting position is one of:
+
+ - The VFS root (defined by mount namespace and chroot), for absolute paths
+
+ - The VFS position of an existing FD, for relative paths passed to `*at`
+ syscalls (e.g. `statat`)
+
+ - The current working directory, for relative paths passed to non-`*at`
+ syscalls and `*at` syscalls with `AT_FDCWD`
+
+- File-description-based: A filesystem object is described by an existing FD,
+ passed to a `f*` syscall (e.g. `fstat`).
+
+Many of our issues with 9P arise from its (and VFS') interposition of a model
+based on inodes between the filesystem syscall API and filesystem
+implementations. We propose to replace 9P with a protocol that does not feature
+inodes at all, and instead closely follows the filesystem syscall API by
+featuring only path-based and FD-based operations, with minimal deviations as
+necessary to ameliorate deficiencies in the syscall interface (see below). This
+approach addresses the issues described above:
+
+- Even on shared filesystems, most application filesystem syscalls are
+ translated to a single RPC (possibly excepting special cases described
+ below), which is a logical lower bound.
+
+- The behavior of application syscalls on shared filesystems is
+ straightforwardly predictable: path-based syscalls are translated to
+ path-based RPCs, which will re-lookup the file at that path, and FD-based
+ syscalls are translated to FD-based RPCs, which use an existing open file
+ without performing another lookup. (This is at least true on gofers that
+ proxy the host local filesystem; other filesystems that lack support for
+ e.g. certain operations on FDs may have different behavior, but this
+ divergence is at least still predictable and inherent to the underlying
+ filesystem implementation.)
+
+Note that this approach is only feasible in gVisor's next-generation virtual
+filesystem (VFS2), which does not assume the existence of inodes and allows the
+remote filesystem client to translate whole path-based syscalls into RPCs. Thus
+one of the unavoidable tradeoffs associated with such a protocol vs. 9P is the
+inability to construct a Linux client that is performance-competitive with
+gVisor.
+
+### File Permissions
+
+Many filesystem operations are side-effectual, such that file permissions must
+be checked before such operations take effect. The simplest approach to file
+permission checking is for the sentry to obtain permissions from the remote
+filesystem, then apply permission checks in the sentry before performing the
+application-requested operation. However, this requires an additional RPC per
+application syscall (which can't be mitigated by caching on shared filesystems).
+Alternatively, we may delegate file permission checking to gofers. In general,
+file permission checks depend on the following properties of the accessor:
+
+- Filesystem UID/GID
+
+- Supplementary GIDs
+
+- Effective capabilities in the accessor's user namespace (i.e. the accessor's
+ effective capability set)
+
+- All UIDs and GIDs mapped in the accessor's user namespace (which determine
+ if the accessor's capabilities apply to accessed files)
+
+We may choose to delay implementation of file permission checking delegation,
+although this is potentially costly since it doubles the number of required RPCs
+for most operations on shared filesystems. We may also consider compromise
+options, such as only delegating file permission checks for accessors in the
+root user namespace.
+
+### Symbolic Links
+
+gVisor usually interprets symbolic link targets in its VFS rather than on the
+filesystem containing the symbolic link; thus e.g. a symlink to
+"/proc/self/maps" on a remote filesystem resolves to said file in the sentry's
+procfs rather than the host's. This implies that:
+
+- Remote filesystem servers that proxy filesystems supporting symlinks must
+ check if each path component is a symlink during path traversal.
+
+- Absolute symlinks require that the sentry restart the operation at its
+ contextual VFS root (which is task-specific and may not be on a remote
+ filesystem at all), so if a remote filesystem server encounters an absolute
+ symlink during path traversal on behalf of a path-based operation, it must
+ terminate path traversal and return the symlink target.
+
+- Relative symlinks begin target resolution in the parent directory of the
+ symlink, so in theory most relative symlinks can be handled automatically
+ during the path traversal that encounters the symlink, provided that said
+ traversal is supplied with the number of remaining symlinks before `ELOOP`.
+ However, the new path traversed by the symlink target may cross VFS mount
+ boundaries, such that it's only safe for remote filesystem servers to
+ speculatively follow relative symlinks for side-effect-free operations such
+ as `stat` (where the sentry can simply ignore results that are inapplicable
+ due to crossing mount boundaries). We may choose to delay implementation of
+ this feature, at the cost of an additional RPC per relative symlink (note
+ that even if the symlink target crosses a mount boundary, the sentry will
+ need to `stat` the path to the mount boundary to confirm that each traversed
+ component is an accessible directory); until it is implemented, relative
+ symlinks may be handled like absolute symlinks, by terminating path
+ traversal and returning the symlink target.
+
+The possibility of symlinks (and the possibility of a compromised sentry) means
+that the sentry may issue RPCs with paths that, in the absence of symlinks,
+would traverse beyond the root of the remote filesystem. For example, the sentry
+may issue an RPC with a path like "/foo/../..", on the premise that if "/foo" is
+a symlink then the resulting path may be elsewhere on the remote filesystem. To
+handle this, path traversal must also track its current depth below the remote
+filesystem root, and terminate path traversal if it would ascend beyond this
+point.
+
+### Path Traversal
+
+Since path-based VFS operations will translate to path-based RPCs, filesystem
+servers will need to handle path traversal. From the perspective of a given
+filesystem implementation in the server, there are two basic approaches to path
+traversal:
+
+- Inode-walk: For each path component, obtain a handle to the underlying
+ filesystem object (e.g. with `open(O_PATH)`), check if that object is a
+ symlink (as described above) and that that object is accessible by the
+ caller (e.g. with `fstat()`), then continue to the next path component (e.g.
+ with `openat()`). This ensures that the checked filesystem object is the one
+ used to obtain the next object in the traversal, which is intuitively
+ appealing. However, while this approach works for host local filesystems, it
+ requires features that are not widely supported by other filesystems.
+
+- Path-walk: For each path component, use a path-based operation to determine
+ if the filesystem object currently referred to by that path component is a
+ symlink / is accessible. This is highly portable, but suffers from quadratic
+ behavior (at the level of the underlying filesystem implementation, the
+ first path component will be traversed a number of times equal to the number
+ of path components in the path).
+
+The implementation should support either option by delegating path traversal to
+filesystem implementations within the server (like VFS and the remote filesystem
+protocol itself), as inode-walking is still safe, efficient, amenable to FD
+caching, and implementable on non-shared host local filesystems (a sufficiently
+common case as to be worth considering in the design).
+
+Both approaches are susceptible to race conditions that may permit sandboxed
+filesystem escapes:
+
+- Under inode-walk, a malicious application may cause a directory to be moved
+ (with `rename`) during path traversal, such that the filesystem
+ implementation incorrectly determines whether subsequent inodes are located
+ in paths that should be visible to sandboxed applications.
+
+- Under path-walk, a malicious application may cause a non-symlink file to be
+ replaced with a symlink during path traversal, such that following path
+ operations will incorrectly follow the symlink.
+
+Both race conditions can, to some extent, be mitigated in filesystem server
+implementations by synchronizing path traversal with the hazardous operations in
+question. However, shared filesystems are frequently used to share data between
+sandboxed and unsandboxed applications in a controlled way, and in some cases a
+malicious sandboxed application may be able to take advantage of a hazardous
+filesystem operation performed by an unsandboxed application. In some cases,
+filesystem features may be available to ensure safety even in such cases (e.g.
+[the new openat2() syscall](https://man7.org/linux/man-pages/man2/openat2.2.html)),
+but it is not clear how to solve this problem in general. (Note that this issue
+is not specific to our design; rather, it is a fundamental limitation of
+filesystem sandboxing.)
+
+### Filesystem Multiplexing
+
+A given sentry may need to access multiple distinct remote filesystems (e.g.
+different volumes for a given container). In many cases, there is no advantage
+to serving these filesystems from distinct filesystem servers, or accessing them
+through distinct connections (factors such as maximum RPC concurrency should be
+based on available host resources). Therefore, the protocol should support
+multiplexing of distinct filesystem trees within a single session. 9P supports
+this by allowing multiple calls to the `attach` RPC to produce fids representing
+distinct filesystem trees, but this is somewhat clunky; we propose a much
+simpler mechanism wherein each message that conveys a path also conveys a
+numeric filesystem ID that identifies a filesystem tree.
+
+## Alternatives Considered
+
+### Additional Extensions to 9P
+
+There are at least three conceptual aspects to 9P:
+
+- Wire format: messages with a 4-byte little-endian size prefix, strings with
+ a 2-byte little-endian size prefix, etc. Whether the wire format is worth
+ retaining is unclear; in particular, it's unclear that the 9P wire format
+ has a significant advantage over protobufs, which are substantially easier
+ to extend. Note that the official Go protobuf implementation is widely known
+ to suffer from a significant number of performance deficiencies, so if we
+ choose to switch to protobuf, we may need to use an alternative toolchain
+ such as `gogo/protobuf` (which is also widely used in the Go ecosystem, e.g.
+ by Kubernetes).
+
+- Filesystem model: fids, qids, etc. Discarding this is one of the motivations
+ for this proposal.
+
+- RPCs: Twalk, Tlopen, etc. In addition to previously-described
+ inefficiencies, most of these are dependent on the filesystem model and
+ therefore must be discarded.
+
+### FUSE
+
+The FUSE (Filesystem in Userspace) protocol is frequently used to provide
+arbitrary userspace filesystem implementations to a host Linux kernel.
+Unfortunately, FUSE is also inode-based, and therefore doesn't address any of
+the problems we have with 9P.
+
+### virtio-fs
+
+virtio-fs is an ongoing project aimed at improving Linux VM filesystem
+performance when accessing Linux host filesystems (vs. virtio-9p). In brief, it
+is based on:
+
+- Using a FUSE client in the guest that communicates over virtio with a FUSE
+ server in the host.
+
+- Using DAX to map the host page cache into the guest.
+
+- Using a file metadata table in shared memory to avoid VM exits for metadata
+ updates.
+
+None of these improvements seem applicable to gVisor:
+
+- As explained above, FUSE is still inode-based, so it is still susceptible to
+ most of the problems we have with 9P.
+
+- Our use of host file descriptors already allows us to leverage the host page
+ cache for file contents.
+
+- Our need for shared filesystem coherence is usually based on a user
+ requirement that an out-of-sandbox filesystem mutation is guaranteed to be
+ visible by all subsequent observations from within the sandbox, or vice
+ versa; it's not clear that this can be guaranteed without a synchronous
+ signaling mechanism like an RPC.
diff --git a/tools/go_marshal/marshal/BUILD b/pkg/marshal/BUILD
index 4aec98218..aac0161fa 100644
--- a/tools/go_marshal/marshal/BUILD
+++ b/pkg/marshal/BUILD
@@ -11,7 +11,5 @@ go_library(
visibility = [
"//:sandbox",
],
- deps = [
- "//pkg/usermem",
- ],
+ deps = ["//pkg/usermem"],
)
diff --git a/tools/go_marshal/marshal/marshal.go b/pkg/marshal/marshal.go
index 85b196f08..d8cb44b40 100644
--- a/tools/go_marshal/marshal/marshal.go
+++ b/pkg/marshal/marshal.go
@@ -26,9 +26,10 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// Task provides a subset of kernel.Task, used in marshalling. We don't import
-// the kernel package directly to avoid circular dependency.
-type Task interface {
+// CopyContext defines the memory operations required to marshal to and from
+// user memory. Typically, kernel.Task is used to provide implementations for
+// these operations.
+type CopyContext interface {
// CopyScratchBuffer provides a task goroutine-local scratch buffer. See
// kernel.CopyScratchBuffer.
CopyScratchBuffer(size int) []byte
@@ -107,7 +108,7 @@ type Marshallable interface {
// If the copy-in from the task memory is only partially successful, CopyIn
// should still attempt to deserialize as much data as possible. See comment
// for UnmarshalBytes.
- CopyIn(task Task, addr usermem.Addr) (int, error)
+ CopyIn(cc CopyContext, addr usermem.Addr) (int, error)
// CopyOut serializes a Marshallable type to a task's memory. This may only
// be called from a task goroutine. This is more efficient than calling
@@ -118,7 +119,7 @@ type Marshallable interface {
// The copy-out to the task memory may be partially successful, in which
// case CopyOut returns how much data was serialized. See comment for
// MarshalBytes for implications.
- CopyOut(task Task, addr usermem.Addr) (int, error)
+ CopyOut(cc CopyContext, addr usermem.Addr) (int, error)
// CopyOutN is like CopyOut, but explicitly requests a partial
// copy-out. Note that this may yield unexpected results for non-packed
@@ -126,7 +127,7 @@ type Marshallable interface {
// comment on MarshalBytes.
//
// The limit must be less than or equal to SizeBytes().
- CopyOutN(task Task, addr usermem.Addr, limit int) (int, error)
+ CopyOutN(cc CopyContext, addr usermem.Addr, limit int) (int, error)
}
// go-marshal generates additional functions for a type based on additional
@@ -156,10 +157,10 @@ type Marshallable interface {
// func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... }
//
// // CopyFooSliceIn copies in a slice of Foo objects from the task's memory.
-// func CopyFooSliceIn(task marshal.Task, addr usermem.Addr, dst []Foo) (int, error) { ... }
+// func CopyFooSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []Foo) (int, error) { ... }
//
// // CopyFooSliceIn copies out a slice of Foo objects to the task's memory.
-// func CopyFooSliceOut(task marshal.Task, addr usermem.Addr, src []Foo) (int, error) { ... }
+// func CopyFooSliceOut(cc marshal.CopyContext, addr usermem.Addr, src []Foo) (int, error) { ... }
//
// The name of the functions are of the format "Copy%sIn" and "Copy%sOut", where
// %s is the first argument to the slice clause. This directive is not supported
@@ -174,10 +175,10 @@ type Marshallable interface {
// This is only valid on newtypes on primitives, and causes the generated
// functions to accept slices of the inner type instead:
//
-// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []int32) (int, error) { ... }
+// func CopyInt32SliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []int32) (int, error) { ... }
//
// Without "inner", they would instead be:
//
-// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []Int32) (int, error) { ... }
+// func CopyInt32SliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []Int32) (int, error) { ... }
//
// This may help avoid a cast depending on how the generated functions are used.
diff --git a/tools/go_marshal/marshal/marshal_impl_util.go b/pkg/marshal/marshal_impl_util.go
index 89c7d3575..ea75e09f2 100644
--- a/tools/go_marshal/marshal/marshal_impl_util.go
+++ b/pkg/marshal/marshal_impl_util.go
@@ -44,7 +44,7 @@ func (StubMarshallable) MarshalBytes(dst []byte) {
// UnmarshalBytes implements Marshallable.UnmarshalBytes.
func (StubMarshallable) UnmarshalBytes(src []byte) {
- panic("Please implement your own UnMarshalBytes function")
+ panic("Please implement your own UnmarshalBytes function")
}
// Packed implements Marshallable.Packed.
@@ -63,16 +63,16 @@ func (StubMarshallable) UnmarshalUnsafe(src []byte) {
}
// CopyIn implements Marshallable.CopyIn.
-func (StubMarshallable) CopyIn(task Task, addr usermem.Addr) (int, error) {
+func (StubMarshallable) CopyIn(cc CopyContext, addr usermem.Addr) (int, error) {
panic("Please implement your own CopyIn function")
}
// CopyOut implements Marshallable.CopyOut.
-func (StubMarshallable) CopyOut(task Task, addr usermem.Addr) (int, error) {
+func (StubMarshallable) CopyOut(cc CopyContext, addr usermem.Addr) (int, error) {
panic("Please implement your own CopyOut function")
}
// CopyOutN implements Marshallable.CopyOutN.
-func (StubMarshallable) CopyOutN(task Task, addr usermem.Addr, limit int) (int, error) {
+func (StubMarshallable) CopyOutN(cc CopyContext, addr usermem.Addr, limit int) (int, error) {
panic("Please implement your own CopyOutN function")
}
diff --git a/tools/go_marshal/primitive/BUILD b/pkg/marshal/primitive/BUILD
index cc08ba63a..d77a11c79 100644
--- a/tools/go_marshal/primitive/BUILD
+++ b/pkg/marshal/primitive/BUILD
@@ -12,7 +12,8 @@ go_library(
"//:sandbox",
],
deps = [
+ "//pkg/context",
+ "//pkg/marshal",
"//pkg/usermem",
- "//tools/go_marshal/marshal",
],
)
diff --git a/tools/go_marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go
index d93edda8b..4b342de6b 100644
--- a/tools/go_marshal/primitive/primitive.go
+++ b/pkg/marshal/primitive/primitive.go
@@ -19,8 +19,9 @@ package primitive
import (
"io"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// Int8 is a marshal.Marshallable implementation for int8.
@@ -101,18 +102,18 @@ func (b *ByteSlice) UnmarshalUnsafe(src []byte) {
}
// CopyIn implements marshal.Marshallable.CopyIn.
-func (b *ByteSlice) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- return task.CopyInBytes(addr, *b)
+func (b *ByteSlice) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
+ return cc.CopyInBytes(addr, *b)
}
// CopyOut implements marshal.Marshallable.CopyOut.
-func (b *ByteSlice) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
- return task.CopyOutBytes(addr, *b)
+func (b *ByteSlice) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
+ return cc.CopyOutBytes(addr, *b)
}
// CopyOutN implements marshal.Marshallable.CopyOutN.
-func (b *ByteSlice) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
- return task.CopyOutBytes(addr, (*b)[:limit])
+func (b *ByteSlice) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {
+ return cc.CopyOutBytes(addr, (*b)[:limit])
}
// WriteTo implements io.WriterTo.WriteTo.
@@ -126,13 +127,53 @@ var _ marshal.Marshallable = (*ByteSlice)(nil)
// Below, we define some convenience functions for marshalling primitive types
// using the newtypes above, without requiring superfluous casts.
+// 8-bit integers
+
+// CopyInt8In is a convenient wrapper for copying in an int8 from the task's
+// memory.
+func CopyInt8In(cc marshal.CopyContext, addr usermem.Addr, dst *int8) (int, error) {
+ var buf Int8
+ n, err := buf.CopyIn(cc, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int8(buf)
+ return n, nil
+}
+
+// CopyInt8Out is a convenient wrapper for copying out an int8 to the task's
+// memory.
+func CopyInt8Out(cc marshal.CopyContext, addr usermem.Addr, src int8) (int, error) {
+ srcP := Int8(src)
+ return srcP.CopyOut(cc, addr)
+}
+
+// CopyUint8In is a convenient wrapper for copying in a uint8 from the task's
+// memory.
+func CopyUint8In(cc marshal.CopyContext, addr usermem.Addr, dst *uint8) (int, error) {
+ var buf Uint8
+ n, err := buf.CopyIn(cc, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint8(buf)
+ return n, nil
+}
+
+// CopyUint8Out is a convenient wrapper for copying out a uint8 to the task's
+// memory.
+func CopyUint8Out(cc marshal.CopyContext, addr usermem.Addr, src uint8) (int, error) {
+ srcP := Uint8(src)
+ return srcP.CopyOut(cc, addr)
+}
+
// 16-bit integers
// CopyInt16In is a convenient wrapper for copying in an int16 from the task's
// memory.
-func CopyInt16In(task marshal.Task, addr usermem.Addr, dst *int16) (int, error) {
+func CopyInt16In(cc marshal.CopyContext, addr usermem.Addr, dst *int16) (int, error) {
var buf Int16
- n, err := buf.CopyIn(task, addr)
+ n, err := buf.CopyIn(cc, addr)
if err != nil {
return n, err
}
@@ -142,16 +183,16 @@ func CopyInt16In(task marshal.Task, addr usermem.Addr, dst *int16) (int, error)
// CopyInt16Out is a convenient wrapper for copying out an int16 to the task's
// memory.
-func CopyInt16Out(task marshal.Task, addr usermem.Addr, src int16) (int, error) {
+func CopyInt16Out(cc marshal.CopyContext, addr usermem.Addr, src int16) (int, error) {
srcP := Int16(src)
- return srcP.CopyOut(task, addr)
+ return srcP.CopyOut(cc, addr)
}
// CopyUint16In is a convenient wrapper for copying in a uint16 from the task's
// memory.
-func CopyUint16In(task marshal.Task, addr usermem.Addr, dst *uint16) (int, error) {
+func CopyUint16In(cc marshal.CopyContext, addr usermem.Addr, dst *uint16) (int, error) {
var buf Uint16
- n, err := buf.CopyIn(task, addr)
+ n, err := buf.CopyIn(cc, addr)
if err != nil {
return n, err
}
@@ -161,18 +202,18 @@ func CopyUint16In(task marshal.Task, addr usermem.Addr, dst *uint16) (int, error
// CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's
// memory.
-func CopyUint16Out(task marshal.Task, addr usermem.Addr, src uint16) (int, error) {
+func CopyUint16Out(cc marshal.CopyContext, addr usermem.Addr, src uint16) (int, error) {
srcP := Uint16(src)
- return srcP.CopyOut(task, addr)
+ return srcP.CopyOut(cc, addr)
}
// 32-bit integers
// CopyInt32In is a convenient wrapper for copying in an int32 from the task's
// memory.
-func CopyInt32In(task marshal.Task, addr usermem.Addr, dst *int32) (int, error) {
+func CopyInt32In(cc marshal.CopyContext, addr usermem.Addr, dst *int32) (int, error) {
var buf Int32
- n, err := buf.CopyIn(task, addr)
+ n, err := buf.CopyIn(cc, addr)
if err != nil {
return n, err
}
@@ -182,16 +223,16 @@ func CopyInt32In(task marshal.Task, addr usermem.Addr, dst *int32) (int, error)
// CopyInt32Out is a convenient wrapper for copying out an int32 to the task's
// memory.
-func CopyInt32Out(task marshal.Task, addr usermem.Addr, src int32) (int, error) {
+func CopyInt32Out(cc marshal.CopyContext, addr usermem.Addr, src int32) (int, error) {
srcP := Int32(src)
- return srcP.CopyOut(task, addr)
+ return srcP.CopyOut(cc, addr)
}
// CopyUint32In is a convenient wrapper for copying in a uint32 from the task's
// memory.
-func CopyUint32In(task marshal.Task, addr usermem.Addr, dst *uint32) (int, error) {
+func CopyUint32In(cc marshal.CopyContext, addr usermem.Addr, dst *uint32) (int, error) {
var buf Uint32
- n, err := buf.CopyIn(task, addr)
+ n, err := buf.CopyIn(cc, addr)
if err != nil {
return n, err
}
@@ -201,18 +242,18 @@ func CopyUint32In(task marshal.Task, addr usermem.Addr, dst *uint32) (int, error
// CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's
// memory.
-func CopyUint32Out(task marshal.Task, addr usermem.Addr, src uint32) (int, error) {
+func CopyUint32Out(cc marshal.CopyContext, addr usermem.Addr, src uint32) (int, error) {
srcP := Uint32(src)
- return srcP.CopyOut(task, addr)
+ return srcP.CopyOut(cc, addr)
}
// 64-bit integers
// CopyInt64In is a convenient wrapper for copying in an int64 from the task's
// memory.
-func CopyInt64In(task marshal.Task, addr usermem.Addr, dst *int64) (int, error) {
+func CopyInt64In(cc marshal.CopyContext, addr usermem.Addr, dst *int64) (int, error) {
var buf Int64
- n, err := buf.CopyIn(task, addr)
+ n, err := buf.CopyIn(cc, addr)
if err != nil {
return n, err
}
@@ -222,16 +263,16 @@ func CopyInt64In(task marshal.Task, addr usermem.Addr, dst *int64) (int, error)
// CopyInt64Out is a convenient wrapper for copying out an int64 to the task's
// memory.
-func CopyInt64Out(task marshal.Task, addr usermem.Addr, src int64) (int, error) {
+func CopyInt64Out(cc marshal.CopyContext, addr usermem.Addr, src int64) (int, error) {
srcP := Int64(src)
- return srcP.CopyOut(task, addr)
+ return srcP.CopyOut(cc, addr)
}
// CopyUint64In is a convenient wrapper for copying in a uint64 from the task's
// memory.
-func CopyUint64In(task marshal.Task, addr usermem.Addr, dst *uint64) (int, error) {
+func CopyUint64In(cc marshal.CopyContext, addr usermem.Addr, dst *uint64) (int, error) {
var buf Uint64
- n, err := buf.CopyIn(task, addr)
+ n, err := buf.CopyIn(cc, addr)
if err != nil {
return n, err
}
@@ -241,7 +282,68 @@ func CopyUint64In(task marshal.Task, addr usermem.Addr, dst *uint64) (int, error
// CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's
// memory.
-func CopyUint64Out(task marshal.Task, addr usermem.Addr, src uint64) (int, error) {
+func CopyUint64Out(cc marshal.CopyContext, addr usermem.Addr, src uint64) (int, error) {
srcP := Uint64(src)
- return srcP.CopyOut(task, addr)
+ return srcP.CopyOut(cc, addr)
+}
+
+// CopyByteSliceIn is a convenient wrapper for copying in a []byte from the
+// task's memory.
+func CopyByteSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst *[]byte) (int, error) {
+ var buf ByteSlice
+ n, err := buf.CopyIn(cc, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = []byte(buf)
+ return n, nil
+}
+
+// CopyByteSliceOut is a convenient wrapper for copying out a []byte to the
+// task's memory.
+func CopyByteSliceOut(cc marshal.CopyContext, addr usermem.Addr, src []byte) (int, error) {
+ srcP := ByteSlice(src)
+ return srcP.CopyOut(cc, addr)
+}
+
+// CopyStringIn is a convenient wrapper for copying in a string from the
+// task's memory.
+func CopyStringIn(cc marshal.CopyContext, addr usermem.Addr, dst *string) (int, error) {
+ var buf ByteSlice
+ n, err := buf.CopyIn(cc, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = string(buf)
+ return n, nil
+}
+
+// CopyStringOut is a convenient wrapper for copying out a string to the task's
+// memory.
+func CopyStringOut(cc marshal.CopyContext, addr usermem.Addr, src string) (int, error) {
+ srcP := ByteSlice(src)
+ return srcP.CopyOut(cc, addr)
+}
+
+// IOCopyContext wraps an object implementing usermem.IO to implement
+// marshal.CopyContext.
+type IOCopyContext struct {
+ Ctx context.Context
+ IO usermem.IO
+ Opts usermem.IOOpts
+}
+
+// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
+func (i *IOCopyContext) CopyScratchBuffer(size int) []byte {
+ return make([]byte, size)
+}
+
+// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
+func (i *IOCopyContext) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) {
+ return i.IO.CopyOut(i.Ctx, addr, b, i.Opts)
+}
+
+// CopyInBytes implements marshal.CopyContext.CopyInBytes.
+func (i *IOCopyContext) CopyInBytes(addr usermem.Addr, b []byte) (int, error) {
+ return i.IO.CopyIn(i.Ctx, addr, b, i.Opts)
}
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index 36832ec86..4b4f9bd52 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -225,9 +225,9 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree
// Verify will modify the cursor for data, but always restores it to its
// original position upon exit. The cursor for tree is modified and not
// restored.
-func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte, dataAndTreeInSameFile bool) error {
+func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte, dataAndTreeInSameFile bool) (int64, error) {
if readSize <= 0 {
- return fmt.Errorf("Unexpected read size: %d", readSize)
+ return 0, fmt.Errorf("Unexpected read size: %d", readSize)
}
layout := InitLayout(int64(dataSize), dataAndTreeInSameFile)
@@ -240,29 +240,30 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
// finishes.
origOffset, err := data.Seek(0, io.SeekCurrent)
if err != nil {
- return fmt.Errorf("Find current data offset failed: %v", err)
+ return 0, fmt.Errorf("Find current data offset failed: %v", err)
}
defer data.Seek(origOffset, io.SeekStart)
// Move to the first block that contains target data.
if _, err := data.Seek(firstDataBlock*layout.blockSize, io.SeekStart); err != nil {
- return fmt.Errorf("Seek to datablock start failed: %v", err)
+ return 0, fmt.Errorf("Seek to datablock start failed: %v", err)
}
buf := make([]byte, layout.blockSize)
var readErr error
- bytesRead := 0
+ total := int64(0)
for i := firstDataBlock; i <= lastDataBlock; i++ {
// Read a block that includes all or part of target range in
// input data.
- bytesRead, readErr = data.Read(buf)
+ bytesRead, err := data.Read(buf)
+ readErr = err
// If at the end of input data and all previous blocks are
// verified, return the verified input data and EOF.
if readErr == io.EOF && bytesRead == 0 {
break
}
if readErr != nil && readErr != io.EOF {
- return fmt.Errorf("Read from data failed: %v", err)
+ return 0, fmt.Errorf("Read from data failed: %v", err)
}
// If this is the end of file, zero the remaining bytes in buf,
// otherwise they are still from the previous block.
@@ -274,7 +275,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
}
}
if err := verifyBlock(tree, layout, buf, i, expectedRoot); err != nil {
- return err
+ return 0, err
}
// startOff is the beginning of the read range within the
// current data block. Note that for all blocks other than the
@@ -298,10 +299,14 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
if endOff > int64(bytesRead) {
endOff = int64(bytesRead)
}
- w.Write(buf[startOff:endOff])
+ n, err := w.Write(buf[startOff:endOff])
+ if err != nil {
+ return total, err
+ }
+ total += int64(n)
}
- return readErr
+ return total, readErr
}
// verifyBlock verifies a block against tree. index is the number of block in
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
index ad50ba5f6..daaca759a 100644
--- a/pkg/merkletree/merkletree_test.go
+++ b/pkg/merkletree/merkletree_test.go
@@ -67,17 +67,17 @@ func TestLayout(t *testing.T) {
t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) {
l := InitLayout(tc.dataSize, tc.dataAndTreeInSameFile)
if l.blockSize != int64(usermem.PageSize) {
- t.Errorf("got blockSize %d, want %d", l.blockSize, usermem.PageSize)
+ t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize)
}
if l.digestSize != sha256DigestSize {
- t.Errorf("got digestSize %d, want %d", l.digestSize, sha256DigestSize)
+ t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize)
}
if l.numLevels() != len(tc.expectedLevelOffset) {
- t.Errorf("got levels %d, want %d", l.numLevels(), len(tc.expectedLevelOffset))
+ t.Errorf("Got levels %d, want %d", l.numLevels(), len(tc.expectedLevelOffset))
}
for i := 0; i < l.numLevels() && i < len(tc.expectedLevelOffset); i++ {
if l.levelOffset[i] != tc.expectedLevelOffset[i] {
- t.Errorf("got levelStart[%d] %d, want %d", i, l.levelOffset[i], tc.expectedLevelOffset[i])
+ t.Errorf("Got levelStart[%d] %d, want %d", i, l.levelOffset[i], tc.expectedLevelOffset[i])
}
}
})
@@ -169,11 +169,11 @@ func TestGenerate(t *testing.T) {
}, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile)
}
if err != nil {
- t.Fatalf("got err: %v, want nil", err)
+ t.Fatalf("Got err: %v, want nil", err)
}
if !bytes.Equal(root, tc.expectedRoot) {
- t.Errorf("got root: %v, want %v", root, tc.expectedRoot)
+ t.Errorf("Got root: %v, want %v", root, tc.expectedRoot)
}
}
})
@@ -334,14 +334,21 @@ func TestVerify(t *testing.T) {
var buf bytes.Buffer
data[tc.modifyByte] ^= 1
if tc.shouldSucceed {
- if err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err != nil && err != io.EOF {
+ n, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile)
+ if err != nil && err != io.EOF {
t.Errorf("Verification failed when expected to succeed: %v", err)
}
- if int64(buf.Len()) != tc.verifySize || !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) {
- t.Errorf("Incorrect output from Verify")
+ if n != tc.verifySize {
+ t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize)
+ }
+ if int64(buf.Len()) != tc.verifySize {
+ t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize)
+ }
+ if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) {
+ t.Errorf("Incorrect output buf from Verify")
}
} else {
- if err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err == nil {
+ if _, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err == nil {
t.Errorf("Verification succeeded when expected to fail")
}
}
@@ -382,14 +389,21 @@ func TestVerifyRandom(t *testing.T) {
var buf bytes.Buffer
// Checks that the random portion of data from the original data is
// verified successfully.
- if err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err != nil && err != io.EOF {
+ n, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile)
+ if err != nil && err != io.EOF {
t.Errorf("Verification failed for correct data: %v", err)
}
if size > dataSize-start {
size = dataSize - start
}
- if int64(buf.Len()) != size || !bytes.Equal(data[start:start+size], buf.Bytes()) {
- t.Errorf("Incorrect output from Verify")
+ if n != size {
+ t.Errorf("Got Verify output size %d, want %d", n, size)
+ }
+ if int64(buf.Len()) != size {
+ t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size)
+ }
+ if !bytes.Equal(data[start:start+size], buf.Bytes()) {
+ t.Errorf("Incorrect output buf from Verify")
}
buf.Reset()
@@ -397,7 +411,7 @@ func TestVerifyRandom(t *testing.T) {
randBytePos := rand.Int63n(size)
data[start+randBytePos] ^= 1
- if err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err == nil {
+ if _, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err == nil {
t.Errorf("Verification succeeded for modified data")
}
}
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go
index 2ee07b664..28fe081d6 100644
--- a/pkg/p9/client_file.go
+++ b/pkg/p9/client_file.go
@@ -54,6 +54,8 @@ func (c *Client) newFile(fid FID) *clientFile {
//
// This proxies all of the interfaces found in file.go.
type clientFile struct {
+ DisallowServerCalls
+
// client is the originating client.
client *Client
@@ -283,6 +285,39 @@ func (c *clientFile) Close() error {
return nil
}
+// SetAttrClose implements File.SetAttrClose.
+func (c *clientFile) SetAttrClose(valid SetAttrMask, attr SetAttr) error {
+ if !versionSupportsTsetattrclunk(c.client.version) {
+ setAttrErr := c.SetAttr(valid, attr)
+
+ // Try to close file even in case of failure above. Since the state of the
+ // file is unknown to the caller, it will not attempt to close the file
+ // again.
+ if err := c.Close(); err != nil {
+ return err
+ }
+
+ return setAttrErr
+ }
+
+ // Avoid double close.
+ if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
+ return syscall.EBADF
+ }
+
+ // Send the message.
+ if err := c.client.sendRecv(&Tsetattrclunk{FID: c.fid, Valid: valid, SetAttr: attr}, &Rsetattrclunk{}); err != nil {
+ // If an error occurred, we toss away the FID. This isn't ideal,
+ // but I'm not sure what else makes sense in this context.
+ log.Warningf("Tsetattrclunk failed, losing FID %v: %v", c.fid, err)
+ return err
+ }
+
+ // Return the FID to the pool.
+ c.client.fidPool.Put(uint64(c.fid))
+ return nil
+}
+
// Open implements File.Open.
func (c *clientFile) Open(flags OpenFlags) (*fd.FD, QID, uint32, error) {
if atomic.LoadUint32(&c.closed) != 0 {
@@ -681,6 +716,3 @@ func (c *clientFile) Flush() error {
return c.client.sendRecv(&Tflushf{FID: c.fid}, &Rflushf{})
}
-
-// Renamed implements File.Renamed.
-func (c *clientFile) Renamed(newDir File, newName string) {}
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
index cab35896f..c2e3a3f98 100644
--- a/pkg/p9/file.go
+++ b/pkg/p9/file.go
@@ -135,6 +135,14 @@ type File interface {
// On the server, Close has no concurrency guarantee.
Close() error
+ // SetAttrClose is the equivalent of calling SetAttr() followed by Close().
+ // This can be used to set file times before closing the file in a single
+ // operation.
+ //
+ // On the server, SetAttr has a write concurrency guarantee.
+ // On the server, Close has no concurrency guarantee.
+ SetAttrClose(valid SetAttrMask, attr SetAttr) error
+
// Open must be called prior to using Read, Write or Readdir. Once Open
// is called, some operations, such as Walk, will no longer work.
//
@@ -286,3 +294,19 @@ type DefaultWalkGetAttr struct{}
func (DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) {
return nil, nil, AttrMask{}, Attr{}, syscall.ENOSYS
}
+
+// DisallowClientCalls panics if a client-only function is called.
+type DisallowClientCalls struct{}
+
+// SetAttrClose implements File.SetAttrClose.
+func (DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error {
+ panic("SetAttrClose should not be called on the server")
+}
+
+// DisallowServerCalls panics if a server-only function is called.
+type DisallowServerCalls struct{}
+
+// Renamed implements File.Renamed.
+func (*clientFile) Renamed(File, string) {
+ panic("Renamed should not be called on the client")
+}
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index 1db5797dd..abd237f46 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -123,6 +123,37 @@ func (t *Tclunk) handle(cs *connState) message {
return &Rclunk{}
}
+func (t *Tsetattrclunk) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ setAttrErr := ref.safelyWrite(func() error {
+ // We don't allow setattr on files that have been deleted.
+ // This might be technically incorrect, as it's possible that
+ // there were multiple links and you can still change the
+ // corresponding inode information.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+
+ // Set the attributes.
+ return ref.file.SetAttr(t.Valid, t.SetAttr)
+ })
+
+ // Try to delete FID even in case of failure above. Since the state of the
+ // file is unknown to the caller, it will not attempt to close the file again.
+ if !cs.DeleteFID(t.FID) {
+ return newErr(syscall.EBADF)
+ }
+ if setAttrErr != nil {
+ return newErr(setAttrErr)
+ }
+ return &Rsetattrclunk{}
+}
+
// handle implements handler.handle.
func (t *Tremove) handle(cs *connState) message {
ref, ok := cs.LookupFID(t.FID)
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index 2cb59f934..cf13cbb69 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -317,6 +317,64 @@ func (r *Rclunk) String() string {
return "Rclunk{}"
}
+// Tsetattrclunk is a setattr+close request.
+type Tsetattrclunk struct {
+ // FID is the FID to change.
+ FID FID
+
+ // Valid is the set of bits which will be used.
+ Valid SetAttrMask
+
+ // SetAttr is the set request.
+ SetAttr SetAttr
+}
+
+// decode implements encoder.decode.
+func (t *Tsetattrclunk) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Valid.decode(b)
+ t.SetAttr.decode(b)
+}
+
+// encode implements encoder.encode.
+func (t *Tsetattrclunk) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ t.Valid.encode(b)
+ t.SetAttr.encode(b)
+}
+
+// Type implements message.Type.
+func (*Tsetattrclunk) Type() MsgType {
+ return MsgTsetattrclunk
+}
+
+// String implements fmt.Stringer.
+func (t *Tsetattrclunk) String() string {
+ return fmt.Sprintf("Tsetattrclunk{FID: %d, Valid: %v, SetAttr: %s}", t.FID, t.Valid, t.SetAttr)
+}
+
+// Rsetattrclunk is a setattr+close response.
+type Rsetattrclunk struct {
+}
+
+// decode implements encoder.decode.
+func (*Rsetattrclunk) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (*Rsetattrclunk) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rsetattrclunk) Type() MsgType {
+ return MsgRsetattrclunk
+}
+
+// String implements fmt.Stringer.
+func (r *Rsetattrclunk) String() string {
+ return "Rsetattrclunk{}"
+}
+
// Tremove is a remove request.
//
// This will eventually be replaced by Tunlinkat.
@@ -2657,6 +2715,8 @@ func init() {
msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} })
msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} })
msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} })
+ msgRegistry.register(MsgTsetattrclunk, func() message { return &Tsetattrclunk{} })
+ msgRegistry.register(MsgRsetattrclunk, func() message { return &Rsetattrclunk{} })
msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} })
msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} })
}
diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go
index 7facc9f5e..bfeb6c236 100644
--- a/pkg/p9/messages_test.go
+++ b/pkg/p9/messages_test.go
@@ -376,6 +376,30 @@ func TestEncodeDecode(t *testing.T) {
&Rumknod{
Rmknod{QID: QID{Type: 1}},
},
+ &Tsetattrclunk{
+ FID: 1,
+ Valid: SetAttrMask{
+ Permissions: true,
+ UID: true,
+ GID: true,
+ Size: true,
+ ATime: true,
+ MTime: true,
+ CTime: true,
+ ATimeNotSystemTime: true,
+ MTimeNotSystemTime: true,
+ },
+ SetAttr: SetAttr{
+ Permissions: 1,
+ UID: 2,
+ GID: 3,
+ Size: 4,
+ ATimeSeconds: 5,
+ ATimeNanoSeconds: 6,
+ MTimeSeconds: 7,
+ MTimeNanoSeconds: 8,
+ },
+ },
}
for _, enc := range objs {
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index 122c457d2..2235f8968 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -315,86 +315,88 @@ type MsgType uint8
// MsgType declarations.
const (
- MsgTlerror MsgType = 6
- MsgRlerror = 7
- MsgTstatfs = 8
- MsgRstatfs = 9
- MsgTlopen = 12
- MsgRlopen = 13
- MsgTlcreate = 14
- MsgRlcreate = 15
- MsgTsymlink = 16
- MsgRsymlink = 17
- MsgTmknod = 18
- MsgRmknod = 19
- MsgTrename = 20
- MsgRrename = 21
- MsgTreadlink = 22
- MsgRreadlink = 23
- MsgTgetattr = 24
- MsgRgetattr = 25
- MsgTsetattr = 26
- MsgRsetattr = 27
- MsgTlistxattr = 28
- MsgRlistxattr = 29
- MsgTxattrwalk = 30
- MsgRxattrwalk = 31
- MsgTxattrcreate = 32
- MsgRxattrcreate = 33
- MsgTgetxattr = 34
- MsgRgetxattr = 35
- MsgTsetxattr = 36
- MsgRsetxattr = 37
- MsgTremovexattr = 38
- MsgRremovexattr = 39
- MsgTreaddir = 40
- MsgRreaddir = 41
- MsgTfsync = 50
- MsgRfsync = 51
- MsgTlink = 70
- MsgRlink = 71
- MsgTmkdir = 72
- MsgRmkdir = 73
- MsgTrenameat = 74
- MsgRrenameat = 75
- MsgTunlinkat = 76
- MsgRunlinkat = 77
- MsgTversion = 100
- MsgRversion = 101
- MsgTauth = 102
- MsgRauth = 103
- MsgTattach = 104
- MsgRattach = 105
- MsgTflush = 108
- MsgRflush = 109
- MsgTwalk = 110
- MsgRwalk = 111
- MsgTread = 116
- MsgRread = 117
- MsgTwrite = 118
- MsgRwrite = 119
- MsgTclunk = 120
- MsgRclunk = 121
- MsgTremove = 122
- MsgRremove = 123
- MsgTflushf = 124
- MsgRflushf = 125
- MsgTwalkgetattr = 126
- MsgRwalkgetattr = 127
- MsgTucreate = 128
- MsgRucreate = 129
- MsgTumkdir = 130
- MsgRumkdir = 131
- MsgTumknod = 132
- MsgRumknod = 133
- MsgTusymlink = 134
- MsgRusymlink = 135
- MsgTlconnect = 136
- MsgRlconnect = 137
- MsgTallocate = 138
- MsgRallocate = 139
- MsgTchannel = 250
- MsgRchannel = 251
+ MsgTlerror MsgType = 6
+ MsgRlerror MsgType = 7
+ MsgTstatfs MsgType = 8
+ MsgRstatfs MsgType = 9
+ MsgTlopen MsgType = 12
+ MsgRlopen MsgType = 13
+ MsgTlcreate MsgType = 14
+ MsgRlcreate MsgType = 15
+ MsgTsymlink MsgType = 16
+ MsgRsymlink MsgType = 17
+ MsgTmknod MsgType = 18
+ MsgRmknod MsgType = 19
+ MsgTrename MsgType = 20
+ MsgRrename MsgType = 21
+ MsgTreadlink MsgType = 22
+ MsgRreadlink MsgType = 23
+ MsgTgetattr MsgType = 24
+ MsgRgetattr MsgType = 25
+ MsgTsetattr MsgType = 26
+ MsgRsetattr MsgType = 27
+ MsgTlistxattr MsgType = 28
+ MsgRlistxattr MsgType = 29
+ MsgTxattrwalk MsgType = 30
+ MsgRxattrwalk MsgType = 31
+ MsgTxattrcreate MsgType = 32
+ MsgRxattrcreate MsgType = 33
+ MsgTgetxattr MsgType = 34
+ MsgRgetxattr MsgType = 35
+ MsgTsetxattr MsgType = 36
+ MsgRsetxattr MsgType = 37
+ MsgTremovexattr MsgType = 38
+ MsgRremovexattr MsgType = 39
+ MsgTreaddir MsgType = 40
+ MsgRreaddir MsgType = 41
+ MsgTfsync MsgType = 50
+ MsgRfsync MsgType = 51
+ MsgTlink MsgType = 70
+ MsgRlink MsgType = 71
+ MsgTmkdir MsgType = 72
+ MsgRmkdir MsgType = 73
+ MsgTrenameat MsgType = 74
+ MsgRrenameat MsgType = 75
+ MsgTunlinkat MsgType = 76
+ MsgRunlinkat MsgType = 77
+ MsgTversion MsgType = 100
+ MsgRversion MsgType = 101
+ MsgTauth MsgType = 102
+ MsgRauth MsgType = 103
+ MsgTattach MsgType = 104
+ MsgRattach MsgType = 105
+ MsgTflush MsgType = 108
+ MsgRflush MsgType = 109
+ MsgTwalk MsgType = 110
+ MsgRwalk MsgType = 111
+ MsgTread MsgType = 116
+ MsgRread MsgType = 117
+ MsgTwrite MsgType = 118
+ MsgRwrite MsgType = 119
+ MsgTclunk MsgType = 120
+ MsgRclunk MsgType = 121
+ MsgTremove MsgType = 122
+ MsgRremove MsgType = 123
+ MsgTflushf MsgType = 124
+ MsgRflushf MsgType = 125
+ MsgTwalkgetattr MsgType = 126
+ MsgRwalkgetattr MsgType = 127
+ MsgTucreate MsgType = 128
+ MsgRucreate MsgType = 129
+ MsgTumkdir MsgType = 130
+ MsgRumkdir MsgType = 131
+ MsgTumknod MsgType = 132
+ MsgRumknod MsgType = 133
+ MsgTusymlink MsgType = 134
+ MsgRusymlink MsgType = 135
+ MsgTlconnect MsgType = 136
+ MsgRlconnect MsgType = 137
+ MsgTallocate MsgType = 138
+ MsgRallocate MsgType = 139
+ MsgTsetattrclunk MsgType = 140
+ MsgRsetattrclunk MsgType = 141
+ MsgTchannel MsgType = 250
+ MsgRchannel MsgType = 251
)
// QIDType represents the file type for QIDs.
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index 6e7bb3db2..6e605b14c 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -1225,22 +1225,31 @@ func TestOpen(t *testing.T) {
func TestClose(t *testing.T) {
type closeTest struct {
name string
- closeFn func(backend *Mock, f p9.File)
+ closeFn func(backend *Mock, f p9.File) error
}
cases := []closeTest{
{
name: "close",
- closeFn: func(_ *Mock, f p9.File) {
- f.Close()
+ closeFn: func(_ *Mock, f p9.File) error {
+ return f.Close()
},
},
{
name: "remove",
- closeFn: func(backend *Mock, f p9.File) {
+ closeFn: func(backend *Mock, f p9.File) error {
// Allow the rename call in the parent, automatically translated.
backend.parent.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Times(1)
- f.(deprecatedRemover).Remove()
+ return f.(deprecatedRemover).Remove()
+ },
+ },
+ {
+ name: "setAttrClose",
+ closeFn: func(backend *Mock, f p9.File) error {
+ valid := p9.SetAttrMask{ATime: true}
+ attr := p9.SetAttr{ATimeSeconds: 1, ATimeNanoSeconds: 2}
+ backend.EXPECT().SetAttr(valid, attr).Times(1)
+ return f.SetAttrClose(valid, attr)
},
},
}
@@ -1258,7 +1267,9 @@ func TestClose(t *testing.T) {
_, backend, f := walkHelper(h, name, root)
// Close via the prescribed method.
- tc.closeFn(backend, f)
+ if err := tc.closeFn(backend, f); err != nil {
+ t.Fatalf("closeFn failed: %v", err)
+ }
// Everything should fail with EBADF.
if _, _, err := f.Walk(nil); err != syscall.EBADF {
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
index 09cde9f5a..8d7168ef5 100644
--- a/pkg/p9/version.go
+++ b/pkg/p9/version.go
@@ -26,7 +26,7 @@ const (
//
// Clients are expected to start requesting this version number and
// to continuously decrement it until a Tversion request succeeds.
- highestSupportedVersion uint32 = 11
+ highestSupportedVersion uint32 = 12
// lowestSupportedVersion is the lowest supported version X in a
// version string of the format 9P2000.L.Google.X.
@@ -173,3 +173,9 @@ func versionSupportsGetSetXattr(v uint32) bool {
func versionSupportsListRemoveXattr(v uint32) bool {
return v >= 11
}
+
+// versionSupportsTsetattrclunk returns true if version v supports
+// the Tsetattrclunk message.
+func versionSupportsTsetattrclunk(v uint32) bool {
+ return v >= 12
+}
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index 57d8542b9..699ea8ac3 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -257,6 +257,8 @@ func (l *LeakMode) Get() interface{} {
// String implements flag.Value.
func (l *LeakMode) String() string {
switch *l {
+ case UninitializedLeakChecking:
+ return "uninitialized"
case NoLeakChecking:
return "disabled"
case LeaksLogWarning:
@@ -264,7 +266,7 @@ func (l *LeakMode) String() string {
case LeaksLogTraces:
return "log-traces"
}
- panic(fmt.Sprintf("invalid ref leak mode %q", *l))
+ panic(fmt.Sprintf("invalid ref leak mode %d", *l))
}
// leakMode stores the current mode for the reference leak checker.
diff --git a/pkg/safemem/BUILD b/pkg/safemem/BUILD
index ce30382ab..68ed074f8 100644
--- a/pkg/safemem/BUILD
+++ b/pkg/safemem/BUILD
@@ -11,9 +11,7 @@ go_library(
"seq_unsafe.go",
],
visibility = ["//:sandbox"],
- deps = [
- "//pkg/safecopy",
- ],
+ deps = ["//pkg/safecopy"],
)
go_test(
diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go
index 55fd6967e..752e2dc32 100644
--- a/pkg/seccomp/seccomp.go
+++ b/pkg/seccomp/seccomp.go
@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package seccomp provides basic seccomp filters for x86_64 (little endian).
+// Package seccomp provides generation of basic seccomp filters. Currently,
+// only little endian systems are supported.
package seccomp
import (
@@ -64,9 +65,9 @@ func Install(rules SyscallRules) error {
Rules: rules,
Action: linux.SECCOMP_RET_ALLOW,
},
- }, defaultAction)
+ }, defaultAction, defaultAction)
if log.IsLogging(log.Debug) {
- programStr, errDecode := bpf.DecodeProgram(instrs)
+ programStr, errDecode := bpf.DecodeInstructions(instrs)
if errDecode != nil {
programStr = fmt.Sprintf("Error: %v\n%s", errDecode, programStr)
}
@@ -117,7 +118,7 @@ var SyscallName = func(sysno uintptr) string {
// BuildProgram builds a BPF program from the given map of actions to matching
// SyscallRules. The single generated program covers all provided RuleSets.
-func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFInstruction, error) {
+func BuildProgram(rules []RuleSet, defaultAction, badArchAction linux.BPFAction) ([]linux.BPFInstruction, error) {
program := bpf.NewProgramBuilder()
// Be paranoid and check that syscall is done in the expected architecture.
@@ -128,7 +129,7 @@ func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFIn
// defaultLabel is at the bottom of the program. The size of program
// may exceeds 255 lines, which is the limit of a condition jump.
program.AddJump(bpf.Jmp|bpf.Jeq|bpf.K, LINUX_AUDIT_ARCH, skipOneInst, 0)
- program.AddDirectJumpLabel(defaultLabel)
+ program.AddStmt(bpf.Ret|bpf.K, uint32(badArchAction))
if err := buildIndex(rules, program); err != nil {
return nil, err
}
@@ -144,6 +145,11 @@ func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFIn
// buildIndex builds a BST to quickly search through all syscalls.
func buildIndex(rules []RuleSet, program *bpf.ProgramBuilder) error {
+ // Do nothing if rules is empty.
+ if len(rules) == 0 {
+ return nil
+ }
+
// Build a list of all application system calls, across all given rule
// sets. We have a simple BST, but may dispatch individual matchers
// with different actions. The matchers are evaluated linearly.
@@ -216,42 +222,163 @@ func addSyscallArgsCheck(p *bpf.ProgramBuilder, rules []Rule, action linux.BPFAc
labelled := false
for i, arg := range rule {
if arg != nil {
+ // Break out early if using MatchAny since no further
+ // instructions are required.
+ if _, ok := arg.(MatchAny); ok {
+ continue
+ }
+
+ // Determine the data offset for low and high bits of input.
+ dataOffsetLow := seccompDataOffsetArgLow(i)
+ dataOffsetHigh := seccompDataOffsetArgHigh(i)
+ if i == RuleIP {
+ dataOffsetLow = seccompDataOffsetIPLow
+ dataOffsetHigh = seccompDataOffsetIPHigh
+ }
+
+ // Add the conditional operation. Input values to the BPF
+ // program are 64bit values. However, comparisons in BPF can
+ // only be done on 32bit values. This means that we need to do
+ // multiple BPF comparisons in order to do one logical 64bit
+ // comparison.
switch a := arg.(type) {
- case AllowAny:
- case AllowValue:
- dataOffsetLow := seccompDataOffsetArgLow(i)
- dataOffsetHigh := seccompDataOffsetArgHigh(i)
- if i == RuleIP {
- dataOffsetLow = seccompDataOffsetIPLow
- dataOffsetHigh = seccompDataOffsetIPHigh
- }
+ case EqualTo:
+ // EqualTo checks that both the higher and lower 32bits are equal.
high, low := uint32(a>>32), uint32(a)
- // assert arg_low == low
+
+ // Assert that the lower 32bits are equal.
+ // arg_low == low ? continue : violation
p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
- // assert arg_high == high
+
+ // Assert that the lower 32bits are also equal.
+ // arg_high == high ? continue/success : violation
p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
labelled = true
+ case NotEqual:
+ // NotEqual checks that either the higher or lower 32bits
+ // are *not* equal.
+ high, low := uint32(a>>32), uint32(a)
+ labelGood := fmt.Sprintf("ne%v", i)
+
+ // Check if the higher 32bits are (not) equal.
+ // arg_low == low ? continue : success
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+
+ // Assert that the lower 32bits are not equal (assuming
+ // higher bits are equal).
+ // arg_high == high ? violation : continue/success
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
+ p.AddJumpTrueLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0)
+ p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+ labelled = true
case GreaterThan:
- dataOffsetLow := seccompDataOffsetArgLow(i)
- dataOffsetHigh := seccompDataOffsetArgHigh(i)
- if i == RuleIP {
- dataOffsetLow = seccompDataOffsetIPLow
- dataOffsetHigh = seccompDataOffsetIPHigh
- }
- labelGood := fmt.Sprintf("gt%v", i)
+ // GreaterThan checks that the higher 32bits is greater
+ // *or* that the higher 32bits are equal and the lower
+ // 32bits are greater.
high, low := uint32(a>>32), uint32(a)
- // assert arg_high < high
+ labelGood := fmt.Sprintf("gt%v", i)
+
+ // Assert the higher 32bits are greater than or equal.
+ // arg_high >= high ? continue : violation (arg_high < high)
p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
- // arg_high > high
+
+ // Assert that the lower 32bits are greater.
+ // arg_high == high ? continue : success (arg_high > high)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
- // arg_low < low
+ // arg_low > low ? continue/success : violation (arg_high == high and arg_low <= low)
p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
labelled = true
+ case GreaterThanOrEqual:
+ // GreaterThanOrEqual checks that the higher 32bits is
+ // greater *or* that the higher 32bits are equal and the
+ // lower 32bits are greater than or equal.
+ high, low := uint32(a>>32), uint32(a)
+ labelGood := fmt.Sprintf("ge%v", i)
+
+ // Assert the higher 32bits are greater than or equal.
+ // arg_high >= high ? continue : violation (arg_high < high)
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ // arg_high == high ? continue : success (arg_high > high)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+
+ // Assert that the lower 32bits are greater (assuming the
+ // higher bits are equal).
+ // arg_low >= low ? continue/success : violation (arg_high == high and arg_low < low)
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+ labelled = true
+ case LessThan:
+ // LessThan checks that the higher 32bits is less *or* that
+ // the higher 32bits are equal and the lower 32bits are
+ // less.
+ high, low := uint32(a>>32), uint32(a)
+ labelGood := fmt.Sprintf("lt%v", i)
+
+ // Assert the higher 32bits are less than or equal.
+ // arg_high > high ? violation : continue
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
+ p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0)
+ // arg_high == high ? continue : success (arg_high < high)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+
+ // Assert that the lower 32bits are less (assuming the
+ // higher bits are equal).
+ // arg_low >= low ? violation : continue
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
+ p.AddJumpTrueLabel(bpf.Jmp|bpf.Jge|bpf.K, low, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0)
+ p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+ labelled = true
+ case LessThanOrEqual:
+ // LessThan checks that the higher 32bits is less *or* that
+ // the higher 32bits are equal and the lower 32bits are
+ // less than or equal.
+ high, low := uint32(a>>32), uint32(a)
+ labelGood := fmt.Sprintf("le%v", i)
+
+ // Assert the higher 32bits are less than or equal.
+ // assert arg_high > high ? violation : continue
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
+ p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0)
+ // arg_high == high ? continue : success
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+
+ // Assert the lower bits are less than or equal (assuming
+ // the higher bits are equal).
+ // arg_low > low ? violation : success
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
+ p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0)
+ p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+ labelled = true
+ case maskedEqual:
+ // MaskedEqual checks that the bitwise AND of the value and
+ // mask are equal for both the higher and lower 32bits.
+ high, low := uint32(a.value>>32), uint32(a.value)
+ maskHigh, maskLow := uint32(a.mask>>32), uint32(a.mask)
+
+ // Assert that the lower 32bits are equal when masked.
+ // A <- arg_low.
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
+ // A <- arg_low & maskLow
+ p.AddStmt(bpf.Alu|bpf.And|bpf.K, maskLow)
+ // Assert that arg_low & maskLow == low.
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+
+ // Assert that the higher 32bits are equal when masked.
+ // A <- arg_high
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
+ // A <- arg_high & maskHigh
+ p.AddStmt(bpf.Alu|bpf.And|bpf.K, maskHigh)
+ // Assert that arg_high & maskHigh == high.
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ labelled = true
default:
return fmt.Errorf("unknown syscall rule type: %v", reflect.TypeOf(a))
}
diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go
index a52dc1b4e..daf165bbf 100644
--- a/pkg/seccomp/seccomp_rules.go
+++ b/pkg/seccomp/seccomp_rules.go
@@ -39,28 +39,79 @@ func seccompDataOffsetArgHigh(i int) uint32 {
return seccompDataOffsetArgLow(i) + 4
}
-// AllowAny is marker to indicate any value will be accepted.
-type AllowAny struct{}
+// MatchAny is marker to indicate any value will be accepted.
+type MatchAny struct{}
-func (a AllowAny) String() (s string) {
+func (a MatchAny) String() (s string) {
return "*"
}
-// AllowValue specifies a value that needs to be strictly matched.
-type AllowValue uintptr
+// EqualTo specifies a value that needs to be strictly matched.
+type EqualTo uintptr
+
+func (a EqualTo) String() (s string) {
+ return fmt.Sprintf("== %#x", uintptr(a))
+}
+
+// NotEqual specifies a value that is strictly not equal.
+type NotEqual uintptr
+
+func (a NotEqual) String() (s string) {
+ return fmt.Sprintf("!= %#x", uintptr(a))
+}
// GreaterThan specifies a value that needs to be strictly smaller.
type GreaterThan uintptr
-func (a AllowValue) String() (s string) {
- return fmt.Sprintf("%#x ", uintptr(a))
+func (a GreaterThan) String() (s string) {
+ return fmt.Sprintf("> %#x", uintptr(a))
+}
+
+// GreaterThanOrEqual specifies a value that needs to be smaller or equal.
+type GreaterThanOrEqual uintptr
+
+func (a GreaterThanOrEqual) String() (s string) {
+ return fmt.Sprintf(">= %#x", uintptr(a))
+}
+
+// LessThan specifies a value that needs to be strictly greater.
+type LessThan uintptr
+
+func (a LessThan) String() (s string) {
+ return fmt.Sprintf("< %#x", uintptr(a))
+}
+
+// LessThanOrEqual specifies a value that needs to be greater or equal.
+type LessThanOrEqual uintptr
+
+func (a LessThanOrEqual) String() (s string) {
+ return fmt.Sprintf("<= %#x", uintptr(a))
+}
+
+type maskedEqual struct {
+ mask uintptr
+ value uintptr
+}
+
+func (a maskedEqual) String() (s string) {
+ return fmt.Sprintf("& %#x == %#x", a.mask, a.value)
+}
+
+// MaskedEqual specifies a value that matches the input after the input is
+// masked (bitwise &) against the given mask. Can be used to verify that input
+// only includes certain approved flags.
+func MaskedEqual(mask, value uintptr) interface{} {
+ return maskedEqual{
+ mask: mask,
+ value: value,
+ }
}
// Rule stores the allowed syscall arguments.
//
// For example:
// rule := Rule {
-// AllowValue(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0
+// EqualTo(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0
// }
type Rule [7]interface{} // 6 arguments + RIP
@@ -89,12 +140,12 @@ func (r Rule) String() (s string) {
// rules := SyscallRules{
// syscall.SYS_FUTEX: []Rule{
// {
-// AllowAny{},
-// AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
+// MatchAny{},
+// EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
// }, // OR
// {
-// AllowAny{},
-// AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
+// MatchAny{},
+// EqualTo(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
// },
// },
// syscall.SYS_GETPID: []Rule{},
diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go
index 5238df8bd..23f30678d 100644
--- a/pkg/seccomp/seccomp_test.go
+++ b/pkg/seccomp/seccomp_test.go
@@ -76,11 +76,14 @@ func TestBasic(t *testing.T) {
}
for _, test := range []struct {
+ name string
ruleSets []RuleSet
defaultAction linux.BPFAction
+ badArchAction linux.BPFAction
specs []spec
}{
{
+ name: "Single syscall",
ruleSets: []RuleSet{
{
Rules: SyscallRules{1: {}},
@@ -88,26 +91,28 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "Single syscall allowed",
+ desc: "syscall allowed",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "Single syscall disallowed",
+ desc: "syscall disallowed",
data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
},
{
+ name: "Multiple rulesets",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
1: []Rule{
{
- AllowValue(0x1),
+ EqualTo(0x1),
},
},
},
@@ -122,30 +127,32 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_KILL_THREAD,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "Multiple rulesets allowed (1a)",
+ desc: "allowed (1a)",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1}},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "Multiple rulesets allowed (1b)",
+ desc: "allowed (1b)",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "Multiple rulesets allowed (2)",
+ desc: "syscall 1 matched 2nd rule",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "Multiple rulesets allowed (2)",
+ desc: "no match",
data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_KILL_THREAD,
},
},
},
{
+ name: "Multiple syscalls",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
@@ -157,50 +164,52 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "Multiple syscalls allowed (1)",
+ desc: "allowed (1)",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "Multiple syscalls allowed (3)",
+ desc: "allowed (3)",
data: seccompData{nr: 3, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "Multiple syscalls allowed (5)",
+ desc: "allowed (5)",
data: seccompData{nr: 5, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "Multiple syscalls disallowed (0)",
+ desc: "disallowed (0)",
data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "Multiple syscalls disallowed (2)",
+ desc: "disallowed (2)",
data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "Multiple syscalls disallowed (4)",
+ desc: "disallowed (4)",
data: seccompData{nr: 4, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "Multiple syscalls disallowed (6)",
+ desc: "disallowed (6)",
data: seccompData{nr: 6, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "Multiple syscalls disallowed (100)",
+ desc: "disallowed (100)",
data: seccompData{nr: 100, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
},
{
+ name: "Wrong architecture",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
@@ -210,15 +219,17 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "Wrong architecture",
+ desc: "arch (123)",
data: seccompData{nr: 1, arch: 123},
- want: linux.SECCOMP_RET_TRAP,
+ want: linux.SECCOMP_RET_KILL_THREAD,
},
},
},
{
+ name: "Syscall disallowed",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
@@ -228,22 +239,24 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "Syscall disallowed, action trap",
+ desc: "action trap",
data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
},
{
+ name: "Syscall arguments",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
1: []Rule{
{
- AllowAny{},
- AllowValue(0xf),
+ MatchAny{},
+ EqualTo(0xf),
},
},
},
@@ -251,29 +264,31 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "Syscall argument allowed",
+ desc: "allowed",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xf}},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "Syscall argument disallowed",
+ desc: "disallowed",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xe}},
want: linux.SECCOMP_RET_TRAP,
},
},
},
{
+ name: "Multiple arguments",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
1: []Rule{
{
- AllowValue(0xf),
+ EqualTo(0xf),
},
{
- AllowValue(0xe),
+ EqualTo(0xe),
},
},
},
@@ -281,28 +296,30 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "Syscall argument allowed, two rules",
+ desc: "match first rule",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf}},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "Syscall argument allowed, two rules",
+ desc: "match 2nd rule",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xe}},
want: linux.SECCOMP_RET_ALLOW,
},
},
},
{
+ name: "EqualTo",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
1: []Rule{
{
- AllowValue(0),
- AllowValue(math.MaxUint64 - 1),
- AllowValue(math.MaxUint32),
+ EqualTo(0),
+ EqualTo(math.MaxUint64 - 1),
+ EqualTo(math.MaxUint32),
},
},
},
@@ -310,9 +327,10 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "64bit syscall argument allowed",
+ desc: "argument allowed (all match)",
data: seccompData{
nr: 1,
arch: LINUX_AUDIT_ARCH,
@@ -321,7 +339,7 @@ func TestBasic(t *testing.T) {
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "64bit syscall argument disallowed",
+ desc: "argument disallowed (one mismatch)",
data: seccompData{
nr: 1,
arch: LINUX_AUDIT_ARCH,
@@ -330,7 +348,7 @@ func TestBasic(t *testing.T) {
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "64bit syscall argument disallowed",
+ desc: "argument disallowed (multiple mismatch)",
data: seccompData{
nr: 1,
arch: LINUX_AUDIT_ARCH,
@@ -341,6 +359,103 @@ func TestBasic(t *testing.T) {
},
},
{
+ name: "NotEqual",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ NotEqual(0x7aabbccdd),
+ NotEqual(math.MaxUint64 - 1),
+ NotEqual(math.MaxUint32),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "arg allowed",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
+ },
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg disallowed (one equal)",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1},
+ },
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (all equal)",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32},
+ },
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ name: "GreaterThan",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ // 4294967298
+ // Both upper 32 bits and lower 32 bits are non-zero.
+ // 00000000000000000000000000000010
+ // 00000000000000000000000000000010
+ GreaterThan(0x00000002_00000002),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "high 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "high 32bits equal, low 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "high 32bits equal, low 32bits equal",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits equal, low 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000003}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ name: "GreaterThan (multi)",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
@@ -355,46 +470,145 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "GreaterThan: Syscall argument allowed",
+ desc: "arg allowed",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "GreaterThan: Syscall argument disallowed (equal)",
+ desc: "arg disallowed (first arg equal)",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "Syscall argument disallowed (smaller)",
+ desc: "arg disallowed (first arg smaller)",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "GreaterThan2: Syscall argument allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xfbcd000d}},
+ desc: "arg disallowed (second arg equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (second arg smaller)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ name: "GreaterThanOrEqual",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ // 4294967298
+ // Both upper 32 bits and lower 32 bits are non-zero.
+ // 00000000000000000000000000000010
+ // 00000000000000000000000000000010
+ GreaterThanOrEqual(0x00000002_00000002),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "high 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "GreaterThan2: Syscall argument disallowed (equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}},
+ desc: "high 32bits equal, low 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "high 32bits equal, low 32bits equal",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "high 32bits equal, low 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ name: "GreaterThanOrEqual (multi)",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ GreaterThanOrEqual(0xf),
+ GreaterThanOrEqual(0xabcd000d),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "arg allowed (both greater)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg allowed (first arg equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg disallowed (first arg smaller)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "GreaterThan2: Syscall argument disallowed (smaller)",
+ desc: "arg allowed (second arg equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg disallowed (second arg smaller)",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}},
want: linux.SECCOMP_RET_TRAP,
},
+ {
+ desc: "arg disallowed (both arg smaller)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xa000ffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
},
},
{
+ name: "LessThan",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
1: []Rule{
{
- RuleIP: AllowValue(0x7aabbccdd),
+ // 4294967298
+ // Both upper 32 bits and lower 32 bits are non-zero.
+ // 00000000000000000000000000000010
+ // 00000000000000000000000000000010
+ LessThan(0x00000002_00000002),
},
},
},
@@ -402,40 +616,307 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "IP: Syscall instruction pointer allowed",
+ desc: "high 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits equal, low 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits equal, low 32bits equal",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits equal, low 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "high 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ },
+ {
+ name: "LessThan (multi)",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ LessThan(0x1),
+ LessThan(0xabcd000d),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "arg allowed",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg disallowed (first arg equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (first arg greater)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (second arg equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (second arg greater)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (both arg greater)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ name: "LessThanOrEqual",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ // 4294967298
+ // Both upper 32 bits and lower 32 bits are non-zero.
+ // 00000000000000000000000000000010
+ // 00000000000000000000000000000010
+ LessThanOrEqual(0x00000002_00000002),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "high 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits equal, low 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits equal, low 32bits equal",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "high 32bits equal, low 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "high 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ },
+
+ {
+ name: "LessThanOrEqual (multi)",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ LessThanOrEqual(0x1),
+ LessThanOrEqual(0xabcd000d),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "arg allowed",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg allowed (first arg equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg disallowed (first arg greater)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg allowed (second arg equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg disallowed (second arg greater)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (both arg greater)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ name: "MaskedEqual",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ // x & 00000001 00000011 (0x103) == 00000000 00000001 (0x1)
+ // Input x must have lowest order bit set and
+ // must *not* have 8th or second lowest order bit set.
+ MaskedEqual(0x103, 0x1),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "arg allowed (low order mandatory bit)",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ // 00000000 00000000 00000000 00000001
+ args: [6]uint64{0x1},
+ },
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg allowed (low order optional bit)",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ // 00000000 00000000 00000000 00000101
+ args: [6]uint64{0x5},
+ },
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg disallowed (lowest order bit not set)",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ // 00000000 00000000 00000000 00000010
+ args: [6]uint64{0x2},
+ },
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (second lowest order bit set)",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ // 00000000 00000000 00000000 00000011
+ args: [6]uint64{0x3},
+ },
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (8th bit set)",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ // 00000000 00000000 00000001 00000000
+ args: [6]uint64{0x100},
+ },
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ name: "Instruction Pointer",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ RuleIP: EqualTo(0x7aabbccdd),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "allowed",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x7aabbccdd},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "IP: Syscall instruction pointer disallowed",
+ desc: "disallowed",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x711223344},
want: linux.SECCOMP_RET_TRAP,
},
},
},
} {
- instrs, err := BuildProgram(test.ruleSets, test.defaultAction)
- if err != nil {
- t.Errorf("%s: buildProgram() got error: %v", test.specs[0].desc, err)
- continue
- }
- p, err := bpf.Compile(instrs)
- if err != nil {
- t.Errorf("%s: bpf.Compile() got error: %v", test.specs[0].desc, err)
- continue
- }
- for _, spec := range test.specs {
- got, err := bpf.Exec(p, spec.data.asInput())
+ t.Run(test.name, func(t *testing.T) {
+ instrs, err := BuildProgram(test.ruleSets, test.defaultAction, test.badArchAction)
if err != nil {
- t.Errorf("%s: bpf.Exec() got error: %v", spec.desc, err)
- continue
+ t.Fatalf("BuildProgram() got error: %v", err)
+ }
+ p, err := bpf.Compile(instrs)
+ if err != nil {
+ t.Fatalf("bpf.Compile() got error: %v", err)
}
- if got != uint32(spec.want) {
- t.Errorf("%s: bpd.Exec() = %d, want: %d", spec.desc, got, spec.want)
+ for _, spec := range test.specs {
+ got, err := bpf.Exec(p, spec.data.asInput())
+ if err != nil {
+ t.Fatalf("%s: bpf.Exec() got error: %v", spec.desc, err)
+ }
+ if got != uint32(spec.want) {
+ // Include a decoded version of the program in output for debugging purposes.
+ decoded, _ := bpf.DecodeInstructions(instrs)
+ t.Fatalf("%s: got: %d, want: %d\nBPF Program\n%s", spec.desc, got, spec.want, decoded)
+ }
}
- }
+ })
}
}
@@ -457,7 +938,7 @@ func TestRandom(t *testing.T) {
Rules: syscallRules,
Action: linux.SECCOMP_RET_ALLOW,
},
- }, linux.SECCOMP_RET_TRAP)
+ }, linux.SECCOMP_RET_TRAP, linux.SECCOMP_RET_KILL_THREAD)
if err != nil {
t.Fatalf("buildProgram() got error: %v", err)
}
diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go
index fe157f539..7f33e0d9e 100644
--- a/pkg/seccomp/seccomp_test_victim.go
+++ b/pkg/seccomp/seccomp_test_victim.go
@@ -100,7 +100,7 @@ func main() {
if !die {
syscalls[syscall.SYS_OPENAT] = []seccomp.Rule{
{
- seccomp.AllowValue(10),
+ seccomp.EqualTo(10),
},
}
}
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 901e0f320..4af4d6e84 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -22,6 +22,7 @@ go_library(
"signal_info.go",
"signal_stack.go",
"stack.go",
+ "stack_unsafe.go",
"syscalls_amd64.go",
"syscalls_arm64.go",
],
@@ -33,11 +34,12 @@ go_library(
"//pkg/context",
"//pkg/cpuid",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/sentry/limits",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
- "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index a903d031c..d75d665ae 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -72,12 +73,12 @@ type Context interface {
// with return values of varying sizes (for example ARCH_GETFS). This
// is a simple utility function to convert to the native size in these
// cases, and then we can CopyOut.
- Native(val uintptr) interface{}
+ Native(val uintptr) marshal.Marshallable
// Value converts a native type back to a generic value.
// Once a value has been converted to native via the above call -- it
// can be converted back here.
- Value(val interface{}) uintptr
+ Value(val marshal.Marshallable) uintptr
// Width returns the number of bytes for a native value.
Width() uint
@@ -205,7 +206,7 @@ type Context interface {
// equivalent of arch_ptrace():
// PtracePeekUser implements ptrace(PTRACE_PEEKUSR).
- PtracePeekUser(addr uintptr) (interface{}, error)
+ PtracePeekUser(addr uintptr) (marshal.Marshallable, error)
// PtracePokeUser implements ptrace(PTRACE_POKEUSR).
PtracePokeUser(addr, data uintptr) error
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 1c3e3c14c..c7d3a206d 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -23,6 +23,8 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -179,14 +181,14 @@ func (c *context64) SetOldRSeqInterruptedIP(value uintptr) {
}
// Native returns the native type for the given val.
-func (c *context64) Native(val uintptr) interface{} {
- v := uint64(val)
+func (c *context64) Native(val uintptr) marshal.Marshallable {
+ v := primitive.Uint64(val)
return &v
}
// Value returns the generic val for the given native type.
-func (c *context64) Value(val interface{}) uintptr {
- return uintptr(*val.(*uint64))
+func (c *context64) Value(val marshal.Marshallable) uintptr {
+ return uintptr(*val.(*primitive.Uint64))
}
// Width returns the byte width of this architecture.
@@ -293,7 +295,7 @@ func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr {
const userStructSize = 928
// PtracePeekUser implements Context.PtracePeekUser.
-func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
+func (c *context64) PtracePeekUser(addr uintptr) (marshal.Marshallable, error) {
if addr&7 != 0 || addr >= userStructSize {
return nil, syscall.EIO
}
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
index 550741d8c..680d23a9f 100644
--- a/pkg/sentry/arch/arch_arm64.go
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -22,6 +22,8 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -163,14 +165,14 @@ func (c *context64) SetOldRSeqInterruptedIP(value uintptr) {
}
// Native returns the native type for the given val.
-func (c *context64) Native(val uintptr) interface{} {
- v := uint64(val)
+func (c *context64) Native(val uintptr) marshal.Marshallable {
+ v := primitive.Uint64(val)
return &v
}
// Value returns the generic val for the given native type.
-func (c *context64) Value(val interface{}) uintptr {
- return uintptr(*val.(*uint64))
+func (c *context64) Value(val marshal.Marshallable) uintptr {
+ return uintptr(*val.(*primitive.Uint64))
}
// Width returns the byte width of this architecture.
@@ -274,7 +276,7 @@ func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr {
}
// PtracePeekUser implements Context.PtracePeekUser.
-func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
+func (c *context64) PtracePeekUser(addr uintptr) (marshal.Marshallable, error) {
// TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64.
return c.Native(0), nil
}
diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go
index 32173aa20..d3e2324a8 100644
--- a/pkg/sentry/arch/signal_act.go
+++ b/pkg/sentry/arch/signal_act.go
@@ -14,7 +14,7 @@
package arch
-import "gvisor.dev/gvisor/tools/go_marshal/marshal"
+import "gvisor.dev/gvisor/pkg/marshal"
// Special values for SignalAct.Handler.
const (
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
index 6fb756f0e..72e07a988 100644
--- a/pkg/sentry/arch/signal_amd64.go
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -17,17 +17,19 @@
package arch
import (
- "encoding/binary"
"math"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/usermem"
)
// SignalContext64 is equivalent to struct sigcontext, the type passed as the
// second argument to signal handlers set by signal(2).
+//
+// +marshal
type SignalContext64 struct {
R8 uint64
R9 uint64
@@ -68,6 +70,8 @@ const (
)
// UContext64 is equivalent to ucontext_t on 64-bit x86.
+//
+// +marshal
type UContext64 struct {
Flags uint64
Link uint64
@@ -172,12 +176,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
// "... the value (%rsp+8) is always a multiple of 16 (...) when
// control is transferred to the function entry point." - AMD64 ABI
- ucSize := binary.Size(uc)
- if ucSize < 0 {
- // This can only happen if we've screwed up the definition of
- // UContext64.
- panic("can't get size of UContext64")
- }
+ ucSize := uc.SizeBytes()
// st.Arch.Width() is for the restorer address. sizeof(siginfo) == 128.
frameSize := int(st.Arch.Width()) + ucSize + 128
frameBottom := (sp-usermem.Addr(frameSize)) & ^usermem.Addr(15) - 8
@@ -195,18 +194,18 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
info.FixSignalCodeForUser()
// Set up the stack frame.
- infoAddr, err := st.Push(info)
- if err != nil {
+ if _, err := info.CopyOut(st, StackBottomMagic); err != nil {
return err
}
- ucAddr, err := st.Push(uc)
- if err != nil {
+ infoAddr := st.Bottom
+ if _, err := uc.CopyOut(st, StackBottomMagic); err != nil {
return err
}
+ ucAddr := st.Bottom
if act.HasRestorer() {
// Push the restorer return address.
// Note that this doesn't need to be popped.
- if _, err := st.Push(usermem.Addr(act.Restorer)); err != nil {
+ if _, err := primitive.CopyUint64Out(st, StackBottomMagic, act.Restorer); err != nil {
return err
}
} else {
@@ -240,11 +239,11 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
// Copy out the stack frame.
var uc UContext64
- if _, err := st.Pop(&uc); err != nil {
+ if _, err := uc.CopyIn(st, StackBottomMagic); err != nil {
return 0, SignalStack{}, err
}
var info SignalInfo
- if _, err := st.Pop(&info); err != nil {
+ if _, err := info.CopyIn(st, StackBottomMagic); err != nil {
return 0, SignalStack{}, err
}
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
index 642c79dda..7fde5d34e 100644
--- a/pkg/sentry/arch/signal_arm64.go
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build arm64
+
package arch
import (
- "encoding/binary"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -25,6 +26,8 @@ import (
// SignalContext64 is equivalent to struct sigcontext, the type passed as the
// second argument to signal handlers set by signal(2).
+//
+// +marshal
type SignalContext64 struct {
FaultAddr uint64
Regs [31]uint64
@@ -36,6 +39,7 @@ type SignalContext64 struct {
Reserved [3568]uint8
}
+// +marshal
type aarch64Ctx struct {
Magic uint32
Size uint32
@@ -43,6 +47,8 @@ type aarch64Ctx struct {
// FpsimdContext is equivalent to struct fpsimd_context on arm64
// (arch/arm64/include/uapi/asm/sigcontext.h).
+//
+// +marshal
type FpsimdContext struct {
Head aarch64Ctx
Fpsr uint32
@@ -51,13 +57,15 @@ type FpsimdContext struct {
}
// UContext64 is equivalent to ucontext on arm64(arch/arm64/include/uapi/asm/ucontext.h).
+//
+// +marshal
type UContext64 struct {
Flags uint64
Link uint64
Stack SignalStack
Sigset linux.SignalSet
// glibc uses a 1024-bit sigset_t
- _pad [(1024 - 64) / 8]byte
+ _pad [120]byte // (1024 - 64) / 8 = 120
// sigcontext must be aligned to 16-byte
_pad2 [8]byte
// last for future expansion
@@ -94,11 +102,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
},
Sigset: sigset,
}
-
- ucSize := binary.Size(uc)
- if ucSize < 0 {
- panic("can't get size of UContext64")
- }
+ ucSize := uc.SizeBytes()
// frameSize = ucSize + sizeof(siginfo).
// sizeof(siginfo) == 128.
@@ -119,14 +123,14 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
info.FixSignalCodeForUser()
// Set up the stack frame.
- infoAddr, err := st.Push(info)
- if err != nil {
+ if _, err := info.CopyOut(st, StackBottomMagic); err != nil {
return err
}
- ucAddr, err := st.Push(uc)
- if err != nil {
+ infoAddr := st.Bottom
+ if _, err := uc.CopyOut(st, StackBottomMagic); err != nil {
return err
}
+ ucAddr := st.Bottom
// Set up registers.
c.Regs.Sp = uint64(st.Bottom)
@@ -147,11 +151,11 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
// Copy out the stack frame.
var uc UContext64
- if _, err := st.Pop(&uc); err != nil {
+ if _, err := uc.CopyIn(st, StackBottomMagic); err != nil {
return 0, SignalStack{}, err
}
var info SignalInfo
- if _, err := st.Pop(&info); err != nil {
+ if _, err := info.CopyIn(st, StackBottomMagic); err != nil {
return 0, SignalStack{}, err
}
diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go
index 0fa738a1d..a1eae98f9 100644
--- a/pkg/sentry/arch/signal_stack.go
+++ b/pkg/sentry/arch/signal_stack.go
@@ -17,8 +17,8 @@
package arch
import (
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
const (
diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go
index 1108fa0bd..5f06c751d 100644
--- a/pkg/sentry/arch/stack.go
+++ b/pkg/sentry/arch/stack.go
@@ -15,14 +15,16 @@
package arch
import (
- "encoding/binary"
- "fmt"
-
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/usermem"
)
-// Stack is a simple wrapper around a usermem.IO and an address.
+// Stack is a simple wrapper around a usermem.IO and an address. Stack
+// implements marshal.CopyContext, and marshallable values can be pushed or
+// popped from the stack through the marshal.Marshallable interface.
+//
+// Stack is not thread-safe.
type Stack struct {
// Our arch info.
// We use this for automatic Native conversion of usermem.Addrs during
@@ -34,105 +36,60 @@ type Stack struct {
// Our current stack bottom.
Bottom usermem.Addr
-}
-// Push pushes the given values on to the stack.
-//
-// (This method supports Addrs and treats them as native types.)
-func (s *Stack) Push(vals ...interface{}) (usermem.Addr, error) {
- for _, v := range vals {
-
- // We convert some types to well-known serializable quanities.
- var norm interface{}
-
- // For array types, we will automatically add an appropriate
- // terminal value. This is done simply to make the interface
- // easier to use.
- var term interface{}
-
- switch v.(type) {
- case string:
- norm = []byte(v.(string))
- term = byte(0)
- case []int8, []uint8:
- norm = v
- term = byte(0)
- case []int16, []uint16:
- norm = v
- term = uint16(0)
- case []int32, []uint32:
- norm = v
- term = uint32(0)
- case []int64, []uint64:
- norm = v
- term = uint64(0)
- case []usermem.Addr:
- // Special case: simply push recursively.
- _, err := s.Push(s.Arch.Native(uintptr(0)))
- if err != nil {
- return 0, err
- }
- varr := v.([]usermem.Addr)
- for i := len(varr) - 1; i >= 0; i-- {
- _, err := s.Push(varr[i])
- if err != nil {
- return 0, err
- }
- }
- continue
- case usermem.Addr:
- norm = s.Arch.Native(uintptr(v.(usermem.Addr)))
- default:
- norm = v
- }
+ // Scratch buffer used for marshalling to avoid having to repeatedly
+ // allocate scratch memory.
+ scratchBuf []byte
+}
- if term != nil {
- _, err := s.Push(term)
- if err != nil {
- return 0, err
- }
- }
+// scratchBufLen is the default length of Stack.scratchBuf. The
+// largest structs the stack regularly serializes are arch.SignalInfo
+// and arch.UContext64. We'll set the default size as the larger of
+// the two, arch.UContext64.
+var scratchBufLen = (*UContext64)(nil).SizeBytes()
- c := binary.Size(norm)
- if c < 0 {
- return 0, fmt.Errorf("bad binary.Size for %T", v)
- }
- n, err := usermem.CopyObjectOut(context.Background(), s.IO, s.Bottom-usermem.Addr(c), norm, usermem.IOOpts{})
- if err != nil || c != n {
- return 0, err
- }
+// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
+func (s *Stack) CopyScratchBuffer(size int) []byte {
+ if len(s.scratchBuf) < size {
+ s.scratchBuf = make([]byte, size)
+ }
+ return s.scratchBuf[:size]
+}
+// StackBottomMagic is the special address callers must past to all stack
+// marshalling operations to cause the src/dst address to be computed based on
+// the current end of the stack.
+const StackBottomMagic = ^usermem.Addr(0) // usermem.Addr(-1)
+
+// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. CopyOutBytes
+// computes an appropriate address based on the current end of the
+// stack. Callers use the sentinel address StackBottomMagic to marshal methods
+// to indicate this.
+func (s *Stack) CopyOutBytes(sentinel usermem.Addr, b []byte) (int, error) {
+ if sentinel != StackBottomMagic {
+ panic("Attempted to copy out to stack with absolute address")
+ }
+ c := len(b)
+ n, err := s.IO.CopyOut(context.Background(), s.Bottom-usermem.Addr(c), b, usermem.IOOpts{})
+ if err == nil && n == c {
s.Bottom -= usermem.Addr(n)
}
-
- return s.Bottom, nil
+ return n, err
}
-// Pop pops the given values off the stack.
-//
-// (This method supports Addrs and treats them as native types.)
-func (s *Stack) Pop(vals ...interface{}) (usermem.Addr, error) {
- for _, v := range vals {
-
- vaddr, isVaddr := v.(*usermem.Addr)
-
- var n int
- var err error
- if isVaddr {
- value := s.Arch.Native(uintptr(0))
- n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, value, usermem.IOOpts{})
- *vaddr = usermem.Addr(s.Arch.Value(value))
- } else {
- n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, v, usermem.IOOpts{})
- }
- if err != nil {
- return 0, err
- }
-
+// CopyInBytes implements marshal.CopyContext.CopyInBytes. CopyInBytes computes
+// an appropriate address based on the current end of the stack. Callers must
+// use the sentinel address StackBottomMagic to marshal methods to indicate
+// this.
+func (s *Stack) CopyInBytes(sentinel usermem.Addr, b []byte) (int, error) {
+ if sentinel != StackBottomMagic {
+ panic("Attempted to copy in from stack with absolute address")
+ }
+ n, err := s.IO.CopyIn(context.Background(), s.Bottom, b, usermem.IOOpts{})
+ if err == nil {
s.Bottom += usermem.Addr(n)
}
-
- return s.Bottom, nil
+ return n, err
}
// Align aligns the stack to the given offset.
@@ -142,6 +99,22 @@ func (s *Stack) Align(offset int) {
}
}
+// PushNullTerminatedByteSlice writes bs to the stack, followed by an extra null
+// byte at the end. On error, the contents of the stack and the bottom cursor
+// are undefined.
+func (s *Stack) PushNullTerminatedByteSlice(bs []byte) (int, error) {
+ // Note: Stack grows up, so write the terminal null byte first.
+ nNull, err := primitive.CopyUint8Out(s, StackBottomMagic, 0)
+ if err != nil {
+ return 0, err
+ }
+ n, err := primitive.CopyByteSliceOut(s, StackBottomMagic, bs)
+ if err != nil {
+ return 0, err
+ }
+ return n + nNull, nil
+}
+
// StackLayout describes the location of the arguments and environment on the
// stack.
type StackLayout struct {
@@ -177,11 +150,10 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error)
l.EnvvEnd = s.Bottom
envAddrs := make([]usermem.Addr, len(env))
for i := len(env) - 1; i >= 0; i-- {
- addr, err := s.Push(env[i])
- if err != nil {
+ if _, err := s.PushNullTerminatedByteSlice([]byte(env[i])); err != nil {
return StackLayout{}, err
}
- envAddrs[i] = addr
+ envAddrs[i] = s.Bottom
}
l.EnvvStart = s.Bottom
@@ -189,11 +161,10 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error)
l.ArgvEnd = s.Bottom
argAddrs := make([]usermem.Addr, len(args))
for i := len(args) - 1; i >= 0; i-- {
- addr, err := s.Push(args[i])
- if err != nil {
+ if _, err := s.PushNullTerminatedByteSlice([]byte(args[i])); err != nil {
return StackLayout{}, err
}
- argAddrs[i] = addr
+ argAddrs[i] = s.Bottom
}
l.ArgvStart = s.Bottom
@@ -222,26 +193,26 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error)
auxv = append(auxv, usermem.Addr(a.Key), a.Value)
}
auxv = append(auxv, usermem.Addr(0))
- _, err := s.Push(auxv)
+ _, err := s.pushAddrSliceAndTerminator(auxv)
if err != nil {
return StackLayout{}, err
}
// Push environment.
- _, err = s.Push(envAddrs)
+ _, err = s.pushAddrSliceAndTerminator(envAddrs)
if err != nil {
return StackLayout{}, err
}
// Push args.
- _, err = s.Push(argAddrs)
+ _, err = s.pushAddrSliceAndTerminator(argAddrs)
if err != nil {
return StackLayout{}, err
}
// Push arg count.
- _, err = s.Push(usermem.Addr(len(args)))
- if err != nil {
+ lenP := s.Arch.Native(uintptr(len(args)))
+ if _, err = lenP.CopyOut(s, StackBottomMagic); err != nil {
return StackLayout{}, err
}
diff --git a/pkg/sentry/arch/stack_unsafe.go b/pkg/sentry/arch/stack_unsafe.go
new file mode 100644
index 000000000..a90d297ee
--- /dev/null
+++ b/pkg/sentry/arch/stack_unsafe.go
@@ -0,0 +1,69 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "reflect"
+ "runtime"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// pushAddrSliceAndTerminator copies a slices of addresses to the stack, and
+// also pushes an extra null address element at the end of the slice.
+//
+// Internally, we unsafely transmute the slice type from the arch-dependent
+// []usermem.Addr type, to a slice of fixed-sized ints so that we can pass it to
+// go-marshal.
+//
+// On error, the contents of the stack and the bottom cursor are undefined.
+func (s *Stack) pushAddrSliceAndTerminator(src []usermem.Addr) (int, error) {
+ // Note: Stack grows upwards, so push the terminator first.
+ srcHdr := (*reflect.SliceHeader)(unsafe.Pointer(&src))
+ switch s.Arch.Width() {
+ case 8:
+ nNull, err := primitive.CopyUint64Out(s, StackBottomMagic, 0)
+ if err != nil {
+ return 0, err
+ }
+ var dst []uint64
+ dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dst))
+ dstHdr.Data = srcHdr.Data
+ dstHdr.Len = srcHdr.Len
+ dstHdr.Cap = srcHdr.Cap
+ n, err := primitive.CopyUint64SliceOut(s, StackBottomMagic, dst)
+ // Ensures src doesn't get GCed until we're done using it through dst.
+ runtime.KeepAlive(src)
+ return n + nNull, err
+ case 4:
+ nNull, err := primitive.CopyUint32Out(s, StackBottomMagic, 0)
+ if err != nil {
+ return 0, err
+ }
+ var dst []uint32
+ dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dst))
+ dstHdr.Data = srcHdr.Data
+ dstHdr.Len = srcHdr.Len
+ dstHdr.Cap = srcHdr.Cap
+ n, err := primitive.CopyUint32SliceOut(s, StackBottomMagic, dst)
+ // Ensure src doesn't get GCed until we're done using it through dst.
+ runtime.KeepAlive(src)
+ return n + nNull, err
+ default:
+ panic("Unsupported arch width")
+ }
+}
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index 2c5d14be5..deaf5fa23 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -35,7 +35,6 @@ go_library(
"//pkg/sync",
"//pkg/tcpip/link/sniffer",
"//pkg/urpc",
- "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index dfa936563..668f47802 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -23,8 +23,8 @@ import (
"text/tabwriter"
"time"
- "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/sentry/fdimport"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/host"
@@ -203,27 +203,17 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
}
initArgs.Filename = resolved
- fds := make([]int, len(args.FilePayload.Files))
- for i, file := range args.FilePayload.Files {
- if kernel.VFS2Enabled {
- // Need to dup to remove ownership from os.File.
- dup, err := unix.Dup(int(file.Fd()))
- if err != nil {
- return nil, 0, nil, nil, fmt.Errorf("duplicating payload files: %w", err)
- }
- fds[i] = dup
- } else {
- // VFS1 dups the file on import.
- fds[i] = int(file.Fd())
- }
+ fds, err := fd.NewFromFiles(args.Files)
+ if err != nil {
+ return nil, 0, nil, nil, fmt.Errorf("duplicating payload files: %w", err)
}
+ defer func() {
+ for _, fd := range fds {
+ _ = fd.Close()
+ }
+ }()
ttyFile, ttyFileVFS2, err := fdimport.Import(ctx, fdTable, args.StdioIsPty, fds)
if err != nil {
- if kernel.VFS2Enabled {
- for _, fd := range fds {
- unix.Close(fd)
- }
- }
return nil, 0, nil, nil, err
}
diff --git a/pkg/sentry/devices/memdev/BUILD b/pkg/sentry/devices/memdev/BUILD
index abe58f818..4c8604d58 100644
--- a/pkg/sentry/devices/memdev/BUILD
+++ b/pkg/sentry/devices/memdev/BUILD
@@ -18,9 +18,10 @@ go_library(
"//pkg/rand",
"//pkg/safemem",
"//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
- "//pkg/sentry/mm",
- "//pkg/sentry/pgalloc",
"//pkg/sentry/vfs",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/devices/memdev/full.go b/pkg/sentry/devices/memdev/full.go
index 511179e31..fece3e762 100644
--- a/pkg/sentry/devices/memdev/full.go
+++ b/pkg/sentry/devices/memdev/full.go
@@ -24,6 +24,8 @@ import (
const fullDevMinor = 7
// fullDevice implements vfs.Device for /dev/full.
+//
+// +stateify savable
type fullDevice struct{}
// Open implements vfs.Device.Open.
@@ -38,6 +40,8 @@ func (fullDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, op
}
// fullFD implements vfs.FileDescriptionImpl for /dev/full.
+//
+// +stateify savable
type fullFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
diff --git a/pkg/sentry/devices/memdev/null.go b/pkg/sentry/devices/memdev/null.go
index 4918dbeeb..ff5837747 100644
--- a/pkg/sentry/devices/memdev/null.go
+++ b/pkg/sentry/devices/memdev/null.go
@@ -25,6 +25,8 @@ import (
const nullDevMinor = 3
// nullDevice implements vfs.Device for /dev/null.
+//
+// +stateify savable
type nullDevice struct{}
// Open implements vfs.Device.Open.
@@ -39,6 +41,8 @@ func (nullDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, op
}
// nullFD implements vfs.FileDescriptionImpl for /dev/null.
+//
+// +stateify savable
type nullFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
diff --git a/pkg/sentry/devices/memdev/random.go b/pkg/sentry/devices/memdev/random.go
index 5e7fe0280..ac943e3ba 100644
--- a/pkg/sentry/devices/memdev/random.go
+++ b/pkg/sentry/devices/memdev/random.go
@@ -30,6 +30,8 @@ const (
)
// randomDevice implements vfs.Device for /dev/random and /dev/urandom.
+//
+// +stateify savable
type randomDevice struct{}
// Open implements vfs.Device.Open.
@@ -44,6 +46,8 @@ func (randomDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry,
}
// randomFD implements vfs.FileDescriptionImpl for /dev/random.
+//
+// +stateify savable
type randomFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go
index 2e631a252..1929e41cd 100644
--- a/pkg/sentry/devices/memdev/zero.go
+++ b/pkg/sentry/devices/memdev/zero.go
@@ -16,9 +16,10 @@ package memdev
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -26,6 +27,8 @@ import (
const zeroDevMinor = 5
// zeroDevice implements vfs.Device for /dev/zero.
+//
+// +stateify savable
type zeroDevice struct{}
// Open implements vfs.Device.Open.
@@ -40,6 +43,8 @@ func (zeroDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, op
}
// zeroFD implements vfs.FileDescriptionImpl for /dev/zero.
+//
+// +stateify savable
type zeroFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -79,11 +84,22 @@ func (fd *zeroFD) Seek(ctx context.Context, offset int64, whence int32) (int64,
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
- m, err := mm.NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx))
+ if opts.Private || !opts.MaxPerms.Write {
+ // This mapping will never permit writing to the "underlying file" (in
+ // Linux terms, it isn't VM_SHARED), so implement it as an anonymous
+ // mapping, but back it with fd; this is what Linux does, and is
+ // actually application-visible because the resulting VMA will show up
+ // in /proc/[pid]/maps with fd.vfsfd.VirtualDentry()'s path rather than
+ // "/dev/zero (deleted)".
+ opts.Offset = 0
+ opts.MappingIdentity = &fd.vfsfd
+ opts.MappingIdentity.IncRef()
+ return nil
+ }
+ tmpfsFD, err := tmpfs.NewZeroFile(ctx, auth.CredentialsFromContext(ctx), kernel.KernelFromContext(ctx).ShmMount(), opts.Length)
if err != nil {
return err
}
- opts.MappingIdentity = m
- opts.Mappable = m
- return nil
+ defer tmpfsFD.DecRef(ctx)
+ return tmpfsFD.ConfigureMMap(ctx, opts)
}
diff --git a/pkg/sentry/devices/ttydev/ttydev.go b/pkg/sentry/devices/ttydev/ttydev.go
index 664e54498..a287c65ca 100644
--- a/pkg/sentry/devices/ttydev/ttydev.go
+++ b/pkg/sentry/devices/ttydev/ttydev.go
@@ -30,6 +30,8 @@ const (
)
// ttyDevice implements vfs.Device for /dev/tty.
+//
+// +stateify savable
type ttyDevice struct{}
// Open implements vfs.Device.Open.
diff --git a/pkg/sentry/fdimport/BUILD b/pkg/sentry/fdimport/BUILD
index 5e41ceb4e..6b4f8b0ed 100644
--- a/pkg/sentry/fdimport/BUILD
+++ b/pkg/sentry/fdimport/BUILD
@@ -10,6 +10,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/context",
+ "//pkg/fd",
"//pkg/sentry/fs",
"//pkg/sentry/fs/host",
"//pkg/sentry/fsimpl/host",
diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go
index 1b7cb94c0..314661475 100644
--- a/pkg/sentry/fdimport/fdimport.go
+++ b/pkg/sentry/fdimport/fdimport.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/host"
hostvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
@@ -27,8 +28,9 @@ import (
// Import imports a slice of FDs into the given FDTable. If console is true,
// sets up TTY for the first 3 FDs in the slice representing stdin, stdout,
-// stderr. Upon success, Import takes ownership of all FDs.
-func Import(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []int) (*host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
+// stderr. Used FDs are either closed or released. It's safe for the caller to
+// close any remaining files upon return.
+func Import(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []*fd.FD) (*host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
if kernel.VFS2Enabled {
ttyFile, err := importVFS2(ctx, fdTable, console, fds)
return nil, ttyFile, err
@@ -37,7 +39,7 @@ func Import(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []in
return ttyFile, nil, err
}
-func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []int) (*host.TTYFileOperations, error) {
+func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []*fd.FD) (*host.TTYFileOperations, error) {
var ttyFile *fs.File
for appFD, hostFD := range fds {
var appFile *fs.File
@@ -46,11 +48,12 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []
// Import the file as a host TTY file.
if ttyFile == nil {
var err error
- appFile, err = host.ImportFile(ctx, hostFD, true /* isTTY */)
+ appFile, err = host.ImportFile(ctx, hostFD.FD(), true /* isTTY */)
if err != nil {
return nil, err
}
defer appFile.DecRef(ctx)
+ _ = hostFD.Close() // FD is dup'd i ImportFile.
// Remember this in the TTY file, as we will
// use it for the other stdio FDs.
@@ -65,11 +68,12 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []
} else {
// Import the file as a regular host file.
var err error
- appFile, err = host.ImportFile(ctx, hostFD, false /* isTTY */)
+ appFile, err = host.ImportFile(ctx, hostFD.FD(), false /* isTTY */)
if err != nil {
return nil, err
}
defer appFile.DecRef(ctx)
+ _ = hostFD.Close() // FD is dup'd i ImportFile.
}
// Add the file to the FD map.
@@ -84,7 +88,7 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []
return ttyFile.FileOperations.(*host.TTYFileOperations), nil
}
-func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdioFDs []int) (*hostvfs2.TTYFileDescription, error) {
+func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdioFDs []*fd.FD) (*hostvfs2.TTYFileDescription, error) {
k := kernel.KernelFromContext(ctx)
if k == nil {
return nil, fmt.Errorf("cannot find kernel from context")
@@ -98,11 +102,12 @@ func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdi
// Import the file as a host TTY file.
if ttyFile == nil {
var err error
- appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD, true /* isTTY */)
+ appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD.FD(), true /* isTTY */)
if err != nil {
return nil, err
}
defer appFile.DecRef(ctx)
+ hostFD.Release() // FD is transfered to host FD.
// Remember this in the TTY file, as we will use it for the other stdio
// FDs.
@@ -115,11 +120,12 @@ func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdi
}
} else {
var err error
- appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD, false /* isTTY */)
+ appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD.FD(), false /* isTTY */)
if err != nil {
return nil, err
}
defer appFile.DecRef(ctx)
+ hostFD.Release() // FD is transfered to host FD.
}
if err := fdTable.NewFDAtVFS2(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil {
diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md
index 2ca84dd74..05e043583 100644
--- a/pkg/sentry/fs/g3doc/fuse.md
+++ b/pkg/sentry/fs/g3doc/fuse.md
@@ -79,7 +79,7 @@ ops can be implemented in parallel.
- Implement `/dev/fuse` - a character device used to establish an FD for
communication between the sentry and the server daemon.
-- Implement basic FUSE ops like `FUSE_INIT`, `FUSE_DESTROY`.
+- Implement basic FUSE ops like `FUSE_INIT`.
#### Read-only mount with basic file operations
@@ -95,6 +95,103 @@ ops can be implemented in parallel.
- Implement the remaining FUSE ops and decide if we can omit rarely used
operations like ioctl.
+### Design Details
+
+#### Lifecycle for a FUSE Request
+
+- User invokes a syscall
+- Sentry prepares corresponding request
+ - If FUSE device is available
+ - Write the request in binary
+ - If FUSE device is full
+ - Kernel task blocked until available
+- Sentry notifies the readers of fuse device that it's ready for read
+- FUSE daemon reads the request and processes it
+- Sentry waits until a reply is written to the FUSE device
+ - but returns directly for async requests
+- FUSE daemon writes to the fuse device
+- Sentry processes the reply
+ - For sync requests, unblock blocked kernel task
+ - For async requests, execute pre-specified callback if any
+- Sentry returns the syscall to the user
+
+#### Channels and Queues for Requests in Different Stages
+
+`connection.initializedChan`
+
+- a channel that the requests issued before connection initialization blocks
+ on.
+
+`fd.queue`
+
+- a queue of requests that haven’t been read by the FUSE daemon yet.
+
+`fd.completions`
+
+- a map of the requests that have been prepared but not yet received a
+ response, including the ones on the `fd.queue`.
+
+`fd.waitQueue`
+
+- a queue of waiters that is waiting for the fuse device fd to be available,
+ such as the FUSE daemon.
+
+`fd.fullQueueCh`
+
+- a channel that the kernel task will be blocked on when the fd is not
+ available.
+
+#### Basic I/O Implementation
+
+Currently we have implemented basic functionalities of read and write for our
+FUSE. We describe the design and ways to improve it here:
+
+##### Basic FUSE Read
+
+The vfs2 expects implementations of `vfs.FileDescriptionImpl.Read()` and
+`vfs.FileDescriptionImpl.PRead()`. When a syscall is made, it will eventually
+reach our implementation of those interface functions located at
+`pkg/sentry/fsimpl/fuse/regular_file.go` for regular files.
+
+After validation checks of the input, sentry sends `FUSE_READ` requests to the
+FUSE daemon. The FUSE daemon returns data after the `fuse_out_header` as the
+responses. For the first version, we create a copy in kernel memory of those
+data. They are represented as a byte slice in the marshalled struct. This
+happens as a common process for all the FUSE responses at this moment at
+`pkg/sentry/fsimpl/fuse/dev.go:writeLocked()`. We then directly copy from this
+intermediate buffer to the input buffer provided by the read syscall.
+
+There is an extra requirement for FUSE: When mounting the FUSE fs, the mounter
+or the FUSE daemon can specify a `max_read` or a `max_pages` parameter. They are
+the upperbound of the bytes to read in each `FUSE_READ` request. We implemented
+the code to handle the fragmented reads.
+
+To improve the performance: ideally we should have buffer cache to copy those
+data from the responses of FUSE daemon into, as is also the design of several
+other existing file system implementations for sentry, instead of a single-use
+temporary buffer. Directly mapping the memory of one process to another could
+also boost the performance, but to keep them isolated, we did not choose to do
+so.
+
+##### Basic FUSE Write
+
+The vfs2 invokes implementations of `vfs.FileDescriptionImpl.Write()` and
+`vfs.FileDescriptionImpl.PWrite()` on the regular file descriptor of FUSE when a
+user makes write(2) and pwrite(2) syscall.
+
+For valid writes, sentry sends the bytes to write after a `FUSE_WRITE` header
+(can be regarded as a request with 2 payloads) to the FUSE daemon. For the first
+version, we allocate a buffer inside kernel memory to store the bytes from the
+user, and copy directly from that buffer to the memory of FUSE daemon. This
+happens at `pkg/sentry/fsimpl/fuse/dev.go:readLocked()`
+
+The parameters `max_write` and `max_pages` restrict the number of bytes in one
+`FUSE_WRITE`. There are code handling fragmented writes in current
+implementation.
+
+To have better performance: the extra copy created to store the bytes to write
+can be replaced by the buffer cache as well.
+
# Appendix
## FUSE Protocol
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index 42a6c41c2..1368014c4 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/fdnotifier",
"//pkg/iovec",
"//pkg/log",
+ "//pkg/marshal/primitive",
"//pkg/refs",
"//pkg/safemem",
"//pkg/secio",
@@ -55,7 +56,6 @@ go_library(
"//pkg/unet",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go
index 5d4f312cf..c8231e0aa 100644
--- a/pkg/sentry/fs/host/socket_unsafe.go
+++ b/pkg/sentry/fs/host/socket_unsafe.go
@@ -65,10 +65,10 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) (
controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC
if n > length {
- return length, n, msg.Controllen, controlTrunc, err
+ return length, n, msg.Controllen, controlTrunc, nil
}
- return n, n, msg.Controllen, controlTrunc, err
+ return n, n, msg.Controllen, controlTrunc, nil
}
// fdWriteVec sends from bufs to fd.
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 67a807f9d..1183727ab 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -17,6 +17,7 @@ package host
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -24,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// LINT.IfChange
@@ -54,7 +54,7 @@ type TTYFileOperations struct {
func newTTYFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File {
return fs.NewFile(ctx, dirent, flags, &TTYFileOperations{
fileOperations: fileOperations{iops: iops},
- termios: linux.DefaultSlaveTermios,
+ termios: linux.DefaultReplicaTermios,
})
}
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
index b79cd9877..004910453 100644
--- a/pkg/sentry/fs/inode.go
+++ b/pkg/sentry/fs/inode.go
@@ -270,7 +270,7 @@ func (i *Inode) GetXattr(ctx context.Context, name string, size uint64) (string,
// SetXattr calls i.InodeOperations.SetXattr with i as the Inode.
func (i *Inode) SetXattr(ctx context.Context, d *Dirent, name, value string, flags uint32) error {
if i.overlay != nil {
- return overlaySetxattr(ctx, i.overlay, d, name, value, flags)
+ return overlaySetXattr(ctx, i.overlay, d, name, value, flags)
}
return i.InodeOperations.SetXattr(ctx, i, name, value, flags)
}
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
index dc2e353d9..b16ab08ba 100644
--- a/pkg/sentry/fs/inode_overlay.go
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -16,7 +16,6 @@ package fs
import (
"fmt"
- "strings"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -539,7 +538,7 @@ func overlayGetXattr(ctx context.Context, o *overlayEntry, name string, size uin
// Don't forward the value of the extended attribute if it would
// unexpectedly change the behavior of a wrapping overlay layer.
- if strings.HasPrefix(XattrOverlayPrefix, name) {
+ if isXattrOverlay(name) {
return "", syserror.ENODATA
}
@@ -553,9 +552,9 @@ func overlayGetXattr(ctx context.Context, o *overlayEntry, name string, size uin
return s, err
}
-func overlaySetxattr(ctx context.Context, o *overlayEntry, d *Dirent, name, value string, flags uint32) error {
+func overlaySetXattr(ctx context.Context, o *overlayEntry, d *Dirent, name, value string, flags uint32) error {
// Don't allow changes to overlay xattrs through a setxattr syscall.
- if strings.HasPrefix(XattrOverlayPrefix, name) {
+ if isXattrOverlay(name) {
return syserror.EPERM
}
@@ -578,7 +577,7 @@ func overlayListXattr(ctx context.Context, o *overlayEntry, size uint64) (map[st
for name := range names {
// Same as overlayGetXattr, we shouldn't forward along
// overlay attributes.
- if strings.HasPrefix(XattrOverlayPrefix, name) {
+ if isXattrOverlay(name) {
delete(names, name)
}
}
@@ -587,7 +586,7 @@ func overlayListXattr(ctx context.Context, o *overlayEntry, size uint64) (map[st
func overlayRemoveXattr(ctx context.Context, o *overlayEntry, d *Dirent, name string) error {
// Don't allow changes to overlay xattrs through a removexattr syscall.
- if strings.HasPrefix(XattrOverlayPrefix, name) {
+ if isXattrOverlay(name) {
return syserror.EPERM
}
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index 77c2c5c0e..b8b2281a8 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -50,6 +50,7 @@ go_library(
"//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip/header",
+ "//pkg/tcpip/network/ipv4",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index 8615b60f0..e555672ad 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -54,7 +55,7 @@ type tcpMemInode struct {
// size stores the tcp buffer size during save, and sets the buffer
// size in netstack in restore. We must save/restore this here, since
- // netstack itself is stateless.
+ // a netstack instance is created on restore.
size inet.TCPBufferSize
// mu protects against concurrent reads/writes to files based on this
@@ -258,6 +259,9 @@ func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSeque
if src.NumBytes() == 0 {
return 0, nil
}
+
+ // Only consider size of one memory page for input for performance reasons.
+ // We are only reading if it's zero or not anyway.
src = src.TakeFirst(usermem.PageSize - 1)
var v int32
@@ -383,11 +387,125 @@ func (p *proc) newSysNetCore(ctx context.Context, msrc *fs.MountSource, s inet.S
return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
}
+// ipForwarding implements fs.InodeOperations.
+//
+// ipForwarding is used to enable/disable packet forwarding of netstack.
+//
+// +stateify savable
+type ipForwarding struct {
+ fsutil.SimpleFileInode
+
+ stack inet.Stack `state:"wait"`
+
+ // enabled stores the IPv4 forwarding state on save.
+ // We must save/restore this here, since a netstack instance
+ // is created on restore.
+ enabled *bool
+}
+
+func newIPForwardingInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
+ ipf := &ipForwarding{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ stack: s,
+ }
+ sattr := fs.StableAttr{
+ DeviceID: device.ProcDevice.DeviceID(),
+ InodeID: device.ProcDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialFile,
+ }
+ return fs.NewInode(ctx, ipf, msrc, sattr)
+}
+
+// Truncate implements fs.InodeOperations.Truncate. Truncate is called when
+// O_TRUNC is specified for any kind of existing Dirent but is not called via
+// (f)truncate for proc files.
+func (*ipForwarding) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// +stateify savable
+type ipForwardingFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ ipf *ipForwarding
+
+ stack inet.Stack `state:"wait"`
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (ipf *ipForwarding) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ flags.Pwrite = true
+ return fs.NewFile(ctx, dirent, flags, &ipForwardingFile{
+ stack: ipf.stack,
+ ipf: ipf,
+ }), nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *ipForwardingFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ return 0, io.EOF
+ }
+
+ if f.ipf.enabled == nil {
+ enabled := f.stack.Forwarding(ipv4.ProtocolNumber)
+ f.ipf.enabled = &enabled
+ }
+
+ val := "0\n"
+ if *f.ipf.enabled {
+ // Technically, this is not quite compatible with Linux. Linux
+ // stores these as an integer, so if you write "2" into
+ // ip_forward, you should get 2 back.
+ val = "1\n"
+ }
+ n, err := dst.CopyOut(ctx, []byte(val))
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+//
+// Offset is ignored, multiple writes are not supported.
+func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Only consider size of one memory page for input for performance reasons.
+ // We are only reading if it's zero or not anyway.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return n, err
+ }
+ if f.ipf.enabled == nil {
+ f.ipf.enabled = new(bool)
+ }
+ *f.ipf.enabled = v != 0
+ return n, f.stack.SetForwarding(ipv4.ProtocolNumber, *f.ipf.enabled)
+}
+
func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
contents := map[string]*fs.Inode{
// Add tcp_sack.
"tcp_sack": newTCPSackInode(ctx, msrc, s),
+ // Add ip_forward.
+ "ip_forward": newIPForwardingInode(ctx, msrc, s),
+
// The following files are simple stubs until they are
// implemented in netstack, most of these files are
// configuration related. We use the value closest to the
diff --git a/pkg/sentry/fs/proc/sys_net_state.go b/pkg/sentry/fs/proc/sys_net_state.go
index 6eba709c6..4cb4741af 100644
--- a/pkg/sentry/fs/proc/sys_net_state.go
+++ b/pkg/sentry/fs/proc/sys_net_state.go
@@ -14,7 +14,11 @@
package proc
-import "fmt"
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+)
// beforeSave is invoked by stateify.
func (t *tcpMemInode) beforeSave() {
@@ -40,3 +44,12 @@ func (s *tcpSack) afterLoad() {
}
}
}
+
+// afterLoad is invoked by stateify.
+func (ipf *ipForwarding) afterLoad() {
+ if ipf.enabled != nil {
+ if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil {
+ panic(fmt.Sprintf("failed to set IPv4 forwarding [%v]: %v", *ipf.enabled, err))
+ }
+ }
+}
diff --git a/pkg/sentry/fs/proc/sys_net_test.go b/pkg/sentry/fs/proc/sys_net_test.go
index 355e83d47..6ef5738e7 100644
--- a/pkg/sentry/fs/proc/sys_net_test.go
+++ b/pkg/sentry/fs/proc/sys_net_test.go
@@ -123,3 +123,76 @@ func TestConfigureRecvBufferSize(t *testing.T) {
}
}
}
+
+// TestIPForwarding tests the implementation of
+// /proc/sys/net/ipv4/ip_forwarding
+func TestIPForwarding(t *testing.T) {
+ ctx := context.Background()
+ s := inet.NewTestStack()
+
+ var cases = []struct {
+ comment string
+ initial bool
+ str string
+ final bool
+ }{
+ {
+ comment: `Forwarding is disabled; write 1 and enable forwarding`,
+ initial: false,
+ str: "1",
+ final: true,
+ },
+ {
+ comment: `Forwarding is disabled; write 0 and disable forwarding`,
+ initial: false,
+ str: "0",
+ final: false,
+ },
+ {
+ comment: `Forwarding is enabled; write 1 and enable forwarding`,
+ initial: true,
+ str: "1",
+ final: true,
+ },
+ {
+ comment: `Forwarding is enabled; write 0 and disable forwarding`,
+ initial: true,
+ str: "0",
+ final: false,
+ },
+ {
+ comment: `Forwarding is disabled; write 2404 and enable forwarding`,
+ initial: false,
+ str: "2404",
+ final: true,
+ },
+ {
+ comment: `Forwarding is enabled; write 2404 and enable forwarding`,
+ initial: true,
+ str: "2404",
+ final: true,
+ },
+ }
+ for _, c := range cases {
+ t.Run(c.comment, func(t *testing.T) {
+ s.IPForwarding = c.initial
+ ipf := &ipForwarding{stack: s}
+ file := &ipForwardingFile{
+ stack: s,
+ ipf: ipf,
+ }
+
+ // Write the values.
+ src := usermem.BytesIOSequence([]byte(c.str))
+ if n, err := file.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil {
+ t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str))
+ }
+
+ // Read the values from the stack and check them.
+ if got, want := s.IPForwarding, c.final; got != want {
+ t.Errorf("s.IPForwarding incorrect; got: %v, want: %v", got, want)
+ }
+
+ })
+ }
+}
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 9cf7f2a62..22d658acf 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -84,6 +84,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bo
"auxv": newAuxvec(t, msrc),
"cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
"comm": newComm(t, msrc),
+ "cwd": newCwd(t, msrc),
"environ": newExecArgInode(t, msrc, environExecArg),
"exe": newExe(t, msrc),
"fd": newFdDir(t, msrc),
@@ -300,6 +301,49 @@ func (e *exe) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
return exec.PathnameWithDeleted(ctx), nil
}
+// cwd is an fs.InodeOperations symlink for the /proc/PID/cwd file.
+//
+// +stateify savable
+type cwd struct {
+ ramfs.Symlink
+
+ t *kernel.Task
+}
+
+func newCwd(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ cwdSymlink := &cwd{
+ Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""),
+ t: t,
+ }
+ return newProcInode(t, cwdSymlink, msrc, fs.Symlink, t)
+}
+
+// Readlink implements fs.InodeOperations.
+func (e *cwd) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ if !kernel.ContextCanTrace(ctx, e.t, false) {
+ return "", syserror.EACCES
+ }
+ if err := checkTaskState(e.t); err != nil {
+ return "", err
+ }
+ cwd := e.t.FSContext().WorkingDirectory()
+ if cwd == nil {
+ // It could have raced with process deletion.
+ return "", syserror.ESRCH
+ }
+ defer cwd.DecRef(ctx)
+
+ root := fs.RootFromContext(ctx)
+ if root == nil {
+ // It could have raced with process deletion.
+ return "", syserror.ESRCH
+ }
+ defer root.DecRef(ctx)
+
+ name, _ := cwd.FullName(root)
+ return name, nil
+}
+
// namespaceSymlink represents a symlink in the namespacefs, such as the files
// in /proc/<pid>/ns.
//
@@ -604,7 +648,7 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) (
var vss, rss, data uint64
s.t.WithMuLocked(func(t *kernel.Task) {
if fdTable := t.FDTable(); fdTable != nil {
- fds = fdTable.Size()
+ fds = fdTable.CurrentMaxFDs()
}
if mm := t.MemoryManager(); mm != nil {
vss = mm.VirtualMemorySize()
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index 5cb0e0417..e6d0eb359 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -10,13 +10,14 @@ go_library(
"line_discipline.go",
"master.go",
"queue.go",
- "slave.go",
+ "replica.go",
"terminal.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/marshal/primitive",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go
index 463f6189e..c2da80bc2 100644
--- a/pkg/sentry/fs/tty/dir.go
+++ b/pkg/sentry/fs/tty/dir.go
@@ -37,14 +37,14 @@ import (
// This indirectly manages all terminals within the mount.
//
// New Terminals are created by masterInodeOperations.GetFile, which registers
-// the slave Inode in the this directory for discovery via Lookup/Readdir. The
-// slave inode is unregistered when the master file is Released, as the slave
+// the replica Inode in the this directory for discovery via Lookup/Readdir. The
+// replica inode is unregistered when the master file is Released, as the replica
// is no longer discoverable at that point.
//
// References on the underlying Terminal are held by masterFileOperations and
-// slaveInodeOperations.
+// replicaInodeOperations.
//
-// masterInodeOperations and slaveInodeOperations hold a pointer to
+// masterInodeOperations and replicaInodeOperations hold a pointer to
// dirInodeOperations, which is reference counted by the refcount their
// corresponding Dirents hold on their parent (this directory).
//
@@ -76,16 +76,16 @@ type dirInodeOperations struct {
// master is the master PTY inode.
master *fs.Inode
- // slaves contains the slave inodes reachable from the directory.
+ // replicas contains the replica inodes reachable from the directory.
//
- // A new slave is added by allocateTerminal and is removed by
+ // A new replica is added by allocateTerminal and is removed by
// masterFileOperations.Release.
//
- // A reference is held on every slave in the map.
- slaves map[uint32]*fs.Inode
+ // A reference is held on every replica in the map.
+ replicas map[uint32]*fs.Inode
// dentryMap is a SortedDentryMap used to implement Readdir containing
- // the master and all entries in slaves.
+ // the master and all entries in replicas.
dentryMap *fs.SortedDentryMap
// next is the next pty index to use.
@@ -101,7 +101,7 @@ func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode {
d := &dirInodeOperations{
InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermsFromMode(0555), linux.DEVPTS_SUPER_MAGIC),
msrc: m,
- slaves: make(map[uint32]*fs.Inode),
+ replicas: make(map[uint32]*fs.Inode),
dentryMap: fs.NewSortedDentryMap(nil),
}
// Linux devpts uses a default mode of 0000 for ptmx which can be
@@ -133,7 +133,7 @@ func (d *dirInodeOperations) Release(ctx context.Context) {
defer d.mu.Unlock()
d.master.DecRef(ctx)
- if len(d.slaves) != 0 {
+ if len(d.replicas) != 0 {
panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d))
}
}
@@ -149,14 +149,14 @@ func (d *dirInodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name str
return fs.NewDirent(ctx, d.master, name), nil
}
- // Slave number?
+ // Replica number?
n, err := strconv.ParseUint(name, 10, 32)
if err != nil {
// Not found.
return nil, syserror.ENOENT
}
- s, ok := d.slaves[uint32(n)]
+ s, ok := d.replicas[uint32(n)]
if !ok {
return nil, syserror.ENOENT
}
@@ -236,7 +236,7 @@ func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, e
return nil, syserror.ENOMEM
}
- if _, ok := d.slaves[n]; ok {
+ if _, ok := d.replicas[n]; ok {
panic(fmt.Sprintf("pty index collision; index %d already exists", n))
}
@@ -244,19 +244,19 @@ func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, e
d.next++
// The reference returned by newTerminal is returned to the caller.
- // Take another for the slave inode.
+ // Take another for the replica inode.
t.IncRef()
// Create a pts node. The owner is based on the context that opens
// ptmx.
creds := auth.CredentialsFromContext(ctx)
uid, gid := creds.EffectiveKUID, creds.EffectiveKGID
- slave := newSlaveInode(ctx, d, t, fs.FileOwner{uid, gid}, fs.FilePermsFromMode(0666))
+ replica := newReplicaInode(ctx, d, t, fs.FileOwner{uid, gid}, fs.FilePermsFromMode(0666))
- d.slaves[n] = slave
+ d.replicas[n] = replica
d.dentryMap.Add(strconv.FormatUint(uint64(n), 10), fs.DentAttr{
- Type: slave.StableAttr.Type,
- InodeID: slave.StableAttr.InodeID,
+ Type: replica.StableAttr.Type,
+ InodeID: replica.StableAttr.InodeID,
})
return t, nil
@@ -267,18 +267,18 @@ func (d *dirInodeOperations) masterClose(ctx context.Context, t *Terminal) {
d.mu.Lock()
defer d.mu.Unlock()
- // The slave end disappears from the directory when the master end is
- // closed, even if the slave end is open elsewhere.
+ // The replica end disappears from the directory when the master end is
+ // closed, even if the replica end is open elsewhere.
//
// N.B. since we're using a backdoor method to remove a directory entry
// we won't properly fire inotify events like Linux would.
- s, ok := d.slaves[t.n]
+ s, ok := d.replicas[t.n]
if !ok {
panic(fmt.Sprintf("Terminal %+v doesn't exist in %+v?", t, d))
}
s.DecRef(ctx)
- delete(d.slaves, t.n)
+ delete(d.replicas, t.n)
d.dentryMap.Remove(strconv.FormatUint(uint64(t.n), 10))
}
diff --git a/pkg/sentry/fs/tty/fs.go b/pkg/sentry/fs/tty/fs.go
index 2d4d44bf3..13f4901db 100644
--- a/pkg/sentry/fs/tty/fs.go
+++ b/pkg/sentry/fs/tty/fs.go
@@ -79,8 +79,8 @@ type superOperations struct{}
//
// It always returns true, forcing a Lookup for all entries.
//
-// Slave entries are dropped from dir when their master is closed, so an
-// existing slave Dirent in the tree is not sufficient to guarantee that it
+// Replica entries are dropped from dir when their master is closed, so an
+// existing replica Dirent in the tree is not sufficient to guarantee that it
// still exists on the filesystem.
func (superOperations) Revalidate(context.Context, string, *fs.Inode, *fs.Inode) bool {
return true
diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go
index 2e9dd2d55..b34f4a0eb 100644
--- a/pkg/sentry/fs/tty/line_discipline.go
+++ b/pkg/sentry/fs/tty/line_discipline.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -43,7 +44,7 @@ const (
)
// lineDiscipline dictates how input and output are handled between the
-// pseudoterminal (pty) master and slave. It can be configured to alter I/O,
+// pseudoterminal (pty) master and replica. It can be configured to alter I/O,
// modify control characters (e.g. Ctrl-C for SIGINT), etc. The following man
// pages are good resources for how to affect the line discipline:
//
@@ -54,8 +55,8 @@ const (
//
// lineDiscipline has a simple structure but supports a multitude of options
// (see the above man pages). It consists of two queues of bytes: one from the
-// terminal master to slave (the input queue) and one from slave to master (the
-// output queue). When bytes are written to one end of the pty, the line
+// terminal master to replica (the input queue) and one from replica to master
+// (the output queue). When bytes are written to one end of the pty, the line
// discipline reads the bytes, modifies them or takes special action if
// required, and enqueues them to be read by the other end of the pty:
//
@@ -64,7 +65,7 @@ const (
// | (inputQueueWrite) +-------------+ (inputQueueRead) |
// | |
// | v
-// masterFD slaveFD
+// masterFD replicaFD
// ^ |
// | |
// | output to terminal +--------------+ output from process |
@@ -103,8 +104,8 @@ type lineDiscipline struct {
// masterWaiter is used to wait on the master end of the TTY.
masterWaiter waiter.Queue `state:"zerovalue"`
- // slaveWaiter is used to wait on the slave end of the TTY.
- slaveWaiter waiter.Queue `state:"zerovalue"`
+ // replicaWaiter is used to wait on the replica end of the TTY.
+ replicaWaiter waiter.Queue `state:"zerovalue"`
}
func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline {
@@ -115,27 +116,23 @@ func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline {
}
// getTermios gets the linux.Termios for the tty.
-func (l *lineDiscipline) getTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+func (l *lineDiscipline) getTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) {
l.termiosMu.RLock()
defer l.termiosMu.RUnlock()
// We must copy a Termios struct, not KernelTermios.
t := l.termios.ToTermios()
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), t, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := t.CopyOut(task, args[2].Pointer())
return 0, err
}
// setTermios sets a linux.Termios for the tty.
-func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+func (l *lineDiscipline) setTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) {
l.termiosMu.Lock()
defer l.termiosMu.Unlock()
oldCanonEnabled := l.termios.LEnabled(linux.ICANON)
// We must copy a Termios struct, not KernelTermios.
var t linux.Termios
- _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &t, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := t.CopyIn(task, args[2].Pointer())
l.termios.FromTermios(t)
// If canonical mode is turned off, move bytes from inQueue's wait
@@ -146,27 +143,23 @@ func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arc
l.inQueue.pushWaitBufLocked(l)
l.inQueue.readable = true
l.inQueue.mu.Unlock()
- l.slaveWaiter.Notify(waiter.EventIn)
+ l.replicaWaiter.Notify(waiter.EventIn)
}
return 0, err
}
-func (l *lineDiscipline) windowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+func (l *lineDiscipline) windowSize(t *kernel.Task, args arch.SyscallArguments) error {
l.sizeMu.Lock()
defer l.sizeMu.Unlock()
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), l.size, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := l.size.CopyOut(t, args[2].Pointer())
return err
}
-func (l *lineDiscipline) setWindowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+func (l *lineDiscipline) setWindowSize(t *kernel.Task, args arch.SyscallArguments) error {
l.sizeMu.Lock()
defer l.sizeMu.Unlock()
- _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &l.size, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := l.size.CopyIn(t, args[2].Pointer())
return err
}
@@ -176,14 +169,14 @@ func (l *lineDiscipline) masterReadiness() waiter.EventMask {
return l.inQueue.writeReadiness(&linux.MasterTermios) | l.outQueue.readReadiness(&linux.MasterTermios)
}
-func (l *lineDiscipline) slaveReadiness() waiter.EventMask {
+func (l *lineDiscipline) replicaReadiness() waiter.EventMask {
l.termiosMu.RLock()
defer l.termiosMu.RUnlock()
return l.outQueue.writeReadiness(&l.termios) | l.inQueue.readReadiness(&l.termios)
}
-func (l *lineDiscipline) inputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
- return l.inQueue.readableSize(ctx, io, args)
+func (l *lineDiscipline) inputQueueReadSize(t *kernel.Task, args arch.SyscallArguments) error {
+ return l.inQueue.readableSize(t, args)
}
func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) {
@@ -196,7 +189,7 @@ func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSeque
if n > 0 {
l.masterWaiter.Notify(waiter.EventOut)
if pushed {
- l.slaveWaiter.Notify(waiter.EventIn)
+ l.replicaWaiter.Notify(waiter.EventIn)
}
return n, nil
}
@@ -211,14 +204,14 @@ func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequ
return 0, err
}
if n > 0 {
- l.slaveWaiter.Notify(waiter.EventIn)
+ l.replicaWaiter.Notify(waiter.EventIn)
return n, nil
}
return 0, syserror.ErrWouldBlock
}
-func (l *lineDiscipline) outputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
- return l.outQueue.readableSize(ctx, io, args)
+func (l *lineDiscipline) outputQueueReadSize(t *kernel.Task, args arch.SyscallArguments) error {
+ return l.outQueue.readableSize(t, args)
}
func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) {
@@ -229,7 +222,7 @@ func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequ
return 0, err
}
if n > 0 {
- l.slaveWaiter.Notify(waiter.EventOut)
+ l.replicaWaiter.Notify(waiter.EventOut)
if pushed {
l.masterWaiter.Notify(waiter.EventIn)
}
diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go
index e00746017..b91184b1b 100644
--- a/pkg/sentry/fs/tty/master.go
+++ b/pkg/sentry/fs/tty/master.go
@@ -17,9 +17,11 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -152,46 +154,51 @@ func (mf *masterFileOperations) Write(ctx context.Context, _ *fs.File, src userm
// Ioctl implements fs.FileOperations.Ioctl.
func (mf *masterFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // ioctl(2) may only be called from a task goroutine.
+ return 0, syserror.ENOTTY
+ }
+
switch cmd := args[1].Uint(); cmd {
case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
// Get the number of bytes in the output queue read buffer.
- return 0, mf.t.ld.outputQueueReadSize(ctx, io, args)
+ return 0, mf.t.ld.outputQueueReadSize(t, args)
case linux.TCGETS:
// N.B. TCGETS on the master actually returns the configuration
- // of the slave end.
- return mf.t.ld.getTermios(ctx, io, args)
+ // of the replica end.
+ return mf.t.ld.getTermios(t, args)
case linux.TCSETS:
// N.B. TCSETS on the master actually affects the configuration
- // of the slave end.
- return mf.t.ld.setTermios(ctx, io, args)
+ // of the replica end.
+ return mf.t.ld.setTermios(t, args)
case linux.TCSETSW:
// TODO(b/29356795): This should drain the output queue first.
- return mf.t.ld.setTermios(ctx, io, args)
+ return mf.t.ld.setTermios(t, args)
case linux.TIOCGPTN:
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(mf.t.n), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ nP := primitive.Uint32(mf.t.n)
+ _, err := nP.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCSPTLCK:
// TODO(b/29356795): Implement pty locking. For now just pretend we do.
return 0, nil
case linux.TIOCGWINSZ:
- return 0, mf.t.ld.windowSize(ctx, io, args)
+ return 0, mf.t.ld.windowSize(t, args)
case linux.TIOCSWINSZ:
- return 0, mf.t.ld.setWindowSize(ctx, io, args)
+ return 0, mf.t.ld.setWindowSize(t, args)
case linux.TIOCSCTTY:
// Make the given terminal the controlling terminal of the
// calling process.
- return 0, mf.t.setControllingTTY(ctx, io, args, true /* isMaster */)
+ return 0, mf.t.setControllingTTY(ctx, args, true /* isMaster */)
case linux.TIOCNOTTY:
// Release this process's controlling terminal.
- return 0, mf.t.releaseControllingTTY(ctx, io, args, true /* isMaster */)
+ return 0, mf.t.releaseControllingTTY(ctx, args, true /* isMaster */)
case linux.TIOCGPGRP:
// Get the foreground process group.
- return mf.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */)
+ return mf.t.foregroundProcessGroup(ctx, args, true /* isMaster */)
case linux.TIOCSPGRP:
// Set the foreground process group.
- return mf.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */)
+ return mf.t.setForegroundProcessGroup(ctx, args, true /* isMaster */)
default:
maybeEmitUnimplementedEvent(ctx, cmd)
return 0, syserror.ENOTTY
diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go
index c5d7ec717..79975d812 100644
--- a/pkg/sentry/fs/tty/queue.go
+++ b/pkg/sentry/fs/tty/queue.go
@@ -17,8 +17,10 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -32,7 +34,7 @@ import (
const waitBufMaxBytes = 131072
// queue represents one of the input or output queues between a pty master and
-// slave. Bytes written to a queue are added to the read buffer until it is
+// replica. Bytes written to a queue are added to the read buffer until it is
// full, at which point they are written to the wait buffer. Bytes are
// processed (i.e. undergo termios transformations) as they are added to the
// read buffer. The read buffer is readable when its length is nonzero and
@@ -85,17 +87,15 @@ func (q *queue) writeReadiness(t *linux.KernelTermios) waiter.EventMask {
}
// readableSize writes the number of readable bytes to userspace.
-func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+func (q *queue) readableSize(t *kernel.Task, args arch.SyscallArguments) error {
q.mu.Lock()
defer q.mu.Unlock()
- var size int32
+ size := primitive.Int32(0)
if q.readable {
- size = int32(len(q.readBuf))
+ size = primitive.Int32(len(q.readBuf))
}
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), size, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := size.CopyOut(t, args[2].Pointer())
return err
}
diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/replica.go
index 7c7292687..385d230fb 100644
--- a/pkg/sentry/fs/tty/slave.go
+++ b/pkg/sentry/fs/tty/replica.go
@@ -17,9 +17,11 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -27,11 +29,11 @@ import (
// LINT.IfChange
-// slaveInodeOperations are the fs.InodeOperations for the slave end of the
+// replicaInodeOperations are the fs.InodeOperations for the replica end of the
// Terminal (pts file).
//
// +stateify savable
-type slaveInodeOperations struct {
+type replicaInodeOperations struct {
fsutil.SimpleFileInode
// d is the containing dir.
@@ -41,13 +43,13 @@ type slaveInodeOperations struct {
t *Terminal
}
-var _ fs.InodeOperations = (*slaveInodeOperations)(nil)
+var _ fs.InodeOperations = (*replicaInodeOperations)(nil)
-// newSlaveInode creates an fs.Inode for the slave end of a terminal.
+// newReplicaInode creates an fs.Inode for the replica end of a terminal.
//
-// newSlaveInode takes ownership of t.
-func newSlaveInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owner fs.FileOwner, p fs.FilePermissions) *fs.Inode {
- iops := &slaveInodeOperations{
+// newReplicaInode takes ownership of t.
+func newReplicaInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owner fs.FileOwner, p fs.FilePermissions) *fs.Inode {
+ iops := &replicaInodeOperations{
SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, owner, p, linux.DEVPTS_SUPER_MAGIC),
d: d,
t: t,
@@ -64,18 +66,18 @@ func newSlaveInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owne
Type: fs.CharacterDevice,
// See fs/devpts/inode.c:devpts_fill_super.
BlockSize: 1024,
- DeviceFileMajor: linux.UNIX98_PTY_SLAVE_MAJOR,
+ DeviceFileMajor: linux.UNIX98_PTY_REPLICA_MAJOR,
DeviceFileMinor: t.n,
})
}
// Release implements fs.InodeOperations.Release.
-func (si *slaveInodeOperations) Release(ctx context.Context) {
+func (si *replicaInodeOperations) Release(ctx context.Context) {
si.t.DecRef(ctx)
}
// Truncate implements fs.InodeOperations.Truncate.
-func (*slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
+func (*replicaInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
return nil
}
@@ -83,14 +85,15 @@ func (*slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
//
// This may race with destruction of the terminal. If the terminal is gone, it
// returns ENOENT.
-func (si *slaveInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
- return fs.NewFile(ctx, d, flags, &slaveFileOperations{si: si}), nil
+func (si *replicaInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, d, flags, &replicaFileOperations{si: si}), nil
}
-// slaveFileOperations are the fs.FileOperations for the slave end of a terminal.
+// replicaFileOperations are the fs.FileOperations for the replica end of a
+// terminal.
//
// +stateify savable
-type slaveFileOperations struct {
+type replicaFileOperations struct {
fsutil.FilePipeSeek `state:"nosave"`
fsutil.FileNotDirReaddir `state:"nosave"`
fsutil.FileNoFsync `state:"nosave"`
@@ -100,79 +103,84 @@ type slaveFileOperations struct {
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
// si is the inode operations.
- si *slaveInodeOperations
+ si *replicaInodeOperations
}
-var _ fs.FileOperations = (*slaveFileOperations)(nil)
+var _ fs.FileOperations = (*replicaFileOperations)(nil)
// Release implements fs.FileOperations.Release.
-func (sf *slaveFileOperations) Release(context.Context) {
+func (sf *replicaFileOperations) Release(context.Context) {
}
// EventRegister implements waiter.Waitable.EventRegister.
-func (sf *slaveFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
- sf.si.t.ld.slaveWaiter.EventRegister(e, mask)
+func (sf *replicaFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ sf.si.t.ld.replicaWaiter.EventRegister(e, mask)
}
// EventUnregister implements waiter.Waitable.EventUnregister.
-func (sf *slaveFileOperations) EventUnregister(e *waiter.Entry) {
- sf.si.t.ld.slaveWaiter.EventUnregister(e)
+func (sf *replicaFileOperations) EventUnregister(e *waiter.Entry) {
+ sf.si.t.ld.replicaWaiter.EventUnregister(e)
}
// Readiness implements waiter.Waitable.Readiness.
-func (sf *slaveFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
- return sf.si.t.ld.slaveReadiness()
+func (sf *replicaFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return sf.si.t.ld.replicaReadiness()
}
// Read implements fs.FileOperations.Read.
-func (sf *slaveFileOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+func (sf *replicaFileOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
return sf.si.t.ld.inputQueueRead(ctx, dst)
}
// Write implements fs.FileOperations.Write.
-func (sf *slaveFileOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+func (sf *replicaFileOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
return sf.si.t.ld.outputQueueWrite(ctx, src)
}
// Ioctl implements fs.FileOperations.Ioctl.
-func (sf *slaveFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+func (sf *replicaFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // ioctl(2) may only be called from a task goroutine.
+ return 0, syserror.ENOTTY
+ }
+
switch cmd := args[1].Uint(); cmd {
case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
// Get the number of bytes in the input queue read buffer.
- return 0, sf.si.t.ld.inputQueueReadSize(ctx, io, args)
+ return 0, sf.si.t.ld.inputQueueReadSize(t, args)
case linux.TCGETS:
- return sf.si.t.ld.getTermios(ctx, io, args)
+ return sf.si.t.ld.getTermios(t, args)
case linux.TCSETS:
- return sf.si.t.ld.setTermios(ctx, io, args)
+ return sf.si.t.ld.setTermios(t, args)
case linux.TCSETSW:
// TODO(b/29356795): This should drain the output queue first.
- return sf.si.t.ld.setTermios(ctx, io, args)
+ return sf.si.t.ld.setTermios(t, args)
case linux.TIOCGPTN:
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(sf.si.t.n), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ nP := primitive.Uint32(sf.si.t.n)
+ _, err := nP.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCGWINSZ:
- return 0, sf.si.t.ld.windowSize(ctx, io, args)
+ return 0, sf.si.t.ld.windowSize(t, args)
case linux.TIOCSWINSZ:
- return 0, sf.si.t.ld.setWindowSize(ctx, io, args)
+ return 0, sf.si.t.ld.setWindowSize(t, args)
case linux.TIOCSCTTY:
// Make the given terminal the controlling terminal of the
// calling process.
- return 0, sf.si.t.setControllingTTY(ctx, io, args, false /* isMaster */)
+ return 0, sf.si.t.setControllingTTY(ctx, args, false /* isMaster */)
case linux.TIOCNOTTY:
// Release this process's controlling terminal.
- return 0, sf.si.t.releaseControllingTTY(ctx, io, args, false /* isMaster */)
+ return 0, sf.si.t.releaseControllingTTY(ctx, args, false /* isMaster */)
case linux.TIOCGPGRP:
// Get the foreground process group.
- return sf.si.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */)
+ return sf.si.t.foregroundProcessGroup(ctx, args, false /* isMaster */)
case linux.TIOCSPGRP:
// Set the foreground process group.
- return sf.si.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */)
+ return sf.si.t.setForegroundProcessGroup(ctx, args, false /* isMaster */)
default:
maybeEmitUnimplementedEvent(ctx, cmd)
return 0, syserror.ENOTTY
}
}
-// LINT.ThenChange(../../fsimpl/devpts/slave.go)
+// LINT.ThenChange(../../fsimpl/devpts/replica.go)
diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go
index ddcccf4da..4f431d74d 100644
--- a/pkg/sentry/fs/tty/terminal.go
+++ b/pkg/sentry/fs/tty/terminal.go
@@ -17,10 +17,10 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/usermem"
)
// LINT.IfChange
@@ -44,19 +44,19 @@ type Terminal struct {
// this terminal. This field is immutable.
masterKTTY *kernel.TTY
- // slaveKTTY contains the controlling process of the slave end of this
+ // replicaKTTY contains the controlling process of the replica end of this
// terminal. This field is immutable.
- slaveKTTY *kernel.TTY
+ replicaKTTY *kernel.TTY
}
func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal {
- termios := linux.DefaultSlaveTermios
+ termios := linux.DefaultReplicaTermios
t := Terminal{
- d: d,
- n: n,
- ld: newLineDiscipline(termios),
- masterKTTY: &kernel.TTY{Index: n},
- slaveKTTY: &kernel.TTY{Index: n},
+ d: d,
+ n: n,
+ ld: newLineDiscipline(termios),
+ masterKTTY: &kernel.TTY{Index: n},
+ replicaKTTY: &kernel.TTY{Index: n},
}
t.EnableLeakCheck("tty.Terminal")
return &t
@@ -64,7 +64,7 @@ func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal
// setControllingTTY makes tm the controlling terminal of the calling thread
// group.
-func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+func (tm *Terminal) setControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error {
task := kernel.TaskFromContext(ctx)
if task == nil {
panic("setControllingTTY must be called from a task context")
@@ -75,7 +75,7 @@ func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args a
// releaseControllingTTY removes tm as the controlling terminal of the calling
// thread group.
-func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+func (tm *Terminal) releaseControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error {
task := kernel.TaskFromContext(ctx)
if task == nil {
panic("releaseControllingTTY must be called from a task context")
@@ -85,7 +85,7 @@ func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, ar
}
// foregroundProcessGroup gets the process group ID of tm's foreground process.
-func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+func (tm *Terminal) foregroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
task := kernel.TaskFromContext(ctx)
if task == nil {
panic("foregroundProcessGroup must be called from a task context")
@@ -97,24 +97,21 @@ func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, a
}
// Write it out to *arg.
- _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ retP := primitive.Int32(ret)
+ _, err = retP.CopyOut(task, args[2].Pointer())
return 0, err
}
// foregroundProcessGroup sets tm's foreground process.
-func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
task := kernel.TaskFromContext(ctx)
if task == nil {
panic("setForegroundProcessGroup must be called from a task context")
}
// Read in the process group ID.
- var pgid int32
- if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
+ var pgid primitive.Int32
+ if _, err := pgid.CopyIn(task, args[2].Pointer()); err != nil {
return 0, err
}
@@ -126,7 +123,7 @@ func (tm *Terminal) tty(isMaster bool) *kernel.TTY {
if isMaster {
return tm.masterKTTY
}
- return tm.slaveKTTY
+ return tm.replicaKTTY
}
// LINT.ThenChange(../../fsimpl/devpts/terminal.go)
diff --git a/pkg/sentry/fs/tty/tty_test.go b/pkg/sentry/fs/tty/tty_test.go
index 2cbc05678..49edee83d 100644
--- a/pkg/sentry/fs/tty/tty_test.go
+++ b/pkg/sentry/fs/tty/tty_test.go
@@ -22,8 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-func TestSimpleMasterToSlave(t *testing.T) {
- ld := newLineDiscipline(linux.DefaultSlaveTermios)
+func TestSimpleMasterToReplica(t *testing.T) {
+ ld := newLineDiscipline(linux.DefaultReplicaTermios)
ctx := contexttest.Context(t)
inBytes := []byte("hello, tty\n")
src := usermem.BytesIOSequence(inBytes)
diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD
index 3f64fab3a..48e13613a 100644
--- a/pkg/sentry/fsimpl/devpts/BUILD
+++ b/pkg/sentry/fsimpl/devpts/BUILD
@@ -21,8 +21,8 @@ go_library(
"line_discipline.go",
"master.go",
"queue.go",
+ "replica.go",
"root_inode_refs.go",
- "slave.go",
"terminal.go",
],
visibility = ["//pkg/sentry:internal"],
@@ -30,6 +30,8 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
index 57580f4d4..903135fae 100644
--- a/pkg/sentry/fsimpl/devpts/devpts.go
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -35,6 +35,8 @@ import (
const Name = "devpts"
// FilesystemType implements vfs.FilesystemType.
+//
+// +stateify savable
type FilesystemType struct{}
// Name implements vfs.FilesystemType.Name.
@@ -58,6 +60,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return fs.Filesystem.VFSFilesystem(), root.VFSDentry(), nil
}
+// +stateify savable
type filesystem struct {
kernfs.Filesystem
@@ -79,7 +82,7 @@ func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds
// Construct the root directory. This is always inode id 1.
root := &rootInode{
- slaves: make(map[uint32]*slaveInode),
+ replicas: make(map[uint32]*replicaInode),
}
root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555)
root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
@@ -110,6 +113,8 @@ func (fs *filesystem) Release(ctx context.Context) {
}
// rootInode is the root directory inode for the devpts mounts.
+//
+// +stateify savable
type rootInode struct {
implStatFS
kernfs.AlwaysValid
@@ -131,10 +136,10 @@ type rootInode struct {
root *rootInode
// mu protects the fields below.
- mu sync.Mutex
+ mu sync.Mutex `state:"nosave"`
- // slaves maps pty ids to slave inodes.
- slaves map[uint32]*slaveInode
+ // replicas maps pty ids to replica inodes.
+ replicas map[uint32]*replicaInode
// nextIdx is the next pty index to use. Must be accessed atomically.
//
@@ -154,22 +159,22 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error)
idx := i.nextIdx
i.nextIdx++
- // Sanity check that slave with idx does not exist.
- if _, ok := i.slaves[idx]; ok {
+ // Sanity check that replica with idx does not exist.
+ if _, ok := i.replicas[idx]; ok {
panic(fmt.Sprintf("pty index collision; index %d already exists", idx))
}
- // Create the new terminal and slave.
+ // Create the new terminal and replica.
t := newTerminal(idx)
- slave := &slaveInode{
+ replica := &replicaInode{
root: i,
t: t,
}
// Linux always uses pty index + 3 as the inode id. See
// fs/devpts/inode.c:devpts_pty_new().
- slave.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600)
- slave.dentry.Init(slave)
- i.slaves[idx] = slave
+ replica.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600)
+ replica.dentry.Init(replica)
+ i.replicas[idx] = replica
return t, nil
}
@@ -179,16 +184,16 @@ func (i *rootInode) masterClose(t *Terminal) {
i.mu.Lock()
defer i.mu.Unlock()
- // Sanity check that slave with idx exists.
- if _, ok := i.slaves[t.n]; !ok {
+ // Sanity check that replica with idx exists.
+ if _, ok := i.replicas[t.n]; !ok {
panic(fmt.Sprintf("pty with index %d does not exist", t.n))
}
- delete(i.slaves, t.n)
+ delete(i.replicas, t.n)
}
// Open implements kernfs.Inode.Open.
-func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{
+func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{
SeekEnd: kernfs.SeekEndStaticEntries,
})
if err != nil {
@@ -198,16 +203,16 @@ func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.D
}
// Lookup implements kernfs.Inode.Lookup.
-func (i *rootInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+func (i *rootInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
idx, err := strconv.ParseUint(name, 10, 32)
if err != nil {
return nil, syserror.ENOENT
}
i.mu.Lock()
defer i.mu.Unlock()
- if si, ok := i.slaves[uint32(idx)]; ok {
+ if si, ok := i.replicas[uint32(idx)]; ok {
si.dentry.IncRef()
- return si.dentry.VFSDentry(), nil
+ return &si.dentry, nil
}
return nil, syserror.ENOENT
@@ -217,8 +222,8 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error
func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
i.mu.Lock()
defer i.mu.Unlock()
- ids := make([]int, 0, len(i.slaves))
- for id := range i.slaves {
+ ids := make([]int, 0, len(i.replicas))
+ for id := range i.replicas {
ids = append(ids, int(id))
}
sort.Ints(ids)
@@ -226,7 +231,7 @@ func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback,
dirent := vfs.Dirent{
Name: strconv.FormatUint(uint64(id), 10),
Type: linux.DT_CHR,
- Ino: i.slaves[uint32(id)].InodeAttrs.Ino(),
+ Ino: i.replicas[uint32(id)].InodeAttrs.Ino(),
NextOff: offset + 1,
}
if err := cb.Handle(dirent); err != nil {
@@ -237,11 +242,12 @@ func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback,
return offset, nil
}
-// DecRef implements kernfs.Inode.
+// DecRef implements kernfs.Inode.DecRef.
func (i *rootInode) DecRef(context.Context) {
i.rootInodeRefs.DecRef(i.Destroy)
}
+// +stateify savable
type implStatFS struct{}
// StatFS implements kernfs.Inode.StatFS.
diff --git a/pkg/sentry/fsimpl/devpts/devpts_test.go b/pkg/sentry/fsimpl/devpts/devpts_test.go
index b7c149047..448390cfe 100644
--- a/pkg/sentry/fsimpl/devpts/devpts_test.go
+++ b/pkg/sentry/fsimpl/devpts/devpts_test.go
@@ -22,8 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-func TestSimpleMasterToSlave(t *testing.T) {
- ld := newLineDiscipline(linux.DefaultSlaveTermios)
+func TestSimpleMasterToReplica(t *testing.T) {
+ ld := newLineDiscipline(linux.DefaultReplicaTermios)
ctx := contexttest.Context(t)
inBytes := []byte("hello, tty\n")
src := usermem.BytesIOSequence(inBytes)
diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go
index f7bc325d1..e6b0e81cf 100644
--- a/pkg/sentry/fsimpl/devpts/line_discipline.go
+++ b/pkg/sentry/fsimpl/devpts/line_discipline.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -41,7 +42,7 @@ const (
)
// lineDiscipline dictates how input and output are handled between the
-// pseudoterminal (pty) master and slave. It can be configured to alter I/O,
+// pseudoterminal (pty) master and replica. It can be configured to alter I/O,
// modify control characters (e.g. Ctrl-C for SIGINT), etc. The following man
// pages are good resources for how to affect the line discipline:
//
@@ -52,8 +53,8 @@ const (
//
// lineDiscipline has a simple structure but supports a multitude of options
// (see the above man pages). It consists of two queues of bytes: one from the
-// terminal master to slave (the input queue) and one from slave to master (the
-// output queue). When bytes are written to one end of the pty, the line
+// terminal master to replica (the input queue) and one from replica to master
+// (the output queue). When bytes are written to one end of the pty, the line
// discipline reads the bytes, modifies them or takes special action if
// required, and enqueues them to be read by the other end of the pty:
//
@@ -62,7 +63,7 @@ const (
// | (inputQueueWrite) +-------------+ (inputQueueRead) |
// | |
// | v
-// masterFD slaveFD
+// masterFD replicaFD
// ^ |
// | |
// | output to terminal +--------------+ output from process |
@@ -101,8 +102,8 @@ type lineDiscipline struct {
// masterWaiter is used to wait on the master end of the TTY.
masterWaiter waiter.Queue `state:"zerovalue"`
- // slaveWaiter is used to wait on the slave end of the TTY.
- slaveWaiter waiter.Queue `state:"zerovalue"`
+ // replicaWaiter is used to wait on the replica end of the TTY.
+ replicaWaiter waiter.Queue `state:"zerovalue"`
}
func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline {
@@ -113,27 +114,23 @@ func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline {
}
// getTermios gets the linux.Termios for the tty.
-func (l *lineDiscipline) getTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+func (l *lineDiscipline) getTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) {
l.termiosMu.RLock()
defer l.termiosMu.RUnlock()
// We must copy a Termios struct, not KernelTermios.
t := l.termios.ToTermios()
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), t, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := t.CopyOut(task, args[2].Pointer())
return 0, err
}
// setTermios sets a linux.Termios for the tty.
-func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+func (l *lineDiscipline) setTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) {
l.termiosMu.Lock()
defer l.termiosMu.Unlock()
oldCanonEnabled := l.termios.LEnabled(linux.ICANON)
// We must copy a Termios struct, not KernelTermios.
var t linux.Termios
- _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &t, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := t.CopyIn(task, args[2].Pointer())
l.termios.FromTermios(t)
// If canonical mode is turned off, move bytes from inQueue's wait
@@ -144,27 +141,23 @@ func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arc
l.inQueue.pushWaitBufLocked(l)
l.inQueue.readable = true
l.inQueue.mu.Unlock()
- l.slaveWaiter.Notify(waiter.EventIn)
+ l.replicaWaiter.Notify(waiter.EventIn)
}
return 0, err
}
-func (l *lineDiscipline) windowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+func (l *lineDiscipline) windowSize(t *kernel.Task, args arch.SyscallArguments) error {
l.sizeMu.Lock()
defer l.sizeMu.Unlock()
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), l.size, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := l.size.CopyOut(t, args[2].Pointer())
return err
}
-func (l *lineDiscipline) setWindowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+func (l *lineDiscipline) setWindowSize(t *kernel.Task, args arch.SyscallArguments) error {
l.sizeMu.Lock()
defer l.sizeMu.Unlock()
- _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &l.size, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := l.size.CopyIn(t, args[2].Pointer())
return err
}
@@ -174,14 +167,14 @@ func (l *lineDiscipline) masterReadiness() waiter.EventMask {
return l.inQueue.writeReadiness(&linux.MasterTermios) | l.outQueue.readReadiness(&linux.MasterTermios)
}
-func (l *lineDiscipline) slaveReadiness() waiter.EventMask {
+func (l *lineDiscipline) replicaReadiness() waiter.EventMask {
l.termiosMu.RLock()
defer l.termiosMu.RUnlock()
return l.outQueue.writeReadiness(&l.termios) | l.inQueue.readReadiness(&l.termios)
}
-func (l *lineDiscipline) inputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
- return l.inQueue.readableSize(ctx, io, args)
+func (l *lineDiscipline) inputQueueReadSize(t *kernel.Task, io usermem.IO, args arch.SyscallArguments) error {
+ return l.inQueue.readableSize(t, io, args)
}
func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) {
@@ -194,7 +187,7 @@ func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSeque
if n > 0 {
l.masterWaiter.Notify(waiter.EventOut)
if pushed {
- l.slaveWaiter.Notify(waiter.EventIn)
+ l.replicaWaiter.Notify(waiter.EventIn)
}
return n, nil
}
@@ -209,14 +202,14 @@ func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequ
return 0, err
}
if n > 0 {
- l.slaveWaiter.Notify(waiter.EventIn)
+ l.replicaWaiter.Notify(waiter.EventIn)
return n, nil
}
return 0, syserror.ErrWouldBlock
}
-func (l *lineDiscipline) outputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
- return l.outQueue.readableSize(ctx, io, args)
+func (l *lineDiscipline) outputQueueReadSize(t *kernel.Task, io usermem.IO, args arch.SyscallArguments) error {
+ return l.outQueue.readableSize(t, io, args)
}
func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) {
@@ -227,7 +220,7 @@ func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequ
return 0, err
}
if n > 0 {
- l.slaveWaiter.Notify(waiter.EventOut)
+ l.replicaWaiter.Notify(waiter.EventOut)
if pushed {
l.masterWaiter.Notify(waiter.EventIn)
}
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
index 60feb1993..69c2fe951 100644
--- a/pkg/sentry/fsimpl/devpts/master.go
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -17,9 +17,11 @@ package devpts
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -29,6 +31,8 @@ import (
)
// masterInode is the inode for the master end of the Terminal.
+//
+// +stateify savable
type masterInode struct {
implStatFS
kernfs.InodeAttrs
@@ -48,20 +52,18 @@ type masterInode struct {
var _ kernfs.Inode = (*masterInode)(nil)
// Open implements kernfs.Inode.Open.
-func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
t, err := mi.root.allocateTerminal(rp.Credentials())
if err != nil {
return nil, err
}
- mi.IncRef()
fd := &masterFileDescription{
inode: mi,
t: t,
}
fd.LockFD.Init(&mi.locks)
- if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
- mi.DecRef(ctx)
+ if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return &fd.vfsfd, nil
@@ -87,6 +89,7 @@ func (mi *masterInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds
return mi.InodeAttrs.SetStat(ctx, vfsfs, creds, opts)
}
+// +stateify savable
type masterFileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -101,7 +104,6 @@ var _ vfs.FileDescriptionImpl = (*masterFileDescription)(nil)
// Release implements vfs.FileDescriptionImpl.Release.
func (mfd *masterFileDescription) Release(ctx context.Context) {
mfd.inode.root.masterClose(mfd.t)
- mfd.inode.DecRef(ctx)
}
// EventRegister implements waiter.Waitable.EventRegister.
@@ -131,46 +133,51 @@ func (mfd *masterFileDescription) Write(ctx context.Context, src usermem.IOSeque
// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
func (mfd *masterFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // ioctl(2) may only be called from a task goroutine.
+ return 0, syserror.ENOTTY
+ }
+
switch cmd := args[1].Uint(); cmd {
case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
// Get the number of bytes in the output queue read buffer.
- return 0, mfd.t.ld.outputQueueReadSize(ctx, io, args)
+ return 0, mfd.t.ld.outputQueueReadSize(t, io, args)
case linux.TCGETS:
// N.B. TCGETS on the master actually returns the configuration
- // of the slave end.
- return mfd.t.ld.getTermios(ctx, io, args)
+ // of the replica end.
+ return mfd.t.ld.getTermios(t, args)
case linux.TCSETS:
// N.B. TCSETS on the master actually affects the configuration
- // of the slave end.
- return mfd.t.ld.setTermios(ctx, io, args)
+ // of the replica end.
+ return mfd.t.ld.setTermios(t, args)
case linux.TCSETSW:
// TODO(b/29356795): This should drain the output queue first.
- return mfd.t.ld.setTermios(ctx, io, args)
+ return mfd.t.ld.setTermios(t, args)
case linux.TIOCGPTN:
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(mfd.t.n), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ nP := primitive.Uint32(mfd.t.n)
+ _, err := nP.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCSPTLCK:
// TODO(b/29356795): Implement pty locking. For now just pretend we do.
return 0, nil
case linux.TIOCGWINSZ:
- return 0, mfd.t.ld.windowSize(ctx, io, args)
+ return 0, mfd.t.ld.windowSize(t, args)
case linux.TIOCSWINSZ:
- return 0, mfd.t.ld.setWindowSize(ctx, io, args)
+ return 0, mfd.t.ld.setWindowSize(t, args)
case linux.TIOCSCTTY:
// Make the given terminal the controlling terminal of the
// calling process.
- return 0, mfd.t.setControllingTTY(ctx, io, args, true /* isMaster */)
+ return 0, mfd.t.setControllingTTY(ctx, args, true /* isMaster */)
case linux.TIOCNOTTY:
// Release this process's controlling terminal.
- return 0, mfd.t.releaseControllingTTY(ctx, io, args, true /* isMaster */)
+ return 0, mfd.t.releaseControllingTTY(ctx, args, true /* isMaster */)
case linux.TIOCGPGRP:
// Get the foreground process group.
- return mfd.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */)
+ return mfd.t.foregroundProcessGroup(ctx, args, true /* isMaster */)
case linux.TIOCSPGRP:
// Set the foreground process group.
- return mfd.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */)
+ return mfd.t.setForegroundProcessGroup(ctx, args, true /* isMaster */)
default:
maybeEmitUnimplementedEvent(ctx, cmd)
return 0, syserror.ENOTTY
diff --git a/pkg/sentry/fsimpl/devpts/queue.go b/pkg/sentry/fsimpl/devpts/queue.go
index 331c13997..55bff3e60 100644
--- a/pkg/sentry/fsimpl/devpts/queue.go
+++ b/pkg/sentry/fsimpl/devpts/queue.go
@@ -17,8 +17,10 @@ package devpts
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -30,7 +32,7 @@ import (
const waitBufMaxBytes = 131072
// queue represents one of the input or output queues between a pty master and
-// slave. Bytes written to a queue are added to the read buffer until it is
+// replica. Bytes written to a queue are added to the read buffer until it is
// full, at which point they are written to the wait buffer. Bytes are
// processed (i.e. undergo termios transformations) as they are added to the
// read buffer. The read buffer is readable when its length is nonzero and
@@ -83,17 +85,15 @@ func (q *queue) writeReadiness(t *linux.KernelTermios) waiter.EventMask {
}
// readableSize writes the number of readable bytes to userspace.
-func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+func (q *queue) readableSize(t *kernel.Task, io usermem.IO, args arch.SyscallArguments) error {
q.mu.Lock()
defer q.mu.Unlock()
- var size int32
+ size := primitive.Int32(0)
if q.readable {
- size = int32(len(q.readBuf))
+ size = primitive.Int32(len(q.readBuf))
}
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), size, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := size.CopyOut(t, args[2].Pointer())
return err
}
diff --git a/pkg/sentry/fsimpl/devpts/replica.go b/pkg/sentry/fsimpl/devpts/replica.go
new file mode 100644
index 000000000..6515c5536
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/replica.go
@@ -0,0 +1,204 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// replicaInode is the inode for the replica end of the Terminal.
+//
+// +stateify savable
+type replicaInode struct {
+ implStatFS
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ locks vfs.FileLocks
+
+ // Keep a reference to this inode's dentry.
+ dentry kernfs.Dentry
+
+ // root is the devpts root inode.
+ root *rootInode
+
+ // t is the connected Terminal.
+ t *Terminal
+}
+
+var _ kernfs.Inode = (*replicaInode)(nil)
+
+// Open implements kernfs.Inode.Open.
+func (ri *replicaInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &replicaFileDescription{
+ inode: ri,
+ }
+ fd.LockFD.Init(&ri.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+
+}
+
+// Valid implements kernfs.Inode.Valid.
+func (ri *replicaInode) Valid(context.Context) bool {
+ // Return valid if the replica still exists.
+ ri.root.mu.Lock()
+ defer ri.root.mu.Unlock()
+ _, ok := ri.root.replicas[ri.t.n]
+ return ok
+}
+
+// Stat implements kernfs.Inode.Stat.
+func (ri *replicaInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ statx, err := ri.InodeAttrs.Stat(ctx, vfsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ statx.Blksize = 1024
+ statx.RdevMajor = linux.UNIX98_PTY_REPLICA_MAJOR
+ statx.RdevMinor = ri.t.n
+ return statx, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat
+func (ri *replicaInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask&linux.STATX_SIZE != 0 {
+ return syserror.EINVAL
+ }
+ return ri.InodeAttrs.SetStat(ctx, vfsfs, creds, opts)
+}
+
+// +stateify savable
+type replicaFileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ inode *replicaInode
+}
+
+var _ vfs.FileDescriptionImpl = (*replicaFileDescription)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (rfd *replicaFileDescription) Release(ctx context.Context) {}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (rfd *replicaFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ rfd.inode.t.ld.replicaWaiter.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (rfd *replicaFileDescription) EventUnregister(e *waiter.Entry) {
+ rfd.inode.t.ld.replicaWaiter.EventUnregister(e)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (rfd *replicaFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return rfd.inode.t.ld.replicaReadiness()
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (rfd *replicaFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ return rfd.inode.t.ld.inputQueueRead(ctx, dst)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (rfd *replicaFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
+ return rfd.inode.t.ld.outputQueueWrite(ctx, src)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (rfd *replicaFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // ioctl(2) may only be called from a task goroutine.
+ return 0, syserror.ENOTTY
+ }
+
+ switch cmd := args[1].Uint(); cmd {
+ case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
+ // Get the number of bytes in the input queue read buffer.
+ return 0, rfd.inode.t.ld.inputQueueReadSize(t, io, args)
+ case linux.TCGETS:
+ return rfd.inode.t.ld.getTermios(t, args)
+ case linux.TCSETS:
+ return rfd.inode.t.ld.setTermios(t, args)
+ case linux.TCSETSW:
+ // TODO(b/29356795): This should drain the output queue first.
+ return rfd.inode.t.ld.setTermios(t, args)
+ case linux.TIOCGPTN:
+ nP := primitive.Uint32(rfd.inode.t.n)
+ _, err := nP.CopyOut(t, args[2].Pointer())
+ return 0, err
+ case linux.TIOCGWINSZ:
+ return 0, rfd.inode.t.ld.windowSize(t, args)
+ case linux.TIOCSWINSZ:
+ return 0, rfd.inode.t.ld.setWindowSize(t, args)
+ case linux.TIOCSCTTY:
+ // Make the given terminal the controlling terminal of the
+ // calling process.
+ return 0, rfd.inode.t.setControllingTTY(ctx, args, false /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, rfd.inode.t.releaseControllingTTY(ctx, args, false /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return rfd.inode.t.foregroundProcessGroup(ctx, args, false /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return rfd.inode.t.setForegroundProcessGroup(ctx, args, false /* isMaster */)
+ default:
+ maybeEmitUnimplementedEvent(ctx, cmd)
+ return 0, syserror.ENOTTY
+ }
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (rfd *replicaFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ fs := rfd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return rfd.inode.SetStat(ctx, fs, creds, opts)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (rfd *replicaFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := rfd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return rfd.inode.Stat(ctx, fs, opts)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (rfd *replicaFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return rfd.Locks().LockPOSIX(ctx, &rfd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (rfd *replicaFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return rfd.Locks().UnlockPOSIX(ctx, &rfd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go
deleted file mode 100644
index a9da7af64..000000000
--- a/pkg/sentry/fsimpl/devpts/slave.go
+++ /dev/null
@@ -1,198 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package devpts
-
-import (
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/sentry/arch"
- fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
- "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// slaveInode is the inode for the slave end of the Terminal.
-type slaveInode struct {
- implStatFS
- kernfs.InodeAttrs
- kernfs.InodeNoopRefCount
- kernfs.InodeNotDirectory
- kernfs.InodeNotSymlink
-
- locks vfs.FileLocks
-
- // Keep a reference to this inode's dentry.
- dentry kernfs.Dentry
-
- // root is the devpts root inode.
- root *rootInode
-
- // t is the connected Terminal.
- t *Terminal
-}
-
-var _ kernfs.Inode = (*slaveInode)(nil)
-
-// Open implements kernfs.Inode.Open.
-func (si *slaveInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- si.IncRef()
- fd := &slaveFileDescription{
- inode: si,
- }
- fd.LockFD.Init(&si.locks)
- if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
- si.DecRef(ctx)
- return nil, err
- }
- return &fd.vfsfd, nil
-
-}
-
-// Valid implements kernfs.Inode.Valid.
-func (si *slaveInode) Valid(context.Context) bool {
- // Return valid if the slave still exists.
- si.root.mu.Lock()
- defer si.root.mu.Unlock()
- _, ok := si.root.slaves[si.t.n]
- return ok
-}
-
-// Stat implements kernfs.Inode.Stat.
-func (si *slaveInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
- statx, err := si.InodeAttrs.Stat(ctx, vfsfs, opts)
- if err != nil {
- return linux.Statx{}, err
- }
- statx.Blksize = 1024
- statx.RdevMajor = linux.UNIX98_PTY_SLAVE_MAJOR
- statx.RdevMinor = si.t.n
- return statx, nil
-}
-
-// SetStat implements kernfs.Inode.SetStat
-func (si *slaveInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
- if opts.Stat.Mask&linux.STATX_SIZE != 0 {
- return syserror.EINVAL
- }
- return si.InodeAttrs.SetStat(ctx, vfsfs, creds, opts)
-}
-
-type slaveFileDescription struct {
- vfsfd vfs.FileDescription
- vfs.FileDescriptionDefaultImpl
- vfs.LockFD
-
- inode *slaveInode
-}
-
-var _ vfs.FileDescriptionImpl = (*slaveFileDescription)(nil)
-
-// Release implements fs.FileOperations.Release.
-func (sfd *slaveFileDescription) Release(ctx context.Context) {
- sfd.inode.DecRef(ctx)
-}
-
-// EventRegister implements waiter.Waitable.EventRegister.
-func (sfd *slaveFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
- sfd.inode.t.ld.slaveWaiter.EventRegister(e, mask)
-}
-
-// EventUnregister implements waiter.Waitable.EventUnregister.
-func (sfd *slaveFileDescription) EventUnregister(e *waiter.Entry) {
- sfd.inode.t.ld.slaveWaiter.EventUnregister(e)
-}
-
-// Readiness implements waiter.Waitable.Readiness.
-func (sfd *slaveFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
- return sfd.inode.t.ld.slaveReadiness()
-}
-
-// Read implements vfs.FileDescriptionImpl.Read.
-func (sfd *slaveFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
- return sfd.inode.t.ld.inputQueueRead(ctx, dst)
-}
-
-// Write implements vfs.FileDescriptionImpl.Write.
-func (sfd *slaveFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
- return sfd.inode.t.ld.outputQueueWrite(ctx, src)
-}
-
-// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
-func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
- switch cmd := args[1].Uint(); cmd {
- case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
- // Get the number of bytes in the input queue read buffer.
- return 0, sfd.inode.t.ld.inputQueueReadSize(ctx, io, args)
- case linux.TCGETS:
- return sfd.inode.t.ld.getTermios(ctx, io, args)
- case linux.TCSETS:
- return sfd.inode.t.ld.setTermios(ctx, io, args)
- case linux.TCSETSW:
- // TODO(b/29356795): This should drain the output queue first.
- return sfd.inode.t.ld.setTermios(ctx, io, args)
- case linux.TIOCGPTN:
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(sfd.inode.t.n), usermem.IOOpts{
- AddressSpaceActive: true,
- })
- return 0, err
- case linux.TIOCGWINSZ:
- return 0, sfd.inode.t.ld.windowSize(ctx, io, args)
- case linux.TIOCSWINSZ:
- return 0, sfd.inode.t.ld.setWindowSize(ctx, io, args)
- case linux.TIOCSCTTY:
- // Make the given terminal the controlling terminal of the
- // calling process.
- return 0, sfd.inode.t.setControllingTTY(ctx, io, args, false /* isMaster */)
- case linux.TIOCNOTTY:
- // Release this process's controlling terminal.
- return 0, sfd.inode.t.releaseControllingTTY(ctx, io, args, false /* isMaster */)
- case linux.TIOCGPGRP:
- // Get the foreground process group.
- return sfd.inode.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */)
- case linux.TIOCSPGRP:
- // Set the foreground process group.
- return sfd.inode.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */)
- default:
- maybeEmitUnimplementedEvent(ctx, cmd)
- return 0, syserror.ENOTTY
- }
-}
-
-// SetStat implements vfs.FileDescriptionImpl.SetStat.
-func (sfd *slaveFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
- creds := auth.CredentialsFromContext(ctx)
- fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem()
- return sfd.inode.SetStat(ctx, fs, creds, opts)
-}
-
-// Stat implements vfs.FileDescriptionImpl.Stat.
-func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
- fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem()
- return sfd.inode.Stat(ctx, fs, opts)
-}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (sfd *slaveFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return sfd.Locks().LockPOSIX(ctx, &sfd.vfsfd, uid, t, start, length, whence, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (sfd *slaveFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return sfd.Locks().UnlockPOSIX(ctx, &sfd.vfsfd, uid, start, length, whence)
-}
diff --git a/pkg/sentry/fsimpl/devpts/terminal.go b/pkg/sentry/fsimpl/devpts/terminal.go
index 7d2781c54..510bd6d89 100644
--- a/pkg/sentry/fsimpl/devpts/terminal.go
+++ b/pkg/sentry/fsimpl/devpts/terminal.go
@@ -17,9 +17,9 @@ package devpts
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Terminal is a pseudoterminal.
@@ -36,25 +36,25 @@ type Terminal struct {
// this terminal. This field is immutable.
masterKTTY *kernel.TTY
- // slaveKTTY contains the controlling process of the slave end of this
+ // replicaKTTY contains the controlling process of the replica end of this
// terminal. This field is immutable.
- slaveKTTY *kernel.TTY
+ replicaKTTY *kernel.TTY
}
func newTerminal(n uint32) *Terminal {
- termios := linux.DefaultSlaveTermios
+ termios := linux.DefaultReplicaTermios
t := Terminal{
- n: n,
- ld: newLineDiscipline(termios),
- masterKTTY: &kernel.TTY{Index: n},
- slaveKTTY: &kernel.TTY{Index: n},
+ n: n,
+ ld: newLineDiscipline(termios),
+ masterKTTY: &kernel.TTY{Index: n},
+ replicaKTTY: &kernel.TTY{Index: n},
}
return &t
}
// setControllingTTY makes tm the controlling terminal of the calling thread
// group.
-func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+func (tm *Terminal) setControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error {
task := kernel.TaskFromContext(ctx)
if task == nil {
panic("setControllingTTY must be called from a task context")
@@ -65,7 +65,7 @@ func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args a
// releaseControllingTTY removes tm as the controlling terminal of the calling
// thread group.
-func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+func (tm *Terminal) releaseControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error {
task := kernel.TaskFromContext(ctx)
if task == nil {
panic("releaseControllingTTY must be called from a task context")
@@ -75,7 +75,7 @@ func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, ar
}
// foregroundProcessGroup gets the process group ID of tm's foreground process.
-func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+func (tm *Terminal) foregroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
task := kernel.TaskFromContext(ctx)
if task == nil {
panic("foregroundProcessGroup must be called from a task context")
@@ -87,24 +87,21 @@ func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, a
}
// Write it out to *arg.
- _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ retP := primitive.Int32(ret)
+ _, err = retP.CopyOut(task, args[2].Pointer())
return 0, err
}
// foregroundProcessGroup sets tm's foreground process.
-func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
task := kernel.TaskFromContext(ctx)
if task == nil {
panic("setForegroundProcessGroup must be called from a task context")
}
// Read in the process group ID.
- var pgid int32
- if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
+ var pgid primitive.Int32
+ if _, err := pgid.CopyIn(task, args[2].Pointer()); err != nil {
return 0, err
}
@@ -116,5 +113,5 @@ func (tm *Terminal) tty(isMaster bool) *kernel.TTY {
if isMaster {
return tm.masterKTTY
}
- return tm.slaveKTTY
+ return tm.replicaKTTY
}
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
index 52f44f66d..6d1753080 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
@@ -33,8 +33,10 @@ import (
const Name = "devtmpfs"
// FilesystemType implements vfs.FilesystemType.
+//
+// +stateify savable
type FilesystemType struct {
- initOnce sync.Once
+ initOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1664): not yet supported.
initErr error
// fs is the tmpfs filesystem that backs all mounts of this FilesystemType.
@@ -80,7 +82,7 @@ type Accessor struct {
// NewAccessor returns an Accessor that supports creation of device special
// files in the devtmpfs instance registered with name fsTypeName in vfsObj.
func NewAccessor(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, fsTypeName string) (*Accessor, error) {
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "devtmpfs" /* source */, fsTypeName, &vfs.GetFilesystemOptions{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "devtmpfs" /* source */, fsTypeName, &vfs.MountOptions{})
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
index 827a608cb..3a38b8bb4 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
@@ -48,7 +48,7 @@ func setupDevtmpfs(t *testing.T) (context.Context, *auth.Credentials, *vfs.Virtu
})
// Create a test mount namespace with devtmpfs mounted at "/dev".
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "tmpfs" /* source */, "tmpfs" /* fsTypeName */, &vfs.GetFilesystemOptions{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "tmpfs" /* source */, "tmpfs" /* fsTypeName */, &vfs.MountOptions{})
if err != nil {
t.Fatalf("failed to create tmpfs root mount: %v", err)
}
diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go
index 812171fa3..1c27ad700 100644
--- a/pkg/sentry/fsimpl/eventfd/eventfd.go
+++ b/pkg/sentry/fsimpl/eventfd/eventfd.go
@@ -30,9 +30,11 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-// EventFileDescription implements FileDescriptionImpl for file-based event
+// EventFileDescription implements vfs.FileDescriptionImpl for file-based event
// notification (eventfd). Eventfds are usually internal to the Sentry but in
// certain situations they may be converted into a host-backed eventfd.
+//
+// +stateify savable
type EventFileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -106,7 +108,7 @@ func (efd *EventFileDescription) HostFD() (int, error) {
return efd.hostfd, nil
}
-// Release implements FileDescriptionImpl.Release()
+// Release implements vfs.FileDescriptionImpl.Release.
func (efd *EventFileDescription) Release(context.Context) {
efd.mu.Lock()
defer efd.mu.Unlock()
@@ -119,7 +121,7 @@ func (efd *EventFileDescription) Release(context.Context) {
}
}
-// Read implements FileDescriptionImpl.Read.
+// Read implements vfs.FileDescriptionImpl.Read.
func (efd *EventFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
if dst.NumBytes() < 8 {
return 0, syscall.EINVAL
@@ -130,7 +132,7 @@ func (efd *EventFileDescription) Read(ctx context.Context, dst usermem.IOSequenc
return 8, nil
}
-// Write implements FileDescriptionImpl.Write.
+// Write implements vfs.FileDescriptionImpl.Write.
func (efd *EventFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
if src.NumBytes() < 8 {
return 0, syscall.EINVAL
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index abc610ef3..7b1eec3da 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -51,6 +51,8 @@ go_library(
"//pkg/fd",
"//pkg/fspath",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs",
@@ -86,9 +88,9 @@ go_test(
library = ":ext",
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/marshal/primitive",
"//pkg/sentry/contexttest",
"//pkg/sentry/fsimpl/ext/disklayout",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
index a2cc9b59f..c349b886e 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
+++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
@@ -59,7 +59,11 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys
vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ InternalData: int(f.Fd()),
+ },
+ })
if err != nil {
f.Close()
return nil, nil, nil, nil, err
diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go
index 8bb104ff0..1165234f9 100644
--- a/pkg/sentry/fsimpl/ext/block_map_file.go
+++ b/pkg/sentry/fsimpl/ext/block_map_file.go
@@ -18,7 +18,7 @@ import (
"io"
"math"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -34,19 +34,19 @@ type blockMapFile struct {
// directBlks are the direct blocks numbers. The physical blocks pointed by
// these holds file data. Contains file blocks 0 to 11.
- directBlks [numDirectBlks]uint32
+ directBlks [numDirectBlks]primitive.Uint32
// indirectBlk is the physical block which contains (blkSize/4) direct block
// numbers (as uint32 integers).
- indirectBlk uint32
+ indirectBlk primitive.Uint32
// doubleIndirectBlk is the physical block which contains (blkSize/4) indirect
// block numbers (as uint32 integers).
- doubleIndirectBlk uint32
+ doubleIndirectBlk primitive.Uint32
// tripleIndirectBlk is the physical block which contains (blkSize/4) doubly
// indirect block numbers (as uint32 integers).
- tripleIndirectBlk uint32
+ tripleIndirectBlk primitive.Uint32
// coverage at (i)th index indicates the amount of file data a node at
// height (i) covers. Height 0 is the direct block.
@@ -68,10 +68,12 @@ func newBlockMapFile(args inodeArgs) (*blockMapFile, error) {
}
blkMap := file.regFile.inode.diskInode.Data()
- binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks)
- binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk)
- binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk)
- binary.Unmarshal(blkMap[(numDirectBlks+2)*4:(numDirectBlks+3)*4], binary.LittleEndian, &file.tripleIndirectBlk)
+ for i := 0; i < numDirectBlks; i++ {
+ file.directBlks[i].UnmarshalBytes(blkMap[i*4 : (i+1)*4])
+ }
+ file.indirectBlk.UnmarshalBytes(blkMap[numDirectBlks*4 : (numDirectBlks+1)*4])
+ file.doubleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+1)*4 : (numDirectBlks+2)*4])
+ file.tripleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+2)*4 : (numDirectBlks+3)*4])
return file, nil
}
@@ -117,16 +119,16 @@ func (f *blockMapFile) ReadAt(dst []byte, off int64) (int, error) {
switch {
case offset < dirBlksEnd:
// Direct block.
- curR, err = f.read(f.directBlks[offset/f.regFile.inode.blkSize], offset%f.regFile.inode.blkSize, 0, dst[read:])
+ curR, err = f.read(uint32(f.directBlks[offset/f.regFile.inode.blkSize]), offset%f.regFile.inode.blkSize, 0, dst[read:])
case offset < indirBlkEnd:
// Indirect block.
- curR, err = f.read(f.indirectBlk, offset-dirBlksEnd, 1, dst[read:])
+ curR, err = f.read(uint32(f.indirectBlk), offset-dirBlksEnd, 1, dst[read:])
case offset < doubIndirBlkEnd:
// Doubly indirect block.
- curR, err = f.read(f.doubleIndirectBlk, offset-indirBlkEnd, 2, dst[read:])
+ curR, err = f.read(uint32(f.doubleIndirectBlk), offset-indirBlkEnd, 2, dst[read:])
default:
// Triply indirect block.
- curR, err = f.read(f.tripleIndirectBlk, offset-doubIndirBlkEnd, 3, dst[read:])
+ curR, err = f.read(uint32(f.tripleIndirectBlk), offset-doubIndirBlkEnd, 3, dst[read:])
}
read += curR
@@ -174,13 +176,13 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds
read := 0
curChildOff := relFileOff % childCov
for i := startIdx; i < endIdx; i++ {
- var childPhyBlk uint32
+ var childPhyBlk primitive.Uint32
err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk)
if err != nil {
return read, err
}
- n, err := f.read(childPhyBlk, curChildOff, height-1, dst[read:])
+ n, err := f.read(uint32(childPhyBlk), curChildOff, height-1, dst[read:])
read += n
if err != nil {
return read, err
diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go
index 6fa84e7aa..ed98b482e 100644
--- a/pkg/sentry/fsimpl/ext/block_map_test.go
+++ b/pkg/sentry/fsimpl/ext/block_map_test.go
@@ -20,7 +20,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
)
@@ -87,29 +87,33 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) {
mockDisk := make([]byte, mockBMDiskSize)
var fileData []byte
blkNums := newBlkNumGen()
- var data []byte
+ off := 0
+ data := make([]byte, (numDirectBlks+3)*(*primitive.Uint32)(nil).SizeBytes())
// Write the direct blocks.
for i := 0; i < numDirectBlks; i++ {
- curBlkNum := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, curBlkNum)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, curBlkNum, 0, blkNums)...)
+ curBlkNum := primitive.Uint32(blkNums.next())
+ curBlkNum.MarshalBytes(data[off:])
+ off += curBlkNum.SizeBytes()
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(curBlkNum), 0, blkNums)...)
}
// Write to indirect block.
- indirectBlk := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, indirectBlk)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, indirectBlk, 1, blkNums)...)
-
- // Write to indirect block.
- doublyIndirectBlk := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, doublyIndirectBlk)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, doublyIndirectBlk, 2, blkNums)...)
-
- // Write to indirect block.
- triplyIndirectBlk := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...)
+ indirectBlk := primitive.Uint32(blkNums.next())
+ indirectBlk.MarshalBytes(data[off:])
+ off += indirectBlk.SizeBytes()
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(indirectBlk), 1, blkNums)...)
+
+ // Write to double indirect block.
+ doublyIndirectBlk := primitive.Uint32(blkNums.next())
+ doublyIndirectBlk.MarshalBytes(data[off:])
+ off += doublyIndirectBlk.SizeBytes()
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(doublyIndirectBlk), 2, blkNums)...)
+
+ // Write to triple indirect block.
+ triplyIndirectBlk := primitive.Uint32(blkNums.next())
+ triplyIndirectBlk.MarshalBytes(data[off:])
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(triplyIndirectBlk), 3, blkNums)...)
args := inodeArgs{
fs: &filesystem{
@@ -142,9 +146,9 @@ func writeFileDataToBlock(disk []byte, blkNum uint32, height uint, blkNums *blkN
var fileData []byte
for off := blkNum * mockBMBlkSize; off < (blkNum+1)*mockBMBlkSize; off += 4 {
- curBlkNum := blkNums.next()
- copy(disk[off:off+4], binary.Marshal(nil, binary.LittleEndian, curBlkNum))
- fileData = append(fileData, writeFileDataToBlock(disk, curBlkNum, height-1, blkNums)...)
+ curBlkNum := primitive.Uint32(blkNums.next())
+ curBlkNum.MarshalBytes(disk[off : off+4])
+ fileData = append(fileData, writeFileDataToBlock(disk, uint32(curBlkNum), height-1, blkNums)...)
}
return fileData
}
diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go
index 7a1b4219f..9bfed883a 100644
--- a/pkg/sentry/fsimpl/ext/dentry.go
+++ b/pkg/sentry/fsimpl/ext/dentry.go
@@ -20,6 +20,8 @@ import (
)
// dentry implements vfs.DentryImpl.
+//
+// +stateify savable
type dentry struct {
vfsd vfs.Dentry
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index 0fc01668d..0ad79b381 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -16,7 +16,6 @@ package ext
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -28,6 +27,8 @@ import (
)
// directory represents a directory inode. It holds the childList in memory.
+//
+// +stateify savable
type directory struct {
inode inode
@@ -39,7 +40,7 @@ type directory struct {
// Lock Order (outermost locks must be taken first):
// directory.mu
// filesystem.mu
- mu sync.Mutex
+ mu sync.Mutex `state:"nosave"`
// childList is a list containing (1) child dirents and (2) fake dirents
// (with diskDirent == nil) that represent the iteration position of
@@ -98,7 +99,7 @@ func newDirectory(args inodeArgs, newDirent bool) (*directory, error) {
} else {
curDirent.diskDirent = &disklayout.DirentOld{}
}
- binary.Unmarshal(buf, binary.LittleEndian, curDirent.diskDirent)
+ curDirent.diskDirent.UnmarshalBytes(buf)
if curDirent.diskDirent.Inode() != 0 && len(curDirent.diskDirent.FileName()) != 0 {
// Inode number and name length fields being set to 0 is used to indicate
@@ -120,6 +121,8 @@ func (i *inode) isDir() bool {
}
// dirent is the directory.childList node.
+//
+// +stateify savable
type dirent struct {
diskDirent disklayout.Dirent
@@ -129,6 +132,8 @@ type dirent struct {
// directoryFD represents a directory file description. It implements
// vfs.FileDescriptionImpl.
+//
+// +stateify savable
type directoryFD struct {
fileDescription
vfs.DirectoryFileDescriptionDefaultImpl
diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD
index 9bd9c76c0..d98a05dd8 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/BUILD
+++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD
@@ -22,10 +22,11 @@ go_library(
"superblock_old.go",
"test_utils.go",
],
+ marshal = True,
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
+ "//pkg/marshal",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group.go b/pkg/sentry/fsimpl/ext/disklayout/block_group.go
index ad6f4fef8..0d56ae9da 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group.go
@@ -14,6 +14,10 @@
package disklayout
+import (
+ "gvisor.dev/gvisor/pkg/marshal"
+)
+
// BlockGroup represents a Linux ext block group descriptor. An ext file system
// is split into a series of block groups. This provides an access layer to
// information needed to access and use a block group.
@@ -30,6 +34,8 @@ package disklayout
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#block-group-descriptors.
type BlockGroup interface {
+ marshal.Marshallable
+
// InodeTable returns the absolute block number of the block containing the
// inode table. This points to an array of Inode structs. Inode tables are
// statically allocated at mkfs time. The superblock records the number of
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
index 3e16c76db..a35fa22a0 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
@@ -17,6 +17,8 @@ package disklayout
// BlockGroup32Bit emulates the first half of struct ext4_group_desc in
// fs/ext4/ext4.h. It is the block group descriptor struct for ext2, ext3 and
// 32-bit ext4 filesystems. It implements BlockGroup interface.
+//
+// +marshal
type BlockGroup32Bit struct {
BlockBitmapLo uint32
InodeBitmapLo uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
index 9a809197a..d54d1d345 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
@@ -18,6 +18,8 @@ package disklayout
// It is the block group descriptor struct for 64-bit ext4 filesystems.
// It implements BlockGroup interface. It is an extension of the 32-bit
// version of BlockGroup.
+//
+// +marshal
type BlockGroup64Bit struct {
// We embed the 32-bit struct here because 64-bit version is just an extension
// of the 32-bit version.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
index 0ef4294c0..e4ce484e4 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
@@ -21,6 +21,8 @@ import (
// TestBlockGroupSize tests that the block group descriptor structs are of the
// correct size.
func TestBlockGroupSize(t *testing.T) {
- assertSize(t, BlockGroup32Bit{}, 32)
- assertSize(t, BlockGroup64Bit{}, 64)
+ var bgSmall BlockGroup32Bit
+ assertSize(t, &bgSmall, 32)
+ var bgBig BlockGroup64Bit
+ assertSize(t, &bgBig, 64)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent.go b/pkg/sentry/fsimpl/ext/disklayout/dirent.go
index 417b6cf65..568c8cb4c 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent.go
@@ -15,6 +15,7 @@
package disklayout
import (
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
@@ -51,6 +52,8 @@ var (
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#linear-classic-directories.
type Dirent interface {
+ marshal.Marshallable
+
// Inode returns the absolute inode number of the underlying inode.
// Inode number 0 signifies an unused dirent.
Inode() uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
index 29ae4a5c2..51f9c2946 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
@@ -29,12 +29,14 @@ import (
// Note: This struct can be of variable size on disk. The one described below
// is of maximum size and the FileName beyond NameLength bytes might contain
// garbage.
+//
+// +marshal
type DirentNew struct {
InodeNumber uint32
RecordLength uint16
NameLength uint8
FileTypeRaw uint8
- FileNameRaw [MaxFileName]byte
+ FileNameRaw [MaxFileName]byte `marshal:"unaligned"`
}
// Compiles only if DirentNew implements Dirent.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
index 6fff12a6e..d4b19e086 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
@@ -22,11 +22,13 @@ import "gvisor.dev/gvisor/pkg/sentry/fs"
// Note: This struct can be of variable size on disk. The one described below
// is of maximum size and the FileName beyond NameLength bytes might contain
// garbage.
+//
+// +marshal
type DirentOld struct {
InodeNumber uint32
RecordLength uint16
NameLength uint16
- FileNameRaw [MaxFileName]byte
+ FileNameRaw [MaxFileName]byte `marshal:"unaligned"`
}
// Compiles only if DirentOld implements Dirent.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
index 934919f8a..3486864dc 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
@@ -21,6 +21,8 @@ import (
// TestDirentSize tests that the dirent structs are of the correct
// size.
func TestDirentSize(t *testing.T) {
- assertSize(t, DirentOld{}, uintptr(DirentSize))
- assertSize(t, DirentNew{}, uintptr(DirentSize))
+ var dOld DirentOld
+ assertSize(t, &dOld, DirentSize)
+ var dNew DirentNew
+ assertSize(t, &dNew, DirentSize)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
index bdf4e2132..0834e9ba8 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
@@ -36,8 +36,6 @@
// escape analysis on an unknown implementation at compile time.
//
// Notes:
-// - All fields in these structs are exported because binary.Read would
-// panic otherwise.
// - All structures on disk are in little-endian order. Only jbd2 (journal)
// structures are in big-endian order.
// - All OS dependent fields in these structures will be interpretted using
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go
index 4110649ab..b13999bfc 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go
@@ -14,6 +14,10 @@
package disklayout
+import (
+ "gvisor.dev/gvisor/pkg/marshal"
+)
+
// Extents were introduced in ext4 and provide huge performance gains in terms
// data locality and reduced metadata block usage. Extents are organized in
// extent trees. The root node is contained in inode.BlocksRaw.
@@ -64,6 +68,8 @@ type ExtentNode struct {
// ExtentEntry represents an extent tree node entry. The entry can either be
// an ExtentIdx or Extent itself. This exists to simplify navigation logic.
type ExtentEntry interface {
+ marshal.Marshallable
+
// FileBlock returns the first file block number covered by this entry.
FileBlock() uint32
@@ -75,6 +81,8 @@ type ExtentEntry interface {
// tree node begins with this and is followed by `NumEntries` number of:
// - Extent if `Depth` == 0
// - ExtentIdx otherwise
+//
+// +marshal
type ExtentHeader struct {
// Magic in the extent magic number, must be 0xf30a.
Magic uint16
@@ -96,6 +104,8 @@ type ExtentHeader struct {
// internal nodes. Sorted in ascending order based on FirstFileBlock since
// Linux does a binary search on this. This points to a block containing the
// child node.
+//
+// +marshal
type ExtentIdx struct {
FirstFileBlock uint32
ChildBlockLo uint32
@@ -121,6 +131,8 @@ func (ei *ExtentIdx) PhysicalBlock() uint64 {
// nodes. Sorted in ascending order based on FirstFileBlock since Linux does a
// binary search on this. This points to an array of data blocks containing the
// file data. It covers `Length` data blocks starting from `StartBlock`.
+//
+// +marshal
type Extent struct {
FirstFileBlock uint32
Length uint16
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
index 8762b90db..c96002e19 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
@@ -21,7 +21,10 @@ import (
// TestExtentSize tests that the extent structs are of the correct
// size.
func TestExtentSize(t *testing.T) {
- assertSize(t, ExtentHeader{}, ExtentHeaderSize)
- assertSize(t, ExtentIdx{}, ExtentEntrySize)
- assertSize(t, Extent{}, ExtentEntrySize)
+ var h ExtentHeader
+ assertSize(t, &h, ExtentHeaderSize)
+ var i ExtentIdx
+ assertSize(t, &i, ExtentEntrySize)
+ var e Extent
+ assertSize(t, &e, ExtentEntrySize)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode.go b/pkg/sentry/fsimpl/ext/disklayout/inode.go
index 88ae913f5..ef25040a9 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode.go
@@ -16,6 +16,7 @@ package disklayout
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
)
@@ -38,6 +39,8 @@ const (
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#index-nodes.
type Inode interface {
+ marshal.Marshallable
+
// Mode returns the linux file mode which is majorly used to extract
// information like:
// - File permissions (read/write/execute by user/group/others).
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
index 8f9f574ce..a4503f5cf 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
@@ -27,6 +27,8 @@ import "gvisor.dev/gvisor/pkg/sentry/kernel/time"
// are used to provide nanoscond precision. Hence, these timestamps will now
// overflow in May 2446.
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#inode-timestamps.
+//
+// +marshal
type InodeNew struct {
InodeOld
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
index db25b11b6..e6b28babf 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
@@ -30,6 +30,8 @@ const (
//
// All fields representing time are in seconds since the epoch. Which means that
// they will overflow in January 2038.
+//
+// +marshal
type InodeOld struct {
ModeRaw uint16
UIDLo uint16
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
index dd03ee50e..90744e956 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
@@ -24,10 +24,12 @@ import (
// TestInodeSize tests that the inode structs are of the correct size.
func TestInodeSize(t *testing.T) {
- assertSize(t, InodeOld{}, OldInodeSize)
+ var iOld InodeOld
+ assertSize(t, &iOld, OldInodeSize)
// This was updated from 156 bytes to 160 bytes in Oct 2015.
- assertSize(t, InodeNew{}, 160)
+ var iNew InodeNew
+ assertSize(t, &iNew, 160)
}
// TestTimestampSeconds tests that the seconds part of [a/c/m] timestamps in
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock.go b/pkg/sentry/fsimpl/ext/disklayout/superblock.go
index 8bb327006..70948ebe9 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock.go
@@ -14,6 +14,10 @@
package disklayout
+import (
+ "gvisor.dev/gvisor/pkg/marshal"
+)
+
const (
// SbOffset is the absolute offset at which the superblock is placed.
SbOffset = 1024
@@ -38,6 +42,8 @@ const (
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#super-block.
type SuperBlock interface {
+ marshal.Marshallable
+
// InodesCount returns the total number of inodes in this filesystem.
InodesCount() uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
index 53e515fd3..4dc6080fb 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
@@ -17,6 +17,8 @@ package disklayout
// SuperBlock32Bit implements SuperBlock and represents the 32-bit version of
// the ext4_super_block struct in fs/ext4/ext4.h. Should be used only if
// RevLevel = DynamicRev and 64-bit feature is disabled.
+//
+// +marshal
type SuperBlock32Bit struct {
// We embed the old superblock struct here because the 32-bit version is just
// an extension of the old version.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
index 7c1053fb4..2c9039327 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
@@ -19,6 +19,8 @@ package disklayout
// 1024 bytes (smallest possible block size) and hence the superblock always
// fits in no more than one data block. Should only be used when the 64-bit
// feature is set.
+//
+// +marshal
type SuperBlock64Bit struct {
// We embed the 32-bit struct here because 64-bit version is just an extension
// of the 32-bit version.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
index 9221e0251..e4709f23c 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
@@ -16,6 +16,8 @@ package disklayout
// SuperBlockOld implements SuperBlock and represents the old version of the
// superblock struct. Should be used only if RevLevel = OldRev.
+//
+// +marshal
type SuperBlockOld struct {
InodesCountRaw uint32
BlocksCountLo uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
index 463b5ba21..b734b6987 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
@@ -21,7 +21,10 @@ import (
// TestSuperBlockSize tests that the superblock structs are of the correct
// size.
func TestSuperBlockSize(t *testing.T) {
- assertSize(t, SuperBlockOld{}, 84)
- assertSize(t, SuperBlock32Bit{}, 336)
- assertSize(t, SuperBlock64Bit{}, 1024)
+ var sbOld SuperBlockOld
+ assertSize(t, &sbOld, 84)
+ var sb32 SuperBlock32Bit
+ assertSize(t, &sb32, 336)
+ var sb64 SuperBlock64Bit
+ assertSize(t, &sb64, 1024)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
index 9c63f04c0..a4bc08411 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
@@ -18,13 +18,13 @@ import (
"reflect"
"testing"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal"
)
-func assertSize(t *testing.T, v interface{}, want uintptr) {
+func assertSize(t *testing.T, v marshal.Marshallable, want int) {
t.Helper()
- if got := binary.Size(v); got != want {
+ if got := v.SizeBytes(); got != want {
t.Errorf("struct %s should be exactly %d bytes but is %d bytes", reflect.TypeOf(v).Name(), want, got)
}
}
diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go
index 08ffc2834..aca258d40 100644
--- a/pkg/sentry/fsimpl/ext/ext.go
+++ b/pkg/sentry/fsimpl/ext/ext.go
@@ -34,6 +34,8 @@ import (
const Name = "ext"
// FilesystemType implements vfs.FilesystemType.
+//
+// +stateify savable
type FilesystemType struct{}
// Compiles only if FilesystemType implements vfs.FilesystemType.
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 2dbaee287..0989558cd 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -71,7 +71,11 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys
vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ InternalData: int(f.Fd()),
+ },
+ })
if err != nil {
f.Close()
return nil, nil, nil, nil, err
diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go
index c36225a7c..778460107 100644
--- a/pkg/sentry/fsimpl/ext/extent_file.go
+++ b/pkg/sentry/fsimpl/ext/extent_file.go
@@ -18,12 +18,13 @@ import (
"io"
"sort"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/syserror"
)
// extentFile is a type of regular file which uses extents to store file data.
+//
+// +stateify savable
type extentFile struct {
regFile regularFile
@@ -58,7 +59,7 @@ func newExtentFile(args inodeArgs) (*extentFile, error) {
func (f *extentFile) buildExtTree() error {
rootNodeData := f.regFile.inode.diskInode.Data()
- binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header)
+ f.root.Header.UnmarshalBytes(rootNodeData[:disklayout.ExtentHeaderSize])
// Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries.
if f.root.Header.NumEntries > 4 {
@@ -77,7 +78,7 @@ func (f *extentFile) buildExtTree() error {
// Internal node.
curEntry = &disklayout.ExtentIdx{}
}
- binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry)
+ curEntry.UnmarshalBytes(rootNodeData[off : off+disklayout.ExtentEntrySize])
f.root.Entries[i].Entry = curEntry
}
diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go
index cd10d46ee..985f76ac0 100644
--- a/pkg/sentry/fsimpl/ext/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/extent_test.go
@@ -21,7 +21,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
)
@@ -202,13 +201,14 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, []
// writeTree writes the tree represented by `root` to the inode and disk. It
// also writes random file data on disk.
func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBlkSize uint64) []byte {
- rootData := binary.Marshal(nil, binary.LittleEndian, root.Header)
+ rootData := in.diskInode.Data()
+ root.Header.MarshalBytes(rootData)
+ off := root.Header.SizeBytes()
for _, ep := range root.Entries {
- rootData = binary.Marshal(rootData, binary.LittleEndian, ep.Entry)
+ ep.Entry.MarshalBytes(rootData[off:])
+ off += ep.Entry.SizeBytes()
}
- copy(in.diskInode.Data(), rootData)
-
var fileData []byte
for _, ep := range root.Entries {
if root.Header.Height == 0 {
@@ -223,13 +223,14 @@ func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBl
// writeTreeToDisk is the recursive step for writeTree which writes the tree
// on the disk only. Also writes random file data on disk.
func writeTreeToDisk(disk []byte, curNode disklayout.ExtentEntryPair) []byte {
- nodeData := binary.Marshal(nil, binary.LittleEndian, curNode.Node.Header)
+ nodeData := disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:]
+ curNode.Node.Header.MarshalBytes(nodeData)
+ off := curNode.Node.Header.SizeBytes()
for _, ep := range curNode.Node.Entries {
- nodeData = binary.Marshal(nodeData, binary.LittleEndian, ep.Entry)
+ ep.Entry.MarshalBytes(nodeData[off:])
+ off += ep.Entry.SizeBytes()
}
- copy(disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:], nodeData)
-
var fileData []byte
for _, ep := range curNode.Node.Entries {
if curNode.Node.Header.Height == 0 {
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index 8565d1a66..917f1873d 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -38,11 +38,13 @@ var (
)
// filesystem implements vfs.FilesystemImpl.
+//
+// +stateify savable
type filesystem struct {
vfsfs vfs.Filesystem
// mu serializes changes to the Dentry tree.
- mu sync.RWMutex
+ mu sync.RWMutex `state:"nosave"`
// dev represents the underlying fs device. It does not require protection
// because io.ReaderAt permits concurrent read calls to it. It translates to
@@ -490,7 +492,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return syserror.EROFS
}
-// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt.
func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
_, inode, err := fs.walk(ctx, rp, false)
if err != nil {
@@ -504,8 +506,8 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
return nil, syserror.ECONNREFUSED
}
-// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
-func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt.
+func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
_, _, err := fs.walk(ctx, rp, false)
if err != nil {
return nil, err
@@ -513,8 +515,8 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
return nil, syserror.ENOTSUP
}
-// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
-func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt.
+func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) {
_, _, err := fs.walk(ctx, rp, false)
if err != nil {
return "", err
@@ -522,8 +524,8 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
return "", syserror.ENOTSUP
}
-// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
-func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt.
+func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error {
_, _, err := fs.walk(ctx, rp, false)
if err != nil {
return err
@@ -531,8 +533,8 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
return syserror.ENOTSUP
}
-// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
-func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt.
+func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
_, _, err := fs.walk(ctx, rp, false)
if err != nil {
return err
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index 30636cf66..9009ba3c7 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -37,6 +37,8 @@ import (
// |-- regular--
// |-- extent file
// |-- block map file
+//
+// +stateify savable
type inode struct {
// refs is a reference count. refs is accessed using atomic memory operations.
refs int64
diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go
index e73e740d6..4a5539b37 100644
--- a/pkg/sentry/fsimpl/ext/regular_file.go
+++ b/pkg/sentry/fsimpl/ext/regular_file.go
@@ -31,6 +31,8 @@ import (
// regularFile represents a regular file's inode. This too follows the
// inheritance pattern prevelant in the vfs layer described in
// pkg/sentry/vfs/README.md.
+//
+// +stateify savable
type regularFile struct {
inode inode
@@ -67,6 +69,8 @@ func (in *inode) isRegular() bool {
// directoryFD represents a directory file description. It implements
// vfs.FileDescriptionImpl.
+//
+// +stateify savable
type regularFileFD struct {
fileDescription
vfs.LockFD
@@ -75,7 +79,7 @@ type regularFileFD struct {
off int64
// offMu serializes operations that may mutate off.
- offMu sync.Mutex
+ offMu sync.Mutex `state:"nosave"`
}
// Release implements vfs.FileDescriptionImpl.Release.
diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go
index 2fd0d1fa8..5e2bcc837 100644
--- a/pkg/sentry/fsimpl/ext/symlink.go
+++ b/pkg/sentry/fsimpl/ext/symlink.go
@@ -23,6 +23,8 @@ import (
)
// symlink represents a symlink inode.
+//
+// +stateify savable
type symlink struct {
inode inode
target string // immutable
@@ -61,9 +63,11 @@ func (in *inode) isSymlink() bool {
return ok
}
-// symlinkFD represents a symlink file description and implements implements
+// symlinkFD represents a symlink file description and implements
// vfs.FileDescriptionImpl. which may only be used if open options contains
// O_PATH. For this reason most of the functions return EBADF.
+//
+// +stateify savable
type symlinkFD struct {
fileDescription
vfs.NoLockFD
diff --git a/pkg/sentry/fsimpl/ext/utils.go b/pkg/sentry/fsimpl/ext/utils.go
index d8b728f8c..58ef7b9b8 100644
--- a/pkg/sentry/fsimpl/ext/utils.go
+++ b/pkg/sentry/fsimpl/ext/utils.go
@@ -17,21 +17,21 @@ package ext
import (
"io"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/syserror"
)
// readFromDisk performs a binary read from disk into the given struct from
// the absolute offset provided.
-func readFromDisk(dev io.ReaderAt, abOff int64, v interface{}) error {
- n := binary.Size(v)
+func readFromDisk(dev io.ReaderAt, abOff int64, v marshal.Marshallable) error {
+ n := v.SizeBytes()
buf := make([]byte, n)
if read, _ := dev.ReadAt(buf, abOff); read < int(n) {
return syserror.EIO
}
- binary.Unmarshal(buf, binary.LittleEndian, v)
+ v.UnmarshalBytes(buf)
return nil
}
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
index 53a4f3012..045d7ab08 100644
--- a/pkg/sentry/fsimpl/fuse/BUILD
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -30,19 +30,26 @@ go_library(
name = "fuse",
srcs = [
"connection.go",
+ "connection_control.go",
"dev.go",
+ "directory.go",
+ "file.go",
"fusefs.go",
- "init.go",
"inode_refs.go",
+ "read_write.go",
"register.go",
+ "regular_file.go",
"request_list.go",
+ "request_response.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
"//pkg/log",
+ "//pkg/marshal",
"//pkg/refs",
+ "//pkg/safemem",
"//pkg/sentry/fsimpl/devtmpfs",
"//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/kernel",
@@ -52,7 +59,6 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
"@org_golang_x_sys//unix:go_default_library",
],
)
@@ -60,10 +66,15 @@ go_library(
go_test(
name = "fuse_test",
size = "small",
- srcs = ["dev_test.go"],
+ srcs = [
+ "connection_test.go",
+ "dev_test.go",
+ "utils_test.go",
+ ],
library = ":fuse",
deps = [
"//pkg/abi/linux",
+ "//pkg/marshal",
"//pkg/sentry/fsimpl/testutil",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
@@ -71,6 +82,5 @@ go_test(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/fsimpl/fuse/connection.go b/pkg/sentry/fsimpl/fuse/connection.go
index 6df2728ab..8ccda1264 100644
--- a/pkg/sentry/fsimpl/fuse/connection.go
+++ b/pkg/sentry/fsimpl/fuse/connection.go
@@ -15,31 +15,17 @@
package fuse
import (
- "errors"
- "fmt"
"sync"
- "sync/atomic"
- "syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
-// maxActiveRequestsDefault is the default setting controlling the upper bound
-// on the number of active requests at any given time.
-const maxActiveRequestsDefault = 10000
-
-// Ordinary requests have even IDs, while interrupts IDs are odd.
-// Used to increment the unique ID for each FUSE request.
-var reqIDStep uint64 = 2
-
const (
// fuseDefaultMaxBackground is the default value for MaxBackground.
fuseDefaultMaxBackground = 12
@@ -52,43 +38,39 @@ const (
fuseDefaultMaxPagesPerReq = 32
)
-// Request represents a FUSE operation request that hasn't been sent to the
-// server yet.
+// connection is the struct by which the sentry communicates with the FUSE server daemon.
//
-// +stateify savable
-type Request struct {
- requestEntry
-
- id linux.FUSEOpID
- hdr *linux.FUSEHeaderIn
- data []byte
-}
-
-// Response represents an actual response from the server, including the
-// response payload.
+// Lock order:
+// - conn.fd.mu
+// - conn.mu
+// - conn.asyncMu
//
// +stateify savable
-type Response struct {
- opcode linux.FUSEOpcode
- hdr linux.FUSEHeaderOut
- data []byte
-}
-
-// connection is the struct by which the sentry communicates with the FUSE server daemon.
type connection struct {
fd *DeviceFD
+ // mu protects access to struct memebers.
+ mu sync.Mutex `state:"nosave"`
+
+ // attributeVersion is the version of connection's attributes.
+ attributeVersion uint64
+
+ // We target FUSE 7.23.
// The following FUSE_INIT flags are currently unsupported by this implementation:
- // - FUSE_ATOMIC_O_TRUNC: requires open(..., O_TRUNC)
// - FUSE_EXPORT_SUPPORT
- // - FUSE_HANDLE_KILLPRIV
// - FUSE_POSIX_LOCKS: requires POSIX locks
// - FUSE_FLOCK_LOCKS: requires POSIX locks
// - FUSE_AUTO_INVAL_DATA: requires page caching eviction
- // - FUSE_EXPLICIT_INVAL_DATA: requires page caching eviction
// - FUSE_DO_READDIRPLUS/FUSE_READDIRPLUS_AUTO: requires FUSE_READDIRPLUS implementation
// - FUSE_ASYNC_DIO
- // - FUSE_POSIX_ACL: affects defaultPermissions, posixACL, xattr handler
+ // - FUSE_PARALLEL_DIROPS (7.25)
+ // - FUSE_HANDLE_KILLPRIV (7.26)
+ // - FUSE_POSIX_ACL: affects defaultPermissions, posixACL, xattr handler (7.26)
+ // - FUSE_ABORT_ERROR (7.27)
+ // - FUSE_CACHE_SYMLINKS (7.28)
+ // - FUSE_NO_OPENDIR_SUPPORT (7.29)
+ // - FUSE_EXPLICIT_INVAL_DATA: requires page caching eviction (7.30)
+ // - FUSE_MAP_ALIGNMENT (7.31)
// initialized after receiving FUSE_INIT reply.
// Until it's set, suspend sending FUSE requests.
@@ -96,11 +78,7 @@ type connection struct {
initialized int32
// initializedChan is used to block requests before initialization.
- initializedChan chan struct{}
-
- // blocked when there are too many outstading backgrounds requests (NumBackground == MaxBackground).
- // TODO(gvisor.dev/issue/3185): update the numBackground accordingly; use a channel to block.
- blocked bool
+ initializedChan chan struct{} `state:".(bool)"`
// connected (connection established) when a new FUSE file system is created.
// Set to false when:
@@ -109,48 +87,55 @@ type connection struct {
// device release.
connected bool
- // aborted via sysfs.
- // TODO(gvisor.dev/issue/3185): abort all queued requests.
- aborted bool
-
// connInitError if FUSE_INIT encountered error (major version mismatch).
// Only set in INIT.
connInitError bool
// connInitSuccess if FUSE_INIT is successful.
// Only set in INIT.
- // Used for destory.
+ // Used for destory (not yet implemented).
connInitSuccess bool
- // TODO(gvisor.dev/issue/3185): All the queue logic are working in progress.
-
- // NumberBackground is the number of requests in the background.
- numBackground uint16
+ // aborted via sysfs, and will send ECONNABORTED to read after disconnection (instead of ENODEV).
+ // Set only if abortErr is true and via fuse control fs (not yet implemented).
+ // TODO(gvisor.dev/issue/3525): set this to true when user aborts.
+ aborted bool
- // congestionThreshold for NumBackground.
- // Negotiated in FUSE_INIT.
- congestionThreshold uint16
+ // numWating is the number of requests waiting to be
+ // sent to FUSE device or being processed by FUSE daemon.
+ numWaiting uint32
- // maxBackground is the maximum number of NumBackground.
- // Block connection when it is reached.
- // Negotiated in FUSE_INIT.
- maxBackground uint16
+ // Terminology note:
+ //
+ // - `asyncNumMax` is the `MaxBackground` in the FUSE_INIT_IN struct.
+ //
+ // - `asyncCongestionThreshold` is the `CongestionThreshold` in the FUSE_INIT_IN struct.
+ //
+ // We call the "background" requests in unix term as async requests.
+ // The "async requests" in unix term is our async requests that expect a reply,
+ // i.e. `!request.noReply`
- // numActiveBackground is the number of requests in background and has being marked as active.
- numActiveBackground uint16
+ // asyncMu protects the async request fields.
+ asyncMu sync.Mutex `state:"nosave"`
- // numWating is the number of requests waiting for completion.
- numWaiting uint32
+ // asyncNum is the number of async requests.
+ // Protected by asyncMu.
+ asyncNum uint16
- // TODO(gvisor.dev/issue/3185): BgQueue
- // some queue for background queued requests.
+ // asyncCongestionThreshold the number of async requests.
+ // Negotiated in FUSE_INIT as "CongestionThreshold".
+ // TODO(gvisor.dev/issue/3529): add congestion control.
+ // Protected by asyncMu.
+ asyncCongestionThreshold uint16
- // bgLock protects:
- // MaxBackground, CongestionThreshold, NumBackground,
- // NumActiveBackground, BgQueue, Blocked.
- bgLock sync.Mutex
+ // asyncNumMax is the maximum number of asyncNum.
+ // Connection blocks the async requests when it is reached.
+ // Negotiated in FUSE_INIT as "MaxBackground".
+ // Protected by asyncMu.
+ asyncNumMax uint16
// maxRead is the maximum size of a read buffer in in bytes.
+ // Initialized from a fuse fs parameter.
maxRead uint32
// maxWrite is the maximum size of a write buffer in bytes.
@@ -165,23 +150,20 @@ type connection struct {
// Negotiated and only set in INIT.
minor uint32
- // asyncRead if read pages asynchronously.
+ // atomicOTrunc is true when FUSE does not send a separate SETATTR request
+ // before open with O_TRUNC flag.
// Negotiated and only set in INIT.
- asyncRead bool
+ atomicOTrunc bool
- // abortErr is true if kernel need to return an unique read error after abort.
+ // asyncRead if read pages asynchronously.
// Negotiated and only set in INIT.
- abortErr bool
+ asyncRead bool
// writebackCache is true for write-back cache policy,
// false for write-through policy.
// Negotiated and only set in INIT.
writebackCache bool
- // cacheSymlinks if filesystem needs to cache READLINK responses in page cache.
- // Negotiated and only set in INIT.
- cacheSymlinks bool
-
// bigWrites if doing multi-page cached writes.
// Negotiated and only set in INIT.
bigWrites bool
@@ -189,116 +171,86 @@ type connection struct {
// dontMask if filestestem does not apply umask to creation modes.
// Negotiated in INIT.
dontMask bool
+
+ // noOpen if FUSE server doesn't support open operation.
+ // This flag only influence performance, not correctness of the program.
+ noOpen bool
+}
+
+func (conn *connection) saveInitializedChan() bool {
+ select {
+ case <-conn.initializedChan:
+ return true // Closed.
+ default:
+ return false // Not closed.
+ }
+}
+
+func (conn *connection) loadInitializedChan(closed bool) {
+ conn.initializedChan = make(chan struct{}, 1)
+ if closed {
+ close(conn.initializedChan)
+ }
}
// newFUSEConnection creates a FUSE connection to fd.
-func newFUSEConnection(_ context.Context, fd *vfs.FileDescription, maxInFlightRequests uint64) (*connection, error) {
+func newFUSEConnection(_ context.Context, fd *vfs.FileDescription, opts *filesystemOptions) (*connection, error) {
// Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to
// mount a FUSE filesystem.
fuseFD := fd.Impl().(*DeviceFD)
- fuseFD.mounted = true
// Create the writeBuf for the header to be stored in.
hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
fuseFD.writeBuf = make([]byte, hdrLen)
fuseFD.completions = make(map[linux.FUSEOpID]*futureResponse)
- fuseFD.fullQueueCh = make(chan struct{}, maxInFlightRequests)
+ fuseFD.fullQueueCh = make(chan struct{}, opts.maxActiveRequests)
fuseFD.writeCursor = 0
return &connection{
- fd: fuseFD,
- maxBackground: fuseDefaultMaxBackground,
- congestionThreshold: fuseDefaultCongestionThreshold,
- maxPages: fuseDefaultMaxPagesPerReq,
- initializedChan: make(chan struct{}),
- connected: true,
- }, nil
-}
-
-// SetInitialized atomically sets the connection as initialized.
-func (conn *connection) SetInitialized() {
- // Unblock the requests sent before INIT.
- close(conn.initializedChan)
-
- // Close the channel first to avoid the non-atomic situation
- // where conn.initialized is true but there are
- // tasks being blocked on the channel.
- // And it prevents the newer tasks from gaining
- // unnecessary higher chance to be issued before the blocked one.
-
- atomic.StoreInt32(&(conn.initialized), int32(1))
-}
-
-// IsInitialized atomically check if the connection is initialized.
-// pairs with SetInitialized().
-func (conn *connection) Initialized() bool {
- return atomic.LoadInt32(&(conn.initialized)) != 0
-}
-
-// NewRequest creates a new request that can be sent to the FUSE server.
-func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) {
- conn.fd.mu.Lock()
- defer conn.fd.mu.Unlock()
- conn.fd.nextOpID += linux.FUSEOpID(reqIDStep)
-
- hdrLen := (*linux.FUSEHeaderIn)(nil).SizeBytes()
- hdr := linux.FUSEHeaderIn{
- Len: uint32(hdrLen + payload.SizeBytes()),
- Opcode: opcode,
- Unique: conn.fd.nextOpID,
- NodeID: ino,
- UID: uint32(creds.EffectiveKUID),
- GID: uint32(creds.EffectiveKGID),
- PID: pid,
- }
-
- buf := make([]byte, hdr.Len)
- hdr.MarshalUnsafe(buf[:hdrLen])
- payload.MarshalUnsafe(buf[hdrLen:])
-
- return &Request{
- id: hdr.Unique,
- hdr: &hdr,
- data: buf,
+ fd: fuseFD,
+ asyncNumMax: fuseDefaultMaxBackground,
+ asyncCongestionThreshold: fuseDefaultCongestionThreshold,
+ maxRead: opts.maxRead,
+ maxPages: fuseDefaultMaxPagesPerReq,
+ initializedChan: make(chan struct{}),
+ connected: true,
}, nil
}
-// Call makes a request to the server and blocks the invoking task until a
-// server responds with a response. Task should never be nil.
-// Requests will not be sent before the connection is initialized.
-// For async tasks, use CallAsync().
-func (conn *connection) Call(t *kernel.Task, r *Request) (*Response, error) {
- // Block requests sent before connection is initalized.
- if !conn.Initialized() {
- if err := t.Block(conn.initializedChan); err != nil {
- return nil, err
- }
- }
-
- return conn.call(t, r)
+// CallAsync makes an async (aka background) request.
+// It's a simple wrapper around Call().
+func (conn *connection) CallAsync(t *kernel.Task, r *Request) error {
+ r.async = true
+ _, err := conn.Call(t, r)
+ return err
}
-// CallAsync makes an async (aka background) request.
-// Those requests either do not expect a response (e.g. release) or
-// the response should be handled by others (e.g. init).
-// Return immediately unless the connection is blocked (before initialization).
-// Async call example: init, release, forget, aio, interrupt.
+// Call makes a request to the server.
+// Block before the connection is initialized.
// When the Request is FUSE_INIT, it will not be blocked before initialization.
-func (conn *connection) CallAsync(t *kernel.Task, r *Request) error {
+// Task should never be nil.
+//
+// For a sync request, it blocks the invoking task until
+// a server responds with a response.
+//
+// For an async request (that do not expect a response immediately),
+// it returns directly unless being blocked either before initialization
+// or when there are too many async requests ongoing.
+//
+// Example for async request:
+// init, readahead, write, async read/write, fuse_notify_reply,
+// non-sync release, interrupt, forget.
+//
+// The forget request does not have a reply,
+// as documented in include/uapi/linux/fuse.h:FUSE_FORGET.
+func (conn *connection) Call(t *kernel.Task, r *Request) (*Response, error) {
// Block requests sent before connection is initalized.
if !conn.Initialized() && r.hdr.Opcode != linux.FUSE_INIT {
if err := t.Block(conn.initializedChan); err != nil {
- return err
+ return nil, err
}
}
- // This should be the only place that invokes call() with a nil task.
- _, err := conn.call(nil, r)
- return err
-}
-
-// call makes a call without blocking checks.
-func (conn *connection) call(t *kernel.Task, r *Request) (*Response, error) {
if !conn.connected {
return nil, syserror.ENOTCONN
}
@@ -315,31 +267,6 @@ func (conn *connection) call(t *kernel.Task, r *Request) (*Response, error) {
return fut.resolve(t)
}
-// Error returns the error of the FUSE call.
-func (r *Response) Error() error {
- errno := r.hdr.Error
- if errno >= 0 {
- return nil
- }
-
- sysErrNo := syscall.Errno(-errno)
- return error(sysErrNo)
-}
-
-// UnmarshalPayload unmarshals the response data into m.
-func (r *Response) UnmarshalPayload(m marshal.Marshallable) error {
- hdrLen := r.hdr.SizeBytes()
- haveDataLen := r.hdr.Len - uint32(hdrLen)
- wantDataLen := uint32(m.SizeBytes())
-
- if haveDataLen < wantDataLen {
- return fmt.Errorf("payload too small. Minimum data lenth required: %d, but got data length %d", wantDataLen, haveDataLen)
- }
-
- m.UnmarshalUnsafe(r.data[hdrLen:])
- return nil
-}
-
// callFuture makes a request to the server and returns a future response.
// Call resolve() when the response needs to be fulfilled.
func (conn *connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, error) {
@@ -358,11 +285,6 @@ func (conn *connection) callFuture(t *kernel.Task, r *Request) (*futureResponse,
// if there are always too many ongoing requests all the time. The
// supported maxActiveRequests setting should be really high to avoid this.
for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests {
- if t == nil {
- // Since there is no task that is waiting. We must error out.
- return nil, errors.New("FUSE request queue full")
- }
-
log.Infof("Blocking request %v from being queued. Too many active requests: %v",
r.id, conn.fd.numActiveRequests)
conn.fd.mu.Unlock()
@@ -378,9 +300,19 @@ func (conn *connection) callFuture(t *kernel.Task, r *Request) (*futureResponse,
// callFutureLocked makes a request to the server and returns a future response.
func (conn *connection) callFutureLocked(t *kernel.Task, r *Request) (*futureResponse, error) {
+ // Check connected again holding conn.mu.
+ conn.mu.Lock()
+ if !conn.connected {
+ conn.mu.Unlock()
+ // we checked connected before,
+ // this must be due to aborted connection.
+ return nil, syserror.ECONNABORTED
+ }
+ conn.mu.Unlock()
+
conn.fd.queue.PushBack(r)
- conn.fd.numActiveRequests += 1
- fut := newFutureResponse(r.hdr.Opcode)
+ conn.fd.numActiveRequests++
+ fut := newFutureResponse(r)
conn.fd.completions[r.id] = fut
// Signal the readers that there is something to read.
@@ -388,50 +320,3 @@ func (conn *connection) callFutureLocked(t *kernel.Task, r *Request) (*futureRes
return fut, nil
}
-
-// futureResponse represents an in-flight request, that may or may not have
-// completed yet. Convert it to a resolved Response by calling Resolve, but note
-// that this may block.
-//
-// +stateify savable
-type futureResponse struct {
- opcode linux.FUSEOpcode
- ch chan struct{}
- hdr *linux.FUSEHeaderOut
- data []byte
-}
-
-// newFutureResponse creates a future response to a FUSE request.
-func newFutureResponse(opcode linux.FUSEOpcode) *futureResponse {
- return &futureResponse{
- opcode: opcode,
- ch: make(chan struct{}),
- }
-}
-
-// resolve blocks the task until the server responds to its corresponding request,
-// then returns a resolved response.
-func (f *futureResponse) resolve(t *kernel.Task) (*Response, error) {
- // If there is no Task associated with this request - then we don't try to resolve
- // the response. Instead, the task writing the response (proxy to the server) will
- // process the response on our behalf.
- if t == nil {
- log.Infof("fuse.Response.resolve: Not waiting on a response from server.")
- return nil, nil
- }
-
- if err := t.Block(f.ch); err != nil {
- return nil, err
- }
-
- return f.getResponse(), nil
-}
-
-// getResponse creates a Response from the data the futureResponse has.
-func (f *futureResponse) getResponse() *Response {
- return &Response{
- opcode: f.opcode,
- hdr: *f.hdr,
- data: f.data,
- }
-}
diff --git a/pkg/sentry/fsimpl/fuse/init.go b/pkg/sentry/fsimpl/fuse/connection_control.go
index 779c2bd3f..bfde78559 100644
--- a/pkg/sentry/fsimpl/fuse/init.go
+++ b/pkg/sentry/fsimpl/fuse/connection_control.go
@@ -15,7 +15,11 @@
package fuse
import (
+ "sync/atomic"
+ "syscall"
+
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
@@ -29,9 +33,10 @@ const (
// Follow the same behavior as unix fuse implementation.
fuseMaxTimeGranNs = 1000000000
- // Minimum value for MaxWrite.
+ // Minimum value for MaxWrite and MaxRead.
// Follow the same behavior as unix fuse implementation.
fuseMinMaxWrite = 4096
+ fuseMinMaxRead = 4096
// Temporary default value for max readahead, 128kb.
fuseDefaultMaxReadahead = 131072
@@ -49,6 +54,26 @@ var (
MaxUserCongestionThreshold uint16 = fuseDefaultCongestionThreshold
)
+// SetInitialized atomically sets the connection as initialized.
+func (conn *connection) SetInitialized() {
+ // Unblock the requests sent before INIT.
+ close(conn.initializedChan)
+
+ // Close the channel first to avoid the non-atomic situation
+ // where conn.initialized is true but there are
+ // tasks being blocked on the channel.
+ // And it prevents the newer tasks from gaining
+ // unnecessary higher chance to be issued before the blocked one.
+
+ atomic.StoreInt32(&(conn.initialized), int32(1))
+}
+
+// IsInitialized atomically check if the connection is initialized.
+// pairs with SetInitialized().
+func (conn *connection) Initialized() bool {
+ return atomic.LoadInt32(&(conn.initialized)) != 0
+}
+
// InitSend sends a FUSE_INIT request.
func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error {
in := linux.FUSEInitIn{
@@ -70,29 +95,31 @@ func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error {
}
// InitRecv receives a FUSE_INIT reply and process it.
+//
+// Preconditions: conn.asyncMu must not be held if minor verion is newer than 13.
func (conn *connection) InitRecv(res *Response, hasSysAdminCap bool) error {
if err := res.Error(); err != nil {
return err
}
- var out linux.FUSEInitOut
- if err := res.UnmarshalPayload(&out); err != nil {
+ initRes := fuseInitRes{initLen: res.DataLen()}
+ if err := res.UnmarshalPayload(&initRes); err != nil {
return err
}
- return conn.initProcessReply(&out, hasSysAdminCap)
+ return conn.initProcessReply(&initRes.initOut, hasSysAdminCap)
}
// Process the FUSE_INIT reply from the FUSE server.
+// It tries to acquire the conn.asyncMu lock if minor version is newer than 13.
func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap bool) error {
+ // No matter error or not, always set initialzied.
+ // to unblock the blocked requests.
+ defer conn.SetInitialized()
+
// No support for old major fuse versions.
if out.Major != linux.FUSE_KERNEL_VERSION {
conn.connInitError = true
-
- // Set the connection as initialized and unblock the blocked requests
- // (i.e. return error for them).
- conn.SetInitialized()
-
return nil
}
@@ -100,29 +127,14 @@ func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap
conn.connInitSuccess = true
conn.minor = out.Minor
- // No support for limits before minor version 13.
- if out.Minor >= 13 {
- conn.bgLock.Lock()
-
- if out.MaxBackground > 0 {
- conn.maxBackground = out.MaxBackground
-
- if !hasSysAdminCap &&
- conn.maxBackground > MaxUserBackgroundRequest {
- conn.maxBackground = MaxUserBackgroundRequest
- }
- }
-
- if out.CongestionThreshold > 0 {
- conn.congestionThreshold = out.CongestionThreshold
-
- if !hasSysAdminCap &&
- conn.congestionThreshold > MaxUserCongestionThreshold {
- conn.congestionThreshold = MaxUserCongestionThreshold
- }
- }
-
- conn.bgLock.Unlock()
+ // No support for negotiating MaxWrite before minor version 5.
+ if out.Minor >= 5 {
+ conn.maxWrite = out.MaxWrite
+ } else {
+ conn.maxWrite = fuseMinMaxWrite
+ }
+ if conn.maxWrite < fuseMinMaxWrite {
+ conn.maxWrite = fuseMinMaxWrite
}
// No support for the following flags before minor version 6.
@@ -131,8 +143,6 @@ func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap
conn.bigWrites = out.Flags&linux.FUSE_BIG_WRITES != 0
conn.dontMask = out.Flags&linux.FUSE_DONT_MASK != 0
conn.writebackCache = out.Flags&linux.FUSE_WRITEBACK_CACHE != 0
- conn.cacheSymlinks = out.Flags&linux.FUSE_CACHE_SYMLINKS != 0
- conn.abortErr = out.Flags&linux.FUSE_ABORT_ERROR != 0
// TODO(gvisor.dev/issue/3195): figure out how to use TimeGran (0 < TimeGran <= fuseMaxTimeGranNs).
@@ -148,19 +158,90 @@ func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap
}
}
- // No support for negotiating MaxWrite before minor version 5.
- if out.Minor >= 5 {
- conn.maxWrite = out.MaxWrite
- } else {
- conn.maxWrite = fuseMinMaxWrite
+ // No support for limits before minor version 13.
+ if out.Minor >= 13 {
+ conn.asyncMu.Lock()
+
+ if out.MaxBackground > 0 {
+ conn.asyncNumMax = out.MaxBackground
+
+ if !hasSysAdminCap &&
+ conn.asyncNumMax > MaxUserBackgroundRequest {
+ conn.asyncNumMax = MaxUserBackgroundRequest
+ }
+ }
+
+ if out.CongestionThreshold > 0 {
+ conn.asyncCongestionThreshold = out.CongestionThreshold
+
+ if !hasSysAdminCap &&
+ conn.asyncCongestionThreshold > MaxUserCongestionThreshold {
+ conn.asyncCongestionThreshold = MaxUserCongestionThreshold
+ }
+ }
+
+ conn.asyncMu.Unlock()
}
- if conn.maxWrite < fuseMinMaxWrite {
- conn.maxWrite = fuseMinMaxWrite
+
+ return nil
+}
+
+// Abort this FUSE connection.
+// It tries to acquire conn.fd.mu, conn.lock, conn.bgLock in order.
+// All possible requests waiting or blocking will be aborted.
+//
+// Preconditions: conn.fd.mu is locked.
+func (conn *connection) Abort(ctx context.Context) {
+ conn.mu.Lock()
+ conn.asyncMu.Lock()
+
+ if !conn.connected {
+ conn.asyncMu.Unlock()
+ conn.mu.Unlock()
+ conn.fd.mu.Unlock()
+ return
}
- // Set connection as initialized and unblock the requests
- // issued before init.
- conn.SetInitialized()
+ conn.connected = false
- return nil
+ // Empty the `fd.queue` that holds the requests
+ // not yet read by the FUSE daemon yet.
+ // These are a subset of the requests in `fuse.completion` map.
+ for !conn.fd.queue.Empty() {
+ req := conn.fd.queue.Front()
+ conn.fd.queue.Remove(req)
+ }
+
+ var terminate []linux.FUSEOpID
+
+ // 2. Collect the requests have not been sent to FUSE daemon,
+ // or have not received a reply.
+ for unique := range conn.fd.completions {
+ terminate = append(terminate, unique)
+ }
+
+ // Release locks to avoid deadlock.
+ conn.asyncMu.Unlock()
+ conn.mu.Unlock()
+
+ // 1. The requets blocked before initialization.
+ // Will reach call() `connected` check and return.
+ if !conn.Initialized() {
+ conn.SetInitialized()
+ }
+
+ // 2. Terminate the requests collected above.
+ // Set ECONNABORTED error.
+ // sendError() will remove them from `fd.completion` map.
+ // Will enter the path of a normally received error.
+ for _, toTerminate := range terminate {
+ conn.fd.sendError(ctx, -int32(syscall.ECONNABORTED), toTerminate)
+ }
+
+ // 3. The requests not yet written to FUSE device.
+ // Early terminate.
+ // Will reach callFutureLocked() `connected` check and return.
+ close(conn.fd.fullQueueCh)
+
+ // TODO(gvisor.dev/issue/3528): Forget all pending forget reqs.
}
diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go
new file mode 100644
index 000000000..91d16c1cf
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/connection_test.go
@@ -0,0 +1,117 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "math/rand"
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// TestConnectionInitBlock tests if initialization
+// correctly blocks and unblocks the connection.
+// Since it's unfeasible to test kernelTask.Block() in unit test,
+// the code in Call() are not tested here.
+func TestConnectionInitBlock(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ k := kernel.KernelFromContext(s.Ctx)
+
+ conn, _, err := newTestConnection(s, k, maxActiveRequestsDefault)
+ if err != nil {
+ t.Fatalf("newTestConnection: %v", err)
+ }
+
+ select {
+ case <-conn.initializedChan:
+ t.Fatalf("initializedChan should be blocking before SetInitialized")
+ default:
+ }
+
+ conn.SetInitialized()
+
+ select {
+ case <-conn.initializedChan:
+ default:
+ t.Fatalf("initializedChan should not be blocking after SetInitialized")
+ }
+}
+
+func TestConnectionAbort(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ k := kernel.KernelFromContext(s.Ctx)
+ creds := auth.CredentialsFromContext(s.Ctx)
+ task := kernel.TaskFromContext(s.Ctx)
+
+ const numRequests uint64 = 256
+
+ conn, _, err := newTestConnection(s, k, numRequests)
+ if err != nil {
+ t.Fatalf("newTestConnection: %v", err)
+ }
+
+ testObj := &testPayload{
+ data: rand.Uint32(),
+ }
+
+ var futNormal []*futureResponse
+
+ for i := 0; i < int(numRequests); i++ {
+ req, err := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj)
+ if err != nil {
+ t.Fatalf("NewRequest creation failed: %v", err)
+ }
+ fut, err := conn.callFutureLocked(task, req)
+ if err != nil {
+ t.Fatalf("callFutureLocked failed: %v", err)
+ }
+ futNormal = append(futNormal, fut)
+ }
+
+ conn.Abort(s.Ctx)
+
+ // Abort should unblock the initialization channel.
+ // Note: no test requests are actually blocked on `conn.initializedChan`.
+ select {
+ case <-conn.initializedChan:
+ default:
+ t.Fatalf("initializedChan should not be blocking after SetInitialized")
+ }
+
+ // Abort will return ECONNABORTED error to unblocked requests.
+ for _, fut := range futNormal {
+ if fut.getResponse().hdr.Error != -int32(syscall.ECONNABORTED) {
+ t.Fatalf("Incorrect error code received for aborted connection: %v", fut.getResponse().hdr.Error)
+ }
+ }
+
+ // After abort, Call() should return directly with ENOTCONN.
+ req, err := conn.NewRequest(creds, 0, 0, 0, testObj)
+ if err != nil {
+ t.Fatalf("NewRequest creation failed: %v", err)
+ }
+ _, err = conn.Call(task, req)
+ if err != syserror.ENOTCONN {
+ t.Fatalf("Incorrect error code received for Call() after connection aborted")
+ }
+
+}
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go
index e522ff9a0..1b86a4b4c 100644
--- a/pkg/sentry/fsimpl/fuse/dev.go
+++ b/pkg/sentry/fsimpl/fuse/dev.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -32,6 +31,8 @@ import (
const fuseDevMinor = 229
// fuseDevice implements vfs.Device for /dev/fuse.
+//
+// +stateify savable
type fuseDevice struct{}
// Open implements vfs.Device.Open.
@@ -50,15 +51,14 @@ func (fuseDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, op
}
// DeviceFD implements vfs.FileDescriptionImpl for /dev/fuse.
+//
+// +stateify savable
type DeviceFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
vfs.NoLockFD
- // mounted specifies whether a FUSE filesystem was mounted using the DeviceFD.
- mounted bool
-
// nextOpID is used to create new requests.
nextOpID linux.FUSEOpID
@@ -83,7 +83,7 @@ type DeviceFD struct {
writeCursorFR *futureResponse
// mu protects all the queues, maps, buffers and cursors and nextOpID.
- mu sync.Mutex
+ mu sync.Mutex `state:"nosave"`
// waitQueue is used to notify interested parties when the device becomes
// readable or writable.
@@ -92,21 +92,36 @@ type DeviceFD struct {
// fullQueueCh is a channel used to synchronize the readers with the writers.
// Writers (inbound requests to the filesystem) block if there are too many
// unprocessed in-flight requests.
- fullQueueCh chan struct{}
+ fullQueueCh chan struct{} `state:".(int)"`
// fs is the FUSE filesystem that this FD is being used for.
fs *filesystem
}
+func (fd *DeviceFD) saveFullQueueCh() int {
+ return cap(fd.fullQueueCh)
+}
+
+func (fd *DeviceFD) loadFullQueueCh(capacity int) {
+ fd.fullQueueCh = make(chan struct{}, capacity)
+}
+
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *DeviceFD) Release(context.Context) {
- fd.fs.conn.connected = false
+func (fd *DeviceFD) Release(ctx context.Context) {
+ if fd.fs != nil {
+ fd.fs.conn.mu.Lock()
+ fd.fs.conn.connected = false
+ fd.fs.conn.mu.Unlock()
+
+ fd.fs.VFSFilesystem().DecRef(ctx)
+ fd.fs = nil
+ }
}
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
// Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
- if !fd.mounted {
+ if fd.fs == nil {
return 0, syserror.EPERM
}
@@ -116,10 +131,16 @@ func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset in
// Read implements vfs.FileDescriptionImpl.Read.
func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
// Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
- if !fd.mounted {
+ if fd.fs == nil {
return 0, syserror.EPERM
}
+ // Return ENODEV if the filesystem is umounted.
+ if fd.fs.umounted {
+ // TODO(gvisor.dev/issue/3525): return ECONNABORTED if aborted via fuse control fs.
+ return 0, syserror.ENODEV
+ }
+
// We require that any Read done on this filesystem have a sane minimum
// read buffer. It must have the capacity for the fixed parts of any request
// header (Linux uses the request header and the FUSEWriteIn header for this
@@ -143,58 +164,82 @@ func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.R
}
// readLocked implements the reading of the fuse device while locked with DeviceFD.mu.
+//
+// Preconditions: dst is large enough for any reasonable request.
func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
- if fd.queue.Empty() {
- return 0, syserror.ErrWouldBlock
- }
+ var req *Request
- var readCursor uint32
- var bytesRead int64
- for {
- req := fd.queue.Front()
- if dst.NumBytes() < int64(req.hdr.Len) {
- // The request is too large. Cannot process it. All requests must be smaller than the
- // negotiated size as specified by Connection.MaxWrite set as part of the FUSE_INIT
- // handshake.
- errno := -int32(syscall.EIO)
- if req.hdr.Opcode == linux.FUSE_SETXATTR {
- errno = -int32(syscall.E2BIG)
- }
+ // Find the first valid request.
+ // For the normal case this loop only execute once.
+ for !fd.queue.Empty() {
+ req = fd.queue.Front()
- // Return the error to the calling task.
- if err := fd.sendError(ctx, errno, req); err != nil {
- return 0, err
- }
+ if int64(req.hdr.Len)+int64(len(req.payload)) <= dst.NumBytes() {
+ break
+ }
- // We're done with this request.
- fd.queue.Remove(req)
+ // The request is too large. Cannot process it. All requests must be smaller than the
+ // negotiated size as specified by Connection.MaxWrite set as part of the FUSE_INIT
+ // handshake.
+ errno := -int32(syscall.EIO)
+ if req.hdr.Opcode == linux.FUSE_SETXATTR {
+ errno = -int32(syscall.E2BIG)
+ }
- // Restart the read as this request was invalid.
- log.Warningf("fuse.DeviceFD.Read: request found was too large. Restarting read.")
- return fd.readLocked(ctx, dst, opts)
+ // Return the error to the calling task.
+ if err := fd.sendError(ctx, errno, req.hdr.Unique); err != nil {
+ return 0, err
}
- n, err := dst.CopyOut(ctx, req.data[readCursor:])
+ // We're done with this request.
+ fd.queue.Remove(req)
+ req = nil
+ }
+
+ if req == nil {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ // We already checked the size: dst must be able to fit the whole request.
+ // Now we write the marshalled header, the payload,
+ // and the potential additional payload
+ // to the user memory IOSequence.
+
+ n, err := dst.CopyOut(ctx, req.data)
+ if err != nil {
+ return 0, err
+ }
+ if n != len(req.data) {
+ return 0, syserror.EIO
+ }
+
+ if req.hdr.Opcode == linux.FUSE_WRITE {
+ written, err := dst.DropFirst(n).CopyOut(ctx, req.payload)
if err != nil {
return 0, err
}
- readCursor += uint32(n)
- bytesRead += int64(n)
-
- if readCursor >= req.hdr.Len {
- // Fully done with this req, remove it from the queue.
- fd.queue.Remove(req)
- break
+ if written != len(req.payload) {
+ return 0, syserror.EIO
}
+ n += int(written)
}
- return bytesRead, nil
+ // Fully done with this req, remove it from the queue.
+ fd.queue.Remove(req)
+
+ // Remove noReply ones from map of requests expecting a reply.
+ if req.noReply {
+ fd.numActiveRequests -= 1
+ delete(fd.completions, req.hdr.Unique)
+ }
+
+ return int64(n), nil
}
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
// Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
- if !fd.mounted {
+ if fd.fs == nil {
return 0, syserror.EPERM
}
@@ -211,10 +256,15 @@ func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.
// writeLocked implements writing to the fuse device while locked with DeviceFD.mu.
func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
// Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
- if !fd.mounted {
+ if fd.fs == nil {
return 0, syserror.EPERM
}
+ // Return ENODEV if the filesystem is umounted.
+ if fd.fs.umounted {
+ return 0, syserror.ENODEV
+ }
+
var cn, n int64
hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
@@ -276,7 +326,8 @@ func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opt
fut, ok := fd.completions[hdr.Unique]
if !ok {
- // Server sent us a response for a request we never sent?
+ // Server sent us a response for a request we never sent,
+ // or for which we already received a reply (e.g. aborted), an unlikely event.
return 0, syserror.EINVAL
}
@@ -307,8 +358,23 @@ func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opt
// Readiness implements vfs.FileDescriptionImpl.Readiness.
func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ return fd.readinessLocked(mask)
+}
+
+// readinessLocked implements checking the readiness of the fuse device while
+// locked with DeviceFD.mu.
+func (fd *DeviceFD) readinessLocked(mask waiter.EventMask) waiter.EventMask {
var ready waiter.EventMask
- ready |= waiter.EventOut // FD is always writable
+
+ if fd.fs.umounted {
+ ready |= waiter.EventErr
+ return ready & mask
+ }
+
+ // FD is always writable.
+ ready |= waiter.EventOut
if !fd.queue.Empty() {
// Have reqs available, FD is readable.
ready |= waiter.EventIn
@@ -330,7 +396,7 @@ func (fd *DeviceFD) EventUnregister(e *waiter.Entry) {
// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
// Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
- if !fd.mounted {
+ if fd.fs == nil {
return 0, syserror.EPERM
}
@@ -338,59 +404,59 @@ func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64
}
// sendResponse sends a response to the waiting task (if any).
+//
+// Preconditions: fd.mu must be held.
func (fd *DeviceFD) sendResponse(ctx context.Context, fut *futureResponse) error {
- // See if the running task need to perform some action before returning.
- // Since we just finished writing the future, we can be sure that
- // getResponse generates a populated response.
- if err := fd.noReceiverAction(ctx, fut.getResponse()); err != nil {
- return err
- }
+ // Signal the task waiting on a response if any.
+ defer close(fut.ch)
// Signal that the queue is no longer full.
select {
case fd.fullQueueCh <- struct{}{}:
default:
}
- fd.numActiveRequests -= 1
+ fd.numActiveRequests--
+
+ if fut.async {
+ return fd.asyncCallBack(ctx, fut.getResponse())
+ }
- // Signal the task waiting on a response.
- close(fut.ch)
return nil
}
-// sendError sends an error response to the waiting task (if any).
-func (fd *DeviceFD) sendError(ctx context.Context, errno int32, req *Request) error {
+// sendError sends an error response to the waiting task (if any) by calling sendResponse().
+//
+// Preconditions: fd.mu must be held.
+func (fd *DeviceFD) sendError(ctx context.Context, errno int32, unique linux.FUSEOpID) error {
// Return the error to the calling task.
outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
respHdr := linux.FUSEHeaderOut{
Len: outHdrLen,
Error: errno,
- Unique: req.hdr.Unique,
+ Unique: unique,
}
fut, ok := fd.completions[respHdr.Unique]
if !ok {
- // Server sent us a response for a request we never sent?
+ // A response for a request we never sent,
+ // or for which we already received a reply (e.g. aborted).
return syserror.EINVAL
}
delete(fd.completions, respHdr.Unique)
fut.hdr = &respHdr
- if err := fd.sendResponse(ctx, fut); err != nil {
- return err
- }
-
- return nil
+ return fd.sendResponse(ctx, fut)
}
-// noReceiverAction has the calling kernel.Task do some action if its known that no
-// receiver is going to be waiting on the future channel. This is to be used by:
-// FUSE_INIT.
-func (fd *DeviceFD) noReceiverAction(ctx context.Context, r *Response) error {
- if r.opcode == linux.FUSE_INIT {
+// asyncCallBack executes pre-defined callback function for async requests.
+// Currently used by: FUSE_INIT.
+func (fd *DeviceFD) asyncCallBack(ctx context.Context, r *Response) error {
+ switch r.opcode {
+ case linux.FUSE_INIT:
creds := auth.CredentialsFromContext(ctx)
rootUserNs := kernel.KernelFromContext(ctx).RootUserNamespace()
return fd.fs.conn.InitRecv(r, creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, rootUserNs))
+ // TODO(gvisor.dev/issue/3247): support async read: correctly process the response.
}
return nil
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
index 1ffe7ccd2..5986133e9 100644
--- a/pkg/sentry/fsimpl/fuse/dev_test.go
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -16,7 +16,6 @@ package fuse
import (
"fmt"
- "io"
"math/rand"
"testing"
@@ -28,17 +27,12 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// echoTestOpcode is the Opcode used during testing. The server used in tests
// will simply echo the payload back with the appropriate headers.
const echoTestOpcode linux.FUSEOpcode = 1000
-type testPayload struct {
- data uint32
-}
-
// TestFUSECommunication tests that the communication layer between the Sentry and the
// FUSE server daemon works as expected.
func TestFUSECommunication(t *testing.T) {
@@ -327,102 +321,3 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F
}
}
}
-
-func setup(t *testing.T) *testutil.System {
- k, err := testutil.Boot()
- if err != nil {
- t.Fatalf("Error creating kernel: %v", err)
- }
-
- ctx := k.SupervisorContext()
- creds := auth.CredentialsFromContext(ctx)
-
- k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
- AllowUserList: true,
- AllowUserMount: true,
- })
-
- mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
- if err != nil {
- t.Fatalf("NewMountNamespace(): %v", err)
- }
-
- return testutil.NewSystem(ctx, t, k.VFS(), mntns)
-}
-
-// newTestConnection creates a fuse connection that the sentry can communicate with
-// and the FD for the server to communicate with.
-func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*connection, *vfs.FileDescription, error) {
- vfsObj := &vfs.VirtualFilesystem{}
- fuseDev := &DeviceFD{}
-
- if err := vfsObj.Init(system.Ctx); err != nil {
- return nil, nil, err
- }
-
- vd := vfsObj.NewAnonVirtualDentry("genCountFD")
- defer vd.DecRef(system.Ctx)
- if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil {
- return nil, nil, err
- }
-
- fsopts := filesystemOptions{
- maxActiveRequests: maxActiveRequests,
- }
- fs, err := NewFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd)
- if err != nil {
- return nil, nil, err
- }
-
- return fs.conn, &fuseDev.vfsfd, nil
-}
-
-// SizeBytes implements marshal.Marshallable.SizeBytes.
-func (t *testPayload) SizeBytes() int {
- return 4
-}
-
-// MarshalBytes implements marshal.Marshallable.MarshalBytes.
-func (t *testPayload) MarshalBytes(dst []byte) {
- usermem.ByteOrder.PutUint32(dst[:4], t.data)
-}
-
-// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
-func (t *testPayload) UnmarshalBytes(src []byte) {
- *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])}
-}
-
-// Packed implements marshal.Marshallable.Packed.
-func (t *testPayload) Packed() bool {
- return true
-}
-
-// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
-func (t *testPayload) MarshalUnsafe(dst []byte) {
- t.MarshalBytes(dst)
-}
-
-// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
-func (t *testPayload) UnmarshalUnsafe(src []byte) {
- t.UnmarshalBytes(src)
-}
-
-// CopyOutN implements marshal.Marshallable.CopyOutN.
-func (t *testPayload) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
- panic("not implemented")
-}
-
-// CopyOut implements marshal.Marshallable.CopyOut.
-func (t *testPayload) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
- panic("not implemented")
-}
-
-// CopyIn implements marshal.Marshallable.CopyIn.
-func (t *testPayload) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- panic("not implemented")
-}
-
-// WriteTo implements io.WriterTo.WriteTo.
-func (t *testPayload) WriteTo(w io.Writer) (int64, error) {
- panic("not implemented")
-}
diff --git a/pkg/sentry/fsimpl/fuse/directory.go b/pkg/sentry/fsimpl/fuse/directory.go
new file mode 100644
index 000000000..8f220a04b
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/directory.go
@@ -0,0 +1,105 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type directoryFD struct {
+ fileDescription
+}
+
+// Allocate implements directoryFD.Allocate.
+func (*directoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.EISDIR
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (*directoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.EISDIR
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (*directoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.EISDIR
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (*directoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.EISDIR
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (*directoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.EISDIR
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (dir *directoryFD) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback) error {
+ fusefs := dir.inode().fs
+ task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx)
+
+ in := linux.FUSEReadIn{
+ Fh: dir.Fh,
+ Offset: uint64(atomic.LoadInt64(&dir.off)),
+ Size: linux.FUSE_PAGE_SIZE,
+ Flags: dir.statusFlags(),
+ }
+
+ // TODO(gVisor.dev/issue/3404): Support FUSE_READDIRPLUS.
+ req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in)
+ if err != nil {
+ return err
+ }
+
+ res, err := fusefs.conn.Call(task, req)
+ if err != nil {
+ return err
+ }
+ if err := res.Error(); err != nil {
+ return err
+ }
+
+ var out linux.FUSEDirents
+ if err := res.UnmarshalPayload(&out); err != nil {
+ return err
+ }
+
+ for _, fuseDirent := range out.Dirents {
+ nextOff := int64(fuseDirent.Meta.Off)
+ dirent := vfs.Dirent{
+ Name: fuseDirent.Name,
+ Type: uint8(fuseDirent.Meta.Type),
+ Ino: fuseDirent.Meta.Ino,
+ NextOff: nextOff,
+ }
+
+ if err := callback.Handle(dirent); err != nil {
+ return err
+ }
+ atomic.StoreInt64(&dir.off, nextOff)
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/fuse/file.go b/pkg/sentry/fsimpl/fuse/file.go
new file mode 100644
index 000000000..83f2816b7
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/file.go
@@ -0,0 +1,133 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// fileDescription implements vfs.FileDescriptionImpl for fuse.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ // the file handle used in userspace.
+ Fh uint64
+
+ // Nonseekable is indicate cannot perform seek on a file.
+ Nonseekable bool
+
+ // DirectIO suggest fuse to use direct io operation.
+ DirectIO bool
+
+ // OpenFlag is the flag returned by open.
+ OpenFlag uint32
+
+ // off is the file offset.
+ off int64
+}
+
+func (fd *fileDescription) dentry() *kernfs.Dentry {
+ return fd.vfsfd.Dentry().Impl().(*kernfs.Dentry)
+}
+
+func (fd *fileDescription) inode() *inode {
+ return fd.dentry().Inode().(*inode)
+}
+
+func (fd *fileDescription) filesystem() *vfs.Filesystem {
+ return fd.vfsfd.VirtualDentry().Mount().Filesystem()
+}
+
+func (fd *fileDescription) statusFlags() uint32 {
+ return fd.vfsfd.StatusFlags()
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *fileDescription) Release(ctx context.Context) {
+ // no need to release if FUSE server doesn't implement Open.
+ conn := fd.inode().fs.conn
+ if conn.noOpen {
+ return
+ }
+
+ in := linux.FUSEReleaseIn{
+ Fh: fd.Fh,
+ Flags: fd.statusFlags(),
+ }
+ // TODO(gvisor.dev/issue/3245): add logic when we support file lock owner.
+ var opcode linux.FUSEOpcode
+ if fd.inode().Mode().IsDir() {
+ opcode = linux.FUSE_RELEASEDIR
+ } else {
+ opcode = linux.FUSE_RELEASE
+ }
+ kernelTask := kernel.TaskFromContext(ctx)
+ // ignoring errors and FUSE server reply is analogous to Linux's behavior.
+ req, err := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in)
+ if err != nil {
+ // No way to invoke Call() with an errored request.
+ return
+ }
+ // The reply will be ignored since no callback is defined in asyncCallBack().
+ conn.CallAsync(kernelTask, req)
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, nil
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return 0, nil
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, nil
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return 0, nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, nil
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := fd.filesystem()
+ inode := fd.inode()
+ return inode.Stat(ctx, fs, opts)
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ fs := fd.filesystem()
+ creds := auth.CredentialsFromContext(ctx)
+ return fd.inode().setAttr(ctx, fs, creds, opts, true, fd.Fh)
+}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index 810819ae4..65786e42a 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -16,24 +16,36 @@
package fuse
import (
+ "math"
"strconv"
+ "sync"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// Name is the default filesystem name.
const Name = "fuse"
+// maxActiveRequestsDefault is the default setting controlling the upper bound
+// on the number of active requests at any given time.
+const maxActiveRequestsDefault = 10000
+
// FilesystemType implements vfs.FilesystemType.
+//
+// +stateify savable
type FilesystemType struct{}
+// +stateify savable
type filesystemOptions struct {
// userID specifies the numeric uid of the mount owner.
// This option should not be specified by the filesystem owner.
@@ -56,9 +68,16 @@ type filesystemOptions struct {
// exist at any time. Any further requests will block when trying to
// Call the server.
maxActiveRequests uint64
+
+ // maxRead is the max number of bytes to read,
+ // specified as "max_read" in fs parameters.
+ // If not specified by user, use math.MaxUint32 as default value.
+ maxRead uint32
}
// filesystem implements vfs.FilesystemImpl.
+//
+// +stateify savable
type filesystem struct {
kernfs.Filesystem
devMinor uint32
@@ -69,6 +88,9 @@ type filesystem struct {
// opts is the options the fusefs is initialized with.
opts *filesystemOptions
+
+ // umounted is true if filesystem.Release() has been called.
+ umounted bool
}
// Name implements vfs.FilesystemType.Name.
@@ -142,14 +164,29 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Set the maxInFlightRequests option.
fsopts.maxActiveRequests = maxActiveRequestsDefault
+ if maxReadStr, ok := mopts["max_read"]; ok {
+ delete(mopts, "max_read")
+ maxRead, err := strconv.ParseUint(maxReadStr, 10, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid max_read: max_read=%s", fsType.Name(), maxReadStr)
+ return nil, nil, syserror.EINVAL
+ }
+ if maxRead < fuseMinMaxRead {
+ maxRead = fuseMinMaxRead
+ }
+ fsopts.maxRead = uint32(maxRead)
+ } else {
+ fsopts.maxRead = math.MaxUint32
+ }
+
// Check for unparsed options.
if len(mopts) != 0 {
- log.Warningf("%s.GetFilesystem: unknown options: %v", fsType.Name(), mopts)
+ log.Warningf("%s.GetFilesystem: unsupported or unknown options: %v", fsType.Name(), mopts)
return nil, nil, syserror.EINVAL
}
// Create a new FUSE filesystem.
- fs, err := NewFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd)
+ fs, err := newFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd)
if err != nil {
log.Warningf("%s.NewFUSEFilesystem: failed with error: %v", fsType.Name(), err)
return nil, nil, err
@@ -165,26 +202,28 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
// root is the fusefs root directory.
- root := fs.newInode(creds, fsopts.rootMode)
+ root := fs.newRootInode(creds, fsopts.rootMode)
return fs.VFSFilesystem(), root.VFSDentry(), nil
}
-// NewFUSEFilesystem creates a new FUSE filesystem.
-func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) {
- fs := &filesystem{
- devMinor: devMinor,
- opts: opts,
- }
-
- conn, err := newFUSEConnection(ctx, device, opts.maxActiveRequests)
+// newFUSEFilesystem creates a new FUSE filesystem.
+func newFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) {
+ conn, err := newFUSEConnection(ctx, device, opts)
if err != nil {
log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err)
return nil, syserror.EINVAL
}
- fs.conn = conn
fuseFD := device.Impl().(*DeviceFD)
+
+ fs := &filesystem{
+ devMinor: devMinor,
+ opts: opts,
+ conn: conn,
+ }
+
+ fs.VFSFilesystem().IncRef()
fuseFD.fs = fs
return fs, nil
@@ -192,11 +231,22 @@ func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOpt
// Release implements vfs.FilesystemImpl.Release.
func (fs *filesystem) Release(ctx context.Context) {
+ fs.conn.fd.mu.Lock()
+
+ fs.umounted = true
+ fs.conn.Abort(ctx)
+ // Notify all the waiters on this fd.
+ fs.conn.fd.waitQueue.Notify(waiter.EventIn)
+
+ fs.conn.fd.mu.Unlock()
+
fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
fs.Filesystem.Release(ctx)
}
// inode implements kernfs.Inode.
+//
+// +stateify savable
type inode struct {
inodeRefs
kernfs.InodeAttrs
@@ -205,14 +255,50 @@ type inode struct {
kernfs.InodeNotSymlink
kernfs.OrderedChildren
+ dentry kernfs.Dentry
+
+ // the owning filesystem. fs is immutable.
+ fs *filesystem
+
+ // metaDataMu protects the metadata of this inode.
+ metadataMu sync.Mutex
+
+ nodeID uint64
+
locks vfs.FileLocks
- dentry kernfs.Dentry
+ // size of the file.
+ size uint64
+
+ // attributeVersion is the version of inode's attributes.
+ attributeVersion uint64
+
+ // attributeTime is the remaining vaild time of attributes.
+ attributeTime uint64
+
+ // version of the inode.
+ version uint64
+
+ // link is result of following a symbolic link.
+ link string
+}
+
+func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
+ i := &inode{fs: fs}
+ i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755)
+ i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ i.EnableLeakCheck()
+ i.dentry.Init(i)
+ i.nodeID = 1
+
+ return &i.dentry
}
-func (fs *filesystem) newInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
- i := &inode{}
- i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
+func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) *kernfs.Dentry {
+ i := &inode{fs: fs, nodeID: nodeID}
+ creds := auth.Credentials{EffectiveKGID: auth.KGID(attr.UID), EffectiveKUID: auth.KUID(attr.UID)}
+ i.InodeAttrs.Init(&creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode))
+ atomic.StoreUint64(&i.size, attr.Size)
i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
i.EnableLeakCheck()
i.dentry.Init(i)
@@ -221,14 +307,292 @@ func (fs *filesystem) newInode(creds *auth.Credentials, mode linux.FileMode) *ke
}
// Open implements kernfs.Inode.Open.
-func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{
- SeekEnd: kernfs.SeekEndStaticEntries,
- })
+func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ isDir := i.InodeAttrs.Mode().IsDir()
+ // return error if specified to open directory but inode is not a directory.
+ if !isDir && opts.Mode.IsDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if opts.Flags&linux.O_LARGEFILE == 0 && atomic.LoadUint64(&i.size) > linux.MAX_NON_LFS {
+ return nil, syserror.EOVERFLOW
+ }
+
+ var fd *fileDescription
+ var fdImpl vfs.FileDescriptionImpl
+ if isDir {
+ directoryFD := &directoryFD{}
+ fd = &(directoryFD.fileDescription)
+ fdImpl = directoryFD
+ } else {
+ regularFD := &regularFileFD{}
+ fd = &(regularFD.fileDescription)
+ fdImpl = regularFD
+ }
+ // FOPEN_KEEP_CACHE is the defualt flag for noOpen.
+ fd.OpenFlag = linux.FOPEN_KEEP_CACHE
+
+ // Only send open request when FUSE server support open or is opening a directory.
+ if !i.fs.conn.noOpen || isDir {
+ kernelTask := kernel.TaskFromContext(ctx)
+ if kernelTask == nil {
+ log.Warningf("fusefs.Inode.Open: couldn't get kernel task from context")
+ return nil, syserror.EINVAL
+ }
+
+ // Build the request.
+ var opcode linux.FUSEOpcode
+ if isDir {
+ opcode = linux.FUSE_OPENDIR
+ } else {
+ opcode = linux.FUSE_OPEN
+ }
+
+ in := linux.FUSEOpenIn{Flags: opts.Flags & ^uint32(linux.O_CREAT|linux.O_EXCL|linux.O_NOCTTY)}
+ if !i.fs.conn.atomicOTrunc {
+ in.Flags &= ^uint32(linux.O_TRUNC)
+ }
+
+ req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in)
+ if err != nil {
+ return nil, err
+ }
+
+ // Send the request and receive the reply.
+ res, err := i.fs.conn.Call(kernelTask, req)
+ if err != nil {
+ return nil, err
+ }
+ if err := res.Error(); err == syserror.ENOSYS && !isDir {
+ i.fs.conn.noOpen = true
+ } else if err != nil {
+ return nil, err
+ } else {
+ out := linux.FUSEOpenOut{}
+ if err := res.UnmarshalPayload(&out); err != nil {
+ return nil, err
+ }
+
+ // Process the reply.
+ fd.OpenFlag = out.OpenFlag
+ if isDir {
+ fd.OpenFlag &= ^uint32(linux.FOPEN_DIRECT_IO)
+ }
+
+ fd.Fh = out.Fh
+ }
+ }
+
+ // TODO(gvisor.dev/issue/3234): invalidate mmap after implemented it for FUSE Inode
+ fd.DirectIO = fd.OpenFlag&linux.FOPEN_DIRECT_IO != 0
+ fdOptions := &vfs.FileDescriptionOptions{}
+ if fd.OpenFlag&linux.FOPEN_NONSEEKABLE != 0 {
+ fdOptions.DenyPRead = true
+ fdOptions.DenyPWrite = true
+ fd.Nonseekable = true
+ }
+
+ // If we don't send SETATTR before open (which is indicated by atomicOTrunc)
+ // and O_TRUNC is set, update the inode's version number and clean existing data
+ // by setting the file size to 0.
+ if i.fs.conn.atomicOTrunc && opts.Flags&linux.O_TRUNC != 0 {
+ i.fs.conn.mu.Lock()
+ i.fs.conn.attributeVersion++
+ i.attributeVersion = i.fs.conn.attributeVersion
+ atomic.StoreUint64(&i.size, 0)
+ i.fs.conn.mu.Unlock()
+ i.attributeTime = 0
+ }
+
+ if err := fd.vfsfd.Init(fdImpl, opts.Flags, rp.Mount(), d.VFSDentry(), fdOptions); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// Lookup implements kernfs.Inode.Lookup.
+func (i *inode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
+ in := linux.FUSELookupIn{Name: name}
+ return i.newEntry(ctx, name, 0, linux.FUSE_LOOKUP, &in)
+}
+
+// IterDirents implements kernfs.Inode.IterDirents.
+func (*inode) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ return offset, nil
+}
+
+// Valid implements kernfs.Inode.Valid.
+func (*inode) Valid(ctx context.Context) bool {
+ return true
+}
+
+// NewFile implements kernfs.Inode.NewFile.
+func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*kernfs.Dentry, error) {
+ kernelTask := kernel.TaskFromContext(ctx)
+ if kernelTask == nil {
+ log.Warningf("fusefs.Inode.NewFile: couldn't get kernel task from context", i.nodeID)
+ return nil, syserror.EINVAL
+ }
+ in := linux.FUSECreateIn{
+ CreateMeta: linux.FUSECreateMeta{
+ Flags: opts.Flags,
+ Mode: uint32(opts.Mode) | linux.S_IFREG,
+ Umask: uint32(kernelTask.FSContext().Umask()),
+ },
+ Name: name,
+ }
+ return i.newEntry(ctx, name, linux.S_IFREG, linux.FUSE_CREATE, &in)
+}
+
+// NewNode implements kernfs.Inode.NewNode.
+func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*kernfs.Dentry, error) {
+ in := linux.FUSEMknodIn{
+ MknodMeta: linux.FUSEMknodMeta{
+ Mode: uint32(opts.Mode),
+ Rdev: linux.MakeDeviceID(uint16(opts.DevMajor), opts.DevMinor),
+ Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()),
+ },
+ Name: name,
+ }
+ return i.newEntry(ctx, name, opts.Mode.FileType(), linux.FUSE_MKNOD, &in)
+}
+
+// NewSymlink implements kernfs.Inode.NewSymlink.
+func (i *inode) NewSymlink(ctx context.Context, name, target string) (*kernfs.Dentry, error) {
+ in := linux.FUSESymLinkIn{
+ Name: name,
+ Target: target,
+ }
+ return i.newEntry(ctx, name, linux.S_IFLNK, linux.FUSE_SYMLINK, &in)
+}
+
+// Unlink implements kernfs.Inode.Unlink.
+func (i *inode) Unlink(ctx context.Context, name string, child *kernfs.Dentry) error {
+ kernelTask := kernel.TaskFromContext(ctx)
+ if kernelTask == nil {
+ log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID)
+ return syserror.EINVAL
+ }
+ in := linux.FUSEUnlinkIn{Name: name}
+ req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in)
+ if err != nil {
+ return err
+ }
+ res, err := i.fs.conn.Call(kernelTask, req)
+ if err != nil {
+ return err
+ }
+ // only return error, discard res.
+ if err := res.Error(); err != nil {
+ return err
+ }
+ return i.dentry.RemoveChildLocked(name, child)
+}
+
+// NewDir implements kernfs.Inode.NewDir.
+func (i *inode) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*kernfs.Dentry, error) {
+ in := linux.FUSEMkdirIn{
+ MkdirMeta: linux.FUSEMkdirMeta{
+ Mode: uint32(opts.Mode),
+ Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()),
+ },
+ Name: name,
+ }
+ return i.newEntry(ctx, name, linux.S_IFDIR, linux.FUSE_MKDIR, &in)
+}
+
+// RmDir implements kernfs.Inode.RmDir.
+func (i *inode) RmDir(ctx context.Context, name string, child *kernfs.Dentry) error {
+ fusefs := i.fs
+ task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx)
+
+ in := linux.FUSERmDirIn{Name: name}
+ req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in)
+ if err != nil {
+ return err
+ }
+
+ res, err := i.fs.conn.Call(task, req)
+ if err != nil {
+ return err
+ }
+ if err := res.Error(); err != nil {
+ return err
+ }
+
+ return i.dentry.RemoveChildLocked(name, child)
+}
+
+// newEntry calls FUSE server for entry creation and allocates corresponding entry according to response.
+// Shared by FUSE_MKNOD, FUSE_MKDIR, FUSE_SYMLINK, FUSE_LINK and FUSE_LOOKUP.
+func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMode, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*kernfs.Dentry, error) {
+ kernelTask := kernel.TaskFromContext(ctx)
+ if kernelTask == nil {
+ log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID)
+ return nil, syserror.EINVAL
+ }
+ req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload)
+ if err != nil {
+ return nil, err
+ }
+ res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return nil, err
}
- return fd.VFSFileDescription(), nil
+ if err := res.Error(); err != nil {
+ return nil, err
+ }
+ out := linux.FUSEEntryOut{}
+ if err := res.UnmarshalPayload(&out); err != nil {
+ return nil, err
+ }
+ if opcode != linux.FUSE_LOOKUP && ((out.Attr.Mode&linux.S_IFMT)^uint32(fileType) != 0 || out.NodeID == 0 || out.NodeID == linux.FUSE_ROOT_ID) {
+ return nil, syserror.EIO
+ }
+ child := i.fs.newInode(out.NodeID, out.Attr)
+ return child, nil
+}
+
+// Getlink implements kernfs.Inode.Getlink.
+func (i *inode) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ path, err := i.Readlink(ctx, mnt)
+ return vfs.VirtualDentry{}, path, err
+}
+
+// Readlink implements kernfs.Inode.Readlink.
+func (i *inode) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
+ if i.Mode().FileType()&linux.S_IFLNK == 0 {
+ return "", syserror.EINVAL
+ }
+ if len(i.link) == 0 {
+ kernelTask := kernel.TaskFromContext(ctx)
+ if kernelTask == nil {
+ log.Warningf("fusefs.Inode.Readlink: couldn't get kernel task from context")
+ return "", syserror.EINVAL
+ }
+ req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{})
+ if err != nil {
+ return "", err
+ }
+ res, err := i.fs.conn.Call(kernelTask, req)
+ if err != nil {
+ return "", err
+ }
+ i.link = string(res.data[res.hdr.SizeBytes():])
+ if !mnt.Options().ReadOnly {
+ i.attributeTime = 0
+ }
+ }
+ return i.link, nil
+}
+
+// getFUSEAttr returns a linux.FUSEAttr of this inode stored in local cache.
+// TODO(gvisor.dev/issue/3679): Add support for other fields.
+func (i *inode) getFUSEAttr() linux.FUSEAttr {
+ return linux.FUSEAttr{
+ Ino: i.Ino(),
+ Size: atomic.LoadUint64(&i.size),
+ Mode: uint32(i.Mode()),
+ }
}
// statFromFUSEAttr makes attributes from linux.FUSEAttr to linux.Statx. The
@@ -284,50 +648,92 @@ func statFromFUSEAttr(attr linux.FUSEAttr, mask, devMinor uint32) linux.Statx {
return stat
}
-// Stat implements kernfs.Inode.Stat.
-func (i *inode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
- fusefs := fs.Impl().(*filesystem)
- conn := fusefs.conn
- task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx)
+// getAttr gets the attribute of this inode by issuing a FUSE_GETATTR request
+// or read from local cache. It updates the corresponding attributes if
+// necessary.
+func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions, flags uint32, fh uint64) (linux.FUSEAttr, error) {
+ attributeVersion := atomic.LoadUint64(&i.fs.conn.attributeVersion)
+
+ // TODO(gvisor.dev/issue/3679): send the request only if
+ // - invalid local cache for fields specified in the opts.Mask
+ // - forced update
+ // - i.attributeTime expired
+ // If local cache is still valid, return local cache.
+ // Currently we always send a request,
+ // and we always set the metadata with the new result,
+ // unless attributeVersion has changed.
+
+ task := kernel.TaskFromContext(ctx)
if task == nil {
log.Warningf("couldn't get kernel task from context")
- return linux.Statx{}, syserror.EINVAL
+ return linux.FUSEAttr{}, syserror.EINVAL
}
- var in linux.FUSEGetAttrIn
- // We don't set any attribute in the request, because in VFS2 fstat(2) will
- // finally be translated into vfs.FilesystemImpl.StatAt() (see
- // pkg/sentry/syscalls/linux/vfs2/stat.go), resulting in the same flow
- // as stat(2). Thus GetAttrFlags and Fh variable will never be used in VFS2.
- req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.Ino(), linux.FUSE_GETATTR, &in)
+ creds := auth.CredentialsFromContext(ctx)
+
+ in := linux.FUSEGetAttrIn{
+ GetAttrFlags: flags,
+ Fh: fh,
+ }
+ req, err := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in)
if err != nil {
- return linux.Statx{}, err
+ return linux.FUSEAttr{}, err
}
- res, err := conn.Call(task, req)
+ res, err := i.fs.conn.Call(task, req)
if err != nil {
- return linux.Statx{}, err
+ return linux.FUSEAttr{}, err
}
if err := res.Error(); err != nil {
- return linux.Statx{}, err
+ return linux.FUSEAttr{}, err
}
var out linux.FUSEGetAttrOut
if err := res.UnmarshalPayload(&out); err != nil {
- return linux.Statx{}, err
+ return linux.FUSEAttr{}, err
+ }
+
+ // Local version is newer, return the local one.
+ // Skip the update.
+ if attributeVersion != 0 && atomic.LoadUint64(&i.attributeVersion) > attributeVersion {
+ return i.getFUSEAttr(), nil
}
- // Set all metadata into kernfs.InodeAttrs.
- if err := i.SetStat(ctx, fs, creds, vfs.SetStatOptions{
- Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, fusefs.devMinor),
+ // Set the metadata of kernfs.InodeAttrs.
+ if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{
+ Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor),
}); err != nil {
+ return linux.FUSEAttr{}, err
+ }
+
+ // Set the size if no error (after SetStat() check).
+ atomic.StoreUint64(&i.size, out.Attr.Size)
+
+ return out.Attr, nil
+}
+
+// reviseAttr attempts to update the attributes for internal purposes
+// by calling getAttr with a pre-specified mask.
+// Used by read, write, lseek.
+func (i *inode) reviseAttr(ctx context.Context, flags uint32, fh uint64) error {
+ // Never need atime for internal purposes.
+ _, err := i.getAttr(ctx, i.fs.VFSFilesystem(), vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS &^ linux.STATX_ATIME,
+ }, flags, fh)
+ return err
+}
+
+// Stat implements kernfs.Inode.Stat.
+func (i *inode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ attr, err := i.getAttr(ctx, fs, opts, 0, 0)
+ if err != nil {
return linux.Statx{}, err
}
- return statFromFUSEAttr(out.Attr, opts.Mask, fusefs.devMinor), nil
+ return statFromFUSEAttr(attr, opts.Mask, i.fs.devMinor), nil
}
-// DecRef implements kernfs.Inode.
+// DecRef implements kernfs.Inode.DecRef.
func (i *inode) DecRef(context.Context) {
i.inodeRefs.DecRef(i.Destroy)
}
@@ -337,3 +743,84 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e
// TODO(gvisor.dev/issues/3413): Complete the implementation of statfs.
return vfs.GenericStatFS(linux.FUSE_SUPER_MAGIC), nil
}
+
+// fattrMaskFromStats converts vfs.SetStatOptions.Stat.Mask to linux stats mask
+// aligned with the attribute mask defined in include/linux/fs.h.
+func fattrMaskFromStats(mask uint32) uint32 {
+ var fuseAttrMask uint32
+ maskMap := map[uint32]uint32{
+ linux.STATX_MODE: linux.FATTR_MODE,
+ linux.STATX_UID: linux.FATTR_UID,
+ linux.STATX_GID: linux.FATTR_GID,
+ linux.STATX_SIZE: linux.FATTR_SIZE,
+ linux.STATX_ATIME: linux.FATTR_ATIME,
+ linux.STATX_MTIME: linux.FATTR_MTIME,
+ linux.STATX_CTIME: linux.FATTR_CTIME,
+ }
+ for statxMask, fattrMask := range maskMap {
+ if mask&statxMask != 0 {
+ fuseAttrMask |= fattrMask
+ }
+ }
+ return fuseAttrMask
+}
+
+// SetStat implements kernfs.Inode.SetStat.
+func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ return i.setAttr(ctx, fs, creds, opts, false, 0)
+}
+
+func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions, useFh bool, fh uint64) error {
+ conn := i.fs.conn
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ log.Warningf("couldn't get kernel task from context")
+ return syserror.EINVAL
+ }
+
+ // We should retain the original file type when assigning new mode.
+ fileType := uint16(i.Mode()) & linux.S_IFMT
+ fattrMask := fattrMaskFromStats(opts.Stat.Mask)
+ if useFh {
+ fattrMask |= linux.FATTR_FH
+ }
+ in := linux.FUSESetAttrIn{
+ Valid: fattrMask,
+ Fh: fh,
+ Size: opts.Stat.Size,
+ Atime: uint64(opts.Stat.Atime.Sec),
+ Mtime: uint64(opts.Stat.Mtime.Sec),
+ Ctime: uint64(opts.Stat.Ctime.Sec),
+ AtimeNsec: opts.Stat.Atime.Nsec,
+ MtimeNsec: opts.Stat.Mtime.Nsec,
+ CtimeNsec: opts.Stat.Ctime.Nsec,
+ Mode: uint32(fileType | opts.Stat.Mode),
+ UID: opts.Stat.UID,
+ GID: opts.Stat.GID,
+ }
+ req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in)
+ if err != nil {
+ return err
+ }
+
+ res, err := conn.Call(task, req)
+ if err != nil {
+ return err
+ }
+ if err := res.Error(); err != nil {
+ return err
+ }
+ out := linux.FUSEGetAttrOut{}
+ if err := res.UnmarshalPayload(&out); err != nil {
+ return err
+ }
+
+ // Set the metadata of kernfs.InodeAttrs.
+ if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{
+ Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor),
+ }); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go
new file mode 100644
index 000000000..625d1547f
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/read_write.go
@@ -0,0 +1,242 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "io"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// ReadInPages sends FUSE_READ requests for the size after round it up to
+// a multiple of page size, blocks on it for reply, processes the reply
+// and returns the payload (or joined payloads) as a byte slice.
+// This is used for the general purpose reading.
+// We do not support direct IO (which read the exact number of bytes)
+// at this moment.
+func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off uint64, size uint32) ([][]byte, uint32, error) {
+ attributeVersion := atomic.LoadUint64(&fs.conn.attributeVersion)
+
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ log.Warningf("fusefs.Read: couldn't get kernel task from context")
+ return nil, 0, syserror.EINVAL
+ }
+
+ // Round up to a multiple of page size.
+ readSize, _ := usermem.PageRoundUp(uint64(size))
+
+ // One request cannnot exceed either maxRead or maxPages.
+ maxPages := fs.conn.maxRead >> usermem.PageShift
+ if maxPages > uint32(fs.conn.maxPages) {
+ maxPages = uint32(fs.conn.maxPages)
+ }
+
+ var outs [][]byte
+ var sizeRead uint32
+
+ // readSize is a multiple of usermem.PageSize.
+ // Always request bytes as a multiple of pages.
+ pagesRead, pagesToRead := uint32(0), uint32(readSize>>usermem.PageShift)
+
+ // Reuse the same struct for unmarshalling to avoid unnecessary memory allocation.
+ in := linux.FUSEReadIn{
+ Fh: fd.Fh,
+ LockOwner: 0, // TODO(gvisor.dev/issue/3245): file lock
+ ReadFlags: 0, // TODO(gvisor.dev/issue/3245): |= linux.FUSE_READ_LOCKOWNER
+ Flags: fd.statusFlags(),
+ }
+
+ // This loop is intended for fragmented read where the bytes to read is
+ // larger than either the maxPages or maxRead.
+ // For the majority of reads with normal size, this loop should only
+ // execute once.
+ for pagesRead < pagesToRead {
+ pagesCanRead := pagesToRead - pagesRead
+ if pagesCanRead > maxPages {
+ pagesCanRead = maxPages
+ }
+
+ in.Offset = off + (uint64(pagesRead) << usermem.PageShift)
+ in.Size = pagesCanRead << usermem.PageShift
+
+ req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // TODO(gvisor.dev/issue/3247): support async read.
+
+ res, err := fs.conn.Call(t, req)
+ if err != nil {
+ return nil, 0, err
+ }
+ if err := res.Error(); err != nil {
+ return nil, 0, err
+ }
+
+ // Not enough bytes in response,
+ // either we reached EOF,
+ // or the FUSE server sends back a response
+ // that cannot even fit the hdr.
+ if len(res.data) <= res.hdr.SizeBytes() {
+ // We treat both case as EOF here for now
+ // since there is no reliable way to detect
+ // the over-short hdr case.
+ break
+ }
+
+ // Directly using the slice to avoid extra copy.
+ out := res.data[res.hdr.SizeBytes():]
+
+ outs = append(outs, out)
+ sizeRead += uint32(len(out))
+
+ pagesRead += pagesCanRead
+ }
+
+ defer fs.ReadCallback(ctx, fd, off, size, sizeRead, attributeVersion)
+
+ // No bytes returned: offset >= EOF.
+ if len(outs) == 0 {
+ return nil, 0, io.EOF
+ }
+
+ return outs, sizeRead, nil
+}
+
+// ReadCallback updates several information after receiving a read response.
+// Due to readahead, sizeRead can be larger than size.
+func (fs *filesystem) ReadCallback(ctx context.Context, fd *regularFileFD, off uint64, size uint32, sizeRead uint32, attributeVersion uint64) {
+ // TODO(gvisor.dev/issue/3247): support async read.
+ // If this is called by an async read, correctly process it.
+ // May need to update the signature.
+
+ i := fd.inode()
+ // TODO(gvisor.dev/issue/1193): Invalidate or update atime.
+
+ // Reached EOF.
+ if sizeRead < size {
+ // TODO(gvisor.dev/issue/3630): If we have writeback cache, then we need to fill this hole.
+ // Might need to update the buf to be returned from the Read().
+
+ // Update existing size.
+ newSize := off + uint64(sizeRead)
+ fs.conn.mu.Lock()
+ if attributeVersion == i.attributeVersion && newSize < atomic.LoadUint64(&i.size) {
+ fs.conn.attributeVersion++
+ i.attributeVersion = i.fs.conn.attributeVersion
+ atomic.StoreUint64(&i.size, newSize)
+ }
+ fs.conn.mu.Unlock()
+ }
+}
+
+// Write sends FUSE_WRITE requests and return the bytes
+// written according to the response.
+//
+// Preconditions: len(data) == size.
+func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, size uint32, data []byte) (uint32, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ log.Warningf("fusefs.Read: couldn't get kernel task from context")
+ return 0, syserror.EINVAL
+ }
+
+ // One request cannnot exceed either maxWrite or maxPages.
+ maxWrite := uint32(fs.conn.maxPages) << usermem.PageShift
+ if maxWrite > fs.conn.maxWrite {
+ maxWrite = fs.conn.maxWrite
+ }
+
+ // Reuse the same struct for unmarshalling to avoid unnecessary memory allocation.
+ in := linux.FUSEWriteIn{
+ Fh: fd.Fh,
+ // TODO(gvisor.dev/issue/3245): file lock
+ LockOwner: 0,
+ // TODO(gvisor.dev/issue/3245): |= linux.FUSE_READ_LOCKOWNER
+ // TODO(gvisor.dev/issue/3237): |= linux.FUSE_WRITE_CACHE (not added yet)
+ WriteFlags: 0,
+ Flags: fd.statusFlags(),
+ }
+
+ var written uint32
+
+ // This loop is intended for fragmented write where the bytes to write is
+ // larger than either the maxWrite or maxPages or when bigWrites is false.
+ // Unless a small value for max_write is explicitly used, this loop
+ // is expected to execute only once for the majority of the writes.
+ for written < size {
+ toWrite := size - written
+
+ // Limit the write size to one page.
+ // Note that the bigWrites flag is obsolete,
+ // latest libfuse always sets it on.
+ if !fs.conn.bigWrites && toWrite > usermem.PageSize {
+ toWrite = usermem.PageSize
+ }
+
+ // Limit the write size to maxWrite.
+ if toWrite > maxWrite {
+ toWrite = maxWrite
+ }
+
+ in.Offset = off + uint64(written)
+ in.Size = toWrite
+
+ req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_WRITE, &in)
+ if err != nil {
+ return 0, err
+ }
+
+ req.payload = data[written : written+toWrite]
+
+ // TODO(gvisor.dev/issue/3247): support async write.
+
+ res, err := fs.conn.Call(t, req)
+ if err != nil {
+ return 0, err
+ }
+ if err := res.Error(); err != nil {
+ return 0, err
+ }
+
+ out := linux.FUSEWriteOut{}
+ if err := res.UnmarshalPayload(&out); err != nil {
+ return 0, err
+ }
+
+ // Write more than requested? EIO.
+ if out.Size > toWrite {
+ return 0, syserror.EIO
+ }
+
+ written += out.Size
+
+ // Break if short write. Not necessarily an error.
+ if out.Size != toWrite {
+ break
+ }
+ }
+
+ return written, nil
+}
diff --git a/pkg/sentry/fsimpl/fuse/regular_file.go b/pkg/sentry/fsimpl/fuse/regular_file.go
new file mode 100644
index 000000000..5bdd096c3
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/regular_file.go
@@ -0,0 +1,230 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "io"
+ "math"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type regularFileFD struct {
+ fileDescription
+
+ // off is the file offset.
+ off int64
+ // offMu protects off.
+ offMu sync.Mutex
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ size := dst.NumBytes()
+ if size == 0 {
+ // Early return if count is 0.
+ return 0, nil
+ } else if size > math.MaxUint32 {
+ // FUSE only supports uint32 for size.
+ // Overflow.
+ return 0, syserror.EINVAL
+ }
+
+ // TODO(gvisor.dev/issue/3678): Add direct IO support.
+
+ inode := fd.inode()
+
+ // Reading beyond EOF, update file size if outdated.
+ if uint64(offset+size) > atomic.LoadUint64(&inode.size) {
+ if err := inode.reviseAttr(ctx, linux.FUSE_GETATTR_FH, fd.Fh); err != nil {
+ return 0, err
+ }
+ // If the offset after update is still too large, return error.
+ if uint64(offset) >= atomic.LoadUint64(&inode.size) {
+ return 0, io.EOF
+ }
+ }
+
+ // Truncate the read with updated file size.
+ fileSize := atomic.LoadUint64(&inode.size)
+ if uint64(offset+size) > fileSize {
+ size = int64(fileSize) - offset
+ }
+
+ buffers, n, err := inode.fs.ReadInPages(ctx, fd, uint64(offset), uint32(size))
+ if err != nil {
+ return 0, err
+ }
+
+ // TODO(gvisor.dev/issue/3237): support indirect IO (e.g. caching),
+ // store the bytes that were read ahead.
+
+ // Update the number of bytes to copy for short read.
+ if n < uint32(size) {
+ size = int64(n)
+ }
+
+ // Copy the bytes read to the dst.
+ // This loop is intended for fragmented reads.
+ // For the majority of reads, this loop only execute once.
+ var copied int64
+ for _, buffer := range buffers {
+ toCopy := int64(len(buffer))
+ if copied+toCopy > size {
+ toCopy = size - copied
+ }
+ cp, err := dst.DropFirst64(copied).CopyOut(ctx, buffer[:toCopy])
+ if err != nil {
+ return 0, err
+ }
+ if int64(cp) != toCopy {
+ return 0, syserror.EIO
+ }
+ copied += toCopy
+ }
+
+ return copied, nil
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ fd.offMu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ fd.offMu.Lock()
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset and error. The
+// final offset should be ignored by PWrite.
+func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
+ if offset < 0 {
+ return 0, offset, syserror.EINVAL
+ }
+
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, offset, syserror.EOPNOTSUPP
+ }
+
+ inode := fd.inode()
+ inode.metadataMu.Lock()
+ defer inode.metadataMu.Unlock()
+
+ // If the file is opened with O_APPEND, update offset to file size.
+ // Note: since our Open() implements the interface of kernfs,
+ // and kernfs currently does not support O_APPEND, this will never
+ // be true before we switch out from kernfs.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Locking inode.metadataMu is sufficient for reading size
+ offset = int64(inode.size)
+ }
+
+ srclen := src.NumBytes()
+
+ if srclen > math.MaxUint32 {
+ // FUSE only supports uint32 for size.
+ // Overflow.
+ return 0, offset, syserror.EINVAL
+ }
+ if end := offset + srclen; end < offset {
+ // Overflow.
+ return 0, offset, syserror.EINVAL
+ }
+
+ srclen, err = vfs.CheckLimit(ctx, offset, srclen)
+ if err != nil {
+ return 0, offset, err
+ }
+
+ if srclen == 0 {
+ // Return before causing any side effects.
+ return 0, offset, nil
+ }
+
+ src = src.TakeFirst64(srclen)
+
+ // TODO(gvisor.dev/issue/3237): Add cache support:
+ // buffer cache. Ideally we write from src to our buffer cache first.
+ // The slice passed to fs.Write() should be a slice from buffer cache.
+ data := make([]byte, srclen)
+ // Reason for making a copy here: connection.Call() blocks on kerneltask,
+ // which in turn acquires mm.activeMu lock. Functions like CopyInTo() will
+ // attemp to acquire the mm.activeMu lock as well -> deadlock.
+ // We must finish reading from the userspace memory before
+ // t.Block() deactivates it.
+ cp, err := src.CopyIn(ctx, data)
+ if err != nil {
+ return 0, offset, err
+ }
+ if int64(cp) != srclen {
+ return 0, offset, syserror.EIO
+ }
+
+ n, err := fd.inode().fs.Write(ctx, fd, uint64(offset), uint32(srclen), data)
+ if err != nil {
+ return 0, offset, err
+ }
+
+ if n == 0 {
+ // We have checked srclen != 0 previously.
+ // If err == nil, then it's a short write and we return EIO.
+ return 0, offset, syserror.EIO
+ }
+
+ written = int64(n)
+ finalOff = offset + written
+
+ if finalOff > int64(inode.size) {
+ atomic.StoreUint64(&inode.size, uint64(finalOff))
+ atomic.AddUint64(&inode.fs.conn.attributeVersion, 1)
+ }
+
+ return
+}
diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go
new file mode 100644
index 000000000..7fa00569b
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/request_response.go
@@ -0,0 +1,229 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// fuseInitRes is a variable-length wrapper of linux.FUSEInitOut. The FUSE
+// server may implement an older version of FUSE protocol, which contains a
+// linux.FUSEInitOut with less attributes.
+//
+// Dynamically-sized objects cannot be marshalled.
+type fuseInitRes struct {
+ marshal.StubMarshallable
+
+ // initOut contains the response from the FUSE server.
+ initOut linux.FUSEInitOut
+
+ // initLen is the total length of bytes of the response.
+ initLen uint32
+}
+
+// UnmarshalBytes deserializes src to the initOut attribute in a fuseInitRes.
+func (r *fuseInitRes) UnmarshalBytes(src []byte) {
+ out := &r.initOut
+
+ // Introduced before FUSE kernel version 7.13.
+ out.Major = uint32(usermem.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+ out.Minor = uint32(usermem.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+ out.MaxReadahead = uint32(usermem.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+ out.Flags = uint32(usermem.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+ out.MaxBackground = uint16(usermem.ByteOrder.Uint16(src[:2]))
+ src = src[2:]
+ out.CongestionThreshold = uint16(usermem.ByteOrder.Uint16(src[:2]))
+ src = src[2:]
+ out.MaxWrite = uint32(usermem.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+
+ // Introduced in FUSE kernel version 7.23.
+ if len(src) >= 4 {
+ out.TimeGran = uint32(usermem.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+ }
+ // Introduced in FUSE kernel version 7.28.
+ if len(src) >= 2 {
+ out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2]))
+ src = src[2:]
+ }
+}
+
+// SizeBytes is the size of the payload of the FUSE_INIT response.
+func (r *fuseInitRes) SizeBytes() int {
+ return int(r.initLen)
+}
+
+// Ordinary requests have even IDs, while interrupts IDs are odd.
+// Used to increment the unique ID for each FUSE request.
+var reqIDStep uint64 = 2
+
+// Request represents a FUSE operation request that hasn't been sent to the
+// server yet.
+//
+// +stateify savable
+type Request struct {
+ requestEntry
+
+ id linux.FUSEOpID
+ hdr *linux.FUSEHeaderIn
+ data []byte
+
+ // payload for this request: extra bytes to write after
+ // the data slice. Used by FUSE_WRITE.
+ payload []byte
+
+ // If this request is async.
+ async bool
+ // If we don't care its response.
+ // Manually set by the caller.
+ noReply bool
+}
+
+// NewRequest creates a new request that can be sent to the FUSE server.
+func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) {
+ conn.fd.mu.Lock()
+ defer conn.fd.mu.Unlock()
+ conn.fd.nextOpID += linux.FUSEOpID(reqIDStep)
+
+ hdrLen := (*linux.FUSEHeaderIn)(nil).SizeBytes()
+ hdr := linux.FUSEHeaderIn{
+ Len: uint32(hdrLen + payload.SizeBytes()),
+ Opcode: opcode,
+ Unique: conn.fd.nextOpID,
+ NodeID: ino,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ PID: pid,
+ }
+
+ buf := make([]byte, hdr.Len)
+
+ // TODO(gVisor.dev/issue/3698): Use the unsafe version once go_marshal is safe to use again.
+ hdr.MarshalBytes(buf[:hdrLen])
+ payload.MarshalBytes(buf[hdrLen:])
+
+ return &Request{
+ id: hdr.Unique,
+ hdr: &hdr,
+ data: buf,
+ }, nil
+}
+
+// futureResponse represents an in-flight request, that may or may not have
+// completed yet. Convert it to a resolved Response by calling Resolve, but note
+// that this may block.
+//
+// +stateify savable
+type futureResponse struct {
+ opcode linux.FUSEOpcode
+ ch chan struct{}
+ hdr *linux.FUSEHeaderOut
+ data []byte
+
+ // If this request is async.
+ async bool
+}
+
+// newFutureResponse creates a future response to a FUSE request.
+func newFutureResponse(req *Request) *futureResponse {
+ return &futureResponse{
+ opcode: req.hdr.Opcode,
+ ch: make(chan struct{}),
+ async: req.async,
+ }
+}
+
+// resolve blocks the task until the server responds to its corresponding request,
+// then returns a resolved response.
+func (f *futureResponse) resolve(t *kernel.Task) (*Response, error) {
+ // Return directly for async requests.
+ if f.async {
+ return nil, nil
+ }
+
+ if err := t.Block(f.ch); err != nil {
+ return nil, err
+ }
+
+ return f.getResponse(), nil
+}
+
+// getResponse creates a Response from the data the futureResponse has.
+func (f *futureResponse) getResponse() *Response {
+ return &Response{
+ opcode: f.opcode,
+ hdr: *f.hdr,
+ data: f.data,
+ }
+}
+
+// Response represents an actual response from the server, including the
+// response payload.
+//
+// +stateify savable
+type Response struct {
+ opcode linux.FUSEOpcode
+ hdr linux.FUSEHeaderOut
+ data []byte
+}
+
+// Error returns the error of the FUSE call.
+func (r *Response) Error() error {
+ errno := r.hdr.Error
+ if errno >= 0 {
+ return nil
+ }
+
+ sysErrNo := syscall.Errno(-errno)
+ return error(sysErrNo)
+}
+
+// DataLen returns the size of the response without the header.
+func (r *Response) DataLen() uint32 {
+ return r.hdr.Len - uint32(r.hdr.SizeBytes())
+}
+
+// UnmarshalPayload unmarshals the response data into m.
+func (r *Response) UnmarshalPayload(m marshal.Marshallable) error {
+ hdrLen := r.hdr.SizeBytes()
+ haveDataLen := r.hdr.Len - uint32(hdrLen)
+ wantDataLen := uint32(m.SizeBytes())
+
+ if haveDataLen < wantDataLen {
+ return fmt.Errorf("payload too small. Minimum data lenth required: %d, but got data length %d", wantDataLen, haveDataLen)
+ }
+
+ // The response data is empty unless there is some payload. And so, doesn't
+ // need to be unmarshalled.
+ if r.data == nil {
+ return nil
+ }
+
+ // TODO(gVisor.dev/issue/3698): Use the unsafe version once go_marshal is safe to use again.
+ m.UnmarshalBytes(r.data[hdrLen:])
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go
new file mode 100644
index 000000000..e1d9e3365
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/utils_test.go
@@ -0,0 +1,132 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "io"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func setup(t *testing.T) *testutil.System {
+ k, err := testutil.Boot()
+ if err != nil {
+ t.Fatalf("Error creating kernel: %v", err)
+ }
+
+ ctx := k.SupervisorContext()
+ creds := auth.CredentialsFromContext(ctx)
+
+ k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserList: true,
+ AllowUserMount: true,
+ })
+
+ mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{})
+ if err != nil {
+ t.Fatalf("NewMountNamespace(): %v", err)
+ }
+
+ return testutil.NewSystem(ctx, t, k.VFS(), mntns)
+}
+
+// newTestConnection creates a fuse connection that the sentry can communicate with
+// and the FD for the server to communicate with.
+func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*connection, *vfs.FileDescription, error) {
+ vfsObj := &vfs.VirtualFilesystem{}
+ fuseDev := &DeviceFD{}
+
+ if err := vfsObj.Init(system.Ctx); err != nil {
+ return nil, nil, err
+ }
+
+ vd := vfsObj.NewAnonVirtualDentry("genCountFD")
+ defer vd.DecRef(system.Ctx)
+ if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, nil, err
+ }
+
+ fsopts := filesystemOptions{
+ maxActiveRequests: maxActiveRequests,
+ }
+ fs, err := newFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return fs.conn, &fuseDev.vfsfd, nil
+}
+
+type testPayload struct {
+ marshal.StubMarshallable
+ data uint32
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (t *testPayload) SizeBytes() int {
+ return 4
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (t *testPayload) MarshalBytes(dst []byte) {
+ usermem.ByteOrder.PutUint32(dst[:4], t.data)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (t *testPayload) UnmarshalBytes(src []byte) {
+ *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])}
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (t *testPayload) Packed() bool {
+ return true
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (t *testPayload) MarshalUnsafe(dst []byte) {
+ t.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (t *testPayload) UnmarshalUnsafe(src []byte) {
+ t.UnmarshalBytes(src)
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (t *testPayload) CopyOutN(task marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {
+ panic("not implemented")
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (t *testPayload) CopyOut(task marshal.CopyContext, addr usermem.Addr) (int, error) {
+ panic("not implemented")
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (t *testPayload) CopyIn(task marshal.CopyContext, addr usermem.Addr) (int, error) {
+ panic("not implemented")
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (t *testPayload) WriteTo(w io.Writer) (int64, error) {
+ panic("not implemented")
+}
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 91d2ae199..18c884b59 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -117,11 +117,12 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
d.syntheticChildren++
}
+// +stateify savable
type directoryFD struct {
fileDescription
vfs.DirectoryFileDescriptionDefaultImpl
- mu sync.Mutex
+ mu sync.Mutex `state:"nosave"`
off int64
dirents []vfs.Dirent
}
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 5d0f487db..94d96261b 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -1026,7 +1026,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
// step is required even if !d.cachedMetadataAuthoritative() because
// d.mappings has to be updated.
// d.metadataMu has already been acquired if trunc == true.
- d.updateFileSizeLocked(0)
+ d.updateSizeLocked(0)
if d.cachedMetadataAuthoritative() {
d.touchCMtimeLocked()
@@ -1311,6 +1311,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if !renamed.isDir() {
return syserror.EISDIR
}
+ if genericIsAncestorDentry(replaced, renamed) {
+ return syserror.ENOTEMPTY
+ }
} else {
if rp.MustBeDir() || renamed.isDir() {
return syserror.ENOTDIR
@@ -1361,14 +1364,15 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// with reference counts and queue oldParent for checkCachingLocked if the
// parent isn't actually changing.
if oldParent != newParent {
+ oldParent.decRefLocked()
ds = appendDentry(ds, oldParent)
newParent.IncRef()
if renamed.isSynthetic() {
oldParent.syntheticChildren--
newParent.syntheticChildren++
}
+ renamed.parent = newParent
}
- renamed.parent = newParent
renamed.name = newName
if newParent.children == nil {
newParent.children = make(map[string]*dentry)
@@ -1412,11 +1416,11 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
return err
}
- if err := d.setStat(ctx, rp.Credentials(), &opts, rp.Mount()); err != nil {
- fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ err = d.setStat(ctx, rp.Credentials(), &opts, rp.Mount())
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ if err != nil {
return err
}
- fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
@@ -1491,7 +1495,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return fs.unlinkAt(ctx, rp, false /* dir */)
}
-// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt.
func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
var ds *[]*dentry
fs.renameMu.RLock()
@@ -1519,8 +1523,8 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
return nil, syserror.ECONNREFUSED
}
-// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
-func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt.
+func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
@@ -1528,11 +1532,11 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
if err != nil {
return nil, err
}
- return d.listxattr(ctx, rp.Credentials(), size)
+ return d.listXattr(ctx, rp.Credentials(), size)
}
-// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
-func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt.
+func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
@@ -1540,11 +1544,11 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
if err != nil {
return "", err
}
- return d.getxattr(ctx, rp.Credentials(), &opts)
+ return d.getXattr(ctx, rp.Credentials(), &opts)
}
-// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
-func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt.
+func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error {
var ds *[]*dentry
fs.renameMu.RLock()
d, err := fs.resolveLocked(ctx, rp, &ds)
@@ -1552,18 +1556,18 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
return err
}
- if err := d.setxattr(ctx, rp.Credentials(), &opts); err != nil {
- fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ err = d.setXattr(ctx, rp.Credentials(), &opts)
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ if err != nil {
return err
}
- fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
return nil
}
-// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
-func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt.
+func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
var ds *[]*dentry
fs.renameMu.RLock()
d, err := fs.resolveLocked(ctx, rp, &ds)
@@ -1571,11 +1575,11 @@ func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath,
fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
return err
}
- if err := d.removexattr(ctx, rp.Credentials(), name); err != nil {
- fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ err = d.removeXattr(ctx, rp.Credentials(), name)
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ if err != nil {
return err
}
- fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
return nil
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 57bff1789..8608471f8 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -62,9 +62,13 @@ import (
const Name = "9p"
// FilesystemType implements vfs.FilesystemType.
+//
+// +stateify savable
type FilesystemType struct{}
// filesystem implements vfs.FilesystemImpl.
+//
+// +stateify savable
type filesystem struct {
vfsfs vfs.Filesystem
@@ -77,7 +81,7 @@ type filesystem struct {
iopts InternalFilesystemOptions
// client is the client used by this filesystem. client is immutable.
- client *p9.Client
+ client *p9.Client `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
// clock is a realtime clock used to set timestamps in file operations.
clock ktime.Clock
@@ -95,7 +99,7 @@ type filesystem struct {
// reference count (such that it is usable as vfs.ResolvingPath.Start() or
// is reachable from its children), or if it is a child dentry (such that
// it is reachable from its parent).
- renameMu sync.RWMutex
+ renameMu sync.RWMutex `state:"nosave"`
// cachedDentries contains all dentries with 0 references. (Due to race
// conditions, it may also contain dentries with non-zero references.)
@@ -107,7 +111,7 @@ type filesystem struct {
// syncableDentries contains all dentries in this filesystem for which
// !dentry.file.isNil(). specialFileFDs contains all open specialFileFDs.
// These fields are protected by syncMu.
- syncMu sync.Mutex
+ syncMu sync.Mutex `state:"nosave"`
syncableDentries map[*dentry]struct{}
specialFileFDs map[*specialFileFD]struct{}
@@ -120,6 +124,8 @@ type filesystem struct {
// dentries, it comes from QID.Path from the 9P server. Synthetic dentries
// have have their inodeNumber generated sequentially, with the MSB reserved to
// prevent conflicts with regular dentries.
+//
+// +stateify savable
type inodeNumber uint64
// Reserve MSB for synthetic mounts.
@@ -132,6 +138,7 @@ func inoFromPath(path uint64) inodeNumber {
return inodeNumber(path &^ syntheticInoMask)
}
+// +stateify savable
type filesystemOptions struct {
// "Standard" 9P options.
fd int
@@ -177,6 +184,8 @@ type filesystemOptions struct {
// InteropMode controls the client's interaction with other remote filesystem
// users.
+//
+// +stateify savable
type InteropMode uint32
const (
@@ -195,11 +204,7 @@ const (
// and consistent with Linux's semantics (in particular, it is not always
// possible for clients to set arbitrary atimes and mtimes depending on the
// remote filesystem implementation, and never possible for clients to set
- // arbitrary ctimes.) If a dentry containing a client-defined atime or
- // mtime is evicted from cache, client timestamps will be sent to the
- // remote filesystem on a best-effort basis to attempt to ensure that
- // timestamps will be preserved when another dentry representing the same
- // file is instantiated.
+ // arbitrary ctimes.)
InteropModeExclusive InteropMode = iota
// InteropModeWritethrough is appropriate when there are read-only users of
@@ -239,6 +244,8 @@ const (
// InternalFilesystemOptions may be passed as
// vfs.GetFilesystemOptions.InternalData to FilesystemType.GetFilesystem.
+//
+// +stateify savable
type InternalFilesystemOptions struct {
// If LeakConnection is true, do not close the connection to the server
// when the Filesystem is released. This is necessary for deployments in
@@ -538,6 +545,8 @@ func (fs *filesystem) Release(ctx context.Context) {
}
// dentry implements vfs.DentryImpl.
+//
+// +stateify savable
type dentry struct {
vfsd vfs.Dentry
@@ -567,7 +576,7 @@ type dentry struct {
// If file.isNil(), this dentry represents a synthetic file, i.e. a file
// that does not exist on the remote filesystem. As of this writing, the
// only files that can be synthetic are sockets, pipes, and directories.
- file p9file
+ file p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
// If deleted is non-zero, the file represented by this dentry has been
// deleted. deleted is accessed using atomic memory operations.
@@ -579,7 +588,7 @@ type dentry struct {
cached bool
dentryEntry
- dirMu sync.Mutex
+ dirMu sync.Mutex `state:"nosave"`
// If this dentry represents a directory, children contains:
//
@@ -611,7 +620,7 @@ type dentry struct {
// To mutate:
// - Lock metadataMu and use atomic operations to update because we might
// have atomic readers that don't hold the lock.
- metadataMu sync.Mutex
+ metadataMu sync.Mutex `state:"nosave"`
ino inodeNumber // immutable
mode uint32 // type is immutable, perms are mutable
uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
@@ -642,7 +651,7 @@ type dentry struct {
// other metadata fields.
nlink uint32
- mapsMu sync.Mutex
+ mapsMu sync.Mutex `state:"nosave"`
// If this dentry represents a regular file, mappings tracks mappings of
// the file into memmap.MappingSpaces. mappings is protected by mapsMu.
@@ -666,12 +675,12 @@ type dentry struct {
// either p9.File transitions from closed (isNil() == true) to open
// (isNil() == false), it may be mutated with handleMu locked, but cannot
// be closed until the dentry is destroyed.
- handleMu sync.RWMutex
- readFile p9file
- writeFile p9file
+ handleMu sync.RWMutex `state:"nosave"`
+ readFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ writeFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
hostFD int32
- dataMu sync.RWMutex
+ dataMu sync.RWMutex `state:"nosave"`
// If this dentry represents a regular file that is client-cached, cache
// maps offsets into the cached file to offsets into
@@ -837,7 +846,7 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
atomic.StoreUint32(&d.nlink, uint32(attr.NLink))
}
if mask.Size {
- d.updateFileSizeLocked(attr.Size)
+ d.updateSizeLocked(attr.Size)
}
}
@@ -991,7 +1000,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
// d.size should be kept up to date, and privatized
// copy-on-write mappings of truncated pages need to be
// invalidated, even if InteropModeShared is in effect.
- d.updateFileSizeLocked(stat.Size)
+ d.updateSizeLocked(stat.Size)
}
}
if d.fs.opts.interop == InteropModeShared {
@@ -1028,8 +1037,31 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
return nil
}
+// doAllocate performs an allocate operation on d. Note that d.metadataMu will
+// be held when allocate is called.
+func (d *dentry) doAllocate(ctx context.Context, offset, length uint64, allocate func() error) error {
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+
+ // Allocating a smaller size is a noop.
+ size := offset + length
+ if d.cachedMetadataAuthoritative() && size <= d.size {
+ return nil
+ }
+
+ err := allocate()
+ if err != nil {
+ return err
+ }
+ d.updateSizeLocked(size)
+ if d.cachedMetadataAuthoritative() {
+ d.touchCMtimeLocked()
+ }
+ return nil
+}
+
// Preconditions: d.metadataMu must be locked.
-func (d *dentry) updateFileSizeLocked(newSize uint64) {
+func (d *dentry) updateSizeLocked(newSize uint64) {
d.dataMu.Lock()
oldSize := d.size
atomic.StoreUint64(&d.size, newSize)
@@ -1067,6 +1099,21 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes)
return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
}
+func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error {
+ // We only support xattrs prefixed with "user." (see b/148380782). Currently,
+ // there is no need to expose any other xattrs through a gofer.
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ mode := linux.FileMode(atomic.LoadUint32(&d.mode))
+ kuid := auth.KUID(atomic.LoadUint32(&d.uid))
+ kgid := auth.KGID(atomic.LoadUint32(&d.gid))
+ if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil {
+ return err
+ }
+ return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name)
+}
+
func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error {
return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&child.uid)))
}
@@ -1300,30 +1347,19 @@ func (d *dentry) destroyLocked(ctx context.Context) {
d.handleMu.Unlock()
if !d.file.isNil() {
- if !d.isDeleted() {
- // Write dirty timestamps back to the remote filesystem.
- atimeDirty := atomic.LoadUint32(&d.atimeDirty) != 0
- mtimeDirty := atomic.LoadUint32(&d.mtimeDirty) != 0
- if atimeDirty || mtimeDirty {
- atime := atomic.LoadInt64(&d.atime)
- mtime := atomic.LoadInt64(&d.mtime)
- if err := d.file.setAttr(ctx, p9.SetAttrMask{
- ATime: atimeDirty,
- ATimeNotSystemTime: atimeDirty,
- MTime: mtimeDirty,
- MTimeNotSystemTime: mtimeDirty,
- }, p9.SetAttr{
- ATimeSeconds: uint64(atime / 1e9),
- ATimeNanoSeconds: uint64(atime % 1e9),
- MTimeSeconds: uint64(mtime / 1e9),
- MTimeNanoSeconds: uint64(mtime % 1e9),
- }); err != nil {
- log.Warningf("gofer.dentry.destroyLocked: failed to write dirty timestamps back: %v", err)
- }
- }
+ // Note that it's possible that d.atimeDirty or d.mtimeDirty are true,
+ // i.e. client and server timestamps may differ (because e.g. a client
+ // write was serviced by the page cache, and only written back to the
+ // remote file later). Ideally, we'd write client timestamps back to
+ // the remote filesystem so that timestamps for a new dentry
+ // instantiated for the same file would remain coherent. Unfortunately,
+ // this turns out to be too expensive in many cases, so for now we
+ // don't do this.
+ if err := d.file.close(ctx); err != nil {
+ log.Warningf("gofer.dentry.destroyLocked: failed to close file: %v", err)
}
- d.file.close(ctx)
d.file = p9file{}
+
// Remove d from the set of syncable dentries.
d.fs.syncMu.Lock()
delete(d.fs.syncableDentries, d)
@@ -1351,9 +1387,7 @@ func (d *dentry) setDeleted() {
atomic.StoreUint32(&d.deleted, 1)
}
-// We only support xattrs prefixed with "user." (see b/148380782). Currently,
-// there is no need to expose any other xattrs through a gofer.
-func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) {
+func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) {
if d.file.isNil() || !d.userXattrSupported() {
return nil, nil
}
@@ -1363,6 +1397,7 @@ func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size ui
}
xattrs := make([]string, 0, len(xattrMap))
for x := range xattrMap {
+ // We only support xattrs in the user.* namespace.
if strings.HasPrefix(x, linux.XATTR_USER_PREFIX) {
xattrs = append(xattrs, x)
}
@@ -1370,51 +1405,33 @@ func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size ui
return xattrs, nil
}
-func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) {
+func (d *dentry) getXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) {
if d.file.isNil() {
return "", syserror.ENODATA
}
- if err := d.checkPermissions(creds, vfs.MayRead); err != nil {
+ if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil {
return "", err
}
- if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
- return "", syserror.EOPNOTSUPP
- }
- if !d.userXattrSupported() {
- return "", syserror.ENODATA
- }
return d.file.getXattr(ctx, opts.Name, opts.Size)
}
-func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetxattrOptions) error {
+func (d *dentry) setXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetXattrOptions) error {
if d.file.isNil() {
return syserror.EPERM
}
- if err := d.checkPermissions(creds, vfs.MayWrite); err != nil {
+ if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil {
return err
}
- if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
- return syserror.EOPNOTSUPP
- }
- if !d.userXattrSupported() {
- return syserror.EPERM
- }
return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags)
}
-func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name string) error {
+func (d *dentry) removeXattr(ctx context.Context, creds *auth.Credentials, name string) error {
if d.file.isNil() {
return syserror.EPERM
}
- if err := d.checkPermissions(creds, vfs.MayWrite); err != nil {
+ if err := d.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil {
return err
}
- if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
- return syserror.EOPNOTSUPP
- }
- if !d.userXattrSupported() {
- return syserror.EPERM
- }
return d.file.removeXattr(ctx, name)
}
@@ -1623,12 +1640,14 @@ func (d *dentry) decLinks() {
// fileDescription is embedded by gofer implementations of
// vfs.FileDescriptionImpl.
+//
+// +stateify savable
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.LockFD
- lockLogging sync.Once
+ lockLogging sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
}
func (fd *fileDescription) filesystem() *filesystem {
@@ -1666,30 +1685,30 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
return nil
}
-// Listxattr implements vfs.FileDescriptionImpl.Listxattr.
-func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
- return fd.dentry().listxattr(ctx, auth.CredentialsFromContext(ctx), size)
+// ListXattr implements vfs.FileDescriptionImpl.ListXattr.
+func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) {
+ return fd.dentry().listXattr(ctx, auth.CredentialsFromContext(ctx), size)
}
-// Getxattr implements vfs.FileDescriptionImpl.Getxattr.
-func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) {
- return fd.dentry().getxattr(ctx, auth.CredentialsFromContext(ctx), &opts)
+// GetXattr implements vfs.FileDescriptionImpl.GetXattr.
+func (fd *fileDescription) GetXattr(ctx context.Context, opts vfs.GetXattrOptions) (string, error) {
+ return fd.dentry().getXattr(ctx, auth.CredentialsFromContext(ctx), &opts)
}
-// Setxattr implements vfs.FileDescriptionImpl.Setxattr.
-func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error {
+// SetXattr implements vfs.FileDescriptionImpl.SetXattr.
+func (fd *fileDescription) SetXattr(ctx context.Context, opts vfs.SetXattrOptions) error {
d := fd.dentry()
- if err := d.setxattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil {
+ if err := d.setXattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil {
return err
}
d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
return nil
}
-// Removexattr implements vfs.FileDescriptionImpl.Removexattr.
-func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
+// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr.
+func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error {
d := fd.dentry()
- if err := d.removexattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil {
+ if err := d.removeXattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil {
return err
}
d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go
index 104157512..a9ebe1206 100644
--- a/pkg/sentry/fsimpl/gofer/handle.go
+++ b/pkg/sentry/fsimpl/gofer/handle.go
@@ -25,6 +25,8 @@ import (
// handle represents a remote "open file descriptor", consisting of an opened
// fid (p9.File) and optionally a host file descriptor.
+//
+// These are explicitly not savable.
type handle struct {
file p9file
fd int32 // -1 if unavailable
diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go
index 87f0b877f..21b4a96fe 100644
--- a/pkg/sentry/fsimpl/gofer/p9file.go
+++ b/pkg/sentry/fsimpl/gofer/p9file.go
@@ -127,6 +127,13 @@ func (f p9file) close(ctx context.Context) error {
return err
}
+func (f p9file) setAttrClose(ctx context.Context, valid p9.SetAttrMask, attr p9.SetAttr) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.SetAttrClose(valid, attr)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
func (f p9file) open(ctx context.Context, flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) {
ctx.UninterruptibleSleepStart(false)
fdobj, qid, iounit, err := f.file.Open(flags)
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index a2e9342d5..eeaf6e444 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -39,11 +39,12 @@ func (d *dentry) isRegularFile() bool {
return d.fileType() == linux.S_IFREG
}
+// +stateify savable
type regularFileFD struct {
fileDescription
// off is the file offset. off is protected by mu.
- mu sync.Mutex
+ mu sync.Mutex `state:"nosave"`
off int64
}
@@ -79,28 +80,11 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error {
// Allocate implements vfs.FileDescriptionImpl.Allocate.
func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
d := fd.dentry()
- d.metadataMu.Lock()
- defer d.metadataMu.Unlock()
-
- // Allocating a smaller size is a noop.
- size := offset + length
- if d.cachedMetadataAuthoritative() && size <= d.size {
- return nil
- }
-
- d.handleMu.RLock()
- err := d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length)
- d.handleMu.RUnlock()
- if err != nil {
- return err
- }
- d.dataMu.Lock()
- atomic.StoreUint64(&d.size, size)
- d.dataMu.Unlock()
- if d.cachedMetadataAuthoritative() {
- d.touchCMtimeLocked()
- }
- return nil
+ return d.doAllocate(ctx, offset, length, func() error {
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ return d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length)
+ })
}
// PRead implements vfs.FileDescriptionImpl.PRead.
@@ -915,6 +899,8 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) {
// dentryPlatformFile is only used when a host FD representing the remote file
// is available (i.e. dentry.hostFD >= 0), and that FD is used for application
// memory mappings (i.e. !filesystem.opts.forcePageCache).
+//
+// +stateify savable
type dentryPlatformFile struct {
*dentry
@@ -927,7 +913,7 @@ type dentryPlatformFile struct {
hostFileMapper fsutil.HostFileMapper
// hostFileMapperInitOnce is used to lazily initialize hostFileMapper.
- hostFileMapperInitOnce sync.Once
+ hostFileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
}
// IncRef implements memmap.File.IncRef.
diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go
index 85d2bee72..326b940a7 100644
--- a/pkg/sentry/fsimpl/gofer/socket.go
+++ b/pkg/sentry/fsimpl/gofer/socket.go
@@ -36,12 +36,14 @@ func (d *dentry) isSocket() bool {
// An endpoint's lifetime is the time between when filesystem.BoundEndpointAt()
// is called and either BoundEndpoint.BidirectionalConnect or
// BoundEndpoint.UnidirectionalConnect is called.
+//
+// +stateify savable
type endpoint struct {
// dentry is the filesystem dentry which produced this endpoint.
dentry *dentry
// file is the p9 file that contains a single unopened fid.
- file p9.File
+ file p9.File `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
// path is the sentry path where this endpoint is bound.
path string
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index 3c39aa9b7..71581736c 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -33,11 +34,13 @@ import (
// special files, and (when filesystemOptions.regularFilesUseSpecialFileFD is
// in effect) regular files. specialFileFD differs from regularFileFD by using
// per-FD handles instead of shared per-dentry handles, and never buffering I/O.
+//
+// +stateify savable
type specialFileFD struct {
fileDescription
// handle is used for file I/O. handle is immutable.
- handle handle
+ handle handle `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
// isRegularFile is true if this FD represents a regular file which is only
// possible when filesystemOptions.regularFilesUseSpecialFileFD is in
@@ -55,7 +58,7 @@ type specialFileFD struct {
queue waiter.Queue
// If seekable is true, off is the file offset. off is protected by mu.
- mu sync.Mutex
+ mu sync.Mutex `state:"nosave"`
off int64
}
@@ -135,6 +138,16 @@ func (fd *specialFileFD) EventUnregister(e *waiter.Entry) {
fd.fileDescription.EventUnregister(e)
}
+func (fd *specialFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ if fd.isRegularFile {
+ d := fd.dentry()
+ return d.doAllocate(ctx, offset, length, func() error {
+ return fd.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length)
+ })
+ }
+ return fd.FileDescriptionDefaultImpl.Allocate(ctx, mode, offset, length)
+}
+
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
if fd.seekable && offset < 0 {
@@ -235,11 +248,12 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
d.touchCMtime()
}
buf := make([]byte, src.NumBytes())
- // Don't do partial writes if we get a partial read from src.
- if _, err := src.CopyIn(ctx, buf); err != nil {
- return 0, offset, err
+ copied, copyErr := src.CopyIn(ctx, buf)
+ if copied == 0 && copyErr != nil {
+ // Only return the error if we didn't get any data.
+ return 0, offset, copyErr
}
- n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
+ n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:copied])), uint64(offset))
if err == syserror.EAGAIN {
err = syserror.ErrWouldBlock
}
@@ -256,7 +270,10 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
atomic.StoreUint64(&d.size, uint64(offset))
}
}
- return int64(n), offset, err
+ if err != nil {
+ return int64(n), offset, err
+ }
+ return int64(n), offset, copyErr
}
// Write implements vfs.FileDescriptionImpl.Write.
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index be1c88c82..56bcf9bdb 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -49,6 +49,7 @@ go_library(
"//pkg/fspath",
"//pkg/iovec",
"//pkg/log",
+ "//pkg/marshal/primitive",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
@@ -72,7 +73,6 @@ go_library(
"//pkg/unet",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 7561f821c..ffe4ddb32 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -58,7 +58,7 @@ func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) (
canMap: fileType == linux.S_IFREG,
}
i.pf.inode = i
- i.refs.EnableLeakCheck()
+ i.EnableLeakCheck()
// Non-seekable files can't be memory mapped, assert this.
if !i.seekable && i.canMap {
@@ -126,7 +126,7 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions)
// For simplicity, fileDescription.offset is set to 0. Technically, we
// should only set to 0 on files that are not seekable (sockets, pipes,
// etc.), and use the offset from the host fd otherwise when importing.
- return i.open(ctx, d.VFSDentry(), mnt, flags)
+ return i.open(ctx, d, mnt, flags)
}
// ImportFD sets up and returns a vfs.FileDescription from a donated fd.
@@ -137,14 +137,16 @@ func ImportFD(ctx context.Context, mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs
}
// filesystemType implements vfs.FilesystemType.
+//
+// +stateify savable
type filesystemType struct{}
-// GetFilesystem implements FilesystemType.GetFilesystem.
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (filesystemType) GetFilesystem(context.Context, *vfs.VirtualFilesystem, *auth.Credentials, string, vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
panic("host.filesystemType.GetFilesystem should never be called")
}
-// Name implements FilesystemType.Name.
+// Name implements vfs.FilesystemType.Name.
func (filesystemType) Name() string {
return "none"
}
@@ -166,6 +168,8 @@ func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) {
}
// filesystem implements vfs.FilesystemImpl.
+//
+// +stateify savable
type filesystem struct {
kernfs.Filesystem
@@ -185,6 +189,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
}
// inode implements kernfs.Inode.
+//
+// +stateify savable
type inode struct {
kernfs.InodeNoStatFS
kernfs.InodeNotDirectory
@@ -193,7 +199,7 @@ type inode struct {
locks vfs.FileLocks
// When the reference count reaches zero, the host fd is closed.
- refs inodeRefs
+ inodeRefs
// hostFD contains the host fd that this file was originally created from,
// which must be available at time of restore.
@@ -233,7 +239,7 @@ type inode struct {
canMap bool
// mapsMu protects mappings.
- mapsMu sync.Mutex
+ mapsMu sync.Mutex `state:"nosave"`
// If canMap is true, mappings tracks mappings of hostFD into
// memmap.MappingSpaces.
@@ -243,7 +249,7 @@ type inode struct {
pf inodePlatformFile
}
-// CheckPermissions implements kernfs.Inode.
+// CheckPermissions implements kernfs.Inode.CheckPermissions.
func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
var s syscall.Stat_t
if err := syscall.Fstat(i.hostFD, &s); err != nil {
@@ -252,7 +258,7 @@ func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, a
return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(s.Mode), auth.KUID(s.Uid), auth.KGID(s.Gid))
}
-// Mode implements kernfs.Inode.
+// Mode implements kernfs.Inode.Mode.
func (i *inode) Mode() linux.FileMode {
var s syscall.Stat_t
if err := syscall.Fstat(i.hostFD, &s); err != nil {
@@ -263,7 +269,7 @@ func (i *inode) Mode() linux.FileMode {
return linux.FileMode(s.Mode)
}
-// Stat implements kernfs.Inode.
+// Stat implements kernfs.Inode.Stat.
func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
if opts.Mask&linux.STATX__RESERVED != 0 {
return linux.Statx{}, syserror.EINVAL
@@ -376,7 +382,7 @@ func (i *inode) fstat(fs *filesystem) (linux.Statx, error) {
}, nil
}
-// SetStat implements kernfs.Inode.
+// SetStat implements kernfs.Inode.SetStat.
func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
s := &opts.Stat
@@ -435,19 +441,9 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
return nil
}
-// IncRef implements kernfs.Inode.
-func (i *inode) IncRef() {
- i.refs.IncRef()
-}
-
-// TryIncRef implements kernfs.Inode.
-func (i *inode) TryIncRef() bool {
- return i.refs.TryIncRef()
-}
-
-// DecRef implements kernfs.Inode.
+// DecRef implements kernfs.Inode.DecRef.
func (i *inode) DecRef(ctx context.Context) {
- i.refs.DecRef(func() {
+ i.inodeRefs.DecRef(func() {
if i.wouldBlock {
fdnotifier.RemoveFD(int32(i.hostFD))
}
@@ -457,16 +453,16 @@ func (i *inode) DecRef(ctx context.Context) {
})
}
-// Open implements kernfs.Inode.
-func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+// Open implements kernfs.Inode.Open.
+func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
// Once created, we cannot re-open a socket fd through /proc/[pid]/fd/.
if i.Mode().FileType() == linux.S_IFSOCK {
return nil, syserror.ENXIO
}
- return i.open(ctx, vfsd, rp.Mount(), opts.Flags)
+ return i.open(ctx, d, rp.Mount(), opts.Flags)
}
-func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, error) {
+func (i *inode) open(ctx context.Context, d *kernfs.Dentry, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, error) {
var s syscall.Stat_t
if err := syscall.Fstat(i.hostFD, &s); err != nil {
return nil, err
@@ -490,17 +486,17 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u
return nil, err
}
// Currently, we only allow Unix sockets to be imported.
- return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d, &i.locks)
+ return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d.VFSDentry(), &i.locks)
case syscall.S_IFREG, syscall.S_IFIFO, syscall.S_IFCHR:
if i.isTTY {
fd := &TTYFileDescription{
fileDescription: fileDescription{inode: i},
- termios: linux.DefaultSlaveTermios,
+ termios: linux.DefaultReplicaTermios,
}
fd.LockFD.Init(&i.locks)
vfsfd := &fd.vfsfd
- if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := vfsfd.Init(fd, flags, mnt, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return vfsfd, nil
@@ -509,7 +505,7 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u
fd := &fileDescription{inode: i}
fd.LockFD.Init(&i.locks)
vfsfd := &fd.vfsfd
- if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := vfsfd.Init(fd, flags, mnt, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return vfsfd, nil
@@ -521,6 +517,8 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u
}
// fileDescription is embedded by host fd implementations of FileDescriptionImpl.
+//
+// +stateify savable
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -535,40 +533,35 @@ type fileDescription struct {
inode *inode
// offsetMu protects offset.
- offsetMu sync.Mutex
+ offsetMu sync.Mutex `state:"nosave"`
// offset specifies the current file offset. It is only meaningful when
// inode.seekable is true.
offset int64
}
-// SetStat implements vfs.FileDescriptionImpl.
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
creds := auth.CredentialsFromContext(ctx)
return f.inode.SetStat(ctx, f.vfsfd.Mount().Filesystem(), creds, opts)
}
-// Stat implements vfs.FileDescriptionImpl.
+// Stat implements vfs.FileDescriptionImpl.Stat.
func (f *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
return f.inode.Stat(ctx, f.vfsfd.Mount().Filesystem(), opts)
}
-// Release implements vfs.FileDescriptionImpl.
+// Release implements vfs.FileDescriptionImpl.Release.
func (f *fileDescription) Release(context.Context) {
// noop
}
-// Allocate implements vfs.FileDescriptionImpl.
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uint64) error {
- if !f.inode.seekable {
- return syserror.ESPIPE
- }
-
- // TODO(gvisor.dev/issue/3589): Implement Allocate for non-pipe hostfds.
- return syserror.EOPNOTSUPP
+ return unix.Fallocate(f.inode.hostFD, uint32(mode), int64(offset), int64(length))
}
-// PRead implements FileDescriptionImpl.
+// PRead implements vfs.FileDescriptionImpl.PRead.
func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
i := f.inode
if !i.seekable {
@@ -578,7 +571,7 @@ func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, off
return readFromHostFD(ctx, i.hostFD, dst, offset, opts.Flags)
}
-// Read implements FileDescriptionImpl.
+// Read implements vfs.FileDescriptionImpl.Read.
func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
i := f.inode
if !i.seekable {
@@ -615,7 +608,7 @@ func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, off
return int64(n), err
}
-// PWrite implements FileDescriptionImpl.
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
if !f.inode.seekable {
return 0, syserror.ESPIPE
@@ -624,7 +617,7 @@ func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, of
return f.writeToHostFD(ctx, src, offset, opts.Flags)
}
-// Write implements FileDescriptionImpl.
+// Write implements vfs.FileDescriptionImpl.Write.
func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
i := f.inode
if !i.seekable {
@@ -672,7 +665,7 @@ func (f *fileDescription) writeToHostFD(ctx context.Context, src usermem.IOSeque
return int64(n), err
}
-// Seek implements FileDescriptionImpl.
+// Seek implements vfs.FileDescriptionImpl.Seek.
//
// Note that we do not support seeking on directories, since we do not even
// allow directory fds to be imported at all.
@@ -737,13 +730,13 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i
return f.offset, nil
}
-// Sync implements FileDescriptionImpl.
+// Sync implements vfs.FileDescriptionImpl.Sync.
func (f *fileDescription) Sync(context.Context) error {
// TODO(gvisor.dev/issue/1897): Currently, we always sync everything.
return unix.Fsync(f.inode.hostFD)
}
-// ConfigureMMap implements FileDescriptionImpl.
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error {
if !f.inode.canMap {
return syserror.ENODEV
diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go
index 65d3af38c..b51a17bed 100644
--- a/pkg/sentry/fsimpl/host/mmap.go
+++ b/pkg/sentry/fsimpl/host/mmap.go
@@ -27,11 +27,13 @@ import (
// cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef.
//
// inodePlatformFile should only be used if inode.canMap is true.
+//
+// +stateify savable
type inodePlatformFile struct {
*inode
// fdRefsMu protects fdRefs.
- fdRefsMu sync.Mutex
+ fdRefsMu sync.Mutex `state:"nosave"`
// fdRefs counts references on memmap.File offsets. It is used solely for
// memory accounting.
@@ -41,7 +43,7 @@ type inodePlatformFile struct {
fileMapper fsutil.HostFileMapper
// fileMapperInitOnce is used to lazily initialize fileMapper.
- fileMapperInitOnce sync.Once
+ fileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
}
// IncRef implements memmap.File.IncRef.
diff --git a/pkg/sentry/fsimpl/host/socket_unsafe.go b/pkg/sentry/fsimpl/host/socket_unsafe.go
index 35ded24bc..c0bf45f08 100644
--- a/pkg/sentry/fsimpl/host/socket_unsafe.go
+++ b/pkg/sentry/fsimpl/host/socket_unsafe.go
@@ -63,10 +63,10 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) (
controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC
if n > length {
- return length, n, msg.Controllen, controlTrunc, err
+ return length, n, msg.Controllen, controlTrunc, nil
}
- return n, n, msg.Controllen, controlTrunc, err
+ return n, n, msg.Controllen, controlTrunc, nil
}
// fdWriteVec sends from bufs to fd.
diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go
index 7a9be4b97..f5c596fec 100644
--- a/pkg/sentry/fsimpl/host/tty.go
+++ b/pkg/sentry/fsimpl/host/tty.go
@@ -17,6 +17,7 @@ package host
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -25,11 +26,12 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// TTYFileDescription implements vfs.FileDescriptionImpl for a host file
// descriptor that wraps a TTY FD.
+//
+// +stateify savable
type TTYFileDescription struct {
fileDescription
@@ -76,7 +78,7 @@ func (t *TTYFileDescription) Release(ctx context.Context) {
t.fileDescription.Release(ctx)
}
-// PRead implements vfs.FileDescriptionImpl.
+// PRead implements vfs.FileDescriptionImpl.PRead.
//
// Reading from a TTY is only allowed for foreground process groups. Background
// process groups will either get EIO or a SIGTTIN.
@@ -94,7 +96,7 @@ func (t *TTYFileDescription) PRead(ctx context.Context, dst usermem.IOSequence,
return t.fileDescription.PRead(ctx, dst, offset, opts)
}
-// Read implements vfs.FileDescriptionImpl.
+// Read implements vfs.FileDescriptionImpl.Read.
//
// Reading from a TTY is only allowed for foreground process groups. Background
// process groups will either get EIO or a SIGTTIN.
@@ -112,7 +114,7 @@ func (t *TTYFileDescription) Read(ctx context.Context, dst usermem.IOSequence, o
return t.fileDescription.Read(ctx, dst, opts)
}
-// PWrite implements vfs.FileDescriptionImpl.
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (t *TTYFileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
t.mu.Lock()
defer t.mu.Unlock()
@@ -127,7 +129,7 @@ func (t *TTYFileDescription) PWrite(ctx context.Context, src usermem.IOSequence,
return t.fileDescription.PWrite(ctx, src, offset, opts)
}
-// Write implements vfs.FileDescriptionImpl.
+// Write implements vfs.FileDescriptionImpl.Write.
func (t *TTYFileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
t.mu.Lock()
defer t.mu.Unlock()
@@ -142,7 +144,7 @@ func (t *TTYFileDescription) Write(ctx context.Context, src usermem.IOSequence,
return t.fileDescription.Write(ctx, src, opts)
}
-// Ioctl implements vfs.FileDescriptionImpl.
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
task := kernel.TaskFromContext(ctx)
if task == nil {
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index 637dca70c..5e91e0536 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -83,6 +83,7 @@ go_library(
"slot_list.go",
"static_directory_refs.go",
"symlink.go",
+ "synthetic_directory.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index 1ee089620..b929118b1 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -56,9 +56,9 @@ func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint
}
// Open implements Inode.Open.
-func (f *DynamicBytesFile) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+func (f *DynamicBytesFile) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &DynamicBytesFD{}
- if err := fd.Init(rp.Mount(), vfsd, f.data, &f.locks, opts.Flags); err != nil {
+ if err := fd.Init(rp.Mount(), d, f.data, &f.locks, opts.Flags); err != nil {
return nil, err
}
return &fd.vfsfd, nil
@@ -87,12 +87,12 @@ type DynamicBytesFD struct {
}
// Init initializes a DynamicBytesFD.
-func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, locks *vfs.FileLocks, flags uint32) error {
+func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *Dentry, data vfs.DynamicBytesSource, locks *vfs.FileLocks, flags uint32) error {
fd.LockFD.Init(locks)
- if err := fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := fd.vfsfd.Init(fd, flags, m, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
return err
}
- fd.inode = d.Impl().(*Dentry).inode
+ fd.inode = d.inode
fd.SetDataSource(data)
return nil
}
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index 6518ff5cd..0a4cd4057 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -29,6 +29,8 @@ import (
)
// SeekEndConfig describes the SEEK_END behaviour for FDs.
+//
+// +stateify savable
type SeekEndConfig int
// Constants related to SEEK_END behaviour for FDs.
@@ -41,6 +43,8 @@ const (
)
// GenericDirectoryFDOptions contains configuration for a GenericDirectoryFD.
+//
+// +stateify savable
type GenericDirectoryFDOptions struct {
SeekEnd SeekEndConfig
}
@@ -56,6 +60,8 @@ type GenericDirectoryFDOptions struct {
// Must be initialize with Init before first use.
//
// Lock ordering: mu => children.mu.
+//
+// +stateify savable
type GenericDirectoryFD struct {
vfs.FileDescriptionDefaultImpl
vfs.DirectoryFileDescriptionDefaultImpl
@@ -68,7 +74,7 @@ type GenericDirectoryFD struct {
children *OrderedChildren
// mu protects the fields below.
- mu sync.Mutex
+ mu sync.Mutex `state:"nosave"`
// off is the current directory offset. Protected by "mu".
off int64
@@ -76,12 +82,12 @@ type GenericDirectoryFD struct {
// NewGenericDirectoryFD creates a new GenericDirectoryFD and returns its
// dentry.
-func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions, fdOpts GenericDirectoryFDOptions) (*GenericDirectoryFD, error) {
+func NewGenericDirectoryFD(m *vfs.Mount, d *Dentry, children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions, fdOpts GenericDirectoryFDOptions) (*GenericDirectoryFD, error) {
fd := &GenericDirectoryFD{}
if err := fd.Init(children, locks, opts, fdOpts); err != nil {
return nil, err
}
- if err := fd.vfsfd.Init(fd, opts.Flags, m, d, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := fd.vfsfd.Init(fd, opts.Flags, m, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return fd, nil
@@ -195,8 +201,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
// these.
childIdx := fd.off - 2
for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() {
- inode := it.Dentry.Impl().(*Dentry).inode
- stat, err := inode.Stat(ctx, fd.filesystem(), opts)
+ stat, err := it.Dentry.inode.Stat(ctx, fd.filesystem(), opts)
if err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 0e3011689..5cc1c4281 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -37,8 +37,7 @@ import (
// * !rp.Done().
//
// Postcondition: Caller must call fs.processDeferredDecRefs*.
-func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, mayFollowSymlinks bool) (*vfs.Dentry, error) {
- d := vfsd.Impl().(*Dentry)
+func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, mayFollowSymlinks bool) (*Dentry, error) {
if !d.isDir() {
return nil, syserror.ENOTDIR
}
@@ -55,20 +54,20 @@ afterSymlink:
// calls d_revalidate(), but walk_component() => handle_dots() does not.
if name == "." {
rp.Advance()
- return vfsd, nil
+ return d, nil
}
if name == ".." {
- if isRoot, err := rp.CheckRoot(ctx, vfsd); err != nil {
+ if isRoot, err := rp.CheckRoot(ctx, d.VFSDentry()); err != nil {
return nil, err
} else if isRoot || d.parent == nil {
rp.Advance()
- return vfsd, nil
+ return d, nil
}
- if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
+ if err := rp.CheckMount(ctx, d.parent.VFSDentry()); err != nil {
return nil, err
}
rp.Advance()
- return &d.parent.vfsd, nil
+ return d.parent, nil
}
if len(name) > linux.NAME_MAX {
return nil, syserror.ENAMETOOLONG
@@ -79,7 +78,7 @@ afterSymlink:
if err != nil {
return nil, err
}
- if err := rp.CheckMount(ctx, &next.vfsd); err != nil {
+ if err := rp.CheckMount(ctx, next.VFSDentry()); err != nil {
return nil, err
}
// Resolve any symlink at current path component.
@@ -102,7 +101,7 @@ afterSymlink:
goto afterSymlink
}
rp.Advance()
- return &next.vfsd, nil
+ return next, nil
}
// revalidateChildLocked must be called after a call to parent.vfsd.Child(name)
@@ -122,25 +121,21 @@ func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
if !child.inode.Valid(ctx) {
delete(parent.children, name)
vfsObj.InvalidateDentry(ctx, &child.vfsd)
- fs.deferDecRef(&child.vfsd) // Reference from Lookup.
+ fs.deferDecRef(child) // Reference from Lookup.
child = nil
}
}
if child == nil {
- // Dentry isn't cached; it either doesn't exist or failed
- // revalidation. Attempt to resolve it via Lookup.
- //
- // FIXME(gvisor.dev/issue/1193): Inode.Lookup() should return
- // *(kernfs.)Dentry, not *vfs.Dentry, since (kernfs.)Filesystem assumes
- // that all dentries in the filesystem are (kernfs.)Dentry and performs
- // vfs.DentryImpl casts accordingly.
- childVFSD, err := parent.inode.Lookup(ctx, name)
+ // Dentry isn't cached; it either doesn't exist or failed revalidation.
+ // Attempt to resolve it via Lookup.
+ c, err := parent.inode.Lookup(ctx, name)
if err != nil {
return nil, err
}
- // Reference on childVFSD dropped by a corresponding Valid.
- child = childVFSD.Impl().(*Dentry)
- parent.insertChildLocked(name, child)
+ // Reference on c (provided by Lookup) will be dropped when the dentry
+ // fails validation.
+ parent.InsertChildLocked(name, c)
+ child = c
}
return child, nil
}
@@ -153,20 +148,19 @@ func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
// Preconditions: Filesystem.mu must be locked for at least reading.
//
// Postconditions: Caller must call fs.processDeferredDecRefs*.
-func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) {
- vfsd := rp.Start()
+func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingPath) (*Dentry, error) {
+ d := rp.Start().Impl().(*Dentry)
for !rp.Done() {
var err error
- vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd, true /* mayFollowSymlinks */)
+ d, err = fs.stepExistingLocked(ctx, rp, d, true /* mayFollowSymlinks */)
if err != nil {
- return nil, nil, err
+ return nil, err
}
}
- d := vfsd.Impl().(*Dentry)
if rp.MustBeDir() && !d.isDir() {
- return nil, nil, syserror.ENOTDIR
+ return nil, syserror.ENOTDIR
}
- return vfsd, d.inode, nil
+ return d, nil
}
// walkParentDirLocked resolves all but the last path component of rp to an
@@ -181,20 +175,19 @@ func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingP
// * !rp.Done().
//
// Postconditions: Caller must call fs.processDeferredDecRefs*.
-func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) {
- vfsd := rp.Start()
+func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath) (*Dentry, error) {
+ d := rp.Start().Impl().(*Dentry)
for !rp.Final() {
var err error
- vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd, true /* mayFollowSymlinks */)
+ d, err = fs.stepExistingLocked(ctx, rp, d, true /* mayFollowSymlinks */)
if err != nil {
- return nil, nil, err
+ return nil, err
}
}
- d := vfsd.Impl().(*Dentry)
if !d.isDir() {
- return nil, nil, syserror.ENOTDIR
+ return nil, syserror.ENOTDIR
}
- return vfsd, d.inode, nil
+ return d, nil
}
// checkCreateLocked checks that a file named rp.Component() may be created in
@@ -202,10 +195,9 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving
//
// Preconditions:
// * Filesystem.mu must be locked for at least reading.
-// * parentInode == parentVFSD.Impl().(*Dentry).Inode.
// * isDir(parentInode) == true.
-func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) {
- if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *Dentry) (string, error) {
+ if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return "", err
}
pc := rp.Component()
@@ -215,11 +207,10 @@ func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *v
if len(pc) > linux.NAME_MAX {
return "", syserror.ENAMETOOLONG
}
- // FIXME(gvisor.dev/issue/1193): Data race due to not holding dirMu.
- if _, ok := parentVFSD.Impl().(*Dentry).children[pc]; ok {
+ if _, ok := parent.children[pc]; ok {
return "", syserror.EEXIST
}
- if parentVFSD.IsDead() {
+ if parent.VFSDentry().IsDead() {
return "", syserror.ENOENT
}
return pc, nil
@@ -228,8 +219,8 @@ func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *v
// checkDeleteLocked checks that the file represented by vfsd may be deleted.
//
// Preconditions: Filesystem.mu must be locked for at least reading.
-func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error {
- parent := vfsd.Impl().(*Dentry).parent
+func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry) error {
+ parent := d.parent
if parent == nil {
return syserror.EBUSY
}
@@ -258,11 +249,11 @@ func (fs *Filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds
defer fs.processDeferredDecRefs(ctx)
defer fs.mu.RUnlock()
- _, inode, err := fs.walkExistingLocked(ctx, rp)
+ d, err := fs.walkExistingLocked(ctx, rp)
if err != nil {
return err
}
- return inode.CheckPermissions(ctx, creds, ats)
+ return d.inode.CheckPermissions(ctx, creds, ats)
}
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
@@ -270,20 +261,20 @@ func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
fs.mu.RLock()
defer fs.processDeferredDecRefs(ctx)
defer fs.mu.RUnlock()
- vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
+ d, err := fs.walkExistingLocked(ctx, rp)
if err != nil {
return nil, err
}
if opts.CheckSearchable {
- d := vfsd.Impl().(*Dentry)
if !d.isDir() {
return nil, syserror.ENOTDIR
}
- if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
+ if err := d.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
}
+ vfsd := d.VFSDentry()
vfsd.IncRef() // Ownership transferred to caller.
return vfsd, nil
}
@@ -293,12 +284,12 @@ func (fs *Filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa
fs.mu.RLock()
defer fs.processDeferredDecRefs(ctx)
defer fs.mu.RUnlock()
- vfsd, _, err := fs.walkParentDirLocked(ctx, rp)
+ d, err := fs.walkParentDirLocked(ctx, rp)
if err != nil {
return nil, err
}
- vfsd.IncRef() // Ownership transferred to caller.
- return vfsd, nil
+ d.IncRef() // Ownership transferred to caller.
+ return d.VFSDentry(), nil
}
// LinkAt implements vfs.FilesystemImpl.LinkAt.
@@ -308,12 +299,15 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
}
fs.mu.Lock()
defer fs.mu.Unlock()
- parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ parent, err := fs.walkParentDirLocked(ctx, rp)
fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
- pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
+
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+ pc, err := checkCreateLocked(ctx, rp, parent)
if err != nil {
return err
}
@@ -330,11 +324,11 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
return syserror.EPERM
}
- childVFSD, err := parentInode.NewLink(ctx, pc, d.inode)
+ child, err := parent.inode.NewLink(ctx, pc, d.inode)
if err != nil {
return err
}
- parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry))
+ parent.InsertChildLocked(pc, child)
return nil
}
@@ -345,12 +339,15 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
}
fs.mu.Lock()
defer fs.mu.Unlock()
- parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ parent, err := fs.walkParentDirLocked(ctx, rp)
fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
- pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
+
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+ pc, err := checkCreateLocked(ctx, rp, parent)
if err != nil {
return err
}
@@ -358,11 +355,14 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return err
}
defer rp.Mount().EndWrite()
- childVFSD, err := parentInode.NewDir(ctx, pc, opts)
+ child, err := parent.inode.NewDir(ctx, pc, opts)
if err != nil {
- return err
+ if !opts.ForSyntheticMountpoint || err == syserror.EEXIST {
+ return err
+ }
+ child = newSyntheticDirectory(rp.Credentials(), opts.Mode)
}
- parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry))
+ parent.InsertChildLocked(pc, child)
return nil
}
@@ -373,12 +373,15 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
}
fs.mu.Lock()
defer fs.mu.Unlock()
- parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ parent, err := fs.walkParentDirLocked(ctx, rp)
fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
- pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
+
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+ pc, err := checkCreateLocked(ctx, rp, parent)
if err != nil {
return err
}
@@ -386,11 +389,11 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return err
}
defer rp.Mount().EndWrite()
- newVFSD, err := parentInode.NewNode(ctx, pc, opts)
+ newD, err := parent.inode.NewNode(ctx, pc, opts)
if err != nil {
return err
}
- parentVFSD.Impl().(*Dentry).InsertChild(pc, newVFSD.Impl().(*Dentry))
+ parent.InsertChildLocked(pc, newD)
return nil
}
@@ -406,28 +409,27 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
// Do not create new file.
if opts.Flags&linux.O_CREAT == 0 {
fs.mu.RLock()
- vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
+ d, err := fs.walkExistingLocked(ctx, rp)
if err != nil {
fs.mu.RUnlock()
fs.processDeferredDecRefs(ctx)
return nil, err
}
- if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
+ if err := d.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
fs.mu.RUnlock()
fs.processDeferredDecRefs(ctx)
return nil, err
}
- inode.IncRef()
- defer inode.DecRef(ctx)
+ d.inode.IncRef()
+ defer d.inode.DecRef(ctx)
fs.mu.RUnlock()
fs.processDeferredDecRefs(ctx)
- return inode.Open(ctx, rp, vfsd, opts)
+ return d.inode.Open(ctx, rp, d, opts)
}
// May create new file.
mustCreate := opts.Flags&linux.O_EXCL != 0
- vfsd := rp.Start()
- inode := vfsd.Impl().(*Dentry).inode
+ d := rp.Start().Impl().(*Dentry)
fs.mu.Lock()
unlocked := false
unlock := func() {
@@ -444,22 +446,22 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if mustCreate {
return nil, syserror.EEXIST
}
- if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
+ if err := d.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
- inode.IncRef()
- defer inode.DecRef(ctx)
+ d.inode.IncRef()
+ defer d.inode.DecRef(ctx)
unlock()
- return inode.Open(ctx, rp, vfsd, opts)
+ return d.inode.Open(ctx, rp, d, opts)
}
afterTrailingSymlink:
- parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ parent, err := fs.walkParentDirLocked(ctx, rp)
fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return nil, err
}
// Check for search permission in the parent directory.
- if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
+ if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
// Reject attempts to open directories with O_CREAT.
@@ -474,10 +476,10 @@ afterTrailingSymlink:
return nil, syserror.ENAMETOOLONG
}
// Determine whether or not we need to create a file.
- childVFSD, err := fs.stepExistingLocked(ctx, rp, parentVFSD, false /* mayFollowSymlinks */)
+ child, err := fs.stepExistingLocked(ctx, rp, parent, false /* mayFollowSymlinks */)
if err == syserror.ENOENT {
// Already checked for searchability above; now check for writability.
- if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil {
+ if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil {
return nil, err
}
if err := rp.Mount().CheckBeginWrite(); err != nil {
@@ -485,16 +487,18 @@ afterTrailingSymlink:
}
defer rp.Mount().EndWrite()
// Create and open the child.
- childVFSD, err = parentInode.NewFile(ctx, pc, opts)
+ child, err := parent.inode.NewFile(ctx, pc, opts)
if err != nil {
return nil, err
}
- child := childVFSD.Impl().(*Dentry)
- parentVFSD.Impl().(*Dentry).InsertChild(pc, child)
+ // FIXME(gvisor.dev/issue/1193): Race between checking existence with
+ // fs.stepExistingLocked and parent.InsertChild. If possible, we should hold
+ // dirMu from one to the other.
+ parent.InsertChild(pc, child)
child.inode.IncRef()
defer child.inode.DecRef(ctx)
unlock()
- return child.inode.Open(ctx, rp, childVFSD, opts)
+ return child.inode.Open(ctx, rp, child, opts)
}
if err != nil {
return nil, err
@@ -503,7 +507,6 @@ afterTrailingSymlink:
if mustCreate {
return nil, syserror.EEXIST
}
- child := childVFSD.Impl().(*Dentry)
if rp.ShouldFollowSymlink() && child.isSymlink() {
targetVD, targetPathname, err := child.inode.Getlink(ctx, rp.Mount())
if err != nil {
@@ -530,22 +533,22 @@ afterTrailingSymlink:
child.inode.IncRef()
defer child.inode.DecRef(ctx)
unlock()
- return child.inode.Open(ctx, rp, &child.vfsd, opts)
+ return child.inode.Open(ctx, rp, child, opts)
}
// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
fs.mu.RLock()
- d, inode, err := fs.walkExistingLocked(ctx, rp)
+ d, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
fs.processDeferredDecRefs(ctx)
if err != nil {
return "", err
}
- if !d.Impl().(*Dentry).isSymlink() {
+ if !d.isSymlink() {
return "", syserror.EINVAL
}
- return inode.Readlink(ctx)
+ return d.inode.Readlink(ctx, rp.Mount())
}
// RenameAt implements vfs.FilesystemImpl.RenameAt.
@@ -562,11 +565,10 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// Resolve the destination directory first to verify that it's on this
// Mount.
- dstDirVFSD, dstDirInode, err := fs.walkParentDirLocked(ctx, rp)
+ dstDir, err := fs.walkParentDirLocked(ctx, rp)
if err != nil {
return err
}
- dstDir := dstDirVFSD.Impl().(*Dentry)
mnt := rp.Mount()
if mnt != oldParentVD.Mount() {
return syserror.EXDEV
@@ -584,16 +586,15 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if err != nil {
return err
}
- srcVFSD := &src.vfsd
// Can we remove the src dentry?
- if err := checkDeleteLocked(ctx, rp, srcVFSD); err != nil {
+ if err := checkDeleteLocked(ctx, rp, src); err != nil {
return err
}
// Can we create the dst dentry?
var dst *Dentry
- pc, err := checkCreateLocked(ctx, rp, dstDirVFSD, dstDirInode)
+ pc, err := checkCreateLocked(ctx, rp, dstDir)
switch err {
case nil:
// Ok, continue with rename as replacement.
@@ -604,14 +605,14 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
dst = dstDir.children[pc]
if dst == nil {
- panic(fmt.Sprintf("Child %q for parent Dentry %+v disappeared inside atomic section?", pc, dstDirVFSD))
+ panic(fmt.Sprintf("Child %q for parent Dentry %+v disappeared inside atomic section?", pc, dstDir))
}
default:
return err
}
var dstVFSD *vfs.Dentry
if dst != nil {
- dstVFSD = &dst.vfsd
+ dstVFSD = dst.VFSDentry()
}
mntns := vfs.MountNamespaceFromContext(ctx)
@@ -627,17 +628,18 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
defer dstDir.dirMu.Unlock()
}
+ srcVFSD := src.VFSDentry()
if err := virtfs.PrepareRenameDentry(mntns, srcVFSD, dstVFSD); err != nil {
return err
}
- replaced, err := srcDir.inode.Rename(ctx, src.name, pc, srcVFSD, dstDirVFSD)
+ replaced, err := srcDir.inode.Rename(ctx, src.name, pc, src, dstDir)
if err != nil {
virtfs.AbortRenameDentry(srcVFSD, dstVFSD)
return err
}
delete(srcDir.children, src.name)
if srcDir != dstDir {
- fs.deferDecRef(srcDirVFSD)
+ fs.deferDecRef(srcDir)
dstDir.IncRef()
}
src.parent = dstDir
@@ -646,7 +648,11 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
dstDir.children = make(map[string]*Dentry)
}
dstDir.children[pc] = src
- virtfs.CommitRenameReplaceDentry(ctx, srcVFSD, replaced)
+ var replaceVFSD *vfs.Dentry
+ if replaced != nil {
+ replaceVFSD = replaced.VFSDentry()
+ }
+ virtfs.CommitRenameReplaceDentry(ctx, srcVFSD, replaceVFSD)
return nil
}
@@ -654,7 +660,8 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
fs.mu.Lock()
defer fs.mu.Unlock()
- vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
+
+ d, err := fs.walkExistingLocked(ctx, rp)
fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
@@ -663,14 +670,13 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
return err
}
defer rp.Mount().EndWrite()
- if err := checkDeleteLocked(ctx, rp, vfsd); err != nil {
+ if err := checkDeleteLocked(ctx, rp, d); err != nil {
return err
}
- d := vfsd.Impl().(*Dentry)
if !d.isDir() {
return syserror.ENOTDIR
}
- if inode.HasChildren() {
+ if d.inode.HasChildren() {
return syserror.ENOTEMPTY
}
virtfs := rp.VirtualFilesystem()
@@ -680,10 +686,12 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
mntns := vfs.MountNamespaceFromContext(ctx)
defer mntns.DecRef(ctx)
+ vfsd := d.VFSDentry()
if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
return err
}
- if err := parentDentry.inode.RmDir(ctx, rp.Component(), vfsd); err != nil {
+
+ if err := parentDentry.inode.RmDir(ctx, d.name, d); err != nil {
virtfs.AbortDeleteDentry(vfsd)
return err
}
@@ -694,7 +702,7 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
fs.mu.RLock()
- _, inode, err := fs.walkExistingLocked(ctx, rp)
+ d, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
fs.processDeferredDecRefs(ctx)
if err != nil {
@@ -703,31 +711,31 @@ func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
if opts.Stat.Mask == 0 {
return nil
}
- return inode.SetStat(ctx, fs.VFSFilesystem(), rp.Credentials(), opts)
+ return d.inode.SetStat(ctx, fs.VFSFilesystem(), rp.Credentials(), opts)
}
// StatAt implements vfs.FilesystemImpl.StatAt.
func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
fs.mu.RLock()
- _, inode, err := fs.walkExistingLocked(ctx, rp)
+ d, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
fs.processDeferredDecRefs(ctx)
if err != nil {
return linux.Statx{}, err
}
- return inode.Stat(ctx, fs.VFSFilesystem(), opts)
+ return d.inode.Stat(ctx, fs.VFSFilesystem(), opts)
}
// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
fs.mu.RLock()
- _, inode, err := fs.walkExistingLocked(ctx, rp)
+ d, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
fs.processDeferredDecRefs(ctx)
if err != nil {
return linux.Statfs{}, err
}
- return inode.StatFS(ctx, fs.VFSFilesystem())
+ return d.inode.StatFS(ctx, fs.VFSFilesystem())
}
// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
@@ -737,12 +745,15 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
}
fs.mu.Lock()
defer fs.mu.Unlock()
- parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ parent, err := fs.walkParentDirLocked(ctx, rp)
fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
}
- pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+
+ pc, err := checkCreateLocked(ctx, rp, parent)
if err != nil {
return err
}
@@ -750,11 +761,11 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
return err
}
defer rp.Mount().EndWrite()
- childVFSD, err := parentInode.NewSymlink(ctx, pc, target)
+ child, err := parent.inode.NewSymlink(ctx, pc, target)
if err != nil {
return err
}
- parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry))
+ parent.InsertChildLocked(pc, child)
return nil
}
@@ -762,7 +773,8 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
fs.mu.Lock()
defer fs.mu.Unlock()
- vfsd, _, err := fs.walkExistingLocked(ctx, rp)
+
+ d, err := fs.walkExistingLocked(ctx, rp)
fs.processDeferredDecRefsLocked(ctx)
if err != nil {
return err
@@ -771,10 +783,9 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return err
}
defer rp.Mount().EndWrite()
- if err := checkDeleteLocked(ctx, rp, vfsd); err != nil {
+ if err := checkDeleteLocked(ctx, rp, d); err != nil {
return err
}
- d := vfsd.Impl().(*Dentry)
if d.isDir() {
return syserror.EISDIR
}
@@ -784,10 +795,11 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
defer parentDentry.dirMu.Unlock()
mntns := vfs.MountNamespaceFromContext(ctx)
defer mntns.DecRef(ctx)
+ vfsd := d.VFSDentry()
if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
return err
}
- if err := parentDentry.inode.Unlink(ctx, rp.Component(), vfsd); err != nil {
+ if err := parentDentry.inode.Unlink(ctx, d.name, d); err != nil {
virtfs.AbortDeleteDentry(vfsd)
return err
}
@@ -795,25 +807,25 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return nil
}
-// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt.
func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
fs.mu.RLock()
- _, inode, err := fs.walkExistingLocked(ctx, rp)
+ d, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
fs.processDeferredDecRefs(ctx)
if err != nil {
return nil, err
}
- if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil {
+ if err := d.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil {
return nil, err
}
return nil, syserror.ECONNREFUSED
}
-// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
-func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt.
+func (fs *Filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
fs.mu.RLock()
- _, _, err := fs.walkExistingLocked(ctx, rp)
+ _, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
fs.processDeferredDecRefs(ctx)
if err != nil {
@@ -823,10 +835,10 @@ func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
return nil, syserror.ENOTSUP
}
-// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
-func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt.
+func (fs *Filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) {
fs.mu.RLock()
- _, _, err := fs.walkExistingLocked(ctx, rp)
+ _, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
fs.processDeferredDecRefs(ctx)
if err != nil {
@@ -836,10 +848,10 @@ func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
return "", syserror.ENOTSUP
}
-// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
-func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt.
+func (fs *Filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error {
fs.mu.RLock()
- _, _, err := fs.walkExistingLocked(ctx, rp)
+ _, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
fs.processDeferredDecRefs(ctx)
if err != nil {
@@ -849,10 +861,10 @@ func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
return syserror.ENOTSUP
}
-// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
-func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt.
+func (fs *Filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
fs.mu.RLock()
- _, _, err := fs.walkExistingLocked(ctx, rp)
+ _, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
fs.processDeferredDecRefs(ctx)
if err != nil {
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index c0b863ba4..49210e748 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -31,6 +31,8 @@ import (
// count for inodes, performing no extra actions when references are obtained or
// released. This is suitable for simple file inodes that don't reference any
// resources.
+//
+// +stateify savable
type InodeNoopRefCount struct {
}
@@ -50,30 +52,32 @@ func (InodeNoopRefCount) TryIncRef() bool {
// InodeDirectoryNoNewChildren partially implements the Inode interface.
// InodeDirectoryNoNewChildren represents a directory inode which does not
// support creation of new children.
+//
+// +stateify savable
type InodeDirectoryNoNewChildren struct{}
// NewFile implements Inode.NewFile.
-func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*Dentry, error) {
return nil, syserror.EPERM
}
// NewDir implements Inode.NewDir.
-func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*Dentry, error) {
return nil, syserror.EPERM
}
// NewLink implements Inode.NewLink.
-func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*Dentry, error) {
return nil, syserror.EPERM
}
// NewSymlink implements Inode.NewSymlink.
-func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*Dentry, error) {
return nil, syserror.EPERM
}
// NewNode implements Inode.NewNode.
-func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*Dentry, error) {
return nil, syserror.EPERM
}
@@ -81,6 +85,8 @@ func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOpt
// inodeDirectory and inodeDynamicDirectory sub interfaces. Inodes that do not
// represent directories can embed this to provide no-op implementations for
// directory-related functions.
+//
+// +stateify savable
type InodeNotDirectory struct {
}
@@ -90,47 +96,47 @@ func (InodeNotDirectory) HasChildren() bool {
}
// NewFile implements Inode.NewFile.
-func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*Dentry, error) {
panic("NewFile called on non-directory inode")
}
// NewDir implements Inode.NewDir.
-func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*Dentry, error) {
panic("NewDir called on non-directory inode")
}
// NewLink implements Inode.NewLinkink.
-func (InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewLink(context.Context, string, Inode) (*Dentry, error) {
panic("NewLink called on non-directory inode")
}
// NewSymlink implements Inode.NewSymlink.
-func (InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewSymlink(context.Context, string, string) (*Dentry, error) {
panic("NewSymlink called on non-directory inode")
}
// NewNode implements Inode.NewNode.
-func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*Dentry, error) {
panic("NewNode called on non-directory inode")
}
// Unlink implements Inode.Unlink.
-func (InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error {
+func (InodeNotDirectory) Unlink(context.Context, string, *Dentry) error {
panic("Unlink called on non-directory inode")
}
// RmDir implements Inode.RmDir.
-func (InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error {
+func (InodeNotDirectory) RmDir(context.Context, string, *Dentry) error {
panic("RmDir called on non-directory inode")
}
// Rename implements Inode.Rename.
-func (InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) {
+func (InodeNotDirectory) Rename(context.Context, string, string, *Dentry, *Dentry) (*Dentry, error) {
panic("Rename called on non-directory inode")
}
// Lookup implements Inode.Lookup.
-func (InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+func (InodeNotDirectory) Lookup(ctx context.Context, name string) (*Dentry, error) {
panic("Lookup called on non-directory inode")
}
@@ -149,10 +155,12 @@ func (InodeNotDirectory) Valid(context.Context) bool {
// dymanic entries (i.e. entries that are not "hashed" into the
// vfs.Dentry.children) can embed this to provide no-op implementations for
// functions related to dynamic entries.
+//
+// +stateify savable
type InodeNoDynamicLookup struct{}
// Lookup implements Inode.Lookup.
-func (InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+func (InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*Dentry, error) {
return nil, syserror.ENOENT
}
@@ -169,10 +177,12 @@ func (InodeNoDynamicLookup) Valid(ctx context.Context) bool {
// InodeNotSymlink partially implements the Inode interface, specifically the
// inodeSymlink sub interface. All inodes that are not symlinks may embed this
// to return the appropriate errors from symlink-related functions.
+//
+// +stateify savable
type InodeNotSymlink struct{}
// Readlink implements Inode.Readlink.
-func (InodeNotSymlink) Readlink(context.Context) (string, error) {
+func (InodeNotSymlink) Readlink(context.Context, *vfs.Mount) (string, error) {
return "", syserror.EINVAL
}
@@ -186,6 +196,8 @@ func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry,
// inode attributes.
//
// Must be initialized by Init prior to first use.
+//
+// +stateify savable
type InodeAttrs struct {
devMajor uint32
devMinor uint32
@@ -256,12 +268,29 @@ func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (li
// SetStat implements Inode.SetStat.
func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ return a.SetInodeStat(ctx, fs, creds, opts)
+}
+
+// SetInodeStat sets the corresponding attributes from opts to InodeAttrs.
+// This function can be used by other kernfs-based filesystem implementation to
+// sets the unexported attributes into kernfs.InodeAttrs.
+func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
if opts.Stat.Mask == 0 {
return nil
}
- if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 {
+
+ // Note that not all fields are modifiable. For example, the file type and
+ // inode numbers are immutable after node creation. Setting the size is often
+ // allowed by kernfs files but does not do anything. If some other behavior is
+ // needed, the embedder should consider extending SetStat.
+ //
+ // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
+ if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 {
return syserror.EPERM
}
+ if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() {
+ return syserror.EISDIR
+ }
if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
return err
}
@@ -284,13 +313,6 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut
atomic.StoreUint32(&a.gid, stat.GID)
}
- // Note that not all fields are modifiable. For example, the file type and
- // inode numbers are immutable after node creation.
-
- // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
- // Also, STATX_SIZE will need some special handling, because read-only static
- // files should return EIO for truncate operations.
-
return nil
}
@@ -320,13 +342,16 @@ func (a *InodeAttrs) DecLinks() {
}
}
+// +stateify savable
type slot struct {
Name string
- Dentry *vfs.Dentry
+ Dentry *Dentry
slotEntry
}
// OrderedChildrenOptions contains initialization options for OrderedChildren.
+//
+// +stateify savable
type OrderedChildrenOptions struct {
// Writable indicates whether vfs.FilesystemImpl methods implemented by
// OrderedChildren may modify the tracked children. This applies to
@@ -342,12 +367,14 @@ type OrderedChildrenOptions struct {
// directories.
//
// Must be initialize with Init before first use.
+//
+// +stateify savable
type OrderedChildren struct {
// Can children be modified by user syscalls? It set to false, interface
// methods that would modify the children return EPERM. Immutable.
writable bool
- mu sync.RWMutex
+ mu sync.RWMutex `state:"nosave"`
order slotList
set map[string]*slot
}
@@ -380,7 +407,7 @@ func (o *OrderedChildren) Populate(d *Dentry, children map[string]*Dentry) uint3
if child.isDir() {
links++
}
- if err := o.Insert(name, child.VFSDentry()); err != nil {
+ if err := o.Insert(name, child); err != nil {
panic(fmt.Sprintf("Collision when attempting to insert child %q (%+v) into %+v", name, child, d))
}
d.InsertChild(name, child)
@@ -397,7 +424,7 @@ func (o *OrderedChildren) HasChildren() bool {
// Insert inserts child into o. This ignores the writability of o, as this is
// not part of the vfs.FilesystemImpl interface, and is a lower-level operation.
-func (o *OrderedChildren) Insert(name string, child *vfs.Dentry) error {
+func (o *OrderedChildren) Insert(name string, child *Dentry) error {
o.mu.Lock()
defer o.mu.Unlock()
if _, ok := o.set[name]; ok {
@@ -421,10 +448,10 @@ func (o *OrderedChildren) removeLocked(name string) {
}
// Precondition: caller must hold o.mu for writing.
-func (o *OrderedChildren) replaceChildLocked(name string, new *vfs.Dentry) *vfs.Dentry {
+func (o *OrderedChildren) replaceChildLocked(name string, new *Dentry) *Dentry {
if s, ok := o.set[name]; ok {
// Existing slot with given name, simply replace the dentry.
- var old *vfs.Dentry
+ var old *Dentry
old, s.Dentry = s.Dentry, new
return old
}
@@ -440,7 +467,7 @@ func (o *OrderedChildren) replaceChildLocked(name string, new *vfs.Dentry) *vfs.
}
// Precondition: caller must hold o.mu for reading or writing.
-func (o *OrderedChildren) checkExistingLocked(name string, child *vfs.Dentry) error {
+func (o *OrderedChildren) checkExistingLocked(name string, child *Dentry) error {
s, ok := o.set[name]
if !ok {
return syserror.ENOENT
@@ -452,7 +479,7 @@ func (o *OrderedChildren) checkExistingLocked(name string, child *vfs.Dentry) er
}
// Unlink implements Inode.Unlink.
-func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.Dentry) error {
+func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *Dentry) error {
if !o.writable {
return syserror.EPERM
}
@@ -468,12 +495,13 @@ func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.De
}
// Rmdir implements Inode.Rmdir.
-func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *vfs.Dentry) error {
+func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *Dentry) error {
// We're not responsible for checking that child is a directory, that it's
// empty, or updating any link counts; so this is the same as unlink.
return o.Unlink(ctx, name, child)
}
+// +stateify savable
type renameAcrossDifferentImplementationsError struct{}
func (renameAcrossDifferentImplementationsError) Error() string {
@@ -489,8 +517,8 @@ func (renameAcrossDifferentImplementationsError) Error() string {
// that will support Rename.
//
// Postcondition: reference on any replaced dentry transferred to caller.
-func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (*vfs.Dentry, error) {
- dst, ok := dstDir.Impl().(*Dentry).inode.(interface{}).(*OrderedChildren)
+func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *Dentry) (*Dentry, error) {
+ dst, ok := dstDir.inode.(interface{}).(*OrderedChildren)
if !ok {
return nil, renameAcrossDifferentImplementationsError{}
}
@@ -532,12 +560,14 @@ func (o *OrderedChildren) nthLocked(i int64) *slot {
}
// InodeSymlink partially implements Inode interface for symlinks.
+//
+// +stateify savable
type InodeSymlink struct {
InodeNotDirectory
}
// Open implements Inode.Open.
-func (InodeSymlink) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+func (InodeSymlink) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
return nil, syserror.ELOOP
}
@@ -564,6 +594,7 @@ var _ Inode = (*StaticDirectory)(nil)
func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*Dentry, fdOpts GenericDirectoryFDOptions) *Dentry {
inode := &StaticDirectory{}
inode.Init(creds, devMajor, devMinor, ino, perm, fdOpts)
+ inode.EnableLeakCheck()
dentry := &Dentry{}
dentry.Init(inode)
@@ -584,9 +615,9 @@ func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint3
s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm)
}
-// Open implements kernfs.Inode.
-func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := NewGenericDirectoryFD(rp.Mount(), vfsd, &s.OrderedChildren, &s.locks, &opts, s.fdOpts)
+// Open implements kernfs.Inode.Open.
+func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := NewGenericDirectoryFD(rp.Mount(), d, &s.OrderedChildren, &s.locks, &opts, s.fdOpts)
if err != nil {
return nil, err
}
@@ -598,21 +629,25 @@ func (*StaticDirectory) SetStat(context.Context, *vfs.Filesystem, *auth.Credenti
return syserror.EPERM
}
-// DecRef implements kernfs.Inode.
+// DecRef implements kernfs.Inode.DecRef.
func (s *StaticDirectory) DecRef(context.Context) {
s.StaticDirectoryRefs.DecRef(s.Destroy)
}
// AlwaysValid partially implements kernfs.inodeDynamicLookup.
+//
+// +stateify savable
type AlwaysValid struct{}
-// Valid implements kernfs.inodeDynamicLookup.
+// Valid implements kernfs.inodeDynamicLookup.Valid.
func (*AlwaysValid) Valid(context.Context) bool {
return true
}
// InodeNoStatFS partially implements the Inode interface, where the client
// filesystem doesn't support statfs(2).
+//
+// +stateify savable
type InodeNoStatFS struct{}
// StatFS implements Inode.StatFS.
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 88fcd54aa..6d3d79333 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -29,7 +29,7 @@
//
// Reference Model:
//
-// Kernfs dentries represents named pointers to inodes. Dentries and inode have
+// Kernfs dentries represents named pointers to inodes. Dentries and inodes have
// independent lifetimes and reference counts. A child dentry unconditionally
// holds a reference on its parent directory's dentry. A dentry also holds a
// reference on the inode it points to. Multiple dentries can point to the same
@@ -60,20 +60,23 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
)
// Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory
// filesystem. Concrete implementations are expected to embed this in their own
// Filesystem type.
+//
+// +stateify savable
type Filesystem struct {
vfsfs vfs.Filesystem
- droppedDentriesMu sync.Mutex
+ droppedDentriesMu sync.Mutex `state:"nosave"`
// droppedDentries is a list of dentries waiting to be DecRef()ed. This is
// used to defer dentry destruction until mu can be acquired for
// writing. Protected by droppedDentriesMu.
- droppedDentries []*vfs.Dentry
+ droppedDentries []*Dentry
// mu synchronizes the lifetime of Dentries on this filesystem. Holding it
// for reading guarantees continued existence of any resolved dentries, but
@@ -96,7 +99,7 @@ type Filesystem struct {
// defer fs.mu.RUnlock()
// ...
// fs.deferDecRef(dentry)
- mu sync.RWMutex
+ mu sync.RWMutex `state:"nosave"`
// nextInoMinusOne is used to to allocate inode numbers on this
// filesystem. Must be accessed by atomic operations.
@@ -107,7 +110,7 @@ type Filesystem struct {
// processDeferredDecRefs{,Locked}. See comment on Filesystem.mu.
//
// Precondition: d must not already be pending destruction.
-func (fs *Filesystem) deferDecRef(d *vfs.Dentry) {
+func (fs *Filesystem) deferDecRef(d *Dentry) {
fs.droppedDentriesMu.Lock()
fs.droppedDentries = append(fs.droppedDentries, d)
fs.droppedDentriesMu.Unlock()
@@ -159,6 +162,8 @@ const (
// to, and child dentries hold a reference on their parent.
//
// Must be initialized by Init prior to first use.
+//
+// +stateify savable
type Dentry struct {
DentryRefs
@@ -172,7 +177,11 @@ type Dentry struct {
name string
// dirMu protects children and the names of child Dentries.
- dirMu sync.Mutex
+ //
+ // Note that holding fs.mu for writing is not sufficient;
+ // revalidateChildLocked(), which is a very hot path, may modify children with
+ // fs.mu acquired for reading only.
+ dirMu sync.Mutex `state:"nosave"`
children map[string]*Dentry
inode Inode
@@ -239,24 +248,25 @@ func (d *Dentry) Watches() *vfs.Watches {
func (d *Dentry) OnZeroWatches(context.Context) {}
// InsertChild inserts child into the vfs dentry cache with the given name under
-// this dentry. This does not update the directory inode, so calling this on
-// its own isn't sufficient to insert a child into a directory. InsertChild
-// updates the link count on d if required.
+// this dentry. This does not update the directory inode, so calling this on its
+// own isn't sufficient to insert a child into a directory.
//
// Precondition: d must represent a directory inode.
func (d *Dentry) InsertChild(name string, child *Dentry) {
d.dirMu.Lock()
- d.insertChildLocked(name, child)
+ d.InsertChildLocked(name, child)
d.dirMu.Unlock()
}
-// insertChildLocked is equivalent to InsertChild, with additional
+// InsertChildLocked is equivalent to InsertChild, with additional
// preconditions.
//
-// Precondition: d.dirMu must be locked.
-func (d *Dentry) insertChildLocked(name string, child *Dentry) {
+// Preconditions:
+// * d must represent a directory inode.
+// * d.dirMu must be locked.
+func (d *Dentry) InsertChildLocked(name string, child *Dentry) {
if !d.isDir() {
- panic(fmt.Sprintf("InsertChild called on non-directory Dentry: %+v.", d))
+ panic(fmt.Sprintf("InsertChildLocked called on non-directory Dentry: %+v.", d))
}
d.IncRef() // DecRef in child's Dentry.destroy.
child.parent = d
@@ -267,6 +277,36 @@ func (d *Dentry) insertChildLocked(name string, child *Dentry) {
d.children[name] = child
}
+// RemoveChild removes child from the vfs dentry cache. This does not update the
+// directory inode or modify the inode to be unlinked. So calling this on its own
+// isn't sufficient to remove a child from a directory.
+//
+// Precondition: d must represent a directory inode.
+func (d *Dentry) RemoveChild(name string, child *Dentry) error {
+ d.dirMu.Lock()
+ defer d.dirMu.Unlock()
+ return d.RemoveChildLocked(name, child)
+}
+
+// RemoveChildLocked is equivalent to RemoveChild, with additional
+// preconditions.
+//
+// Precondition: d.dirMu must be locked.
+func (d *Dentry) RemoveChildLocked(name string, child *Dentry) error {
+ if !d.isDir() {
+ panic(fmt.Sprintf("RemoveChild called on non-directory Dentry: %+v.", d))
+ }
+ c, ok := d.children[name]
+ if !ok {
+ return syserror.ENOENT
+ }
+ if c != child {
+ panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! Child: %+v, vfs: %+v", c, child))
+ }
+ delete(d.children, name)
+ return nil
+}
+
// Inode returns the dentry's inode.
func (d *Dentry) Inode() Inode {
return d.inode
@@ -287,7 +327,6 @@ func (d *Dentry) Inode() Inode {
//
// - Checking that dentries passed to methods are of the appropriate file type.
// - Checking permissions.
-// - Updating link and reference counts.
//
// Specific responsibilities of implementations are documented below.
type Inode interface {
@@ -297,7 +336,8 @@ type Inode interface {
inodeRefs
// Methods related to node metadata. A generic implementation is provided by
- // InodeAttrs.
+ // InodeAttrs. Note that a concrete filesystem using kernfs is responsible for
+ // managing link counts.
inodeMetadata
// Method for inodes that represent symlink. InodeNotSymlink provides a
@@ -315,11 +355,11 @@ type Inode interface {
// Open creates a file description for the filesystem object represented by
// this inode. The returned file description should hold a reference on the
- // inode for its lifetime.
+ // dentry for its lifetime.
//
// Precondition: rp.Done(). vfsd.Impl() must be the kernfs Dentry containing
// the inode on which Open() is being called.
- Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error)
+ Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error)
// StatFS returns filesystem statistics for the client filesystem. This
// corresponds to vfs.FilesystemImpl.StatFSAt. If the client filesystem
@@ -369,30 +409,30 @@ type inodeDirectory interface {
HasChildren() bool
// NewFile creates a new regular file inode.
- NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error)
+ NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*Dentry, error)
// NewDir creates a new directory inode.
- NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error)
+ NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*Dentry, error)
// NewLink creates a new hardlink to a specified inode in this
// directory. Implementations should create a new kernfs Dentry pointing to
// target, and update target's link count.
- NewLink(ctx context.Context, name string, target Inode) (*vfs.Dentry, error)
+ NewLink(ctx context.Context, name string, target Inode) (*Dentry, error)
// NewSymlink creates a new symbolic link inode.
- NewSymlink(ctx context.Context, name, target string) (*vfs.Dentry, error)
+ NewSymlink(ctx context.Context, name, target string) (*Dentry, error)
// NewNode creates a new filesystem node for a mknod syscall.
- NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*vfs.Dentry, error)
+ NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*Dentry, error)
// Unlink removes a child dentry from this directory inode.
- Unlink(ctx context.Context, name string, child *vfs.Dentry) error
+ Unlink(ctx context.Context, name string, child *Dentry) error
// RmDir removes an empty child directory from this directory
// inode. Implementations must update the parent directory's link count,
// if required. Implementations are not responsible for checking that child
// is a directory, checking for an empty directory.
- RmDir(ctx context.Context, name string, child *vfs.Dentry) error
+ RmDir(ctx context.Context, name string, child *Dentry) error
// Rename is called on the source directory containing an inode being
// renamed. child should point to the resolved child in the source
@@ -400,7 +440,7 @@ type inodeDirectory interface {
// should return the replaced dentry or nil otherwise.
//
// Precondition: Caller must serialize concurrent calls to Rename.
- Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (replaced *vfs.Dentry, err error)
+ Rename(ctx context.Context, oldname, newname string, child, dstDir *Dentry) (replaced *Dentry, err error)
}
type inodeDynamicLookup interface {
@@ -418,14 +458,14 @@ type inodeDynamicLookup interface {
//
// Lookup returns the child with an extra reference and the caller owns this
// reference.
- Lookup(ctx context.Context, name string) (*vfs.Dentry, error)
+ Lookup(ctx context.Context, name string) (*Dentry, error)
// Valid should return true if this inode is still valid, or needs to
// be resolved again by a call to Lookup.
Valid(ctx context.Context) bool
// IterDirents is used to iterate over dynamically created entries. It invokes
- // cb on each entry in the directory represented by the FileDescription.
+ // cb on each entry in the directory represented by the Inode.
// 'offset' is the offset for the entire IterDirents call, which may include
// results from the caller (e.g. "." and ".."). 'relOffset' is the offset
// inside the entries returned by this IterDirents invocation. In other words,
@@ -437,7 +477,7 @@ type inodeDynamicLookup interface {
type inodeSymlink interface {
// Readlink returns the target of a symbolic link. If an inode is not a
// symlink, the implementation should return EINVAL.
- Readlink(ctx context.Context) (string, error)
+ Readlink(ctx context.Context, mnt *vfs.Mount) (string, error)
// Getlink returns the target of a symbolic link, as used by path
// resolution:
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index 675587c6b..e413242dc 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -52,7 +52,7 @@ func newTestSystem(t *testing.T, rootFn RootDentryFn) *testutil.System {
v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
- mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.GetFilesystemOptions{})
+ mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.MountOptions{})
if err != nil {
t.Fatalf("Failed to create testfs root mount: %v", err)
}
@@ -121,8 +121,8 @@ func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMod
return &dir.dentry
}
-func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{
+func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{
SeekEnd: kernfs.SeekEndStaticEntries,
})
if err != nil {
@@ -162,8 +162,8 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte
return &dir.dentry
}
-func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{
+func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{
SeekEnd: kernfs.SeekEndStaticEntries,
})
if err != nil {
@@ -176,38 +176,36 @@ func (d *dir) DecRef(context.Context) {
d.dirRefs.DecRef(d.Destroy)
}
-func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) {
+func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*kernfs.Dentry, error) {
creds := auth.CredentialsFromContext(ctx)
dir := d.fs.newDir(creds, opts.Mode, nil)
- dirVFSD := dir.VFSDentry()
- if err := d.OrderedChildren.Insert(name, dirVFSD); err != nil {
+ if err := d.OrderedChildren.Insert(name, dir); err != nil {
dir.DecRef(ctx)
return nil, err
}
d.IncLinks(1)
- return dirVFSD, nil
+ return dir, nil
}
-func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) {
+func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*kernfs.Dentry, error) {
creds := auth.CredentialsFromContext(ctx)
f := d.fs.newFile(creds, "")
- fVFSD := f.VFSDentry()
- if err := d.OrderedChildren.Insert(name, fVFSD); err != nil {
+ if err := d.OrderedChildren.Insert(name, f); err != nil {
f.DecRef(ctx)
return nil, err
}
- return fVFSD, nil
+ return f, nil
}
-func (*dir) NewLink(context.Context, string, kernfs.Inode) (*vfs.Dentry, error) {
+func (*dir) NewLink(context.Context, string, kernfs.Inode) (*kernfs.Dentry, error) {
return nil, syserror.EPERM
}
-func (*dir) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+func (*dir) NewSymlink(context.Context, string, string) (*kernfs.Dentry, error) {
return nil, syserror.EPERM
}
-func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*kernfs.Dentry, error) {
return nil, syserror.EPERM
}
diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go
index 64731a3e4..58a93eaac 100644
--- a/pkg/sentry/fsimpl/kernfs/symlink.go
+++ b/pkg/sentry/fsimpl/kernfs/symlink.go
@@ -24,6 +24,8 @@ import (
// StaticSymlink provides an Inode implementation for symlinks that point to
// a immutable target.
+//
+// +stateify savable
type StaticSymlink struct {
InodeAttrs
InodeNoopRefCount
@@ -51,8 +53,8 @@ func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor
s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777)
}
-// Readlink implements Inode.
-func (s *StaticSymlink) Readlink(_ context.Context) (string, error) {
+// Readlink implements Inode.Readlink.
+func (s *StaticSymlink) Readlink(_ context.Context, _ *vfs.Mount) (string, error) {
return s.target, nil
}
diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
new file mode 100644
index 000000000..ea7f073eb
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
@@ -0,0 +1,102 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// syntheticDirectory implements kernfs.Inode for a directory created by
+// MkdirAt(ForSyntheticMountpoint=true).
+//
+// +stateify savable
+type syntheticDirectory struct {
+ InodeAttrs
+ InodeNoStatFS
+ InodeNoopRefCount
+ InodeNoDynamicLookup
+ InodeNotSymlink
+ OrderedChildren
+
+ locks vfs.FileLocks
+}
+
+var _ Inode = (*syntheticDirectory)(nil)
+
+func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) *Dentry {
+ inode := &syntheticDirectory{}
+ inode.Init(creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm)
+ d := &Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (dir *syntheticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
+ if perm&^linux.PermissionsMask != 0 {
+ panic(fmt.Sprintf("perm contains non-permission bits: %#o", perm))
+ }
+ dir.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.S_IFDIR|perm)
+ dir.OrderedChildren.Init(OrderedChildrenOptions{
+ Writable: true,
+ })
+}
+
+// Open implements Inode.Open.
+func (dir *syntheticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := NewGenericDirectoryFD(rp.Mount(), d, &dir.OrderedChildren, &dir.locks, &opts, GenericDirectoryFDOptions{})
+ if err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// NewFile implements Inode.NewFile.
+func (dir *syntheticDirectory) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewDir implements Inode.NewDir.
+func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*Dentry, error) {
+ if !opts.ForSyntheticMountpoint {
+ return nil, syserror.EPERM
+ }
+ subdird := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask)
+ if err := dir.OrderedChildren.Insert(name, subdird); err != nil {
+ subdird.DecRef(ctx)
+ return nil, err
+ }
+ return subdird, nil
+}
+
+// NewLink implements Inode.NewLink.
+func (dir *syntheticDirectory) NewLink(ctx context.Context, name string, target Inode) (*Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewSymlink implements Inode.NewSymlink.
+func (dir *syntheticDirectory) NewSymlink(ctx context.Context, name, target string) (*Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewNode implements Inode.NewNode.
+func (dir *syntheticDirectory) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*Dentry, error) {
+ return nil, syserror.EPERM
+}
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index 13735eb05..73b126669 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -81,6 +82,8 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
Start: d.parent.upperVD,
Path: fspath.Parse(d.name),
}
+ // Used during copy-up of memory-mapped regular files.
+ var mmapOpts *memmap.MMapOpts
cleanupUndoCopyUp := func() {
var err error
if ftype == linux.S_IFDIR {
@@ -89,7 +92,11 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
err = vfsObj.UnlinkAt(ctx, d.fs.creds, &newpop)
}
if err != nil {
- ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after copy-up error: %v", err)
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after copy-up error: %v", err))
+ }
+ if d.upperVD.Ok() {
+ d.upperVD.DecRef(ctx)
+ d.upperVD = vfs.VirtualDentry{}
}
}
switch ftype {
@@ -132,6 +139,25 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
break
}
}
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ if d.wrappedMappable != nil {
+ // We may have memory mappings of the file on the lower layer.
+ // Switch to mapping the file on the upper layer instead.
+ mmapOpts = &memmap.MMapOpts{
+ Perms: usermem.ReadWrite,
+ MaxPerms: usermem.ReadWrite,
+ }
+ if err := newFD.ConfigureMMap(ctx, mmapOpts); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ if mmapOpts.MappingIdentity != nil {
+ mmapOpts.MappingIdentity.DecRef(ctx)
+ }
+ // Don't actually switch Mappables until the end of copy-up; see
+ // below for why.
+ }
if err := newFD.SetStat(ctx, vfs.SetStatOptions{
Stat: linux.Statx{
Mask: linux.STATX_UID | linux.STATX_GID,
@@ -234,7 +260,10 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
panic(fmt.Sprintf("unexpected file type %o", ftype))
}
- // TODO(gvisor.dev/issue/1199): copy up xattrs
+ if err := d.copyXattrsLocked(ctx); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
// Update the dentry's device and inode numbers (except for directories,
// for which these remain overlay-assigned).
@@ -246,14 +275,10 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
Mask: linux.STATX_INO,
})
if err != nil {
- d.upperVD.DecRef(ctx)
- d.upperVD = vfs.VirtualDentry{}
cleanupUndoCopyUp()
return err
}
if upperStat.Mask&linux.STATX_INO == 0 {
- d.upperVD.DecRef(ctx)
- d.upperVD = vfs.VirtualDentry{}
cleanupUndoCopyUp()
return syserror.EREMOTE
}
@@ -262,6 +287,135 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
atomic.StoreUint64(&d.ino, upperStat.Ino)
}
+ if mmapOpts != nil && mmapOpts.Mappable != nil {
+ // Note that if mmapOpts != nil, then d.mapsMu is locked for writing
+ // (from the S_IFREG path above).
+
+ // Propagate mappings of d to the new Mappable. Remember which mappings
+ // we added so we can remove them on failure.
+ upperMappable := mmapOpts.Mappable
+ allAdded := make(map[memmap.MappableRange]memmap.MappingsOfRange)
+ for seg := d.lowerMappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ added := make(memmap.MappingsOfRange)
+ for m := range seg.Value() {
+ if err := upperMappable.AddMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable); err != nil {
+ for m := range added {
+ upperMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable)
+ }
+ for mr, mappings := range allAdded {
+ for m := range mappings {
+ upperMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, mr.Start, m.Writable)
+ }
+ }
+ return err
+ }
+ added[m] = struct{}{}
+ }
+ allAdded[seg.Range()] = added
+ }
+
+ // Switch to the new Mappable. We do this at the end of copy-up
+ // because:
+ //
+ // - We need to switch Mappables (by changing d.wrappedMappable) before
+ // invalidating Translations from the old Mappable (to pick up
+ // Translations from the new one).
+ //
+ // - We need to lock d.dataMu while changing d.wrappedMappable, but
+ // must invalidate Translations with d.dataMu unlocked (due to lock
+ // ordering).
+ //
+ // - Consequently, once we unlock d.dataMu, other threads may
+ // immediately observe the new (copied-up) Mappable, which we want to
+ // delay until copy-up is guaranteed to succeed.
+ d.dataMu.Lock()
+ lowerMappable := d.wrappedMappable
+ d.wrappedMappable = upperMappable
+ d.dataMu.Unlock()
+ d.lowerMappings.InvalidateAll(memmap.InvalidateOpts{})
+
+ // Remove mappings from the old Mappable.
+ for seg := d.lowerMappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ for m := range seg.Value() {
+ lowerMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable)
+ }
+ }
+ d.lowerMappings.RemoveAll()
+ }
+
atomic.StoreUint32(&d.copiedUp, 1)
return nil
}
+
+// copyXattrsLocked copies a subset of lower's extended attributes to upper.
+// Attributes that configure an overlay in the lower are not copied up.
+//
+// Preconditions: d.copyMu must be locked for writing.
+func (d *dentry) copyXattrsLocked(ctx context.Context) error {
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ lowerPop := &vfs.PathOperation{Root: d.lowerVDs[0], Start: d.lowerVDs[0]}
+ upperPop := &vfs.PathOperation{Root: d.upperVD, Start: d.upperVD}
+
+ lowerXattrs, err := vfsObj.ListXattrAt(ctx, d.fs.creds, lowerPop, 0)
+ if err != nil {
+ if err == syserror.EOPNOTSUPP {
+ // There are no guarantees as to the contents of lowerXattrs.
+ return nil
+ }
+ ctx.Infof("failed to copy up xattrs because ListXattrAt failed: %v", err)
+ return err
+ }
+
+ for _, name := range lowerXattrs {
+ // Do not copy up overlay attributes.
+ if isOverlayXattr(name) {
+ continue
+ }
+
+ value, err := vfsObj.GetXattrAt(ctx, d.fs.creds, lowerPop, &vfs.GetXattrOptions{Name: name, Size: 0})
+ if err != nil {
+ ctx.Infof("failed to copy up xattrs because GetXattrAt failed: %v", err)
+ return err
+ }
+
+ if err := vfsObj.SetXattrAt(ctx, d.fs.creds, upperPop, &vfs.SetXattrOptions{Name: name, Value: value}); err != nil {
+ ctx.Infof("failed to copy up xattrs because SetXattrAt failed: %v", err)
+ return err
+ }
+ }
+ return nil
+}
+
+// copyUpDescendantsLocked ensures that all descendants of d are copied up.
+//
+// Preconditions:
+// * filesystem.renameMu must be locked.
+// * d.dirMu must be locked.
+// * d.isDir().
+func (d *dentry) copyUpDescendantsLocked(ctx context.Context, ds **[]*dentry) error {
+ dirents, err := d.getDirentsLocked(ctx)
+ if err != nil {
+ return err
+ }
+ for _, dirent := range dirents {
+ if dirent.Name == "." || dirent.Name == ".." {
+ continue
+ }
+ child, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds)
+ if err != nil {
+ return err
+ }
+ if err := child.copyUpLocked(ctx); err != nil {
+ return err
+ }
+ if child.isDir() {
+ child.dirMu.Lock()
+ err := child.copyUpDescendantsLocked(ctx, ds)
+ child.dirMu.Unlock()
+ if err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/overlay/directory.go b/pkg/sentry/fsimpl/overlay/directory.go
index b1b292e83..df4492346 100644
--- a/pkg/sentry/fsimpl/overlay/directory.go
+++ b/pkg/sentry/fsimpl/overlay/directory.go
@@ -100,12 +100,13 @@ func (d *dentry) collectWhiteoutsForRmdirLocked(ctx context.Context) (map[string
return whiteouts, readdirErr
}
+// +stateify savable
type directoryFD struct {
fileDescription
vfs.DirectoryFileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
- mu sync.Mutex
+ mu sync.Mutex `state:"nosave"`
off int64
dirents []vfs.Dirent
}
@@ -116,10 +117,12 @@ func (fd *directoryFD) Release(ctx context.Context) {
// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ d := fd.dentry()
+ defer d.InotifyWithParent(ctx, linux.IN_ACCESS, 0, vfs.PathEvent)
+
fd.mu.Lock()
defer fd.mu.Unlock()
- d := fd.dentry()
if fd.dirents == nil {
ds, err := d.getDirents(ctx)
if err != nil {
@@ -143,7 +146,14 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
defer d.fs.renameMu.RUnlock()
d.dirMu.Lock()
defer d.dirMu.Unlock()
+ return d.getDirentsLocked(ctx)
+}
+// Preconditions:
+// * filesystem.renameMu must be locked.
+// * d.dirMu must be locked.
+// * d.isDir().
+func (d *dentry) getDirentsLocked(ctx context.Context) ([]vfs.Dirent, error) {
if d.dirents != nil {
return d.dirents, nil
}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index e720bfb0b..bd11372d5 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -15,6 +15,8 @@
package overlay
import (
+ "fmt"
+ "strings"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -27,10 +29,15 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// _OVL_XATTR_PREFIX is an extended attribute key prefix to identify overlayfs
+// attributes.
+// Linux: fs/overlayfs/overlayfs.h:OVL_XATTR_PREFIX
+const _OVL_XATTR_PREFIX = linux.XATTR_TRUSTED_PREFIX + "overlay."
+
// _OVL_XATTR_OPAQUE is an extended attribute key whose value is set to "y" for
// opaque directories.
// Linux: fs/overlayfs/overlayfs.h:OVL_XATTR_OPAQUE
-const _OVL_XATTR_OPAQUE = linux.XATTR_TRUSTED_PREFIX + "overlay.opaque"
+const _OVL_XATTR_OPAQUE = _OVL_XATTR_PREFIX + "opaque"
func isWhiteout(stat *linux.Statx) bool {
return stat.Mode&linux.S_IFMT == linux.S_IFCHR && stat.RdevMajor == 0 && stat.RdevMinor == 0
@@ -205,6 +212,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
lookupErr = err
return false
}
+ defer childVD.DecRef(ctx)
mask := uint32(linux.STATX_TYPE)
if !existsOnAnyLayer {
@@ -243,6 +251,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
}
// Update child to include this layer.
+ childVD.IncRef()
if isUpper {
child.upperVD = childVD
child.copiedUp = 1
@@ -267,10 +276,10 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
// Directories are merged with directories from lower layers if they
// are not explicitly opaque.
- opaqueVal, err := vfsObj.GetxattrAt(ctx, fs.creds, &vfs.PathOperation{
+ opaqueVal, err := vfsObj.GetXattrAt(ctx, fs.creds, &vfs.PathOperation{
Root: childVD,
Start: childVD,
- }, &vfs.GetxattrOptions{
+ }, &vfs.GetXattrOptions{
Name: _OVL_XATTR_OPAQUE,
Size: 1,
})
@@ -490,7 +499,13 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if err := create(parent, name, childLayer == lookupLayerUpperWhiteout); err != nil {
return err
}
+
parent.dirents = nil
+ ev := linux.IN_CREATE
+ if dir {
+ ev |= linux.IN_ISDIR
+ }
+ parent.watches.Notify(ctx, name, uint32(ev), 0 /* cookie */, vfs.InodeEvent, false /* unlinked */)
return nil
}
@@ -504,7 +519,7 @@ func (fs *filesystem) createWhiteout(ctx context.Context, vfsObj *vfs.VirtualFil
func (fs *filesystem) cleanupRecreateWhiteout(ctx context.Context, vfsObj *vfs.VirtualFilesystem, pop *vfs.PathOperation) {
if err := fs.createWhiteout(ctx, vfsObj, pop); err != nil {
- ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to recreate whiteout after failed file creation: %v", err)
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate whiteout after failed file creation: %v", err))
}
}
@@ -616,12 +631,13 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
},
}); err != nil {
if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &newpop); cleanupErr != nil {
- ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after LinkAt metadata update failure: %v", cleanupErr)
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after LinkAt metadata update failure: %v", cleanupErr))
} else if haveUpperWhiteout {
fs.cleanupRecreateWhiteout(ctx, vfsObj, &newpop)
}
return err
}
+ old.watches.Notify(ctx, "", linux.IN_ATTRIB, 0 /* cookie */, vfs.InodeEvent, false /* unlinked */)
return nil
})
}
@@ -655,7 +671,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
},
}); err != nil {
if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil {
- ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr)
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr))
} else if haveUpperWhiteout {
fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
}
@@ -665,12 +681,12 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
// There may be directories on lower layers (previously hidden by
// the whiteout) that the new directory should not be merged with.
// Mark it opaque to prevent merging.
- if err := vfsObj.SetxattrAt(ctx, fs.creds, &pop, &vfs.SetxattrOptions{
+ if err := vfsObj.SetXattrAt(ctx, fs.creds, &pop, &vfs.SetXattrOptions{
Name: _OVL_XATTR_OPAQUE,
Value: "y",
}); err != nil {
if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil {
- ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt set-opaque failure: %v", cleanupErr)
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt set-opaque failure: %v", cleanupErr))
} else {
fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
}
@@ -714,7 +730,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
},
}); err != nil {
if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
- ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr)
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr))
} else if haveUpperWhiteout {
fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
}
@@ -743,6 +759,9 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
start := rp.Start().Impl().(*dentry)
if rp.Done() {
+ if mayCreate && rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
if mustCreate {
return nil, syserror.EEXIST
}
@@ -766,6 +785,10 @@ afterTrailingSymlink:
if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
+ // Reject attempts to open directories with O_CREAT.
+ if mayCreate && rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
// Determine whether or not we need to create a file.
parent.dirMu.Lock()
child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
@@ -774,12 +797,11 @@ afterTrailingSymlink:
parent.dirMu.Unlock()
return fd, err
}
+ parent.dirMu.Unlock()
if err != nil {
- parent.dirMu.Unlock()
return nil, err
}
// Open existing child or follow symlink.
- parent.dirMu.Unlock()
if mustCreate {
return nil, syserror.EEXIST
}
@@ -794,6 +816,9 @@ afterTrailingSymlink:
start = parent
goto afterTrailingSymlink
}
+ if rp.MustBeDir() && !child.isDir() {
+ return nil, syserror.ENOTDIR
+ }
if mayWrite {
if err := child.copyUpLocked(ctx); err != nil {
return nil, err
@@ -925,7 +950,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
},
}); err != nil {
if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
- ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr)
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr))
} else if haveUpperWhiteout {
fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
}
@@ -936,7 +961,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
child, err := fs.getChildLocked(ctx, parent, childName, ds)
if err != nil {
if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
- ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr)
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr))
} else if haveUpperWhiteout {
fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
}
@@ -957,6 +982,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
// just can't open it anymore for some reason.
return nil, err
}
+ parent.watches.Notify(ctx, childName, linux.IN_CREATE, 0 /* cookie */, vfs.PathEvent, false /* unlinked */)
return &fd.vfsfd, nil
}
@@ -1002,9 +1028,224 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
defer mnt.EndWrite()
- // FIXME(gvisor.dev/issue/1199): Actually implement rename.
- _ = newParent
- return syserror.EXDEV
+ oldParent := oldParentVD.Dentry().Impl().(*dentry)
+ creds := rp.Credentials()
+ if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ // We need a dentry representing the renamed file since, if it's a
+ // directory, we need to check for write permission on it.
+ oldParent.dirMu.Lock()
+ defer oldParent.dirMu.Unlock()
+ renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds)
+ if err != nil {
+ return err
+ }
+ if err := vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&oldParent.mode)), auth.KUID(atomic.LoadUint32(&renamed.uid))); err != nil {
+ return err
+ }
+ if renamed.isDir() {
+ if renamed == newParent || genericIsAncestorDentry(renamed, newParent) {
+ return syserror.EINVAL
+ }
+ if oldParent != newParent {
+ if err := renamed.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ }
+ } else {
+ if opts.MustBeDir || rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ }
+
+ if oldParent != newParent {
+ if err := newParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ newParent.dirMu.Lock()
+ defer newParent.dirMu.Unlock()
+ }
+ if newParent.vfsd.IsDead() {
+ return syserror.ENOENT
+ }
+ replacedLayer, err := fs.lookupLayerLocked(ctx, newParent, newName)
+ if err != nil {
+ return err
+ }
+ var (
+ replaced *dentry
+ replacedVFSD *vfs.Dentry
+ whiteouts map[string]bool
+ )
+ if replacedLayer.existsInOverlay() {
+ replaced, err = fs.getChildLocked(ctx, newParent, newName, &ds)
+ if err != nil {
+ return err
+ }
+ replacedVFSD = &replaced.vfsd
+ if replaced.isDir() {
+ if !renamed.isDir() {
+ return syserror.EISDIR
+ }
+ if genericIsAncestorDentry(replaced, renamed) {
+ return syserror.ENOTEMPTY
+ }
+ replaced.dirMu.Lock()
+ defer replaced.dirMu.Unlock()
+ whiteouts, err = replaced.collectWhiteoutsForRmdirLocked(ctx)
+ if err != nil {
+ return err
+ }
+ } else {
+ if rp.MustBeDir() || renamed.isDir() {
+ return syserror.ENOTDIR
+ }
+ }
+ }
+
+ if oldParent == newParent && oldName == newName {
+ return nil
+ }
+
+ // renamed and oldParent need to be copied-up before they're renamed on the
+ // upper layer.
+ if err := renamed.copyUpLocked(ctx); err != nil {
+ return err
+ }
+ // If renamed is a directory, all of its descendants need to be copied-up
+ // before they're renamed on the upper layer.
+ if renamed.isDir() {
+ if err := renamed.copyUpDescendantsLocked(ctx, &ds); err != nil {
+ return err
+ }
+ }
+ // newParent must be copied-up before it can contain renamed on the upper
+ // layer.
+ if err := newParent.copyUpLocked(ctx); err != nil {
+ return err
+ }
+ // If replaced exists, it doesn't need to be copied-up, but we do need to
+ // serialize with copy-up. Holding renameMu for writing should be
+ // sufficient, but out of an abundance of caution...
+ if replaced != nil {
+ replaced.copyMu.RLock()
+ defer replaced.copyMu.RUnlock()
+ }
+
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef(ctx)
+ if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil {
+ return err
+ }
+
+ newpop := vfs.PathOperation{
+ Root: newParent.upperVD,
+ Start: newParent.upperVD,
+ Path: fspath.Parse(newName),
+ }
+
+ needRecreateWhiteouts := false
+ cleanupRecreateWhiteouts := func() {
+ if !needRecreateWhiteouts {
+ return
+ }
+ for whiteoutName, whiteoutUpper := range whiteouts {
+ if !whiteoutUpper {
+ continue
+ }
+ if err := fs.createWhiteout(ctx, vfsObj, &vfs.PathOperation{
+ Root: replaced.upperVD,
+ Start: replaced.upperVD,
+ Path: fspath.Parse(whiteoutName),
+ }); err != nil && err != syserror.EEXIST {
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RenameAt failure: %v", err))
+ }
+ }
+ }
+ if renamed.isDir() {
+ if replacedLayer == lookupLayerUpper {
+ // Remove whiteouts from the directory being replaced.
+ needRecreateWhiteouts = true
+ for whiteoutName, whiteoutUpper := range whiteouts {
+ if !whiteoutUpper {
+ continue
+ }
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: replaced.upperVD,
+ Start: replaced.upperVD,
+ Path: fspath.Parse(whiteoutName),
+ }); err != nil {
+ cleanupRecreateWhiteouts()
+ vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
+ return err
+ }
+ }
+ } else if replacedLayer == lookupLayerUpperWhiteout {
+ // We need to explicitly remove the whiteout since otherwise rename
+ // on the upper layer will fail with ENOTDIR.
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &newpop); err != nil {
+ vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
+ return err
+ }
+ }
+ }
+
+ // Essentially no gVisor filesystem supports RENAME_WHITEOUT, so just do a
+ // regular rename and create the whiteout at the origin manually. Unlike
+ // RENAME_WHITEOUT, this isn't atomic with respect to other users of the
+ // upper filesystem, but this is already the case for virtually all other
+ // overlay filesystem operations too.
+ oldpop := vfs.PathOperation{
+ Root: oldParent.upperVD,
+ Start: oldParent.upperVD,
+ Path: fspath.Parse(oldName),
+ }
+ if err := vfsObj.RenameAt(ctx, creds, &oldpop, &newpop, &opts); err != nil {
+ cleanupRecreateWhiteouts()
+ vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
+ return err
+ }
+
+ // Below this point, the renamed dentry is now at newpop, and anything we
+ // replaced is gone forever. Commit the rename, update the overlay
+ // filesystem tree, and abandon attempts to recover from errors.
+ vfsObj.CommitRenameReplaceDentry(ctx, &renamed.vfsd, replacedVFSD)
+ delete(oldParent.children, oldName)
+ if replaced != nil {
+ ds = appendDentry(ds, replaced)
+ }
+ if oldParent != newParent {
+ newParent.dirents = nil
+ // This can't drop the last reference on oldParent because one is held
+ // by oldParentVD, so lock recursion is impossible.
+ oldParent.DecRef(ctx)
+ ds = appendDentry(ds, oldParent)
+ newParent.IncRef()
+ renamed.parent = newParent
+ }
+ renamed.name = newName
+ if newParent.children == nil {
+ newParent.children = make(map[string]*dentry)
+ }
+ newParent.children[newName] = renamed
+ oldParent.dirents = nil
+
+ if err := fs.createWhiteout(ctx, vfsObj, &oldpop); err != nil {
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to create whiteout at origin after RenameAt: %v", err))
+ }
+ if renamed.isDir() {
+ if err := vfsObj.SetXattrAt(ctx, fs.creds, &newpop, &vfs.SetXattrOptions{
+ Name: _OVL_XATTR_OPAQUE,
+ Value: "y",
+ }); err != nil {
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to make renamed directory opaque: %v", err))
+ }
+ }
+
+ vfs.InotifyRename(ctx, &renamed.watches, &oldParent.watches, &newParent.watches, oldName, newName, renamed.isDir())
+ return nil
}
// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
@@ -1083,7 +1324,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
Start: child.upperVD,
Path: fspath.Parse(whiteoutName),
}); err != nil && err != syserror.EEXIST {
- ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RmdirAt failure: %v", err)
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RmdirAt failure: %v", err))
}
}
}
@@ -1113,15 +1354,14 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
// Don't attempt to recover from this: the original directory is
// already gone, so any dentries representing it are invalid, and
// creating a new directory won't undo that.
- ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to create whiteout during RmdirAt: %v", err)
- vfsObj.AbortDeleteDentry(&child.vfsd)
- return err
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to create whiteout during RmdirAt: %v", err))
}
vfsObj.CommitDeleteDentry(ctx, &child.vfsd)
delete(parent.children, name)
ds = appendDentry(ds, child)
parent.dirents = nil
+ parent.watches.Notify(ctx, name, linux.IN_DELETE|linux.IN_ISDIR, 0 /* cookie */, vfs.InodeEvent, true /* unlinked */)
return nil
}
@@ -1129,12 +1369,25 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
+ fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
return err
}
+ err = d.setStatLocked(ctx, rp, opts)
+ fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ if err != nil {
+ return err
+ }
+
+ if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ d.InotifyWithParent(ctx, ev, 0 /* cookie */, vfs.InodeEvent)
+ }
+ return nil
+}
+// Precondition: d.fs.renameMu must be held for reading.
+func (d *dentry) setStatLocked(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
return err
@@ -1229,7 +1482,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
},
}); err != nil {
if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
- ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after SymlinkAt metadata update failure: %v", cleanupErr)
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after SymlinkAt metadata update failure: %v", cleanupErr))
} else if haveUpperWhiteout {
fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
}
@@ -1322,70 +1575,175 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
}
if err := fs.createWhiteout(ctx, vfsObj, &pop); err != nil {
- ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to create whiteout during UnlinkAt: %v", err)
- if child != nil {
- vfsObj.AbortDeleteDentry(&child.vfsd)
- }
- return err
+ panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to create whiteout during UnlinkAt: %v", err))
}
+ var cw *vfs.Watches
if child != nil {
vfsObj.CommitDeleteDentry(ctx, &child.vfsd)
delete(parent.children, name)
ds = appendDentry(ds, child)
+ cw = &child.watches
}
+ vfs.InotifyRemoveChild(ctx, cw, &parent.watches, name)
parent.dirents = nil
return nil
}
-// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
-func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+// isOverlayXattr returns whether the given extended attribute configures the
+// overlay.
+func isOverlayXattr(name string) bool {
+ return strings.HasPrefix(name, _OVL_XATTR_PREFIX)
+}
+
+// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt.
+func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
- _, err := fs.resolveLocked(ctx, rp, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return nil, err
}
- // TODO(gvisor.dev/issue/1199): Linux overlayfs actually allows listxattr,
- // but not any other xattr syscalls. For now we just reject all of them.
- return nil, syserror.ENOTSUP
+
+ return fs.listXattr(ctx, d, size)
}
-// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
-func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+func (fs *filesystem) listXattr(ctx context.Context, d *dentry, size uint64) ([]string, error) {
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ top := d.topLayer()
+ names, err := vfsObj.ListXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: top, Start: top}, size)
+ if err != nil {
+ return nil, err
+ }
+
+ // Filter out all overlay attributes.
+ n := 0
+ for _, name := range names {
+ if !isOverlayXattr(name) {
+ names[n] = name
+ n++
+ }
+ }
+ return names[:n], err
+}
+
+// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt.
+func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
- _, err := fs.resolveLocked(ctx, rp, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return "", err
}
- return "", syserror.ENOTSUP
+
+ return fs.getXattr(ctx, d, rp.Credentials(), &opts)
}
-// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
-func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+func (fs *filesystem) getXattr(ctx context.Context, d *dentry, creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) {
+ if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil {
+ return "", err
+ }
+
+ // Return EOPNOTSUPP when fetching an overlay attribute.
+ // See fs/overlayfs/super.c:ovl_own_xattr_get().
+ if isOverlayXattr(opts.Name) {
+ return "", syserror.EOPNOTSUPP
+ }
+
+ // Analogous to fs/overlayfs/super.c:ovl_other_xattr_get().
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ top := d.topLayer()
+ return vfsObj.GetXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: top, Start: top}, opts)
+}
+
+// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt.
+func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
- _, err := fs.resolveLocked(ctx, rp, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ return err
+ }
+
+ err = fs.setXattrLocked(ctx, d, rp.Mount(), rp.Credentials(), &opts)
+ fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
if err != nil {
return err
}
- return syserror.ENOTSUP
+
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0 /* cookie */, vfs.InodeEvent)
+ return nil
+}
+
+// Precondition: fs.renameMu must be locked.
+func (fs *filesystem) setXattrLocked(ctx context.Context, d *dentry, mnt *vfs.Mount, creds *auth.Credentials, opts *vfs.SetXattrOptions) error {
+ if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil {
+ return err
+ }
+
+ // Return EOPNOTSUPP when setting an overlay attribute.
+ // See fs/overlayfs/super.c:ovl_own_xattr_set().
+ if isOverlayXattr(opts.Name) {
+ return syserror.EOPNOTSUPP
+ }
+
+ // Analogous to fs/overlayfs/super.c:ovl_other_xattr_set().
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ if err := d.copyUpLocked(ctx); err != nil {
+ return err
+ }
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ return vfsObj.SetXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: d.upperVD, Start: d.upperVD}, opts)
}
-// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
-func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt.
+func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
- _, err := fs.resolveLocked(ctx, rp, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ return err
+ }
+
+ err = fs.removeXattrLocked(ctx, d, rp.Mount(), rp.Credentials(), name)
+ fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
if err != nil {
return err
}
- return syserror.ENOTSUP
+
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0 /* cookie */, vfs.InodeEvent)
+ return nil
+}
+
+// Precondition: fs.renameMu must be locked.
+func (fs *filesystem) removeXattrLocked(ctx context.Context, d *dentry, mnt *vfs.Mount, creds *auth.Credentials, name string) error {
+ if err := d.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil {
+ return err
+ }
+
+ // Like SetXattrAt, return EOPNOTSUPP when removing an overlay attribute.
+ // Linux passes the remove request to xattr_handler->set.
+ // See fs/xattr.c:vfs_removexattr().
+ if isOverlayXattr(name) {
+ return syserror.EOPNOTSUPP
+ }
+
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ if err := d.copyUpLocked(ctx); err != nil {
+ return err
+ }
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ return vfsObj.RemoveXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: d.upperVD, Start: d.upperVD}, name)
}
// PrependPath implements vfs.FilesystemImpl.PrependPath.
diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go
index 268b32537..853aee951 100644
--- a/pkg/sentry/fsimpl/overlay/non_directory.go
+++ b/pkg/sentry/fsimpl/overlay/non_directory.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -38,6 +39,7 @@ func (d *dentry) readlink(ctx context.Context) (string, error) {
})
}
+// +stateify savable
type nonDirectoryFD struct {
fileDescription
@@ -46,7 +48,7 @@ type nonDirectoryFD struct {
// fileDescription.dentry().upperVD. cachedFlags is the last known value of
// cachedFD.StatusFlags(). copiedUp, cachedFD, and cachedFlags are
// protected by mu.
- mu sync.Mutex
+ mu sync.Mutex `state:"nosave"`
copiedUp bool
cachedFD *vfs.FileDescription
cachedFlags uint32
@@ -146,6 +148,16 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux
return stat, nil
}
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (fd *nonDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ return err
+ }
+ defer wrappedFD.DecRef(ctx)
+ return wrappedFD.Allocate(ctx, mode, offset, length)
+}
+
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
d := fd.dentry()
@@ -172,6 +184,9 @@ func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions)
return err
}
d.updateAfterSetStatLocked(&opts)
+ if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
+ }
return nil
}
@@ -256,10 +271,105 @@ func (fd *nonDirectoryFD) Sync(ctx context.Context) error {
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
- wrappedFD, err := fd.getCurrentFD(ctx)
+ if err := fd.ensureMappable(ctx, opts); err != nil {
+ return err
+ }
+ return vfs.GenericConfigureMMap(&fd.vfsfd, fd.dentry(), opts)
+}
+
+// ensureMappable ensures that fd.dentry().wrappedMappable is not nil.
+func (fd *nonDirectoryFD) ensureMappable(ctx context.Context, opts *memmap.MMapOpts) error {
+ d := fd.dentry()
+
+ // Fast path if we already have a Mappable for the current top layer.
+ if atomic.LoadUint32(&d.isMappable) != 0 {
+ return nil
+ }
+
+ // Only permit mmap of regular files, since other file types may have
+ // unpredictable behavior when mmapped (e.g. /dev/zero).
+ if atomic.LoadUint32(&d.mode)&linux.S_IFMT != linux.S_IFREG {
+ return syserror.ENODEV
+ }
+
+ // Get a Mappable for the current top layer.
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ d.copyMu.RLock()
+ defer d.copyMu.RUnlock()
+ if atomic.LoadUint32(&d.isMappable) != 0 {
+ return nil
+ }
+ wrappedFD, err := fd.currentFDLocked(ctx)
if err != nil {
return err
}
- defer wrappedFD.DecRef(ctx)
- return wrappedFD.ConfigureMMap(ctx, opts)
+ if err := wrappedFD.ConfigureMMap(ctx, opts); err != nil {
+ return err
+ }
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.DecRef(ctx)
+ opts.MappingIdentity = nil
+ }
+ // Use this Mappable for all mappings of this layer (unless we raced with
+ // another call to ensureMappable).
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+ if d.wrappedMappable == nil {
+ d.wrappedMappable = opts.Mappable
+ atomic.StoreUint32(&d.isMappable, 1)
+ }
+ return nil
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ if err := d.wrappedMappable.AddMapping(ctx, ms, ar, offset, writable); err != nil {
+ return err
+ }
+ if !d.isCopiedUp() {
+ d.lowerMappings.AddMapping(ms, ar, offset, writable)
+ }
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ d.wrappedMappable.RemoveMapping(ctx, ms, ar, offset, writable)
+ if !d.isCopiedUp() {
+ d.lowerMappings.RemoveMapping(ms, ar, offset, writable)
+ }
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ if err := d.wrappedMappable.CopyMapping(ctx, ms, srcAR, dstAR, offset, writable); err != nil {
+ return err
+ }
+ if !d.isCopiedUp() {
+ d.lowerMappings.AddMapping(ms, dstAR, offset, writable)
+ }
+ return nil
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ d.dataMu.RLock()
+ defer d.dataMu.RUnlock()
+ return d.wrappedMappable.Translate(ctx, required, optional, at)
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (d *dentry) InvalidateUnsavable(ctx context.Context) error {
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ return d.wrappedMappable.InvalidateUnsavable(ctx)
}
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
index 00562667f..dfbccd05f 100644
--- a/pkg/sentry/fsimpl/overlay/overlay.go
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -22,6 +22,10 @@
// filesystem.renameMu
// dentry.dirMu
// dentry.copyMu
+// *** "memmap.Mappable locks" below this point
+// dentry.mapsMu
+// *** "memmap.Mappable locks taken by Translate" below this point
+// dentry.dataMu
//
// Locking dentry.dirMu in multiple dentries requires that parent dentries are
// locked before child dentries, and that filesystem.renameMu is locked to
@@ -37,6 +41,7 @@ import (
"gvisor.dev/gvisor/pkg/fspath"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -46,6 +51,8 @@ import (
const Name = "overlay"
// FilesystemType implements vfs.FilesystemType.
+//
+// +stateify savable
type FilesystemType struct{}
// Name implements vfs.FilesystemType.Name.
@@ -55,6 +62,8 @@ func (FilesystemType) Name() string {
// FilesystemOptions may be passed as vfs.GetFilesystemOptions.InternalData to
// FilesystemType.GetFilesystem.
+//
+// +stateify savable
type FilesystemOptions struct {
// Callers passing FilesystemOptions to
// overlay.FilesystemType.GetFilesystem() are responsible for ensuring that
@@ -71,6 +80,8 @@ type FilesystemOptions struct {
}
// filesystem implements vfs.FilesystemImpl.
+//
+// +stateify savable
type filesystem struct {
vfsfs vfs.Filesystem
@@ -93,7 +104,7 @@ type filesystem struct {
// renameMu synchronizes renaming with non-renaming operations in order to
// ensure consistent lock ordering between dentry.dirMu in different
// dentries.
- renameMu sync.RWMutex
+ renameMu sync.RWMutex `state:"nosave"`
// lastDirIno is the last inode number assigned to a directory. lastDirIno
// is accessed using atomic memory operations.
@@ -106,16 +117,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fsoptsRaw := opts.InternalData
fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions)
if fsoptsRaw != nil && !haveFSOpts {
- ctx.Warningf("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw)
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw)
return nil, nil, syserror.EINVAL
}
if haveFSOpts {
if len(fsopts.LowerRoots) == 0 {
- ctx.Warningf("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty")
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty")
return nil, nil, syserror.EINVAL
}
if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() {
- ctx.Warningf("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified")
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified")
return nil, nil, syserror.EINVAL
}
// We don't enforce a maximum number of lower layers when not
@@ -132,7 +143,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
delete(mopts, "workdir")
upperPath := fspath.Parse(upperPathname)
if !upperPath.Absolute {
- ctx.Warningf("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname)
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname)
return nil, nil, syserror.EINVAL
}
upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
@@ -144,13 +155,13 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
CheckSearchable: true,
})
if err != nil {
- ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err)
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err)
return nil, nil, err
}
defer upperRoot.DecRef(ctx)
privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */)
if err != nil {
- ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err)
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err)
return nil, nil, err
}
defer privateUpperRoot.DecRef(ctx)
@@ -158,24 +169,24 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
lowerPathnamesStr, ok := mopts["lowerdir"]
if !ok {
- ctx.Warningf("overlay.FilesystemType.GetFilesystem: missing required option lowerdir")
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: missing required option lowerdir")
return nil, nil, syserror.EINVAL
}
delete(mopts, "lowerdir")
lowerPathnames := strings.Split(lowerPathnamesStr, ":")
const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK
if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() {
- ctx.Warningf("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified")
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified")
return nil, nil, syserror.EINVAL
}
if len(lowerPathnames) > maxLowerLayers {
- ctx.Warningf("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers)
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers)
return nil, nil, syserror.EINVAL
}
for _, lowerPathname := range lowerPathnames {
lowerPath := fspath.Parse(lowerPathname)
if !lowerPath.Absolute {
- ctx.Warningf("overlay.FilesystemType.GetFilesystem: lowerdir %q must be absolute", lowerPathname)
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: lowerdir %q must be absolute", lowerPathname)
return nil, nil, syserror.EINVAL
}
lowerRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
@@ -187,13 +198,13 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
CheckSearchable: true,
})
if err != nil {
- ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err)
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err)
return nil, nil, err
}
defer lowerRoot.DecRef(ctx)
privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */)
if err != nil {
- ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err)
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err)
return nil, nil, err
}
defer privateLowerRoot.DecRef(ctx)
@@ -201,7 +212,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
}
if len(mopts) != 0 {
- ctx.Warningf("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts)
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts)
return nil, nil, syserror.EINVAL
}
@@ -274,7 +285,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, syserror.EREMOTE
}
if isWhiteout(&rootStat) {
- ctx.Warningf("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout")
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout")
root.destroyLocked(ctx)
fs.vfsfs.DecRef(ctx)
return nil, nil, syserror.EINVAL
@@ -362,6 +373,8 @@ func (fs *filesystem) newDirIno() uint64 {
}
// dentry implements vfs.DentryImpl.
+//
+// +stateify savable
type dentry struct {
vfsd vfs.Dentry
@@ -394,7 +407,7 @@ type dentry struct {
// and dirents (if not nil) is a cache of dirents as returned by
// directoryFDs representing this directory. children is protected by
// dirMu.
- dirMu sync.Mutex
+ dirMu sync.Mutex `state:"nosave"`
children map[string]*dentry
dirents []vfs.Dirent
@@ -404,7 +417,7 @@ type dentry struct {
// If !upperVD.Ok(), it can transition to a valid vfs.VirtualDentry (i.e.
// be copied up) with copyMu locked for writing; otherwise, it is
// immutable. lowerVDs is always immutable.
- copyMu sync.RWMutex
+ copyMu sync.RWMutex `state:"nosave"`
upperVD vfs.VirtualDentry
lowerVDs []vfs.VirtualDentry
@@ -419,7 +432,43 @@ type dentry struct {
devMinor uint32
ino uint64
+ // If this dentry represents a regular file, then:
+ //
+ // - mapsMu is used to synchronize between copy-up and memmap.Mappable
+ // methods on dentry preceding mm.MemoryManager.activeMu in the lock order.
+ //
+ // - dataMu is used to synchronize between copy-up and
+ // dentry.(memmap.Mappable).Translate.
+ //
+ // - lowerMappings tracks memory mappings of the file. lowerMappings is
+ // used to invalidate mappings of the lower layer when the file is copied
+ // up to ensure that they remain coherent with subsequent writes to the
+ // file. (Note that, as of this writing, Linux overlayfs does not do this;
+ // this feature is a gVisor extension.) lowerMappings is protected by
+ // mapsMu.
+ //
+ // - If this dentry is copied-up, then wrappedMappable is the Mappable
+ // obtained from a call to the current top layer's
+ // FileDescription.ConfigureMMap(). Once wrappedMappable becomes non-nil
+ // (from a call to nonDirectoryFD.ensureMappable()), it cannot become nil.
+ // wrappedMappable is protected by mapsMu and dataMu.
+ //
+ // - isMappable is non-zero iff wrappedMappable is non-nil. isMappable is
+ // accessed using atomic memory operations.
+ mapsMu sync.Mutex
+ lowerMappings memmap.MappingSet
+ dataMu sync.RWMutex
+ wrappedMappable memmap.Mappable
+ isMappable uint32
+
locks vfs.FileLocks
+
+ // watches is the set of inotify watches on the file repesented by this dentry.
+ //
+ // Note that hard links to the same file will not share the same set of
+ // watches, due to the fact that we do not have inode structures in this
+ // overlay implementation.
+ watches vfs.Watches
}
// newDentry creates a new dentry. The dentry initially has no references; it
@@ -479,6 +528,14 @@ func (d *dentry) checkDropLocked(ctx context.Context) {
if atomic.LoadInt64(&d.refs) != 0 {
return
}
+
+ // Make sure that we do not lose watches on dentries that have not been
+ // deleted. Note that overlayfs never calls VFS.InvalidateDentry(), so
+ // d.vfsd.IsDead() indicates that d was deleted.
+ if !d.vfsd.IsDead() && d.watches.Size() > 0 {
+ return
+ }
+
// Refs is still zero; destroy it.
d.destroyLocked(ctx)
return
@@ -507,6 +564,8 @@ func (d *dentry) destroyLocked(ctx context.Context) {
lowerVD.DecRef(ctx)
}
+ d.watches.HandleDeletion(ctx)
+
if d.parent != nil {
d.parent.dirMu.Lock()
if !d.vfsd.IsDead() {
@@ -525,19 +584,36 @@ func (d *dentry) destroyLocked(ctx context.Context) {
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
func (d *dentry) InotifyWithParent(ctx context.Context, events uint32, cookie uint32, et vfs.EventType) {
- // TODO(gvisor.dev/issue/1479): Implement inotify.
+ if d.isDir() {
+ events |= linux.IN_ISDIR
+ }
+
+ // overlayfs never calls VFS.InvalidateDentry(), so d.vfsd.IsDead() indicates
+ // that d was deleted.
+ deleted := d.vfsd.IsDead()
+
+ d.fs.renameMu.RLock()
+ // The ordering below is important, Linux always notifies the parent first.
+ if d.parent != nil {
+ d.parent.watches.Notify(ctx, d.name, events, cookie, et, deleted)
+ }
+ d.watches.Notify(ctx, "", events, cookie, et, deleted)
+ d.fs.renameMu.RUnlock()
}
// Watches implements vfs.DentryImpl.Watches.
func (d *dentry) Watches() *vfs.Watches {
- // TODO(gvisor.dev/issue/1479): Implement inotify.
- return nil
+ return &d.watches
}
// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches.
-//
-// TODO(gvisor.dev/issue/1479): Implement inotify.
-func (d *dentry) OnZeroWatches(context.Context) {}
+func (d *dentry) OnZeroWatches(ctx context.Context) {
+ if atomic.LoadInt64(&d.refs) == 0 {
+ d.fs.renameMu.Lock()
+ d.checkDropLocked(ctx)
+ d.fs.renameMu.Unlock()
+ }
+}
// iterLayers invokes yield on each layer comprising d, from top to bottom. If
// any call to yield returns false, iterLayer stops iteration.
@@ -570,6 +646,16 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes)
return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
}
+func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error {
+ mode := linux.FileMode(atomic.LoadUint32(&d.mode))
+ kuid := auth.KUID(atomic.LoadUint32(&d.uid))
+ kgid := auth.KGID(atomic.LoadUint32(&d.gid))
+ if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil {
+ return err
+ }
+ return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name)
+}
+
// statInternalMask is the set of stat fields that is set by
// dentry.statInternalTo().
const statInternalMask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
@@ -608,6 +694,8 @@ func (d *dentry) updateAfterSetStatLocked(opts *vfs.SetStatOptions) {
// fileDescription is embedded by overlay implementations of
// vfs.FileDescriptionImpl.
+//
+// +stateify savable
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -622,6 +710,48 @@ func (fd *fileDescription) dentry() *dentry {
return fd.vfsfd.Dentry().Impl().(*dentry)
}
+// ListXattr implements vfs.FileDescriptionImpl.ListXattr.
+func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) {
+ return fd.filesystem().listXattr(ctx, fd.dentry(), size)
+}
+
+// GetXattr implements vfs.FileDescriptionImpl.GetXattr.
+func (fd *fileDescription) GetXattr(ctx context.Context, opts vfs.GetXattrOptions) (string, error) {
+ return fd.filesystem().getXattr(ctx, fd.dentry(), auth.CredentialsFromContext(ctx), &opts)
+}
+
+// SetXattr implements vfs.FileDescriptionImpl.SetXattr.
+func (fd *fileDescription) SetXattr(ctx context.Context, opts vfs.SetXattrOptions) error {
+ fs := fd.filesystem()
+ d := fd.dentry()
+
+ fs.renameMu.RLock()
+ err := fs.setXattrLocked(ctx, d, fd.vfsfd.Mount(), auth.CredentialsFromContext(ctx), &opts)
+ fs.renameMu.RUnlock()
+ if err != nil {
+ return err
+ }
+
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr.
+func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error {
+ fs := fd.filesystem()
+ d := fd.dentry()
+
+ fs.renameMu.RLock()
+ err := fs.removeXattrLocked(ctx, d, fd.vfsfd.Mount(), auth.CredentialsFromContext(ctx), name)
+ fs.renameMu.RUnlock()
+ if err != nil {
+ return err
+ }
+
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index 7053ad6db..4e2da4810 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// +stateify savable
type filesystemType struct{}
// Name implements vfs.FilesystemType.Name.
@@ -43,6 +44,7 @@ func (filesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFile
panic("pipefs.filesystemType.GetFilesystem should never be called")
}
+// +stateify savable
type filesystem struct {
kernfs.Filesystem
@@ -76,6 +78,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
}
// inode implements kernfs.Inode.
+//
+// +stateify savable
type inode struct {
kernfs.InodeNotDirectory
kernfs.InodeNotSymlink
@@ -144,8 +148,8 @@ func (i *inode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.
}
// Open implements kernfs.Inode.Open.
-func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- return i.pipe.Open(ctx, rp.Mount(), vfsd, opts.Flags, &i.locks)
+func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ return i.pipe.Open(ctx, rp.Mount(), d.VFSDentry(), opts.Flags, &i.locks)
}
// StatFS implements kernfs.Inode.StatFS.
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index a45b44440..2e086e34c 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -100,6 +100,7 @@ go_library(
"//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip/header",
+ "//pkg/tcpip/network/ipv4",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index 03b5941b9..05d7948ea 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -41,6 +41,7 @@ func (FilesystemType) Name() string {
return Name
}
+// +stateify savable
type filesystem struct {
kernfs.Filesystem
@@ -84,6 +85,8 @@ func (fs *filesystem) Release(ctx context.Context) {
// dynamicInode is an overfitted interface for common Inodes with
// dynamicByteSource types used in procfs.
+//
+// +stateify savable
type dynamicInode interface {
kernfs.Inode
vfs.DynamicBytesSource
@@ -99,6 +102,7 @@ func (fs *filesystem) newDentry(creds *auth.Credentials, ino uint64, perm linux.
return d
}
+// +stateify savable
type staticFile struct {
kernfs.DynamicBytesFile
vfs.StaticData
@@ -118,10 +122,13 @@ func newStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64
// InternalData contains internal data passed in to the procfs mount via
// vfs.GetFilesystemOptions.InternalData.
+//
+// +stateify savable
type InternalData struct {
Cgroups map[string]string
}
+// +stateify savable
type implStatFS struct{}
// StatFS implements kernfs.Inode.StatFS.
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index d57d94dbc..47ecd941c 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -68,8 +68,8 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace,
return dentry
}
-// Lookup implements kernfs.inodeDynamicLookup.
-func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+// Lookup implements kernfs.inodeDynamicLookup.Lookup.
+func (i *subtasksInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
tid, err := strconv.ParseUint(name, 10, 32)
if err != nil {
return nil, syserror.ENOENT
@@ -82,12 +82,10 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, e
if subTask.ThreadGroup() != i.task.ThreadGroup() {
return nil, syserror.ENOENT
}
-
- subTaskDentry := i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers)
- return subTaskDentry.VFSDentry(), nil
+ return i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers), nil
}
-// IterDirents implements kernfs.inodeDynamicLookup.
+// IterDirents implements kernfs.inodeDynamicLookup.IterDirents.
func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
tasks := i.task.ThreadGroup().MemberIDs(i.pidns)
if len(tasks) == 0 {
@@ -118,6 +116,7 @@ func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallb
return offset, nil
}
+// +stateify savable
type subtasksFD struct {
kernfs.GenericDirectoryFD
@@ -155,21 +154,21 @@ func (fd *subtasksFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) erro
return fd.GenericDirectoryFD.SetStat(ctx, opts)
}
-// Open implements kernfs.Inode.
-func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+// Open implements kernfs.Inode.Open.
+func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &subtasksFD{task: i.task}
if err := fd.Init(&i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{
SeekEnd: kernfs.SeekEndZero,
}); err != nil {
return nil, err
}
- if err := fd.VFSFileDescription().Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := fd.VFSFileDescription().Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return fd.VFSFileDescription(), nil
}
-// Stat implements kernfs.Inode.
+// Stat implements kernfs.Inode.Stat.
func (i *subtasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts)
if err != nil {
@@ -181,12 +180,12 @@ func (i *subtasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs
return stat, nil
}
-// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed.
func (*subtasksInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
}
-// DecRef implements kernfs.Inode.
+// DecRef implements kernfs.Inode.DecRef.
func (i *subtasksInode) DecRef(context.Context) {
i.subtasksInodeRefs.DecRef(i.Destroy)
}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index dbdb5d929..a7cd6f57e 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -53,6 +53,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace
"auxv": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &auxvData{task: task}),
"cmdline": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}),
"comm": fs.newComm(task, fs.NextIno(), 0444),
+ "cwd": fs.newCwdSymlink(task, fs.NextIno()),
"environ": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}),
"exe": fs.newExeSymlink(task, fs.NextIno()),
"fd": fs.newFDDirInode(task),
@@ -106,9 +107,9 @@ func (i *taskInode) Valid(ctx context.Context) bool {
return i.task.ExitState() != kernel.TaskExitDead
}
-// Open implements kernfs.Inode.
-func (i *taskInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{
+// Open implements kernfs.Inode.Open.
+func (i *taskInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{
SeekEnd: kernfs.SeekEndZero,
})
if err != nil {
@@ -117,18 +118,20 @@ func (i *taskInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.D
return fd.VFSFileDescription(), nil
}
-// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed.
func (*taskInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
}
-// DecRef implements kernfs.Inode.
+// DecRef implements kernfs.Inode.DecRef.
func (i *taskInode) DecRef(context.Context) {
i.taskInodeRefs.DecRef(i.Destroy)
}
// taskOwnedInode implements kernfs.Inode and overrides inode owner with task
// effective user and group.
+//
+// +stateify savable
type taskOwnedInode struct {
kernfs.Inode
@@ -168,7 +171,7 @@ func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.
return d
}
-// Stat implements kernfs.Inode.
+// Stat implements kernfs.Inode.Stat.
func (i *taskOwnedInode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
stat, err := i.Inode.Stat(ctx, fs, opts)
if err != nil {
@@ -186,7 +189,7 @@ func (i *taskOwnedInode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.
return stat, nil
}
-// CheckPermissions implements kernfs.Inode.
+// CheckPermissions implements kernfs.Inode.CheckPermissions.
func (i *taskOwnedInode) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
mode := i.Mode()
uid, gid := i.getOwner(mode)
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index 3f0d78461..0866cea2b 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -51,6 +51,7 @@ func taskFDExists(ctx context.Context, t *kernel.Task, fd int32) bool {
return true
}
+// +stateify savable
type fdDir struct {
locks vfs.FileLocks
@@ -62,7 +63,7 @@ type fdDir struct {
produceSymlink bool
}
-// IterDirents implements kernfs.inodeDynamicLookup.
+// IterDirents implements kernfs.inodeDynamicLookup.IterDirents.
func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
var fds []int32
i.task.WithMuLocked(func(t *kernel.Task) {
@@ -86,14 +87,19 @@ func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, off
Name: strconv.FormatUint(uint64(fd), 10),
Type: typ,
Ino: i.fs.NextIno(),
- NextOff: offset + 1,
+ NextOff: int64(fd) + 3,
}
if err := cb.Handle(dirent); err != nil {
- return offset, err
+ // Getdents should iterate correctly despite mutation
+ // of fds, so we return the next fd to serialize plus
+ // 2 (which accounts for the "." and ".." tracked by
+ // kernfs) as the offset.
+ return int64(fd) + 2, err
}
- offset++
}
- return offset, nil
+ // We serialized them all. Next offset should be higher than last
+ // serialized fd.
+ return int64(fds[len(fds)-1]) + 3, nil
}
// fdDirInode represents the inode for /proc/[pid]/fd directory.
@@ -130,8 +136,8 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) *kernfs.Dentry {
return dentry
}
-// Lookup implements kernfs.inodeDynamicLookup.
-func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+// Lookup implements kernfs.inodeDynamicLookup.Lookup.
+func (i *fdDirInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
fdInt, err := strconv.ParseInt(name, 10, 32)
if err != nil {
return nil, syserror.ENOENT
@@ -140,13 +146,12 @@ func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro
if !taskFDExists(ctx, i.task, fd) {
return nil, syserror.ENOENT
}
- taskDentry := i.fs.newFDSymlink(i.task, fd, i.fs.NextIno())
- return taskDentry.VFSDentry(), nil
+ return i.fs.newFDSymlink(i.task, fd, i.fs.NextIno()), nil
}
-// Open implements kernfs.Inode.
-func (i *fdDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{
+// Open implements kernfs.Inode.Open.
+func (i *fdDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{
SeekEnd: kernfs.SeekEndZero,
})
if err != nil {
@@ -155,7 +160,7 @@ func (i *fdDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.
return fd.VFSFileDescription(), nil
}
-// CheckPermissions implements kernfs.Inode.
+// CheckPermissions implements kernfs.Inode.CheckPermissions.
//
// This is to match Linux, which uses a special permission handler to guarantee
// that a process can still access /proc/self/fd after it has executed
@@ -177,7 +182,7 @@ func (i *fdDirInode) CheckPermissions(ctx context.Context, creds *auth.Credentia
return err
}
-// DecRef implements kernfs.Inode.
+// DecRef implements kernfs.Inode.DecRef.
func (i *fdDirInode) DecRef(context.Context) {
i.fdDirInodeRefs.DecRef(i.Destroy)
}
@@ -209,7 +214,7 @@ func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) *ker
return d
}
-func (s *fdSymlink) Readlink(ctx context.Context) (string, error) {
+func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) {
file, _ := getTaskFD(s.task, s.fd)
if file == nil {
return "", syserror.ENOENT
@@ -264,8 +269,8 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) *kernfs.Dentry {
return dentry
}
-// Lookup implements kernfs.inodeDynamicLookup.
-func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+// Lookup implements kernfs.inodeDynamicLookup.Lookup.
+func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
fdInt, err := strconv.ParseInt(name, 10, 32)
if err != nil {
return nil, syserror.ENOENT
@@ -278,13 +283,12 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry,
task: i.task,
fd: fd,
}
- dentry := i.fs.newTaskOwnedFile(i.task, i.fs.NextIno(), 0444, data)
- return dentry.VFSDentry(), nil
+ return i.fs.newTaskOwnedFile(i.task, i.fs.NextIno(), 0444, data), nil
}
-// Open implements kernfs.Inode.
-func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{
+// Open implements kernfs.Inode.Open.
+func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{
SeekEnd: kernfs.SeekEndZero,
})
if err != nil {
@@ -293,7 +297,7 @@ func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *
return fd.VFSFileDescription(), nil
}
-// DecRef implements kernfs.Inode.
+// DecRef implements kernfs.Inode.DecRef.
func (i *fdInfoDirInode) DecRef(context.Context) {
i.fdInfoDirInodeRefs.DecRef(i.Destroy)
}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 356036b9b..3fbf081a6 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -543,7 +543,7 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
var vss, rss, data uint64
s.task.WithMuLocked(func(t *kernel.Task) {
if fdTable := t.FDTable(); fdTable != nil {
- fds = fdTable.Size()
+ fds = fdTable.CurrentMaxFDs()
}
if mm := t.MemoryManager(); mm != nil {
vss = mm.VirtualMemorySize()
@@ -667,20 +667,24 @@ func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) *kernfs.Dentr
return d
}
-// Readlink implements kernfs.Inode.
-func (s *exeSymlink) Readlink(ctx context.Context) (string, error) {
- if !kernel.ContextCanTrace(ctx, s.task, false) {
- return "", syserror.EACCES
- }
-
- // Pull out the executable for /proc/[pid]/exe.
- exec, err := s.executable()
+// Readlink implements kernfs.Inode.Readlink.
+func (s *exeSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) {
+ exec, _, err := s.Getlink(ctx, nil)
if err != nil {
return "", err
}
defer exec.DecRef(ctx)
- return exec.PathnameWithDeleted(ctx), nil
+ root := vfs.RootFromContext(ctx)
+ if !root.Ok() {
+ // It could have raced with process deletion.
+ return "", syserror.ESRCH
+ }
+ defer root.DecRef(ctx)
+
+ vfsObj := exec.Mount().Filesystem().VirtualFilesystem()
+ name, _ := vfsObj.PathnameWithDeleted(ctx, root, exec)
+ return name, nil
}
// Getlink implements kernfs.Inode.Getlink.
@@ -688,23 +692,12 @@ func (s *exeSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDent
if !kernel.ContextCanTrace(ctx, s.task, false) {
return vfs.VirtualDentry{}, "", syserror.EACCES
}
-
- exec, err := s.executable()
- if err != nil {
- return vfs.VirtualDentry{}, "", err
- }
- defer exec.DecRef(ctx)
-
- vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry()
- vd.IncRef()
- return vd, "", nil
-}
-
-func (s *exeSymlink) executable() (file fsbridge.File, err error) {
if err := checkTaskState(s.task); err != nil {
- return nil, err
+ return vfs.VirtualDentry{}, "", err
}
+ var err error
+ var exec fsbridge.File
s.task.WithMuLocked(func(t *kernel.Task) {
mm := t.MemoryManager()
if mm == nil {
@@ -715,12 +708,78 @@ func (s *exeSymlink) executable() (file fsbridge.File, err error) {
// The MemoryManager may be destroyed, in which case
// MemoryManager.destroy will simply set the executable to nil
// (with locks held).
- file = mm.Executable()
- if file == nil {
+ exec = mm.Executable()
+ if exec == nil {
err = syserror.ESRCH
}
})
- return
+ if err != nil {
+ return vfs.VirtualDentry{}, "", err
+ }
+ defer exec.DecRef(ctx)
+
+ vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry()
+ vd.IncRef()
+ return vd, "", nil
+}
+
+// cwdSymlink is an symlink for the /proc/[pid]/cwd file.
+//
+// +stateify savable
+type cwdSymlink struct {
+ implStatFS
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ task *kernel.Task
+}
+
+var _ kernfs.Inode = (*cwdSymlink)(nil)
+
+func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) *kernfs.Dentry {
+ inode := &cwdSymlink{task: task}
+ inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+// Readlink implements kernfs.Inode.Readlink.
+func (s *cwdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) {
+ cwd, _, err := s.Getlink(ctx, nil)
+ if err != nil {
+ return "", err
+ }
+ defer cwd.DecRef(ctx)
+
+ root := vfs.RootFromContext(ctx)
+ if !root.Ok() {
+ // It could have raced with process deletion.
+ return "", syserror.ESRCH
+ }
+ defer root.DecRef(ctx)
+
+ vfsObj := cwd.Mount().Filesystem().VirtualFilesystem()
+ name, _ := vfsObj.PathnameWithDeleted(ctx, root, cwd)
+ return name, nil
+}
+
+// Getlink implements kernfs.Inode.Getlink.
+func (s *cwdSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ if !kernel.ContextCanTrace(ctx, s.task, false) {
+ return vfs.VirtualDentry{}, "", syserror.EACCES
+ }
+ if err := checkTaskState(s.task); err != nil {
+ return vfs.VirtualDentry{}, "", err
+ }
+ cwd := s.task.FSContext().WorkingDirectoryVFS2()
+ if !cwd.Ok() {
+ // It could have raced with process deletion.
+ return vfs.VirtualDentry{}, "", syserror.ESRCH
+ }
+ return cwd, "", nil
}
// mountInfoData is used to implement /proc/[pid]/mountinfo.
@@ -785,6 +844,7 @@ func (i *mountsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
return nil
}
+// +stateify savable
type namespaceSymlink struct {
kernfs.StaticSymlink
@@ -807,15 +867,15 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri
return d
}
-// Readlink implements Inode.
-func (s *namespaceSymlink) Readlink(ctx context.Context) (string, error) {
+// Readlink implements kernfs.Inode.Readlink.
+func (s *namespaceSymlink) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
if err := checkTaskState(s.task); err != nil {
return "", err
}
- return s.StaticSymlink.Readlink(ctx)
+ return s.StaticSymlink.Readlink(ctx, mnt)
}
-// Getlink implements Inode.Getlink.
+// Getlink implements kernfs.Inode.Getlink.
func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) {
if err := checkTaskState(s.task); err != nil {
return vfs.VirtualDentry{}, "", err
@@ -832,6 +892,8 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir
// namespaceInode is a synthetic inode created to represent a namespace in
// /proc/[pid]/ns/*.
+//
+// +stateify savable
type namespaceInode struct {
implStatFS
kernfs.InodeAttrs
@@ -852,12 +914,12 @@ func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32
i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
}
-// Open implements Inode.Open.
-func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+// Open implements kernfs.Inode.Open.
+func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &namespaceFD{inode: i}
i.IncRef()
fd.LockFD.Init(&i.locks)
- if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return &fd.vfsfd, nil
@@ -865,6 +927,8 @@ func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *
// namespace FD is a synthetic file that represents a namespace in
// /proc/[pid]/ns/*.
+//
+// +stateify savable
type namespaceFD struct {
vfs.FileDescriptionDefaultImpl
vfs.LockFD
@@ -875,20 +939,20 @@ type namespaceFD struct {
var _ vfs.FileDescriptionImpl = (*namespaceFD)(nil)
-// Stat implements FileDescriptionImpl.
+// Stat implements vfs.FileDescriptionImpl.Stat.
func (fd *namespaceFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
return fd.inode.Stat(ctx, vfs, opts)
}
-// SetStat implements FileDescriptionImpl.
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
creds := auth.CredentialsFromContext(ctx)
return fd.inode.SetStat(ctx, vfs, creds, opts)
}
-// Release implements FileDescriptionImpl.
+// Release implements vfs.FileDescriptionImpl.Release.
func (fd *namespaceFD) Release(ctx context.Context) {
fd.inode.DecRef(ctx)
}
diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go
index 4e69782c7..e7f748655 100644
--- a/pkg/sentry/fsimpl/proc/task_net.go
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -616,6 +616,7 @@ type netSnmpData struct {
var _ dynamicInode = (*netSnmpData)(nil)
+// +stateify savable
type snmpLine struct {
prefix string
header string
@@ -660,7 +661,7 @@ func sprintSlice(s []uint64) string {
return r[1 : len(r)-1] // Remove "[]" introduced by fmt of slice.
}
-// Generate implements vfs.DynamicBytesSource.
+// Generate implements vfs.DynamicBytesSource.Generate.
func (d *netSnmpData) Generate(ctx context.Context, buf *bytes.Buffer) error {
types := []interface{}{
&inet.StatSNMPIP{},
@@ -709,7 +710,7 @@ type netRouteData struct {
var _ dynamicInode = (*netRouteData)(nil)
-// Generate implements vfs.DynamicBytesSource.
+// Generate implements vfs.DynamicBytesSource.Generate.
// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show.
func (d *netRouteData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "%-127s\n", "Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT")
@@ -773,7 +774,7 @@ type netStatData struct {
var _ dynamicInode = (*netStatData)(nil)
-// Generate implements vfs.DynamicBytesSource.
+// Generate implements vfs.DynamicBytesSource.Generate.
// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show.
func (d *netStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
buf.WriteString("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed " +
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index 3ea00ab87..d8f5dd509 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -52,8 +52,8 @@ type tasksInode struct {
// '/proc/self' and '/proc/thread-self' have custom directory offsets in
// Linux. So handle them outside of OrderedChildren.
- selfSymlink *vfs.Dentry
- threadSelfSymlink *vfs.Dentry
+ selfSymlink *kernfs.Dentry
+ threadSelfSymlink *kernfs.Dentry
// cgroupControllers is a map of controller name to directory in the
// cgroup hierarchy. These controllers are immutable and will be listed
@@ -81,8 +81,8 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace
inode := &tasksInode{
pidns: pidns,
fs: fs,
- selfSymlink: fs.newSelfSymlink(root, fs.NextIno(), pidns).VFSDentry(),
- threadSelfSymlink: fs.newThreadSelfSymlink(root, fs.NextIno(), pidns).VFSDentry(),
+ selfSymlink: fs.newSelfSymlink(root, fs.NextIno(), pidns),
+ threadSelfSymlink: fs.newThreadSelfSymlink(root, fs.NextIno(), pidns),
cgroupControllers: cgroupControllers,
}
inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
@@ -98,8 +98,8 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace
return inode, dentry
}
-// Lookup implements kernfs.inodeDynamicLookup.
-func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+// Lookup implements kernfs.inodeDynamicLookup.Lookup.
+func (i *tasksInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) {
// Try to lookup a corresponding task.
tid, err := strconv.ParseUint(name, 10, 64)
if err != nil {
@@ -118,11 +118,10 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro
return nil, syserror.ENOENT
}
- taskDentry := i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers)
- return taskDentry.VFSDentry(), nil
+ return i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers), nil
}
-// IterDirents implements kernfs.inodeDynamicLookup.
+// IterDirents implements kernfs.inodeDynamicLookup.IterDirents.
func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
// fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256
const FIRST_PROCESS_ENTRY = 256
@@ -200,9 +199,9 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback
return maxTaskID, nil
}
-// Open implements kernfs.Inode.
-func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{
+// Open implements kernfs.Inode.Open.
+func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{
SeekEnd: kernfs.SeekEndZero,
})
if err != nil {
@@ -229,7 +228,7 @@ func (i *tasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.St
return stat, nil
}
-// DecRef implements kernfs.Inode.
+// DecRef implements kernfs.Inode.DecRef.
func (i *tasksInode) DecRef(context.Context) {
i.tasksInodeRefs.DecRef(i.Destroy)
}
@@ -237,6 +236,8 @@ func (i *tasksInode) DecRef(context.Context) {
// staticFileSetStat implements a special static file that allows inode
// attributes to be set. This is to support /proc files that are readonly, but
// allow attributes to be set.
+//
+// +stateify savable
type staticFileSetStat struct {
dynamicBytesFileSetAttr
vfs.StaticData
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index 8c41729e4..f268c59b0 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// +stateify savable
type selfSymlink struct {
implStatFS
kernfs.InodeAttrs
@@ -51,7 +52,7 @@ func (fs *filesystem) newSelfSymlink(creds *auth.Credentials, ino uint64, pidns
return d
}
-func (s *selfSymlink) Readlink(ctx context.Context) (string, error) {
+func (s *selfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) {
t := kernel.TaskFromContext(ctx)
if t == nil {
// Who is reading this link?
@@ -64,16 +65,17 @@ func (s *selfSymlink) Readlink(ctx context.Context) (string, error) {
return strconv.FormatUint(uint64(tgid), 10), nil
}
-func (s *selfSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) {
- target, err := s.Readlink(ctx)
+func (s *selfSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ target, err := s.Readlink(ctx, mnt)
return vfs.VirtualDentry{}, target, err
}
-// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed.
func (*selfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
}
+// +stateify savable
type threadSelfSymlink struct {
implStatFS
kernfs.InodeAttrs
@@ -94,7 +96,7 @@ func (fs *filesystem) newThreadSelfSymlink(creds *auth.Credentials, ino uint64,
return d
}
-func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) {
+func (s *threadSelfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) {
t := kernel.TaskFromContext(ctx)
if t == nil {
// Who is reading this link?
@@ -108,12 +110,12 @@ func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) {
return fmt.Sprintf("%d/task/%d", tgid, tid), nil
}
-func (s *threadSelfSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) {
- target, err := s.Readlink(ctx)
+func (s *threadSelfSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ target, err := s.Readlink(ctx, mnt)
return vfs.VirtualDentry{}, target, err
}
-// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed.
func (*threadSelfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
}
@@ -121,16 +123,20 @@ func (*threadSelfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Creden
// dynamicBytesFileSetAttr implements a special file that allows inode
// attributes to be set. This is to support /proc files that are readonly, but
// allow attributes to be set.
+//
+// +stateify savable
type dynamicBytesFileSetAttr struct {
kernfs.DynamicBytesFile
}
-// SetStat implements Inode.SetStat.
+// SetStat implements kernfs.Inode.SetStat.
func (d *dynamicBytesFileSetAttr) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
return d.DynamicBytesFile.InodeAttrs.SetStat(ctx, fs, creds, opts)
}
// cpuStats contains the breakdown of CPU time for /proc/stat.
+//
+// +stateify savable
type cpuStats struct {
// user is time spent in userspace tasks with non-positive niceness.
user uint64
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 038a194c7..3312b0418 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -27,9 +27,11 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/usermem"
)
+// +stateify savable
type tcpMemDir int
const (
@@ -67,6 +69,7 @@ func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *ke
"tcp_rmem": fs.newDentry(root, fs.NextIno(), 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
"tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}),
"tcp_wmem": fs.newDentry(root, fs.NextIno(), 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
+ "ip_forward": fs.newDentry(root, fs.NextIno(), 0444, &ipForwarding{stack: stack}),
// The following files are simple stubs until they are implemented in
// netstack, most of these files are configuration related. We use the
@@ -174,7 +177,7 @@ type tcpSackData struct {
var _ vfs.WritableDynamicBytesSource = (*tcpSackData)(nil)
-// Generate implements vfs.DynamicBytesSource.
+// Generate implements vfs.DynamicBytesSource.Generate.
func (d *tcpSackData) Generate(ctx context.Context, buf *bytes.Buffer) error {
if d.enabled == nil {
sack, err := d.stack.TCPSACKEnabled()
@@ -232,7 +235,7 @@ type tcpRecoveryData struct {
var _ vfs.WritableDynamicBytesSource = (*tcpRecoveryData)(nil)
-// Generate implements vfs.DynamicBytesSource.
+// Generate implements vfs.DynamicBytesSource.Generate.
func (d *tcpRecoveryData) Generate(ctx context.Context, buf *bytes.Buffer) error {
recovery, err := d.stack.TCPRecovery()
if err != nil {
@@ -284,7 +287,7 @@ type tcpMemData struct {
var _ vfs.WritableDynamicBytesSource = (*tcpMemData)(nil)
-// Generate implements vfs.DynamicBytesSource.
+// Generate implements vfs.DynamicBytesSource.Generate.
func (d *tcpMemData) Generate(ctx context.Context, buf *bytes.Buffer) error {
d.mu.Lock()
defer d.mu.Unlock()
@@ -354,3 +357,63 @@ func (d *tcpMemData) writeSizeLocked(size inet.TCPBufferSize) error {
panic(fmt.Sprintf("unknown tcpMemFile type: %v", d.dir))
}
}
+
+// ipForwarding implements vfs.WritableDynamicBytesSource for
+// /proc/sys/net/ipv4/ip_forwarding.
+//
+// +stateify savable
+type ipForwarding struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack `state:"wait"`
+ enabled *bool
+}
+
+var _ vfs.WritableDynamicBytesSource = (*ipForwarding)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (ipf *ipForwarding) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if ipf.enabled == nil {
+ enabled := ipf.stack.Forwarding(ipv4.ProtocolNumber)
+ ipf.enabled = &enabled
+ }
+
+ val := "0\n"
+ if *ipf.enabled {
+ // Technically, this is not quite compatible with Linux. Linux stores these
+ // as an integer, so if you write "2" into tcp_sack, you should get 2 back.
+ // Tough luck.
+ val = "1\n"
+ }
+ buf.WriteString(val)
+
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ // No need to handle partial writes thus far.
+ return 0, syserror.EINVAL
+ }
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit input size so as not to impact performance if input size is large.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+ if ipf.enabled == nil {
+ ipf.enabled = new(bool)
+ }
+ *ipf.enabled = v != 0
+ if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil {
+ return 0, err
+ }
+ return n, nil
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go
index be54897bb..6cee22823 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys_test.go
@@ -20,8 +20,10 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/usermem"
)
func newIPv6TestStack() *inet.TestStack {
@@ -76,3 +78,72 @@ func TestIfinet6(t *testing.T) {
t.Errorf("Got n.contents() = %v, want = %v", got, want)
}
}
+
+// TestIPForwarding tests the implementation of
+// /proc/sys/net/ipv4/ip_forwarding
+func TestConfigureIPForwarding(t *testing.T) {
+ ctx := context.Background()
+ s := inet.NewTestStack()
+
+ var cases = []struct {
+ comment string
+ initial bool
+ str string
+ final bool
+ }{
+ {
+ comment: `Forwarding is disabled; write 1 and enable forwarding`,
+ initial: false,
+ str: "1",
+ final: true,
+ },
+ {
+ comment: `Forwarding is disabled; write 0 and disable forwarding`,
+ initial: false,
+ str: "0",
+ final: false,
+ },
+ {
+ comment: `Forwarding is enabled; write 1 and enable forwarding`,
+ initial: true,
+ str: "1",
+ final: true,
+ },
+ {
+ comment: `Forwarding is enabled; write 0 and disable forwarding`,
+ initial: true,
+ str: "0",
+ final: false,
+ },
+ {
+ comment: `Forwarding is disabled; write 2404 and enable forwarding`,
+ initial: false,
+ str: "2404",
+ final: true,
+ },
+ {
+ comment: `Forwarding is enabled; write 2404 and enable forwarding`,
+ initial: true,
+ str: "2404",
+ final: true,
+ },
+ }
+ for _, c := range cases {
+ t.Run(c.comment, func(t *testing.T) {
+ s.IPForwarding = c.initial
+
+ file := &ipForwarding{stack: s, enabled: &c.initial}
+
+ // Write the values.
+ src := usermem.BytesIOSequence([]byte(c.str))
+ if n, err := file.Write(ctx, src, 0); n != int64(len(c.str)) || err != nil {
+ t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str))
+ }
+
+ // Read the values from the stack and check them.
+ if got, want := s.IPForwarding, c.final; got != want {
+ t.Errorf("s.IPForwarding incorrect; got: %v, want: %v", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index d82b3d2f3..6975af5a7 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -67,6 +67,7 @@ var (
taskStaticFiles = map[string]testutil.DirentType{
"auxv": linux.DT_REG,
"cgroup": linux.DT_REG,
+ "cwd": linux.DT_LNK,
"cmdline": linux.DT_REG,
"comm": linux.DT_REG,
"environ": linux.DT_REG,
@@ -104,7 +105,7 @@ func setup(t *testing.T) *testutil.System {
AllowUserMount: true,
})
- mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.GetFilesystemOptions{})
+ mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.MountOptions{})
if err != nil {
t.Fatalf("NewMountNamespace(): %v", err)
}
diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go
index 6297e1df4..bf11b425a 100644
--- a/pkg/sentry/fsimpl/signalfd/signalfd.go
+++ b/pkg/sentry/fsimpl/signalfd/signalfd.go
@@ -26,7 +26,9 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-// SignalFileDescription implements FileDescriptionImpl for signal fds.
+// SignalFileDescription implements vfs.FileDescriptionImpl for signal fds.
+//
+// +stateify savable
type SignalFileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -43,7 +45,7 @@ type SignalFileDescription struct {
target *kernel.Task
// mu protects mask.
- mu sync.Mutex
+ mu sync.Mutex `state:"nosave"`
// mask is the signal mask. Protected by mu.
mask linux.SignalSet
@@ -83,7 +85,7 @@ func (sfd *SignalFileDescription) SetMask(mask linux.SignalSet) {
sfd.mask = mask
}
-// Read implements FileDescriptionImpl.Read.
+// Read implements vfs.FileDescriptionImpl.Read.
func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
// Attempt to dequeue relevant signals.
info, err := sfd.target.Sigtimedwait(sfd.Mask(), 0)
@@ -132,5 +134,5 @@ func (sfd *SignalFileDescription) EventUnregister(entry *waiter.Entry) {
sfd.target.SignalUnregister(entry)
}
-// Release implements FileDescriptionImpl.Release()
+// Release implements vfs.FileDescriptionImpl.Release.
func (sfd *SignalFileDescription) Release(context.Context) {}
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
index 94a998568..29e5371d6 100644
--- a/pkg/sentry/fsimpl/sockfs/sockfs.go
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -28,14 +28,16 @@ import (
)
// filesystemType implements vfs.FilesystemType.
+//
+// +stateify savable
type filesystemType struct{}
-// GetFilesystem implements FilesystemType.GetFilesystem.
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fsType filesystemType) GetFilesystem(_ context.Context, vfsObj *vfs.VirtualFilesystem, _ *auth.Credentials, _ string, _ vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
panic("sockfs.filesystemType.GetFilesystem should never be called")
}
-// Name implements FilesystemType.Name.
+// Name implements vfs.FilesystemType.Name.
//
// Note that registering sockfs is unnecessary, except for the fact that it
// will not show up under /proc/filesystems as a result. This is a very minor
@@ -44,6 +46,7 @@ func (filesystemType) Name() string {
return "sockfs"
}
+// +stateify savable
type filesystem struct {
kernfs.Filesystem
@@ -80,6 +83,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
}
// inode implements kernfs.Inode.
+//
+// +stateify savable
type inode struct {
kernfs.InodeAttrs
kernfs.InodeNoopRefCount
@@ -88,7 +93,7 @@ type inode struct {
}
// Open implements kernfs.Inode.Open.
-func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
return nil, syserror.ENXIO
}
diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go
index 73f3d3309..b75d70ae6 100644
--- a/pkg/sentry/fsimpl/sys/kcov.go
+++ b/pkg/sentry/fsimpl/sys/kcov.go
@@ -36,6 +36,8 @@ func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials)
}
// kcovInode implements kernfs.Inode.
+//
+// +stateify savable
type kcovInode struct {
kernfs.InodeAttrs
kernfs.InodeNoopRefCount
@@ -44,7 +46,7 @@ type kcovInode struct {
implStatFS
}
-func (i *kcovInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+func (i *kcovInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
k := kernel.KernelFromContext(ctx)
if k == nil {
panic("KernelFromContext returned nil")
@@ -54,7 +56,7 @@ func (i *kcovInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.D
kcov: k.NewKcov(),
}
- if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{
+ if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{
DenyPRead: true,
DenyPWrite: true,
}); err != nil {
@@ -63,6 +65,7 @@ func (i *kcovInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.D
return &fd.vfsfd, nil
}
+// +stateify savable
type kcovFD struct {
vfs.FileDescriptionDefaultImpl
vfs.NoLockFD
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index 39952d2d0..1568c581f 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -34,9 +34,13 @@ const Name = "sysfs"
const defaultSysDirMode = linux.FileMode(0755)
// FilesystemType implements vfs.FilesystemType.
+//
+// +stateify savable
type FilesystemType struct{}
// filesystem implements vfs.FilesystemImpl.
+//
+// +stateify savable
type filesystem struct {
kernfs.Filesystem
@@ -117,6 +121,8 @@ func (fs *filesystem) Release(ctx context.Context) {
}
// dir implements kernfs.Inode.
+//
+// +stateify savable
type dir struct {
dirRefs
kernfs.InodeAttrs
@@ -148,8 +154,8 @@ func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.Set
}
// Open implements kernfs.Inode.Open.
-func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{
+func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{
SeekEnd: kernfs.SeekEndStaticEntries,
})
if err != nil {
@@ -169,6 +175,8 @@ func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, err
}
// cpuFile implements kernfs.Inode.
+//
+// +stateify savable
type cpuFile struct {
implStatFS
kernfs.DynamicBytesFile
@@ -190,6 +198,7 @@ func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode li
return d
}
+// +stateify savable
type implStatFS struct{}
// StatFS implements kernfs.Inode.StatFS.
diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go
index 9fd38b295..0a0d914cc 100644
--- a/pkg/sentry/fsimpl/sys/sys_test.go
+++ b/pkg/sentry/fsimpl/sys/sys_test.go
@@ -38,7 +38,7 @@ func newTestSystem(t *testing.T) *testutil.System {
AllowUserMount: true,
})
- mns, err := k.VFS().NewMountNamespace(ctx, creds, "", sys.Name, &vfs.GetFilesystemOptions{})
+ mns, err := k.VFS().NewMountNamespace(ctx, creds, "", sys.Name, &vfs.MountOptions{})
if err != nil {
t.Fatalf("Failed to create new mount namespace: %v", err)
}
diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go
index 86beaa0a8..8853c8ad2 100644
--- a/pkg/sentry/fsimpl/timerfd/timerfd.go
+++ b/pkg/sentry/fsimpl/timerfd/timerfd.go
@@ -26,8 +26,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-// TimerFileDescription implements FileDescriptionImpl for timer fds. It also
+// TimerFileDescription implements vfs.FileDescriptionImpl for timer fds. It also
// implements ktime.TimerListener.
+//
+// +stateify savable
type TimerFileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -62,7 +64,7 @@ func New(ctx context.Context, vfsObj *vfs.VirtualFilesystem, clock ktime.Clock,
return &tfd.vfsfd, nil
}
-// Read implements FileDescriptionImpl.Read.
+// Read implements vfs.FileDescriptionImpl.Read.
func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
const sizeofUint64 = 8
if dst.NumBytes() < sizeofUint64 {
@@ -128,7 +130,7 @@ func (tfd *TimerFileDescription) ResumeTimer() {
tfd.timer.Resume()
}
-// Release implements FileDescriptionImpl.Release()
+// Release implements vfs.FileDescriptionImpl.Release.
func (tfd *TimerFileDescription) Release(context.Context) {
tfd.timer.Destroy()
}
diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
index e5a4218e8..5209a17af 100644
--- a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
@@ -182,7 +182,7 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) {
vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{})
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
@@ -376,7 +376,7 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) {
vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{})
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
diff --git a/pkg/sentry/fsimpl/tmpfs/device_file.go b/pkg/sentry/fsimpl/tmpfs/device_file.go
index ac54d420d..9129d35b7 100644
--- a/pkg/sentry/fsimpl/tmpfs/device_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/device_file.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
+// +stateify savable
type deviceFile struct {
inode inode
kind vfs.DeviceKind
diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index 070c75e68..e90669cf0 100644
--- a/pkg/sentry/fsimpl/tmpfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// +stateify savable
type directory struct {
// Since directories can't be hard-linked, each directory can only be
// associated with a single dentry, which we can store in the directory
@@ -44,7 +45,7 @@ type directory struct {
// (with inode == nil) that represent the iteration position of
// directoryFDs. childList is used to support directoryFD.IterDirents()
// efficiently. childList is protected by iterMu.
- iterMu sync.Mutex
+ iterMu sync.Mutex `state:"nosave"`
childList dentryList
}
@@ -86,6 +87,7 @@ func (dir *directory) mayDelete(creds *auth.Credentials, child *dentry) error {
return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&dir.inode.mode)), auth.KUID(atomic.LoadUint32(&child.inode.uid)))
}
+// +stateify savable
type directoryFD struct {
fileDescription
vfs.DirectoryFileDescriptionDefaultImpl
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index e0de04e05..e39cd305b 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -673,11 +673,11 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
fs.mu.RUnlock()
return err
}
- if err := d.inode.setStat(ctx, rp.Credentials(), &opts); err != nil {
- fs.mu.RUnlock()
+ err = d.inode.setStat(ctx, rp.Credentials(), &opts)
+ fs.mu.RUnlock()
+ if err != nil {
return err
}
- fs.mu.RUnlock()
if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
@@ -770,7 +770,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return nil
}
-// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt.
func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
@@ -792,59 +792,59 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
}
}
-// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
-func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt.
+func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
d, err := resolveLocked(ctx, rp)
if err != nil {
return nil, err
}
- return d.inode.listxattr(size)
+ return d.inode.listXattr(size)
}
-// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
-func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt.
+func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
d, err := resolveLocked(ctx, rp)
if err != nil {
return "", err
}
- return d.inode.getxattr(rp.Credentials(), &opts)
+ return d.inode.getXattr(rp.Credentials(), &opts)
}
-// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
-func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt.
+func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error {
fs.mu.RLock()
d, err := resolveLocked(ctx, rp)
if err != nil {
fs.mu.RUnlock()
return err
}
- if err := d.inode.setxattr(rp.Credentials(), &opts); err != nil {
- fs.mu.RUnlock()
+ err = d.inode.setXattr(rp.Credentials(), &opts)
+ fs.mu.RUnlock()
+ if err != nil {
return err
}
- fs.mu.RUnlock()
d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
return nil
}
-// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
-func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt.
+func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
fs.mu.RLock()
d, err := resolveLocked(ctx, rp)
if err != nil {
fs.mu.RUnlock()
return err
}
- if err := d.inode.removexattr(rp.Credentials(), name); err != nil {
- fs.mu.RUnlock()
+ err = d.inode.removeXattr(rp.Credentials(), name)
+ fs.mu.RUnlock()
+ if err != nil {
return err
}
- fs.mu.RUnlock()
d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
return nil
@@ -865,8 +865,16 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
}
if d.parent == nil {
if d.name != "" {
- // This must be an anonymous memfd file.
+ // This file must have been created by
+ // newUnlinkedRegularFileDescription(). In Linux,
+ // mm/shmem.c:__shmem_file_setup() =>
+ // fs/file_table.c:alloc_file_pseudo() sets the created
+ // dentry's dentry_operations to anon_ops, for which d_dname ==
+ // simple_dname. fs/d_path.c:simple_dname() defines the
+ // dentry's pathname to be its name, prefixed with "/" and
+ // suffixed with " (deleted)".
b.PrependComponent("/" + d.name)
+ b.AppendString(" (deleted)")
return vfs.PrependPathSyntheticError{}
}
return vfs.PrependPathAtNonMountRootError{}
diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
index 5b0471ff4..d772db9e9 100644
--- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go
+++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// +stateify savable
type namedPipe struct {
inode inode
diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
index ec2701d8b..be29a2363 100644
--- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
@@ -158,7 +158,7 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{})
if err != nil {
t.Fatalf("failed to create tmpfs root mount: %v", err)
}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index 0710b65db..a199eb33d 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -36,12 +36,18 @@ import (
)
// regularFile is a regular (=S_IFREG) tmpfs file.
+//
+// +stateify savable
type regularFile struct {
inode inode
// memFile is a platform.File used to allocate pages to this regularFile.
memFile *pgalloc.MemoryFile
+ // memoryUsageKind is the memory accounting category under which pages backing
+ // this regularFile's contents are accounted.
+ memoryUsageKind usage.MemoryKind
+
// mapsMu protects mappings.
mapsMu sync.Mutex `state:"nosave"`
@@ -62,7 +68,7 @@ type regularFile struct {
writableMappingPages uint64
// dataMu protects the fields below.
- dataMu sync.RWMutex
+ dataMu sync.RWMutex `state:"nosave"`
// data maps offsets into the file to offsets into memFile that store
// the file's data.
@@ -86,14 +92,75 @@ type regularFile struct {
func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
file := &regularFile{
- memFile: fs.memFile,
- seals: linux.F_SEAL_SEAL,
+ memFile: fs.memFile,
+ memoryUsageKind: usage.Tmpfs,
+ seals: linux.F_SEAL_SEAL,
}
file.inode.init(file, fs, kuid, kgid, linux.S_IFREG|mode)
file.inode.nlink = 1 // from parent directory
return &file.inode
}
+// newUnlinkedRegularFileDescription creates a regular file on the tmpfs
+// filesystem represented by mount and returns an FD representing that file.
+// The new file is not reachable by path traversal from any other file.
+//
+// newUnlinkedRegularFileDescription is analogous to Linux's
+// mm/shmem.c:__shmem_file_setup().
+//
+// Preconditions: mount must be a tmpfs mount.
+func newUnlinkedRegularFileDescription(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, name string) (*regularFileFD, error) {
+ fs, ok := mount.Filesystem().Impl().(*filesystem)
+ if !ok {
+ panic("tmpfs.newUnlinkedRegularFileDescription() called with non-tmpfs mount")
+ }
+
+ inode := fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, 0777)
+ d := fs.newDentry(inode)
+ defer d.DecRef(ctx)
+ d.name = name
+
+ fd := &regularFileFD{}
+ fd.Init(&inode.locks)
+ flags := uint32(linux.O_RDWR)
+ if err := fd.vfsfd.Init(fd, flags, mount, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return fd, nil
+}
+
+// NewZeroFile creates a new regular file and file description as for
+// mmap(MAP_SHARED | MAP_ANONYMOUS). The file has the given size and is
+// initially (implicitly) filled with zeroes.
+//
+// Preconditions: mount must be a tmpfs mount.
+func NewZeroFile(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, size uint64) (*vfs.FileDescription, error) {
+ // Compare mm/shmem.c:shmem_zero_setup().
+ fd, err := newUnlinkedRegularFileDescription(ctx, creds, mount, "dev/zero")
+ if err != nil {
+ return nil, err
+ }
+ rf := fd.inode().impl.(*regularFile)
+ rf.memoryUsageKind = usage.Anonymous
+ rf.size = size
+ return &fd.vfsfd, err
+}
+
+// NewMemfd creates a new regular file and file description as for
+// memfd_create.
+//
+// Preconditions: mount must be a tmpfs mount.
+func NewMemfd(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, allowSeals bool, name string) (*vfs.FileDescription, error) {
+ fd, err := newUnlinkedRegularFileDescription(ctx, creds, mount, name)
+ if err != nil {
+ return nil, err
+ }
+ if allowSeals {
+ fd.inode().impl.(*regularFile).seals = 0
+ }
+ return &fd.vfsfd, nil
+}
+
// truncate grows or shrinks the file to the given size. It returns true if the
// file size was updated.
func (rf *regularFile) truncate(newSize uint64) (bool, error) {
@@ -226,7 +293,7 @@ func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.
optional.End = pgend
}
- cerr := rf.data.Fill(ctx, required, optional, rf.memFile, usage.Tmpfs, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
+ cerr := rf.data.Fill(ctx, required, optional, rf.memFile, rf.memoryUsageKind, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
// Newly-allocated pages are zeroed, so we don't need to do anything.
return dsts.NumBytes(), nil
})
@@ -260,13 +327,14 @@ func (*regularFile) InvalidateUnsavable(context.Context) error {
return nil
}
+// +stateify savable
type regularFileFD struct {
fileDescription
// off is the file offset. off is accessed using atomic memory operations.
// offMu serializes operations that may mutate off.
off int64
- offMu sync.Mutex
+ offMu sync.Mutex `state:"nosave"`
}
// Release implements vfs.FileDescriptionImpl.Release.
@@ -575,7 +643,7 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
case gap.Ok():
// Allocate memory for the write.
gapMR := gap.Range().Intersect(pgMR)
- fr, err := rw.file.memFile.Allocate(gapMR.Length(), usage.Tmpfs)
+ fr, err := rw.file.memFile.Allocate(gapMR.Length(), rw.file.memoryUsageKind)
if err != nil {
retErr = err
goto exitLoop
diff --git a/pkg/sentry/fsimpl/tmpfs/socket_file.go b/pkg/sentry/fsimpl/tmpfs/socket_file.go
index 3ed650474..5699d5975 100644
--- a/pkg/sentry/fsimpl/tmpfs/socket_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/socket_file.go
@@ -21,6 +21,8 @@ import (
)
// socketFile is a socket (=S_IFSOCK) tmpfs file.
+//
+// +stateify savable
type socketFile struct {
inode inode
ep transport.BoundEndpoint
diff --git a/pkg/sentry/fsimpl/tmpfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go
index b0de5fabe..a102a2ee2 100644
--- a/pkg/sentry/fsimpl/tmpfs/symlink.go
+++ b/pkg/sentry/fsimpl/tmpfs/symlink.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
+// +stateify savable
type symlink struct {
inode inode
target string // immutable
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index c4cec4130..cefec8fde 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -51,9 +51,13 @@ import (
const Name = "tmpfs"
// FilesystemType implements vfs.FilesystemType.
+//
+// +stateify savable
type FilesystemType struct{}
// filesystem implements vfs.FilesystemImpl.
+//
+// +stateify savable
type filesystem struct {
vfsfs vfs.Filesystem
@@ -67,7 +71,7 @@ type filesystem struct {
devMinor uint32
// mu serializes changes to the Dentry tree.
- mu sync.RWMutex
+ mu sync.RWMutex `state:"nosave"`
nextInoMinusOne uint64 // accessed using atomic memory operations
}
@@ -78,6 +82,8 @@ func (FilesystemType) Name() string {
}
// FilesystemOpts is used to pass configuration data to tmpfs.
+//
+// +stateify savable
type FilesystemOpts struct {
// RootFileType is the FileType of the filesystem root. Valid values
// are: S_IFDIR, S_IFREG, and S_IFLNK. Defaults to S_IFDIR.
@@ -221,6 +227,8 @@ var globalStatfs = linux.Statfs{
}
// dentry implements vfs.DentryImpl.
+//
+// +stateify savable
type dentry struct {
vfsd vfs.Dentry
@@ -300,6 +308,8 @@ func (d *dentry) Watches() *vfs.Watches {
func (d *dentry) OnZeroWatches(context.Context) {}
// inode represents a filesystem object.
+//
+// +stateify savable
type inode struct {
// fs is the owning filesystem. fs is immutable.
fs *filesystem
@@ -316,12 +326,12 @@ type inode struct {
// Inode metadata. Writing multiple fields atomically requires holding
// mu, othewise atomic operations can be used.
- mu sync.Mutex
- mode uint32 // file type and mode
- nlink uint32 // protected by filesystem.mu instead of inode.mu
- uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
- gid uint32 // auth.KGID, but ...
- ino uint64 // immutable
+ mu sync.Mutex `state:"nosave"`
+ mode uint32 // file type and mode
+ nlink uint32 // protected by filesystem.mu instead of inode.mu
+ uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
+ gid uint32 // auth.KGID, but ...
+ ino uint64 // immutable
// Linux's tmpfs has no concept of btime.
atime int64 // nanoseconds
@@ -626,74 +636,50 @@ func (i *inode) touchCMtimeLocked() {
atomic.StoreInt64(&i.ctime, now)
}
-func (i *inode) listxattr(size uint64) ([]string, error) {
- return i.xattrs.Listxattr(size)
+func (i *inode) listXattr(size uint64) ([]string, error) {
+ return i.xattrs.ListXattr(size)
}
-func (i *inode) getxattr(creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) {
+func (i *inode) getXattr(creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) {
if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil {
return "", err
}
- return i.xattrs.Getxattr(opts)
+ return i.xattrs.GetXattr(opts)
}
-func (i *inode) setxattr(creds *auth.Credentials, opts *vfs.SetxattrOptions) error {
+func (i *inode) setXattr(creds *auth.Credentials, opts *vfs.SetXattrOptions) error {
if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil {
return err
}
- return i.xattrs.Setxattr(opts)
+ return i.xattrs.SetXattr(opts)
}
-func (i *inode) removexattr(creds *auth.Credentials, name string) error {
+func (i *inode) removeXattr(creds *auth.Credentials, name string) error {
if err := i.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil {
return err
}
- return i.xattrs.Removexattr(name)
+ return i.xattrs.RemoveXattr(name)
}
func (i *inode) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error {
- switch {
- case ats&vfs.MayRead == vfs.MayRead:
- if err := i.checkPermissions(creds, vfs.MayRead); err != nil {
- return err
- }
- case ats&vfs.MayWrite == vfs.MayWrite:
- if err := i.checkPermissions(creds, vfs.MayWrite); err != nil {
- return err
- }
- default:
- panic(fmt.Sprintf("checkXattrPermissions called with impossible AccessTypes: %v", ats))
+ // We currently only support extended attributes in the user.* and
+ // trusted.* namespaces. See b/148380782.
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) && !strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) {
+ return syserror.EOPNOTSUPP
}
-
- switch {
- case strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX):
- // The trusted.* namespace can only be accessed by privileged
- // users.
- if creds.HasCapability(linux.CAP_SYS_ADMIN) {
- return nil
- }
- if ats&vfs.MayWrite == vfs.MayWrite {
- return syserror.EPERM
- }
- return syserror.ENODATA
- case strings.HasPrefix(name, linux.XATTR_USER_PREFIX):
- // Extended attributes in the user.* namespace are only
- // supported for regular files and directories.
- filetype := linux.S_IFMT & atomic.LoadUint32(&i.mode)
- if filetype == linux.S_IFREG || filetype == linux.S_IFDIR {
- return nil
- }
- if ats&vfs.MayWrite == vfs.MayWrite {
- return syserror.EPERM
- }
- return syserror.ENODATA
-
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ kuid := auth.KUID(atomic.LoadUint32(&i.uid))
+ kgid := auth.KGID(atomic.LoadUint32(&i.gid))
+ if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil {
+ return err
}
- return syserror.EOPNOTSUPP
+ return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name)
}
// fileDescription is embedded by tmpfs implementations of
// vfs.FileDescriptionImpl.
+//
+// +stateify savable
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -738,20 +724,20 @@ func (fd *fileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
return globalStatfs, nil
}
-// Listxattr implements vfs.FileDescriptionImpl.Listxattr.
-func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
- return fd.inode().listxattr(size)
+// ListXattr implements vfs.FileDescriptionImpl.ListXattr.
+func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) {
+ return fd.inode().listXattr(size)
}
-// Getxattr implements vfs.FileDescriptionImpl.Getxattr.
-func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) {
- return fd.inode().getxattr(auth.CredentialsFromContext(ctx), &opts)
+// GetXattr implements vfs.FileDescriptionImpl.GetXattr.
+func (fd *fileDescription) GetXattr(ctx context.Context, opts vfs.GetXattrOptions) (string, error) {
+ return fd.inode().getXattr(auth.CredentialsFromContext(ctx), &opts)
}
-// Setxattr implements vfs.FileDescriptionImpl.Setxattr.
-func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error {
+// SetXattr implements vfs.FileDescriptionImpl.SetXattr.
+func (fd *fileDescription) SetXattr(ctx context.Context, opts vfs.SetXattrOptions) error {
d := fd.dentry()
- if err := d.inode.setxattr(auth.CredentialsFromContext(ctx), &opts); err != nil {
+ if err := d.inode.setXattr(auth.CredentialsFromContext(ctx), &opts); err != nil {
return err
}
@@ -760,10 +746,10 @@ func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOption
return nil
}
-// Removexattr implements vfs.FileDescriptionImpl.Removexattr.
-func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
+// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr.
+func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error {
d := fd.dentry()
- if err := d.inode.removexattr(auth.CredentialsFromContext(ctx), name); err != nil {
+ if err := d.inode.removeXattr(auth.CredentialsFromContext(ctx), name); err != nil {
return err
}
@@ -772,37 +758,6 @@ func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
return nil
}
-// NewMemfd creates a new tmpfs regular file and file description that can back
-// an anonymous fd created by memfd_create.
-func NewMemfd(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, allowSeals bool, name string) (*vfs.FileDescription, error) {
- fs, ok := mount.Filesystem().Impl().(*filesystem)
- if !ok {
- panic("NewMemfd() called with non-tmpfs mount")
- }
-
- // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd inodes are set up with
- // S_IRWXUGO.
- inode := fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, 0777)
- rf := inode.impl.(*regularFile)
- if allowSeals {
- rf.seals = 0
- }
-
- d := fs.newDentry(inode)
- defer d.DecRef(ctx)
- d.name = name
-
- // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd files are set up with
- // FMODE_READ | FMODE_WRITE.
- var fd regularFileFD
- fd.Init(&inode.locks)
- flags := uint32(linux.O_RDWR)
- if err := fd.vfsfd.Init(&fd, flags, mount, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
- return nil, err
- }
- return &fd.vfsfd, nil
-}
-
// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
index 6f3e3ae6f..99c8e3c0f 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
@@ -41,7 +41,7 @@ func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentr
vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{})
if err != nil {
return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err)
}
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
index 326c4ed90..0ca750281 100644
--- a/pkg/sentry/fsimpl/verity/BUILD
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
licenses(["notice"])
@@ -13,12 +13,35 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/marshal/primitive",
"//pkg/merkletree",
+ "//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "verity_test",
+ srcs = [
+ "verity_test.go",
+ ],
+ library = ":verity",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/contexttest",
+ "//pkg/sentry/vfs",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 0e17dbddc..a560b0797 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -19,6 +19,7 @@ import (
"fmt"
"io"
"strconv"
+ "strings"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -179,23 +180,19 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// corresponding Merkle tree file.
// This is the offset of the root hash for child in its parent's Merkle
// tree file.
- off, err := vfsObj.GetxattrAt(ctx, fs.creds, &vfs.PathOperation{
+ off, err := vfsObj.GetXattrAt(ctx, fs.creds, &vfs.PathOperation{
Root: child.lowerMerkleVD,
Start: child.lowerMerkleVD,
- }, &vfs.GetxattrOptions{
+ }, &vfs.GetXattrOptions{
Name: merkleOffsetInParentXattr,
- // Offset is a 32 bit integer.
- Size: sizeOfInt32,
+ Size: sizeOfStringInt32,
})
// The Merkle tree file for the child should have been created and
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
if err == syserror.ENOENT || err == syserror.ENODATA {
- if noCrashOnVerificationFailure {
- return nil, err
- }
- panic(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
}
if err != nil {
return nil, err
@@ -204,10 +201,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// unexpected modifications to the file system.
offset, err := strconv.Atoi(off)
if err != nil {
- if noCrashOnVerificationFailure {
- return nil, syserror.EINVAL
- }
- panic(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
}
// Open parent Merkle tree file to read and verify child's root hash.
@@ -221,10 +215,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// The parent Merkle tree file should have been created. If it's
// missing, it indicates an unexpected modification to the file system.
if err == syserror.ENOENT {
- if noCrashOnVerificationFailure {
- return nil, err
- }
- panic(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
}
if err != nil {
return nil, err
@@ -233,19 +224,16 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// dataSize is the size of raw data for the Merkle tree. For a file,
// dataSize is the size of the whole file. For a directory, dataSize is
// the size of all its children's root hashes.
- dataSize, err := parentMerkleFD.Getxattr(ctx, &vfs.GetxattrOptions{
+ dataSize, err := parentMerkleFD.GetXattr(ctx, &vfs.GetXattrOptions{
Name: merkleSizeXattr,
- Size: sizeOfInt32,
+ Size: sizeOfStringInt32,
})
// The Merkle tree file for the child should have been created and
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
if err == syserror.ENOENT || err == syserror.ENODATA {
- if noCrashOnVerificationFailure {
- return nil, err
- }
- panic(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
}
if err != nil {
return nil, err
@@ -255,10 +243,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// unexpected modifications to the file system.
parentSize, err := strconv.Atoi(dataSize)
if err != nil {
- if noCrashOnVerificationFailure {
- return nil, syserror.EINVAL
- }
- panic(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ return nil, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
fdReader := vfs.FileReadWriteSeeker{
@@ -270,11 +255,8 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// contain the root hash of the children in the parent Merkle tree when
// Verify returns with success.
var buf bytes.Buffer
- if err := merkletree.Verify(&buf, &fdReader, &fdReader, int64(parentSize), int64(offset), int64(merkletree.DigestSize()), parent.rootHash, true /* dataAndTreeInSameFile */); err != nil && err != io.EOF {
- if noCrashOnVerificationFailure {
- return nil, syserror.EIO
- }
- panic(fmt.Sprintf("Verification for %s failed: %v", childPath, err))
+ if _, err := merkletree.Verify(&buf, &fdReader, &fdReader, int64(parentSize), int64(offset), int64(merkletree.DigestSize()), parent.rootHash, true /* dataAndTreeInSameFile */); err != nil && err != io.EOF {
+ return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err))
}
// Cache child root hash when it's verified the first time.
@@ -370,10 +352,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// corresponding Merkle tree is found. This indicates an
// unexpected modification to the file system that
// removed/renamed the child.
- if noCrashOnVerificationFailure {
- return nil, childErr
- }
- panic(fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name))
+ return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name))
} else if childErr == nil && childMerkleErr == syserror.ENOENT {
// If in allowRuntimeEnable mode, and the Merkle tree file is
// not created yet, we create an empty Merkle tree file, so that
@@ -392,6 +371,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
Path: fspath.Parse(childMerkleFilename),
}, &vfs.OpenOptions{
Flags: linux.O_RDWR | linux.O_CREAT,
+ Mode: 0644,
})
if err != nil {
return nil, err
@@ -409,11 +389,16 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// If runtime enable is not allowed. This indicates an
// unexpected modification to the file system that
// removed/renamed the Merkle tree file.
- if noCrashOnVerificationFailure {
- return nil, childMerkleErr
- }
- panic(fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name))
+ return nil, alertIntegrityViolation(childMerkleErr, fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name))
}
+ } else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT {
+ // Both the child and the corresponding Merkle tree are missing.
+ // This could be an unexpected modification or due to incorrect
+ // parameter.
+ // TODO(b/167752508): Investigate possible ways to differentiate
+ // cases that both files are deleted from cases that they never
+ // exist in the file system.
+ return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Failed to find file %s", parentPath+"/"+name))
}
mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID)
@@ -572,8 +557,183 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
// OpenAt implements vfs.FilesystemImpl.OpenAt.
func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- //TODO(b/159261227): Implement OpenAt.
- return nil, nil
+ // Verity fs is read-only.
+ if opts.Flags&(linux.O_WRONLY|linux.O_CREAT) != 0 {
+ return nil, syserror.EROFS
+ }
+
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+
+ start := rp.Start().Impl().(*dentry)
+ if rp.Done() {
+ return start.openLocked(ctx, rp, &opts)
+ }
+
+afterTrailingSymlink:
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+
+ // Check for search permission in the parent directory.
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+
+ // Open existing child or follow symlink.
+ parent.dirMu.Lock()
+ child, err := fs.stepLocked(ctx, rp, parent, false /*mayFollowSymlinks*/, &ds)
+ parent.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ if child.isSymlink() && rp.ShouldFollowSymlink() {
+ target, err := child.readlink(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ start = parent
+ goto afterTrailingSymlink
+ }
+ return child.openLocked(ctx, rp, &opts)
+}
+
+// Preconditions: fs.renameMu must be locked.
+func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // Users should not open the Merkle tree files. Those are for verity fs
+ // use only.
+ if strings.Contains(d.name, merklePrefix) {
+ return nil, syserror.EPERM
+ }
+ ats := vfs.AccessTypesForOpenFlags(opts)
+ if err := d.checkPermissions(rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+
+ // Verity fs is read-only.
+ if ats&vfs.MayWrite != 0 {
+ return nil, syserror.EROFS
+ }
+
+ // Get the path to the target file. This is only used to provide path
+ // information in failure case.
+ path, err := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.lowerVD)
+ if err != nil {
+ return nil, err
+ }
+
+ // Open the file in the underlying file system.
+ lowerFD, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ }, opts)
+
+ // The file should exist, as we succeeded in finding its dentry. If it's
+ // missing, it indicates an unexpected modification to the file system.
+ if err != nil {
+ if err == syserror.ENOENT {
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("File %s expected but not found", path))
+ }
+ return nil, err
+ }
+
+ // lowerFD needs to be cleaned up if any error occurs. IncRef will be
+ // called if a verity FD is successfully created.
+ defer lowerFD.DecRef(ctx)
+
+ // Open the Merkle tree file corresponding to the current file/directory
+ // to be used later for verifying Read/Walk.
+ merkleReader, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerMerkleVD,
+ Start: d.lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+
+ // The Merkle tree file should exist, as we succeeded in finding its
+ // dentry. If it's missing, it indicates an unexpected modification to
+ // the file system.
+ if err != nil {
+ if err == syserror.ENOENT {
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path))
+ }
+ return nil, err
+ }
+
+ // merkleReader needs to be cleaned up if any error occurs. IncRef will
+ // be called if a verity FD is successfully created.
+ defer merkleReader.DecRef(ctx)
+
+ lowerFlags := lowerFD.StatusFlags()
+ lowerFDOpts := lowerFD.Options()
+ var merkleWriter *vfs.FileDescription
+ var parentMerkleWriter *vfs.FileDescription
+
+ // Only open the Merkle tree files for write if in allowRuntimeEnable
+ // mode.
+ if d.fs.allowRuntimeEnable {
+ merkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerMerkleVD,
+ Start: d.lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_WRONLY | linux.O_APPEND,
+ })
+ if err != nil {
+ if err == syserror.ENOENT {
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path))
+ }
+ return nil, err
+ }
+ // merkleWriter is cleaned up if any error occurs. IncRef will
+ // be called if a verity FD is created successfully.
+ defer merkleWriter.DecRef(ctx)
+
+ if d.parent != nil {
+ parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.parent.lowerMerkleVD,
+ Start: d.parent.lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_WRONLY | linux.O_APPEND,
+ })
+ if err != nil {
+ if err == syserror.ENOENT {
+ parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD)
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
+ }
+ return nil, err
+ }
+ // parentMerkleWriter is cleaned up if any error occurs. IncRef
+ // will be called if a verity FD is created successfully.
+ defer parentMerkleWriter.DecRef(ctx)
+ }
+ }
+
+ fd := &fileDescription{
+ d: d,
+ lowerFD: lowerFD,
+ merkleReader: merkleReader,
+ merkleWriter: merkleWriter,
+ parentMerkleWriter: parentMerkleWriter,
+ isDir: d.isDir(),
+ }
+
+ if err := fd.vfsfd.Init(fd, lowerFlags, rp.Mount(), &d.vfsd, &lowerFDOpts); err != nil {
+ return nil, err
+ }
+ lowerFD.IncRef()
+ merkleReader.IncRef()
+ if merkleWriter != nil {
+ merkleWriter.IncRef()
+ }
+ if parentMerkleWriter != nil {
+ parentMerkleWriter.IncRef()
+ }
+ return &fd.vfsfd, err
}
// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
@@ -649,7 +809,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return syserror.EROFS
}
-// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt.
func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
var ds *[]*dentry
fs.renameMu.RLock()
@@ -660,8 +820,8 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
return nil, syserror.ECONNREFUSED
}
-// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
-func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt.
+func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
@@ -670,14 +830,14 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
return nil, err
}
lowerVD := d.lowerVD
- return fs.vfsfs.VirtualFilesystem().ListxattrAt(ctx, d.fs.creds, &vfs.PathOperation{
+ return fs.vfsfs.VirtualFilesystem().ListXattrAt(ctx, d.fs.creds, &vfs.PathOperation{
Root: lowerVD,
Start: lowerVD,
}, size)
}
-// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
-func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt.
+func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
@@ -686,20 +846,20 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
return "", err
}
lowerVD := d.lowerVD
- return fs.vfsfs.VirtualFilesystem().GetxattrAt(ctx, d.fs.creds, &vfs.PathOperation{
+ return fs.vfsfs.VirtualFilesystem().GetXattrAt(ctx, d.fs.creds, &vfs.PathOperation{
Root: lowerVD,
Start: lowerVD,
}, &opts)
}
-// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
-func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt.
+func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error {
// Verity file system is read-only.
return syserror.EROFS
}
-// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
-func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt.
+func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
// Verity file system is read-only.
return syserror.EROFS
}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index eedb5f484..fc5eabbca 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -22,16 +22,23 @@
package verity
import (
+ "fmt"
+ "strconv"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+ "gvisor.dev/gvisor/pkg/merkletree"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Name is the default filesystem name.
@@ -50,8 +57,9 @@ const merkleOffsetInParentXattr = "user.merkle.offset"
// whole file. For a directory, it's the size of all its children's root hashes.
const merkleSizeXattr = "user.merkle.size"
-// sizeOfInt32 is the size in bytes for a 32 bit integer in extended attributes.
-const sizeOfInt32 = 4
+// sizeOfStringInt32 is the size for a 32 bit integer stored as string in
+// extended attributes. The maximum value of a 32 bit integer is 10 digits.
+const sizeOfStringInt32 = 10
// noCrashOnVerificationFailure indicates whether the sandbox should panic
// whenever verification fails. If true, an error is returned instead of
@@ -66,9 +74,13 @@ var noCrashOnVerificationFailure bool
var verityMu sync.RWMutex
// FilesystemType implements vfs.FilesystemType.
+//
+// +stateify savable
type FilesystemType struct{}
// filesystem implements vfs.FilesystemImpl.
+//
+// +stateify savable
type filesystem struct {
vfsfs vfs.Filesystem
@@ -93,11 +105,13 @@ type filesystem struct {
// renameMu synchronizes renaming with non-renaming operations in order
// to ensure consistent lock ordering between dentry.dirMu in different
// dentries.
- renameMu sync.RWMutex
+ renameMu sync.RWMutex `state:"nosave"`
}
// InternalFilesystemOptions may be passed as
// vfs.GetFilesystemOptions.InternalData to FilesystemType.GetFilesystem.
+//
+// +stateify savable
type InternalFilesystemOptions struct {
// RootMerkleFileName is the name of the verity root Merkle tree file.
RootMerkleFileName string
@@ -128,6 +142,16 @@ func (FilesystemType) Name() string {
return Name
}
+// alertIntegrityViolation alerts a violation of integrity, which usually means
+// unexpected modification to the file system is detected. In
+// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic.
+func alertIntegrityViolation(err error, msg string) error {
+ if noCrashOnVerificationFailure {
+ return err
+ }
+ panic(msg)
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
iopts, ok := opts.InternalData.(InternalFilesystemOptions)
@@ -141,6 +165,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// verity, and should not be exposed or connected.
mopts := &vfs.MountOptions{
GetFilesystemOptions: iopts.LowerGetFSOptions,
+ InternalMount: true,
}
mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mopts)
if err != nil {
@@ -197,15 +222,12 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, err
}
} else if err != nil {
- // Failed to get dentry for the root Merkle file. This indicates
- // an attack that removed/renamed the root Merkle file, or it's
- // never generated.
- if noCrashOnVerificationFailure {
- fs.vfsfs.DecRef(ctx)
- d.DecRef(ctx)
- return nil, nil, err
- }
- panic("Failed to find root Merkle file")
+ // Failed to get dentry for the root Merkle file. This
+ // indicates an unexpected modification that removed/renamed
+ // the root Merkle file, or it's never generated.
+ fs.vfsfs.DecRef(ctx)
+ d.DecRef(ctx)
+ return nil, nil, alertIntegrityViolation(err, "Failed to find root Merkle file")
}
d.lowerMerkleVD = lowerMerkleVD
@@ -243,6 +265,8 @@ func (fs *filesystem) Release(ctx context.Context) {
}
// dentry implements vfs.DentryImpl.
+//
+// +stateify savable
type dentry struct {
vfsd vfs.Dentry
@@ -269,7 +293,7 @@ type dentry struct {
// and dirents (if not nil) is a cache of dirents as returned by
// directoryFDs representing this directory. children is protected by
// dirMu.
- dirMu sync.Mutex
+ dirMu sync.Mutex `state:"nosave"`
children map[string]*dentry
// lowerVD is the VirtualDentry in the underlying file system.
@@ -413,6 +437,8 @@ func (d *dentry) readlink(ctx context.Context) (string, error) {
// FileDescription is a wrapper of the underlying lowerFD, with support to build
// Merkle trees through the Linux fs-verity API to verify contents read from
// lowerFD.
+//
+// +stateify savable
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -471,12 +497,247 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
return syserror.EPERM
}
+// generateMerkle generates a Merkle tree file for fd. If fd points to a file
+// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The root
+// hash of the generated Merkle tree and the data size is returned.
+// If fd points to a regular file, the data is the content of the file. If fd
+// points to a directory, the data is all root hahes of its children, written
+// to the Merkle tree file.
+func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, error) {
+ fdReader := vfs.FileReadWriteSeeker{
+ FD: fd.lowerFD,
+ Ctx: ctx,
+ }
+ merkleReader := vfs.FileReadWriteSeeker{
+ FD: fd.merkleReader,
+ Ctx: ctx,
+ }
+ merkleWriter := vfs.FileReadWriteSeeker{
+ FD: fd.merkleWriter,
+ Ctx: ctx,
+ }
+ var rootHash []byte
+ var dataSize uint64
+
+ switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT {
+ case linux.S_IFREG:
+ // For a regular file, generate a Merkle tree based on its
+ // content.
+ var err error
+ stat, err := fd.lowerFD.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ return nil, 0, err
+ }
+ dataSize = stat.Size
+
+ rootHash, err = merkletree.Generate(&fdReader, int64(dataSize), &merkleReader, &merkleWriter, false /* dataAndTreeInSameFile */)
+ if err != nil {
+ return nil, 0, err
+ }
+ case linux.S_IFDIR:
+ // For a directory, generate a Merkle tree based on the root
+ // hashes of its children that has already been written to the
+ // Merkle tree file.
+ merkleStat, err := fd.merkleReader.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ return nil, 0, err
+ }
+ dataSize = merkleStat.Size
+
+ rootHash, err = merkletree.Generate(&merkleReader, int64(dataSize), &merkleReader, &merkleWriter, true /* dataAndTreeInSameFile */)
+ if err != nil {
+ return nil, 0, err
+ }
+ default:
+ // TODO(b/167728857): Investigate whether and how we should
+ // enable other types of file.
+ return nil, 0, syserror.EINVAL
+ }
+ return rootHash, dataSize, nil
+}
+
+// enableVerity enables verity features on fd by generating a Merkle tree file
+// and stores its root hash in its parent directory's Merkle tree.
+func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (uintptr, error) {
+ if !fd.d.fs.allowRuntimeEnable {
+ return 0, syserror.EPERM
+ }
+
+ // Lock to prevent other threads performing enable or access the file
+ // while it's being enabled.
+ verityMu.Lock()
+ defer verityMu.Unlock()
+
+ // In allowRuntimeEnable mode, the underlying fd and read/write fd for
+ // the Merkle tree file should have all been initialized. For any file
+ // or directory other than the root, the parent Merkle tree file should
+ // have also been initialized.
+ if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) {
+ return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds")
+ }
+
+ rootHash, dataSize, err := fd.generateMerkle(ctx)
+ if err != nil {
+ return 0, err
+ }
+
+ if fd.parentMerkleWriter != nil {
+ stat, err := fd.parentMerkleWriter.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ return 0, err
+ }
+
+ // Write the root hash of fd to the parent directory's Merkle
+ // tree file, as it should be part of the parent Merkle tree
+ // data. parentMerkleWriter is open with O_APPEND, so it
+ // should write directly to the end of the file.
+ if _, err = fd.parentMerkleWriter.Write(ctx, usermem.BytesIOSequence(rootHash), vfs.WriteOptions{}); err != nil {
+ return 0, err
+ }
+
+ // Record the offset of the root hash of fd in parent directory's
+ // Merkle tree file.
+ if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{
+ Name: merkleOffsetInParentXattr,
+ Value: strconv.Itoa(int(stat.Size)),
+ }); err != nil {
+ return 0, err
+ }
+ }
+
+ // Record the size of the data being hashed for fd.
+ if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{
+ Name: merkleSizeXattr,
+ Value: strconv.Itoa(int(dataSize)),
+ }); err != nil {
+ return 0, err
+ }
+ fd.d.rootHash = append(fd.d.rootHash, rootHash...)
+ return 0, nil
+}
+
+// measureVerity returns the root hash of fd, saved in args[2].
+func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ var metadata linux.DigestMetadata
+
+ // If allowRuntimeEnable is true, an empty fd.d.rootHash indicates that
+ // verity is not enabled for the file. If allowRuntimeEnable is false,
+ // this is an integrity violation because all files should have verity
+ // enabled, in which case fd.d.rootHash should be set.
+ if len(fd.d.rootHash) == 0 {
+ if fd.d.fs.allowRuntimeEnable {
+ return 0, syserror.ENODATA
+ }
+ return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no root hash found")
+ }
+
+ // The first part of VerityDigest is the metadata.
+ if _, err := metadata.CopyIn(t, verityDigest); err != nil {
+ return 0, err
+ }
+ if metadata.DigestSize < uint16(len(fd.d.rootHash)) {
+ return 0, syserror.EOVERFLOW
+ }
+
+ // Populate the output digest size, since DigestSize is both input and
+ // output.
+ metadata.DigestSize = uint16(len(fd.d.rootHash))
+
+ // First copy the metadata.
+ if _, err := metadata.CopyOut(t, verityDigest); err != nil {
+ return 0, err
+ }
+
+ // Now copy the root hash bytes to the memory after metadata.
+ _, err := t.CopyOutBytes(usermem.Addr(uintptr(verityDigest)+linux.SizeOfDigestMetadata), fd.d.rootHash)
+ return 0, err
+}
+
+func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flags usermem.Addr) (uintptr, error) {
+ f := int32(0)
+
+ // All enabled files should store a root hash. This flag is not settable
+ // via FS_IOC_SETFLAGS.
+ if len(fd.d.rootHash) != 0 {
+ f |= linux.FS_VERITY_FL
+ }
+
+ t := kernel.TaskFromContext(ctx)
+ _, err := primitive.CopyInt32Out(t, flags, f)
+ return 0, err
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := args[1].Uint(); cmd {
+ case linux.FS_IOC_ENABLE_VERITY:
+ return fd.enableVerity(ctx, uio)
+ case linux.FS_IOC_MEASURE_VERITY:
+ return fd.measureVerity(ctx, uio, args[2].Pointer())
+ case linux.FS_IOC_GETFLAGS:
+ return fd.verityFlags(ctx, uio, args[2].Pointer())
+ default:
+ // TODO(b/169682228): Investigate which ioctl commands should
+ // be allowed.
+ return 0, syserror.ENOSYS
+ }
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ // No need to verify if the file is not enabled yet in
+ // allowRuntimeEnable mode.
+ if fd.d.fs.allowRuntimeEnable && len(fd.d.rootHash) == 0 {
+ return fd.lowerFD.PRead(ctx, dst, offset, opts)
+ }
+
+ // dataSize is the size of the whole file.
+ dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{
+ Name: merkleSizeXattr,
+ Size: sizeOfStringInt32,
+ })
+
+ // The Merkle tree file for the child should have been created and
+ // contains the expected xattrs. If the xattr does not exist, it
+ // indicates unexpected modifications to the file system.
+ if err == syserror.ENODATA {
+ return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ // The dataSize xattr should be an integer. If it's not, it indicates
+ // unexpected modifications to the file system.
+ size, err := strconv.Atoi(dataSize)
+ if err != nil {
+ return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
+ }
+
+ dataReader := vfs.FileReadWriteSeeker{
+ FD: fd.lowerFD,
+ Ctx: ctx,
+ }
+
+ merkleReader := vfs.FileReadWriteSeeker{
+ FD: fd.merkleReader,
+ Ctx: ctx,
+ }
+
+ n, err := merkletree.Verify(dst.Writer(ctx), &dataReader, &merkleReader, int64(size), offset, dst.NumBytes(), fd.d.rootHash, false /* dataAndTreeInSameFile */)
+ if err != nil {
+ return 0, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Verification failed: %v", err))
+ }
+ return n, err
+}
+
// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
- return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+ return fd.lowerFD.LockPOSIX(ctx, uid, t, start, length, whence, block)
}
// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
- return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+ return fd.lowerFD.UnlockPOSIX(ctx, uid, start, length, whence)
}
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
new file mode 100644
index 000000000..9b5641092
--- /dev/null
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -0,0 +1,349 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package verity
+
+import (
+ "fmt"
+ "io"
+ "math/rand"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// rootMerkleFilename is the name of the root Merkle tree file.
+const rootMerkleFilename = "root.verity"
+
+// maxDataSize is the maximum data size written to the file for test.
+const maxDataSize = 100000
+
+// newVerityRoot creates a new verity mount, and returns the root. The
+// underlying file system is tmpfs. If the error is not nil, then cleanup
+// should be called when the root is no longer needed.
+func newVerityRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentry, func(), error) {
+ rand.Seed(time.Now().UnixNano())
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err)
+ }
+
+ vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+
+ mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ InternalData: InternalFilesystemOptions{
+ RootMerkleFileName: rootMerkleFilename,
+ LowerName: "tmpfs",
+ AllowRuntimeEnable: true,
+ NoCrashOnVerificationFailure: true,
+ },
+ },
+ })
+ if err != nil {
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create verity root mount: %v", err)
+ }
+ root := mntns.Root()
+ return vfsObj, root, func() {
+ root.DecRef(ctx)
+ mntns.DecRef(ctx)
+ }, nil
+}
+
+// newFileFD creates a new file in the verity mount, and returns the FD. The FD
+// points to a file that has random data generated.
+func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) {
+ creds := auth.CredentialsFromContext(ctx)
+ lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+
+ // Create the file in the underlying file system.
+ lowerFD, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(filePath),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
+ Mode: linux.ModeRegular | mode,
+ })
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // Generate random data to be written to the file.
+ dataSize := rand.Intn(maxDataSize) + 1
+ data := make([]byte, dataSize)
+ rand.Read(data)
+
+ // Write directly to the underlying FD, since verity FD is read-only.
+ n, err := lowerFD.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ return nil, 0, err
+ }
+
+ if n != int64(len(data)) {
+ return nil, 0, fmt.Errorf("lowerFD.Write got write length %d, want %d", n, len(data))
+ }
+
+ lowerFD.DecRef(ctx)
+
+ // Now open the verity file descriptor.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filePath),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular | mode,
+ })
+ return fd, dataSize, err
+}
+
+// corruptRandomBit randomly flips a bit in the file represented by fd.
+func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error {
+ // Flip a random bit in the underlying file.
+ randomPos := int64(rand.Intn(size))
+ byteToModify := make([]byte, 1)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil {
+ return fmt.Errorf("lowerFD.PRead failed: %v", err)
+ }
+ byteToModify[0] ^= 1
+ if _, err := fd.PWrite(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.WriteOptions{}); err != nil {
+ return fmt.Errorf("lowerFD.PWrite failed: %v", err)
+ }
+ return nil
+}
+
+// TestOpen ensures that when a file is created, the corresponding Merkle tree
+// file and the root Merkle tree file exist.
+func TestOpen(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, cleanup, err := newVerityRoot(ctx)
+ if err != nil {
+ t.Fatalf("Failed to create new verity root: %v", err)
+ }
+ defer cleanup()
+
+ filename := "verity-test-file"
+ if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil {
+ t.Fatalf("Failed to create new file fd: %v", err)
+ }
+
+ // Ensure that the corresponding Merkle tree file is created.
+ lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(merklePrefix + filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }); err != nil {
+ t.Errorf("Failed to open Merkle tree file %s: %v", merklePrefix+filename, err)
+ }
+
+ // Ensure the root merkle tree file is created.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(merklePrefix + rootMerkleFilename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }); err != nil {
+ t.Errorf("Failed to open root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err)
+ }
+}
+
+// TestUntouchedFileSucceeds ensures that read from an untouched verity file
+// succeeds after enabling verity for it.
+func TestReadUntouchedFileSucceeds(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, cleanup, err := newVerityRoot(ctx)
+ if err != nil {
+ t.Fatalf("Failed to create new verity root: %v", err)
+ }
+ defer cleanup()
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("Failed to create new file fd: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl failed: %v", err)
+ }
+
+ buf := make([]byte, size)
+ n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead failed: %v", err)
+ }
+
+ if n != int64(size) {
+ t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ }
+}
+
+// TestReopenUntouchedFileSucceeds ensures that reopen an untouched verity file
+// succeeds after enabling verity for it.
+func TestReopenUntouchedFileSucceeds(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, cleanup, err := newVerityRoot(ctx)
+ if err != nil {
+ t.Fatalf("Failed to create new verity root: %v", err)
+ }
+ defer cleanup()
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("Failed to create new file fd: %v", err)
+ }
+
+ // Enable verity on the file and confirms a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl failed: %v", err)
+ }
+
+ // Ensure reopening the verity enabled file succeeds.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err != nil {
+ t.Errorf("reopen enabled file failed: %v", err)
+ }
+}
+
+// TestModifiedFileFails ensures that read from a modified verity file fails.
+func TestModifiedFileFails(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, cleanup, err := newVerityRoot(ctx)
+ if err != nil {
+ t.Fatalf("Failed to create new verity root: %v", err)
+ }
+ defer cleanup()
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("Failed to create new file fd: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl failed: %v", err)
+ }
+
+ // Open a new lowerFD that's read/writable.
+ lowerVD := fd.Impl().(*fileDescription).d.lowerVD
+
+ lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerVD,
+ Start: lowerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("Open lowerFD failed: %v", err)
+ }
+
+ if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("%v", err)
+ }
+
+ // Confirm that read from the modified file fails.
+ buf := make([]byte, size)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
+ t.Fatalf("fd.PRead succeeded with modified file")
+ }
+}
+
+// TestModifiedMerkleFails ensures that read from a verity file fails if the
+// corresponding Merkle tree file is modified.
+func TestModifiedMerkleFails(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, cleanup, err := newVerityRoot(ctx)
+ if err != nil {
+ t.Fatalf("Failed to create new verity root: %v", err)
+ }
+ defer cleanup()
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("Failed to create new file fd: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl failed: %v", err)
+ }
+
+ // Open a new lowerMerkleFD that's read/writable.
+ lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD
+
+ lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerMerkleVD,
+ Start: lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("Open lowerMerkleFD failed: %v", err)
+ }
+
+ // Flip a random bit in the Merkle tree file.
+ stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("Failed to get lowerMerkleFD stat: %v", err)
+ }
+ merkleSize := int(stat.Size)
+ if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil {
+ t.Fatalf("%v", err)
+ }
+
+ // Confirm that read from a file with modified Merkle tree fails.
+ buf := make([]byte, size)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
+ fmt.Println(buf)
+ t.Fatalf("fd.PRead succeeded with modified Merkle file")
+ }
+}
diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD
index 07bf39fed..5bba9de0b 100644
--- a/pkg/sentry/inet/BUILD
+++ b/pkg/sentry/inet/BUILD
@@ -15,6 +15,7 @@ go_library(
],
deps = [
"//pkg/context",
+ "//pkg/tcpip",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index c0b4831d1..fbe6d6aa6 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -15,7 +15,10 @@
// Package inet defines semantics for IP stacks.
package inet
-import "gvisor.dev/gvisor/pkg/tcpip/stack"
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
// Stack represents a TCP/IP stack.
type Stack interface {
@@ -80,6 +83,12 @@ type Stack interface {
// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful
// for restoring a stack after a save.
RestoreCleanupEndpoints([]stack.TransportEndpoint)
+
+ // Forwarding returns if packet forwarding between NICs is enabled.
+ Forwarding(protocol tcpip.NetworkProtocolNumber) bool
+
+ // SetForwarding enables or disables packet forwarding between NICs.
+ SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error
}
// Interface contains information about a network interface.
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index 9771f01fc..1779cc6f3 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -14,7 +14,10 @@
package inet
-import "gvisor.dev/gvisor/pkg/tcpip/stack"
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
// TestStack is a dummy implementation of Stack for tests.
type TestStack struct {
@@ -26,6 +29,7 @@ type TestStack struct {
TCPSendBufSize TCPBufferSize
TCPSACKFlag bool
Recovery TCPLossRecovery
+ IPForwarding bool
}
// NewTestStack returns a TestStack with no network interfaces. The value of
@@ -128,3 +132,14 @@ func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint {
// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
+
+// Forwarding implements inet.Stack.Forwarding.
+func (s *TestStack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
+ return s.IPForwarding
+}
+
+// SetForwarding implements inet.Stack.SetForwarding.
+func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error {
+ s.IPForwarding = enable
+ return nil
+}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index d436daab4..083071b5e 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -69,8 +69,8 @@ go_template_instance(
prefix = "socket",
template = "//pkg/ilist:generic_list",
types = {
- "Element": "*SocketEntry",
- "Linker": "*SocketEntry",
+ "Element": "*SocketRecordVFS1",
+ "Linker": "*SocketRecordVFS1",
},
)
@@ -197,6 +197,7 @@ go_library(
"gvisor.dev/gvisor/pkg/sentry/device",
"gvisor.dev/gvisor/pkg/tcpip",
],
+ marshal = True,
visibility = ["//:sandbox"],
deps = [
":uncaught_signal_go_proto",
@@ -212,6 +213,8 @@ go_library(
"//pkg/eventchannel",
"//pkg/fspath",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/metric",
"//pkg/refs",
"//pkg/refs_vfs2",
@@ -261,7 +264,6 @@ go_library(
"//pkg/tcpip/stack",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD
index 2bc49483a..869e49ebc 100644
--- a/pkg/sentry/kernel/auth/BUILD
+++ b/pkg/sentry/kernel/auth/BUILD
@@ -57,6 +57,7 @@ go_library(
"id_map_set.go",
"user_namespace.go",
],
+ marshal = True,
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
diff --git a/pkg/sentry/kernel/auth/id.go b/pkg/sentry/kernel/auth/id.go
index 0a58ba17c..4c32ee703 100644
--- a/pkg/sentry/kernel/auth/id.go
+++ b/pkg/sentry/kernel/auth/id.go
@@ -19,9 +19,13 @@ import (
)
// UID is a user ID in an unspecified user namespace.
+//
+// +marshal
type UID uint32
// GID is a group ID in an unspecified user namespace.
+//
+// +marshal slice:GIDSlice
type GID uint32
// In the root user namespace, user/group IDs have a 1-to-1 relationship with
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 5773244ac..0ec7344cd 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -111,8 +111,11 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor {
func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) {
ctx := context.Background()
f.init() // Initialize table.
+ f.used = 0
for fd, d := range m {
- f.setAll(fd, d.file, d.fileVFS2, d.flags)
+ if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil {
+ panic("VFS1 or VFS2 files set")
+ }
// Note that we do _not_ need to acquire a extra table reference here. The
// table reference will already be accounted for in the file, so we drop the
@@ -127,7 +130,7 @@ func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) {
}
// drop drops the table reference.
-func (f *FDTable) drop(file *fs.File) {
+func (f *FDTable) drop(ctx context.Context, file *fs.File) {
// Release locks.
file.Dirent.Inode.LockCtx.Posix.UnlockRegion(f, lock.LockRange{0, lock.LockEOF})
@@ -145,14 +148,13 @@ func (f *FDTable) drop(file *fs.File) {
d.InotifyEvent(ev, 0)
// Drop the table reference.
- file.DecRef(context.Background())
+ file.DecRef(ctx)
}
// dropVFS2 drops the table reference.
-func (f *FDTable) dropVFS2(file *vfs.FileDescription) {
+func (f *FDTable) dropVFS2(ctx context.Context, file *vfs.FileDescription) {
// Release any POSIX lock possibly held by the FDTable. Range {0, 0} means the
// entire file.
- ctx := context.Background()
err := file.UnlockPOSIX(ctx, f, 0, 0, linux.SEEK_SET)
if err != nil && err != syserror.ENOLCK {
panic(fmt.Sprintf("UnlockPOSIX failed: %v", err))
@@ -187,12 +189,6 @@ func (f *FDTable) DecRef(ctx context.Context) {
})
}
-// Size returns the number of file descriptor slots currently allocated.
-func (f *FDTable) Size() int {
- size := atomic.LoadInt32(&f.used)
- return int(size)
-}
-
// forEach iterates over all non-nil files in sorted order.
//
// It is the caller's responsibility to acquire an appropriate lock.
@@ -279,7 +275,6 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
}
f.mu.Lock()
- defer f.mu.Unlock()
// From f.next to find available fd.
if fd < f.next {
@@ -289,15 +284,25 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
// Install all entries.
for i := fd; i < end && len(fds) < len(files); i++ {
if d, _, _ := f.get(i); d == nil {
- f.set(i, files[len(fds)], flags) // Set the descriptor.
- fds = append(fds, i) // Record the file descriptor.
+ // Set the descriptor.
+ f.set(ctx, i, files[len(fds)], flags)
+ fds = append(fds, i) // Record the file descriptor.
}
}
// Failure? Unwind existing FDs.
if len(fds) < len(files) {
for _, i := range fds {
- f.set(i, nil, FDFlags{}) // Zap entry.
+ f.set(ctx, i, nil, FDFlags{})
+ }
+ f.mu.Unlock()
+
+ // Drop the reference taken by the call to f.set() that
+ // originally installed the file. Don't call f.drop()
+ // (generating inotify events, etc.) since the file should
+ // appear to have never been inserted into f.
+ for _, file := range files[:len(fds)] {
+ file.DecRef(ctx)
}
return nil, syscall.EMFILE
}
@@ -307,6 +312,7 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
f.next = fds[len(fds)-1] + 1
}
+ f.mu.Unlock()
return fds, nil
}
@@ -334,7 +340,6 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes
}
f.mu.Lock()
- defer f.mu.Unlock()
// From f.next to find available fd.
if fd < f.next {
@@ -344,15 +349,25 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes
// Install all entries.
for i := fd; i < end && len(fds) < len(files); i++ {
if d, _, _ := f.getVFS2(i); d == nil {
- f.setVFS2(i, files[len(fds)], flags) // Set the descriptor.
- fds = append(fds, i) // Record the file descriptor.
+ // Set the descriptor.
+ f.setVFS2(ctx, i, files[len(fds)], flags)
+ fds = append(fds, i) // Record the file descriptor.
}
}
// Failure? Unwind existing FDs.
if len(fds) < len(files) {
for _, i := range fds {
- f.setVFS2(i, nil, FDFlags{}) // Zap entry.
+ f.setVFS2(ctx, i, nil, FDFlags{})
+ }
+ f.mu.Unlock()
+
+ // Drop the reference taken by the call to f.setVFS2() that
+ // originally installed the file. Don't call f.dropVFS2()
+ // (generating inotify events, etc.) since the file should
+ // appear to have never been inserted into f.
+ for _, file := range files[:len(fds)] {
+ file.DecRef(ctx)
}
return nil, syscall.EMFILE
}
@@ -362,6 +377,7 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes
f.next = fds[len(fds)-1] + 1
}
+ f.mu.Unlock()
return fds, nil
}
@@ -397,7 +413,7 @@ func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDesc
}
for fd < end {
if d, _, _ := f.getVFS2(fd); d == nil {
- f.setVFS2(fd, file, flags)
+ f.setVFS2(ctx, fd, file, flags)
if fd == f.next {
// Update next search start position.
f.next = fd + 1
@@ -413,40 +429,55 @@ func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDesc
// reference for that FD, the ref count for that existing reference is
// decremented.
func (f *FDTable) NewFDAt(ctx context.Context, fd int32, file *fs.File, flags FDFlags) error {
- return f.newFDAt(ctx, fd, file, nil, flags)
+ df, _, err := f.newFDAt(ctx, fd, file, nil, flags)
+ if err != nil {
+ return err
+ }
+ if df != nil {
+ f.drop(ctx, df)
+ }
+ return nil
}
// NewFDAtVFS2 sets the file reference for the given FD. If there is an active
// reference for that FD, the ref count for that existing reference is
// decremented.
func (f *FDTable) NewFDAtVFS2(ctx context.Context, fd int32, file *vfs.FileDescription, flags FDFlags) error {
- return f.newFDAt(ctx, fd, nil, file, flags)
+ _, dfVFS2, err := f.newFDAt(ctx, fd, nil, file, flags)
+ if err != nil {
+ return err
+ }
+ if dfVFS2 != nil {
+ f.dropVFS2(ctx, dfVFS2)
+ }
+ return nil
}
-func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) error {
+func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) (*fs.File, *vfs.FileDescription, error) {
if fd < 0 {
// Don't accept negative FDs.
- return syscall.EBADF
+ return nil, nil, syscall.EBADF
}
// Check the limit for the provided file.
if limitSet := limits.FromContext(ctx); limitSet != nil {
if lim := limitSet.Get(limits.NumberOfFiles); lim.Cur != limits.Infinity && uint64(fd) >= lim.Cur {
- return syscall.EMFILE
+ return nil, nil, syscall.EMFILE
}
}
// Install the entry.
f.mu.Lock()
defer f.mu.Unlock()
- f.setAll(fd, file, fileVFS2, flags)
- return nil
+
+ df, dfVFS2 := f.setAll(ctx, fd, file, fileVFS2, flags)
+ return df, dfVFS2, nil
}
// SetFlags sets the flags for the given file descriptor.
//
// True is returned iff flags were changed.
-func (f *FDTable) SetFlags(fd int32, flags FDFlags) error {
+func (f *FDTable) SetFlags(ctx context.Context, fd int32, flags FDFlags) error {
if fd < 0 {
// Don't accept negative FDs.
return syscall.EBADF
@@ -462,14 +493,14 @@ func (f *FDTable) SetFlags(fd int32, flags FDFlags) error {
}
// Update the flags.
- f.set(fd, file, flags)
+ f.set(ctx, fd, file, flags)
return nil
}
// SetFlagsVFS2 sets the flags for the given file descriptor.
//
// True is returned iff flags were changed.
-func (f *FDTable) SetFlagsVFS2(fd int32, flags FDFlags) error {
+func (f *FDTable) SetFlagsVFS2(ctx context.Context, fd int32, flags FDFlags) error {
if fd < 0 {
// Don't accept negative FDs.
return syscall.EBADF
@@ -485,7 +516,7 @@ func (f *FDTable) SetFlagsVFS2(fd int32, flags FDFlags) error {
}
// Update the flags.
- f.setVFS2(fd, file, flags)
+ f.setVFS2(ctx, fd, file, flags)
return nil
}
@@ -551,30 +582,6 @@ func (f *FDTable) GetFDs(ctx context.Context) []int32 {
return fds
}
-// GetRefs returns a stable slice of references to all files and bumps the
-// reference count on each. The caller must use DecRef on each reference when
-// they're done using the slice.
-func (f *FDTable) GetRefs(ctx context.Context) []*fs.File {
- files := make([]*fs.File, 0, f.Size())
- f.forEach(ctx, func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
- file.IncRef() // Acquire a reference for caller.
- files = append(files, file)
- })
- return files
-}
-
-// GetRefsVFS2 returns a stable slice of references to all files and bumps the
-// reference count on each. The caller must use DecRef on each reference when
-// they're done using the slice.
-func (f *FDTable) GetRefsVFS2(ctx context.Context) []*vfs.FileDescription {
- files := make([]*vfs.FileDescription, 0, f.Size())
- f.forEach(ctx, func(_ int32, _ *fs.File, file *vfs.FileDescription, _ FDFlags) {
- file.IncRef() // Acquire a reference for caller.
- files = append(files, file)
- })
- return files
-}
-
// Fork returns an independent FDTable.
func (f *FDTable) Fork(ctx context.Context) *FDTable {
clone := f.k.NewFDTable()
@@ -582,11 +589,8 @@ func (f *FDTable) Fork(ctx context.Context) *FDTable {
f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
// The set function here will acquire an appropriate table
// reference for the clone. We don't need anything else.
- switch {
- case file != nil:
- clone.set(fd, file, flags)
- case fileVFS2 != nil:
- clone.setVFS2(fd, fileVFS2, flags)
+ if df, dfVFS2 := clone.setAll(ctx, fd, file, fileVFS2, flags); df != nil || dfVFS2 != nil {
+ panic("VFS1 or VFS2 files set")
}
})
return clone
@@ -595,13 +599,12 @@ func (f *FDTable) Fork(ctx context.Context) *FDTable {
// Remove removes an FD from and returns a non-file iff successful.
//
// N.B. Callers are required to use DecRef when they are done.
-func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) {
+func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDescription) {
if fd < 0 {
return nil, nil
}
f.mu.Lock()
- defer f.mu.Unlock()
// Update current available position.
if fd < f.next {
@@ -617,24 +620,51 @@ func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) {
case orig2 != nil:
orig2.IncRef()
}
+
if orig != nil || orig2 != nil {
- f.setAll(fd, nil, nil, FDFlags{}) // Zap entry.
+ orig, orig2 = f.setAll(ctx, fd, nil, nil, FDFlags{}) // Zap entry.
+ }
+ f.mu.Unlock()
+
+ if orig != nil {
+ f.drop(ctx, orig)
}
+ if orig2 != nil {
+ f.dropVFS2(ctx, orig2)
+ }
+
return orig, orig2
}
// RemoveIf removes all FDs where cond is true.
func (f *FDTable) RemoveIf(ctx context.Context, cond func(*fs.File, *vfs.FileDescription, FDFlags) bool) {
- f.mu.Lock()
- defer f.mu.Unlock()
+ // TODO(gvisor.dev/issue/1624): Remove fs.File slice.
+ var files []*fs.File
+ var filesVFS2 []*vfs.FileDescription
+ f.mu.Lock()
f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
if cond(file, fileVFS2, flags) {
- f.set(fd, nil, FDFlags{}) // Clear from table.
+ df, dfVFS2 := f.setAll(ctx, fd, nil, nil, FDFlags{}) // Clear from table.
+ if df != nil {
+ files = append(files, df)
+ }
+ if dfVFS2 != nil {
+ filesVFS2 = append(filesVFS2, dfVFS2)
+ }
// Update current available position.
if fd < f.next {
f.next = fd
}
}
})
+ f.mu.Unlock()
+
+ for _, file := range files {
+ f.drop(ctx, file)
+ }
+
+ for _, file := range filesVFS2 {
+ f.dropVFS2(ctx, file)
+ }
}
diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go
index e3f30ba2a..bf5460083 100644
--- a/pkg/sentry/kernel/fd_table_test.go
+++ b/pkg/sentry/kernel/fd_table_test.go
@@ -72,7 +72,7 @@ func TestFDTableMany(t *testing.T) {
}
i := int32(2)
- fdTable.Remove(i)
+ fdTable.Remove(ctx, i)
if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i {
t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err)
}
@@ -93,7 +93,7 @@ func TestFDTableOverLimit(t *testing.T) {
t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err)
} else {
for _, fd := range fds {
- fdTable.Remove(fd)
+ fdTable.Remove(ctx, fd)
}
}
@@ -150,13 +150,13 @@ func TestFDTable(t *testing.T) {
t.Fatalf("fdTable.Get(2): got a %v, wanted nil", ref)
}
- ref, _ := fdTable.Remove(1)
+ ref, _ := fdTable.Remove(ctx, 1)
if ref == nil {
t.Fatalf("fdTable.Remove(1) for an existing FD: failed, want success")
}
ref.DecRef(ctx)
- if ref, _ := fdTable.Remove(1); ref != nil {
+ if ref, _ := fdTable.Remove(ctx, 1); ref != nil {
t.Fatalf("r.Remove(1) for a removed FD: got success, want failure")
}
})
diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go
index 6b8feb107..da79e6627 100644
--- a/pkg/sentry/kernel/fd_table_unsafe.go
+++ b/pkg/sentry/kernel/fd_table_unsafe.go
@@ -18,6 +18,7 @@ import (
"sync/atomic"
"unsafe"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
@@ -78,33 +79,37 @@ func (f *FDTable) getAll(fd int32) (*fs.File, *vfs.FileDescription, FDFlags, boo
return d.file, d.fileVFS2, d.flags, true
}
-// set sets an entry.
-//
-// This handles accounting changes, as well as acquiring and releasing the
-// reference needed by the table iff the file is different.
+// CurrentMaxFDs returns the number of file descriptors that may be stored in f
+// without reallocation.
+func (f *FDTable) CurrentMaxFDs() int {
+ slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice))
+ return len(slice)
+}
+
+// set sets an entry for VFS1, refer to setAll().
//
// Precondition: mu must be held.
-func (f *FDTable) set(fd int32, file *fs.File, flags FDFlags) {
- f.setAll(fd, file, nil, flags)
+func (f *FDTable) set(ctx context.Context, fd int32, file *fs.File, flags FDFlags) *fs.File {
+ dropFile, _ := f.setAll(ctx, fd, file, nil, flags)
+ return dropFile
}
-// setVFS2 sets an entry.
-//
-// This handles accounting changes, as well as acquiring and releasing the
-// reference needed by the table iff the file is different.
+// setVFS2 sets an entry for VFS2, refer to setAll().
//
// Precondition: mu must be held.
-func (f *FDTable) setVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) {
- f.setAll(fd, nil, file, flags)
+func (f *FDTable) setVFS2(ctx context.Context, fd int32, file *vfs.FileDescription, flags FDFlags) *vfs.FileDescription {
+ _, dropFile := f.setAll(ctx, fd, nil, file, flags)
+ return dropFile
}
-// setAll sets an entry.
-//
-// This handles accounting changes, as well as acquiring and releasing the
-// reference needed by the table iff the file is different.
+// setAll sets the file description referred to by fd to file/fileVFS2. If
+// file/fileVFS2 are non-nil, it takes a reference on them. If setAll replaces
+// an existing file description, it returns it with the FDTable's reference
+// transferred to the caller, which must call f.drop/dropVFS2() on the returned
+// file after unlocking f.mu.
//
// Precondition: mu must be held.
-func (f *FDTable) setAll(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
+func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) (*fs.File, *vfs.FileDescription) {
if file != nil && fileVFS2 != nil {
panic("VFS1 and VFS2 files set")
}
@@ -147,25 +152,25 @@ func (f *FDTable) setAll(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription,
}
}
- // Drop the table reference.
+ // Adjust used.
+ switch {
+ case orig == nil && desc != nil:
+ atomic.AddInt32(&f.used, 1)
+ case orig != nil && desc == nil:
+ atomic.AddInt32(&f.used, -1)
+ }
+
if orig != nil {
switch {
case orig.file != nil:
if desc == nil || desc.file != orig.file {
- f.drop(orig.file)
+ return orig.file, nil
}
case orig.fileVFS2 != nil:
if desc == nil || desc.fileVFS2 != orig.fileVFS2 {
- f.dropVFS2(orig.fileVFS2)
+ return nil, orig.fileVFS2
}
}
}
-
- // Adjust used.
- switch {
- case orig == nil && desc != nil:
- atomic.AddInt32(&f.used, 1)
- case orig != nil && desc == nil:
- atomic.AddInt32(&f.used, -1)
- }
+ return nil, nil
}
diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go
index aad63aa99..d3e76ca7b 100644
--- a/pkg/sentry/kernel/kcov.go
+++ b/pkg/sentry/kernel/kcov.go
@@ -89,6 +89,10 @@ func (kcov *Kcov) TaskWork(t *Task) {
kcov.mu.Lock()
defer kcov.mu.Unlock()
+ if kcov.mode != linux.KCOV_TRACE_PC {
+ return
+ }
+
rw := &kcovReadWriter{
mf: kcov.mfp.MemoryFile(),
fr: kcov.mappable.FileRange(),
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 402aa1718..d6c21adb7 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -220,13 +220,18 @@ type Kernel struct {
// danglingEndpoints is used to save / restore tcpip.DanglingEndpoints.
danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"`
- // sockets is the list of all network sockets the system. Protected by
- // extMu.
+ // sockets is the list of all network sockets in the system.
+ // Protected by extMu.
+ // TODO(gvisor.dev/issue/1624): Only used by VFS1.
sockets socketList
- // nextSocketEntry is the next entry number to use in sockets. Protected
+ // socketsVFS2 records all network sockets in the system. Protected by
+ // extMu.
+ socketsVFS2 map[*vfs.FileDescription]*SocketRecord
+
+ // nextSocketRecord is the next entry number to use in sockets. Protected
// by extMu.
- nextSocketEntry uint64
+ nextSocketRecord uint64
// deviceRegistry is used to save/restore device.SimpleDevices.
deviceRegistry struct{} `state:".(*device.Registry)"`
@@ -414,6 +419,8 @@ func (k *Kernel) Init(args InitKernelArgs) error {
return fmt.Errorf("failed to create sockfs mount: %v", err)
}
k.socketMount = socketMount
+
+ k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord)
}
return nil
@@ -507,6 +514,10 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
// flushMountSourceRefs flushes the MountSources for all mounted filesystems
// and open FDs.
func (k *Kernel) flushMountSourceRefs(ctx context.Context) error {
+ if VFS2Enabled {
+ return nil // Not relevant.
+ }
+
// Flush all mount sources for currently mounted filesystems in each task.
flushed := make(map[*fs.MountNamespace]struct{})
k.tasks.mu.RLock()
@@ -533,11 +544,6 @@ func (k *Kernel) flushMountSourceRefs(ctx context.Context) error {
//
// Precondition: Must be called with the kernel paused.
func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.FileDescription) error) (err error) {
- // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
- if VFS2Enabled {
- return nil
- }
-
ts.mu.RLock()
defer ts.mu.RUnlock()
for t := range ts.Root.tids {
@@ -556,6 +562,10 @@ func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.Fi
func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error {
// TODO(gvisor.dev/issue/1663): Add save support for VFS2.
+ if VFS2Enabled {
+ return nil
+ }
+
return ts.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error {
if flags := file.Flags(); !flags.Write {
return nil
@@ -888,17 +898,18 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
opener fsbridge.Lookup
fsContext *FSContext
mntns *fs.MountNamespace
+ mntnsVFS2 *vfs.MountNamespace
)
if VFS2Enabled {
- mntnsVFS2 := args.MountNamespaceVFS2
+ mntnsVFS2 = args.MountNamespaceVFS2
if mntnsVFS2 == nil {
// MountNamespaceVFS2 adds a reference to the namespace, which is
// transferred to the new process.
mntnsVFS2 = k.globalInit.Leader().MountNamespaceVFS2()
}
// Get the root directory from the MountNamespace.
- root := args.MountNamespaceVFS2.Root()
+ root := mntnsVFS2.Root()
// The call to newFSContext below will take a reference on root, so we
// don't need to hold this one.
defer root.DecRef(ctx)
@@ -1008,7 +1019,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
UTSNamespace: args.UTSNamespace,
IPCNamespace: args.IPCNamespace,
AbstractSocketNamespace: args.AbstractSocketNamespace,
- MountNamespaceVFS2: args.MountNamespaceVFS2,
+ MountNamespaceVFS2: mntnsVFS2,
ContainerID: args.ContainerID,
}
t, err := k.tasks.NewTask(config)
@@ -1508,20 +1519,27 @@ func (k *Kernel) SupervisorContext() context.Context {
}
}
-// SocketEntry represents a socket recorded in Kernel.sockets. It implements
+// SocketRecord represents a socket recorded in Kernel.socketsVFS2.
+//
+// +stateify savable
+type SocketRecord struct {
+ k *Kernel
+ Sock *refs.WeakRef // TODO(gvisor.dev/issue/1624): Only used by VFS1.
+ SockVFS2 *vfs.FileDescription // Only used by VFS2.
+ ID uint64 // Socket table entry number.
+}
+
+// SocketRecordVFS1 represents a socket recorded in Kernel.sockets. It implements
// refs.WeakRefUser for sockets stored in the socket table.
//
// +stateify savable
-type SocketEntry struct {
+type SocketRecordVFS1 struct {
socketEntry
- k *Kernel
- Sock *refs.WeakRef
- SockVFS2 *vfs.FileDescription
- ID uint64 // Socket table entry number.
+ SocketRecord
}
// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
-func (s *SocketEntry) WeakRefGone(context.Context) {
+func (s *SocketRecordVFS1) WeakRefGone(context.Context) {
s.k.extMu.Lock()
s.k.sockets.Remove(s)
s.k.extMu.Unlock()
@@ -1532,9 +1550,14 @@ func (s *SocketEntry) WeakRefGone(context.Context) {
// Precondition: Caller must hold a reference to sock.
func (k *Kernel) RecordSocket(sock *fs.File) {
k.extMu.Lock()
- id := k.nextSocketEntry
- k.nextSocketEntry++
- s := &SocketEntry{k: k, ID: id}
+ id := k.nextSocketRecord
+ k.nextSocketRecord++
+ s := &SocketRecordVFS1{
+ SocketRecord: SocketRecord{
+ k: k,
+ ID: id,
+ },
+ }
s.Sock = refs.NewWeakRef(sock, s)
k.sockets.PushBack(s)
k.extMu.Unlock()
@@ -1546,29 +1569,45 @@ func (k *Kernel) RecordSocket(sock *fs.File) {
// Precondition: Caller must hold a reference to sock.
//
// Note that the socket table will not hold a reference on the
-// vfs.FileDescription, because we do not support weak refs on VFS2 files.
+// vfs.FileDescription.
func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) {
k.extMu.Lock()
- id := k.nextSocketEntry
- k.nextSocketEntry++
- s := &SocketEntry{
+ if _, ok := k.socketsVFS2[sock]; ok {
+ panic(fmt.Sprintf("Socket %p added twice", sock))
+ }
+ id := k.nextSocketRecord
+ k.nextSocketRecord++
+ s := &SocketRecord{
k: k,
ID: id,
SockVFS2: sock,
}
- k.sockets.PushBack(s)
+ k.socketsVFS2[sock] = s
+ k.extMu.Unlock()
+}
+
+// DeleteSocketVFS2 removes a VFS2 socket from the system-wide socket table.
+func (k *Kernel) DeleteSocketVFS2(sock *vfs.FileDescription) {
+ k.extMu.Lock()
+ delete(k.socketsVFS2, sock)
k.extMu.Unlock()
}
// ListSockets returns a snapshot of all sockets.
//
-// Callers of ListSockets() in VFS2 should use SocketEntry.SockVFS2.TryIncRef()
+// Callers of ListSockets() in VFS2 should use SocketRecord.SockVFS2.TryIncRef()
// to get a reference on a socket in the table.
-func (k *Kernel) ListSockets() []*SocketEntry {
+func (k *Kernel) ListSockets() []*SocketRecord {
k.extMu.Lock()
- var socks []*SocketEntry
- for s := k.sockets.Front(); s != nil; s = s.Next() {
- socks = append(socks, s)
+ var socks []*SocketRecord
+ if VFS2Enabled {
+ for _, s := range k.socketsVFS2 {
+ socks = append(socks, s)
+ }
+ } else {
+ for s := k.sockets.Front(); s != nil; s = s.Next() {
+ socks = append(socks, &s.SocketRecord)
+ }
}
k.extMu.Unlock()
return socks
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 449643118..99134e634 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -21,6 +21,7 @@ go_library(
"//pkg/amutex",
"//pkg/buffer",
"//pkg/context",
+ "//pkg/marshal/primitive",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index c410c96aa..67beb0ad6 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -17,6 +17,7 @@ package pipe
import (
"fmt"
+ "io"
"sync/atomic"
"syscall"
@@ -215,7 +216,7 @@ func (p *Pipe) readLocked(ctx context.Context, ops readOps) (int64, error) {
if p.view.Size() == 0 {
if !p.HasWriters() {
// There are no writers, return EOF.
- return 0, nil
+ return 0, io.EOF
}
return 0, syserror.ErrWouldBlock
}
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
index 6d58b682f..f665920cb 100644
--- a/pkg/sentry/kernel/pipe/pipe_util.go
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
@@ -145,9 +146,14 @@ func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArgume
v = math.MaxInt32 // Silently truncate.
}
// Copy result to userspace.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ iocc := primitive.IOCopyContext{
+ IO: io,
+ Ctx: ctx,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ }
+ _, err := primitive.CopyInt32Out(&iocc, args[2].Pointer(), int32(v))
return 0, err
default:
return 0, syscall.ENOTTY
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index f223d59e1..f61039f5b 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -67,6 +67,11 @@ func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlag
return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks)
}
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (*VFSPipe) Allocate(context.Context, uint64, uint64, uint64) error {
+ return syserror.ESPIPE
+}
+
// Open opens the pipe represented by vp.
func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) {
vp.mu.Lock()
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 50df179c3..1145faf13 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/syserror"
@@ -999,18 +1000,15 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
// at the address specified by the data parameter, and the return value
// is the error flag." - ptrace(2)
word := t.Arch().Native(0)
- if _, err := usermem.CopyObjectIn(t, target.MemoryManager(), addr, word, usermem.IOOpts{
- IgnorePermissions: true,
- }); err != nil {
+ if _, err := word.CopyIn(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr); err != nil {
return err
}
- _, err := t.CopyOut(data, word)
+ _, err := word.CopyOut(t, data)
return err
case linux.PTRACE_POKETEXT, linux.PTRACE_POKEDATA:
- _, err := usermem.CopyObjectOut(t, target.MemoryManager(), addr, t.Arch().Native(uintptr(data)), usermem.IOOpts{
- IgnorePermissions: true,
- })
+ word := t.Arch().Native(uintptr(data))
+ _, err := word.CopyOut(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr)
return err
case linux.PTRACE_GETREGSET:
@@ -1078,12 +1076,12 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
if target.ptraceSiginfo == nil {
return syserror.EINVAL
}
- _, err := t.CopyOut(data, target.ptraceSiginfo)
+ _, err := target.ptraceSiginfo.CopyOut(t, data)
return err
case linux.PTRACE_SETSIGINFO:
var info arch.SignalInfo
- if _, err := t.CopyIn(data, &info); err != nil {
+ if _, err := info.CopyIn(t, data); err != nil {
return err
}
t.tg.pidns.owner.mu.RLock()
@@ -1098,7 +1096,8 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
if addr != linux.SignalSetSize {
return syserror.EINVAL
}
- _, err := t.CopyOut(data, target.SignalMask())
+ mask := target.SignalMask()
+ _, err := mask.CopyOut(t, data)
return err
case linux.PTRACE_SETSIGMASK:
@@ -1106,7 +1105,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
return syserror.EINVAL
}
var mask linux.SignalSet
- if _, err := t.CopyIn(data, &mask); err != nil {
+ if _, err := mask.CopyIn(t, data); err != nil {
return err
}
// The target's task goroutine is stopped, so this is safe:
@@ -1121,7 +1120,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
case linux.PTRACE_GETEVENTMSG:
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
- _, err := t.CopyOut(usermem.Addr(data), target.ptraceEventMsg)
+ _, err := primitive.CopyUint64Out(t, usermem.Addr(data), target.ptraceEventMsg)
return err
// PEEKSIGINFO is unimplemented but seems to have no users anywhere.
diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go
index cef1276ec..609ad3941 100644
--- a/pkg/sentry/kernel/ptrace_amd64.go
+++ b/pkg/sentry/kernel/ptrace_amd64.go
@@ -30,7 +30,7 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) erro
if err != nil {
return err
}
- _, err = t.CopyOut(data, n)
+ _, err = n.CopyOut(t, data)
return err
case linux.PTRACE_POKEUSR: // aka PTRACE_POKEUSER
diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go
index 413111faf..332bdb8e8 100644
--- a/pkg/sentry/kernel/syscalls.go
+++ b/pkg/sentry/kernel/syscalls.go
@@ -348,6 +348,16 @@ func (s *SyscallTable) LookupName(sysno uintptr) string {
return fmt.Sprintf("sys_%d", sysno) // Unlikely.
}
+// LookupNo looks up a syscall number by name.
+func (s *SyscallTable) LookupNo(name string) (uintptr, error) {
+ for i, syscall := range s.Table {
+ if syscall.Name == name {
+ return uintptr(i), nil
+ }
+ }
+ return 0, fmt.Errorf("syscall %q not found", name)
+}
+
// LookupEmulate looks up an emulation syscall number.
func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) {
sysno, ok := s.Emulate[addr]
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 9d7a9128f..fce1064a7 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -341,12 +341,12 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
nt.SetClearTID(opts.ChildTID)
}
if opts.ChildSetTID {
- // Can't use Task.CopyOut, which assumes AddressSpaceActive.
- usermem.CopyObjectOut(t, nt.MemoryManager(), opts.ChildTID, nt.ThreadID(), usermem.IOOpts{})
+ ctid := nt.ThreadID()
+ ctid.CopyOut(nt.AsCopyContext(usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID)
}
ntid := t.tg.pidns.IDOfTask(nt)
if opts.ParentSetTID {
- t.CopyOut(opts.ParentTID, ntid)
+ ntid.CopyOut(t, opts.ParentTID)
}
kind := ptraceCloneKindClone
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index 9fa528384..d1136461a 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -126,7 +126,11 @@ func (t *Task) SyscallTable() *SyscallTable {
// Preconditions: The caller must be running on the task goroutine, or t.mu
// must be locked.
func (t *Task) Stack() *arch.Stack {
- return &arch.Stack{t.Arch(), t.MemoryManager(), usermem.Addr(t.Arch().Stack())}
+ return &arch.Stack{
+ Arch: t.Arch(),
+ IO: t.MemoryManager(),
+ Bottom: usermem.Addr(t.Arch().Stack()),
+ }
}
// LoadTaskImage loads a specified file into a new TaskContext.
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index b76f7f503..b400a8b41 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -248,7 +248,8 @@ func (*runExitMain) execute(t *Task) taskRunState {
signaled := t.tg.exiting && t.tg.exitStatus.Signaled()
t.tg.signalHandlers.mu.Unlock()
if !signaled {
- if _, err := t.CopyOut(t.cleartid, ThreadID(0)); err == nil {
+ zero := ThreadID(0)
+ if _, err := zero.CopyOut(t, t.cleartid); err == nil {
t.Futex().Wake(t, t.cleartid, false, ^uint32(0), 1)
}
// If the CopyOut fails, there's nothing we can do.
diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go
index 4b535c949..c80391475 100644
--- a/pkg/sentry/kernel/task_futex.go
+++ b/pkg/sentry/kernel/task_futex.go
@@ -16,6 +16,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -87,7 +88,7 @@ func (t *Task) exitRobustList() {
return
}
- next := rl.List
+ next := primitive.Uint64(rl.List)
done := 0
var pendingLockAddr usermem.Addr
if rl.ListOpPending != 0 {
@@ -99,12 +100,12 @@ func (t *Task) exitRobustList() {
// We traverse to the next element of the list before we
// actually wake anything. This prevents the race where waking
// this futex causes a modification of the list.
- thisLockAddr := usermem.Addr(next + rl.FutexOffset)
+ thisLockAddr := usermem.Addr(uint64(next) + rl.FutexOffset)
// Try to decode the next element in the list before waking the
// current futex. But don't check the error until after we've
// woken the current futex. Linux does it in this order too
- _, nextErr := t.CopyIn(usermem.Addr(next), &next)
+ _, nextErr := next.CopyIn(t, usermem.Addr(next))
// Wakeup the current futex if it's not pending.
if thisLockAddr != pendingLockAddr {
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index aa3a573c0..8dc3fec90 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -141,7 +141,7 @@ func (*runApp) handleCPUIDInstruction(t *Task) error {
region := trace.StartRegion(t.traceContext, cpuidRegion)
expected := arch.CPUIDInstruction[:]
found := make([]byte, len(expected))
- _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found)
+ _, err := t.CopyInBytes(usermem.Addr(t.Arch().IP()), found)
if err == nil && bytes.Equal(expected, found) {
// Skip the cpuid instruction.
t.Arch().CPUIDEmulate(t)
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index feaa38596..ebdb83061 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -259,7 +259,11 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct)
// Set up the signal handler. If we have a saved signal mask, the signal
// handler should run with the current mask, but sigreturn should restore
// the saved one.
- st := &arch.Stack{t.Arch(), mm, sp}
+ st := &arch.Stack{
+ Arch: t.Arch(),
+ IO: mm,
+ Bottom: sp,
+ }
mask := t.signalMask
if t.haveSavedSignalMask {
mask = t.savedSignalMask
diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go
index 2dbf86547..0141459e7 100644
--- a/pkg/sentry/kernel/task_syscall.go
+++ b/pkg/sentry/kernel/task_syscall.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -287,7 +288,7 @@ func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState {
// Grab the caller up front, to make sure there's a sensible stack.
caller := t.Arch().Native(uintptr(0))
- if _, err := t.CopyIn(usermem.Addr(t.Arch().Stack()), caller); err != nil {
+ if _, err := caller.CopyIn(t, usermem.Addr(t.Arch().Stack())); err != nil {
t.Debugf("vsyscall %d: error reading return address from stack: %v", sysno, err)
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
@@ -323,7 +324,7 @@ func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState {
type runVsyscallAfterPtraceEventSeccomp struct {
addr usermem.Addr
sysno uintptr
- caller interface{}
+ caller marshal.Marshallable
}
func (r *runVsyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState {
@@ -346,7 +347,7 @@ func (r *runVsyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState {
return t.doVsyscallInvoke(sysno, t.Arch().SyscallArgs(), r.caller)
}
-func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, caller interface{}) taskRunState {
+func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, caller marshal.Marshallable) taskRunState {
rval, ctrl, err := t.executeSyscall(sysno, args)
if ctrl != nil {
t.Debugf("vsyscall %d, caller %x: syscall control: %v", sysno, t.Arch().Value(caller), ctrl)
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
index 0cb86e390..ce134bf54 100644
--- a/pkg/sentry/kernel/task_usermem.go
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -18,6 +18,7 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -43,17 +44,6 @@ func (t *Task) Deactivate() {
}
}
-// CopyIn copies a fixed-size value or slice of fixed-size values in from the
-// task's memory. The copy will fail with syscall.EFAULT if it traverses user
-// memory that is unmapped or not readable by the user.
-//
-// This Task's AddressSpace must be active.
-func (t *Task) CopyIn(addr usermem.Addr, dst interface{}) (int, error) {
- return usermem.CopyObjectIn(t, t.MemoryManager(), addr, dst, usermem.IOOpts{
- AddressSpaceActive: true,
- })
-}
-
// CopyInBytes is a fast version of CopyIn if the caller can serialize the
// data without reflection and pass in a byte slice.
//
@@ -64,17 +54,6 @@ func (t *Task) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
})
}
-// CopyOut copies a fixed-size value or slice of fixed-size values out to the
-// task's memory. The copy will fail with syscall.EFAULT if it traverses user
-// memory that is unmapped or not writeable by the user.
-//
-// This Task's AddressSpace must be active.
-func (t *Task) CopyOut(addr usermem.Addr, src interface{}) (int, error) {
- return usermem.CopyObjectOut(t, t.MemoryManager(), addr, src, usermem.IOOpts{
- AddressSpaceActive: true,
- })
-}
-
// CopyOutBytes is a fast version of CopyOut if the caller can serialize the
// data without reflection and pass in a byte slice.
//
@@ -114,7 +93,7 @@ func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([
var v []string
for {
argAddr := t.Arch().Native(0)
- if _, err := t.CopyIn(addr, argAddr); err != nil {
+ if _, err := argAddr.CopyIn(t, addr); err != nil {
return v, err
}
if t.Arch().Value(argAddr) == 0 {
@@ -302,29 +281,29 @@ func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOp
}, nil
}
-// CopyContextWithOpts wraps a task to allow copying memory to and from the
-// task memory with user specified usermem.IOOpts.
-type CopyContextWithOpts struct {
+// copyContext implements marshal.CopyContext. It wraps a task to allow copying
+// memory to and from the task memory with custom usermem.IOOpts.
+type copyContext struct {
*Task
opts usermem.IOOpts
}
-// AsCopyContextWithOpts wraps the task and returns it as CopyContextWithOpts.
-func (t *Task) AsCopyContextWithOpts(opts usermem.IOOpts) *CopyContextWithOpts {
- return &CopyContextWithOpts{t, opts}
+// AsCopyContext wraps the task and returns it as CopyContext.
+func (t *Task) AsCopyContext(opts usermem.IOOpts) marshal.CopyContext {
+ return &copyContext{t, opts}
}
// CopyInString copies a string in from the task's memory.
-func (t *CopyContextWithOpts) CopyInString(addr usermem.Addr, maxLen int) (string, error) {
+func (t *copyContext) CopyInString(addr usermem.Addr, maxLen int) (string, error) {
return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxLen, t.opts)
}
// CopyInBytes copies task memory into dst from an IO context.
-func (t *CopyContextWithOpts) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
+func (t *copyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
return t.MemoryManager().CopyIn(t, addr, dst, t.opts)
}
// CopyOutBytes copies src into task memoryfrom an IO context.
-func (t *CopyContextWithOpts) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
+func (t *copyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
return t.MemoryManager().CopyOut(t, addr, src, t.opts)
}
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index 872e1a82d..5ae5906e8 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -36,6 +36,8 @@ import (
const TasksLimit = (1 << 16)
// ThreadID is a generic thread identifier.
+//
+// +marshal
type ThreadID int32
// String returns a decimal representation of the ThreadID.
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index 15c88aa7c..c69b62db9 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -122,7 +122,7 @@ func allocStack(ctx context.Context, m *mm.MemoryManager, a arch.Context) (*arch
if err != nil {
return nil, err
}
- return &arch.Stack{a, m, ar.End}, nil
+ return &arch.Stack{Arch: a, IO: m, Bottom: ar.End}, nil
}
const (
@@ -247,20 +247,20 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V
}
// Push the original filename to the stack, for AT_EXECFN.
- execfn, err := stack.Push(args.Filename)
- if err != nil {
+ if _, err := stack.PushNullTerminatedByteSlice([]byte(args.Filename)); err != nil {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push exec filename: %v", err), syserr.FromError(err).ToLinux())
}
+ execfn := stack.Bottom
// Push 16 random bytes on the stack which AT_RANDOM will point to.
var b [16]byte
if _, err := rand.Read(b[:]); err != nil {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to read random bytes: %v", err), syserr.FromError(err).ToLinux())
}
- random, err := stack.Push(b)
- if err != nil {
+ if _, err = stack.PushNullTerminatedByteSlice(b[:]); err != nil {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push random bytes: %v", err), syserr.FromError(err).ToLinux())
}
+ random := stack.Bottom
c := auth.CredentialsFromContext(ctx)
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 3e85964e4..8c9f11cce 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -242,7 +242,7 @@ type MemoryManager struct {
// +stateify savable
type vma struct {
// mappable is the virtual memory object mapped by this vma. If mappable is
- // nil, the vma represents a private anonymous mapping.
+ // nil, the vma represents an anonymous mapping.
mappable memmap.Mappable
// off is the offset into mappable at which this vma begins. If mappable is
diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go
index fdc308542..acac3d357 100644
--- a/pkg/sentry/mm/mm_test.go
+++ b/pkg/sentry/mm/mm_test.go
@@ -51,7 +51,8 @@ func TestUsageASUpdates(t *testing.T) {
defer mm.DecUsers(ctx)
addr, err := mm.MMap(ctx, memmap.MMapOpts{
- Length: 2 * usermem.PageSize,
+ Length: 2 * usermem.PageSize,
+ Private: true,
})
if err != nil {
t.Fatalf("MMap got err %v want nil", err)
diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go
index f4c93baeb..2dbe5b751 100644
--- a/pkg/sentry/mm/special_mappable.go
+++ b/pkg/sentry/mm/special_mappable.go
@@ -136,9 +136,12 @@ func (m *SpecialMappable) Length() uint64 {
// NewSharedAnonMappable returns a SpecialMappable that implements the
// semantics of mmap(MAP_SHARED|MAP_ANONYMOUS) and mappings of /dev/zero.
//
-// TODO(jamieliu): The use of SpecialMappable is a lazy code reuse hack. Linux
-// uses an ephemeral file created by mm/shmem.c:shmem_zero_setup(); we should
-// do the same to get non-zero device and inode IDs.
+// TODO(gvisor.dev/issue/1624): Linux uses an ephemeral file created by
+// mm/shmem.c:shmem_zero_setup(), and VFS2 does something analogous. VFS1 uses
+// a SpecialMappable instead, incorrectly getting device and inode IDs of zero
+// and causing memory for shared anonymous mappings to be allocated up-front
+// instead of on first touch; this is to avoid exacerbating the fs.MountSource
+// leak (b/143656263). Delete this function along with VFS1.
func NewSharedAnonMappable(length uint64, mfp pgalloc.MemoryFileProvider) (*SpecialMappable, error) {
if length == 0 {
return nil, syserror.EINVAL
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index 4c9a575e7..a2555ba1a 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -93,18 +92,6 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (userme
}
} else {
opts.Offset = 0
- if !opts.Private {
- if opts.MappingIdentity != nil {
- return 0, syserror.EINVAL
- }
- m, err := NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx))
- if err != nil {
- return 0, err
- }
- defer m.DecRef(ctx)
- opts.MappingIdentity = m
- opts.Mappable = m
- }
}
if opts.Addr.RoundDown() != opts.Addr {
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 3970dd81d..323837fb1 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -9,12 +9,12 @@ go_library(
"bluepill.go",
"bluepill_allocator.go",
"bluepill_amd64.go",
- "bluepill_amd64.s",
"bluepill_amd64_unsafe.go",
"bluepill_arm64.go",
"bluepill_arm64.s",
"bluepill_arm64_unsafe.go",
"bluepill_fault.go",
+ "bluepill_impl_amd64.s",
"bluepill_unsafe.go",
"context.go",
"filters_amd64.go",
@@ -81,3 +81,11 @@ go_test(
"//pkg/usermem",
],
)
+
+genrule(
+ name = "bluepill_impl_amd64",
+ srcs = ["bluepill_amd64.s"],
+ outs = ["bluepill_impl_amd64.s"],
+ cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
+)
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s
index 2bc34a435..025ea93b5 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.s
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.s
@@ -19,11 +19,6 @@
// This is guaranteed to be zero.
#define VCPU_CPU 0x0
-// CPU_SELF is the self reference in ring0's percpu.
-//
-// This is guaranteed to be zero.
-#define CPU_SELF 0x0
-
// Context offsets.
//
// Only limited use of the context is done in the assembly stub below, most is
@@ -44,7 +39,7 @@ begin:
LEAQ VCPU_CPU(AX), BX
BYTE CLI;
check_vcpu:
- MOVQ CPU_SELF(GS), CX
+ MOVQ ENTRY_CPU_SELF(GS), CX
CMPQ BX, CX
JE right_vCPU
wrong_vcpu:
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index ae813e24e..d46946402 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -156,15 +156,7 @@ func (*KVM) MaxUserAddress() usermem.Addr {
func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) {
// Allocate page tables and install system mappings.
pageTables := pagetables.New(newAllocator())
- applyPhysicalRegions(func(pr physicalRegion) bool {
- // Map the kernel in the upper half.
- pageTables.Map(
- usermem.Addr(ring0.KernelStartAddress|pr.virtual),
- pr.length,
- pagetables.MapOpts{AccessType: usermem.AnyAccess},
- pr.physical)
- return true // Keep iterating.
- })
+ k.machine.mapUpperHalf(pageTables)
// Return the new address space.
return &addressSpace{
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 372a4cbd7..75da253c5 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -155,7 +155,7 @@ func (m *machine) newVCPU() *vCPU {
fd: int(fd),
machine: m,
}
- c.CPU.Init(&m.kernel, c)
+ c.CPU.Init(&m.kernel, c.id, c)
m.vCPUsByID[c.id] = c
// Ensure the signal mask is correct.
@@ -183,9 +183,6 @@ func newMachine(vm int) (*machine, error) {
// Create the machine.
m := &machine{fd: vm}
m.available.L = &m.mu
- m.kernel.Init(ring0.KernelOpts{
- PageTables: pagetables.New(newAllocator()),
- })
// Pull the maximum vCPUs.
maxVCPUs, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS)
@@ -197,6 +194,9 @@ func newMachine(vm int) (*machine, error) {
log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs)
m.vCPUsByTID = make(map[uint64]*vCPU)
m.vCPUsByID = make([]*vCPU, m.maxVCPUs)
+ m.kernel.Init(ring0.KernelOpts{
+ PageTables: pagetables.New(newAllocator()),
+ }, m.maxVCPUs)
// Pull the maximum slots.
maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS)
@@ -219,15 +219,9 @@ func newMachine(vm int) (*machine, error) {
pagetables.MapOpts{AccessType: usermem.AnyAccess},
pr.physical)
- // And keep everything in the upper half.
- m.kernel.PageTables.Map(
- usermem.Addr(ring0.KernelStartAddress|pr.virtual),
- pr.length,
- pagetables.MapOpts{AccessType: usermem.AnyAccess},
- pr.physical)
-
return true // Keep iterating.
})
+ m.mapUpperHalf(m.kernel.PageTables)
var physicalRegionsReadOnly []physicalRegion
var physicalRegionsAvailable []physicalRegion
@@ -365,6 +359,11 @@ func (m *machine) Destroy() {
// Get gets an available vCPU.
//
// This will return with the OS thread locked.
+//
+// It is guaranteed that if any OS thread TID is in guest, m.vCPUs[TID] points
+// to the vCPU in which the OS thread TID is running. So if Get() returns with
+// the corrent context in guest, the vCPU of it must be the same as what
+// Get() returns.
func (m *machine) Get() *vCPU {
m.mu.RLock()
runtime.LockOSThread()
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index acc823ba6..6849ab113 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -144,6 +144,7 @@ func (c *vCPU) initArchState() error {
// Set the entrypoint for the kernel.
kernelUserRegs.RIP = uint64(reflect.ValueOf(ring0.Start).Pointer())
kernelUserRegs.RAX = uint64(reflect.ValueOf(&c.CPU).Pointer())
+ kernelUserRegs.RSP = c.StackTop()
kernelUserRegs.RFLAGS = ring0.KernelFlagsSet
// Set the system registers.
@@ -345,3 +346,43 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
func availableRegionsForSetMem() (phyRegions []physicalRegion) {
return physicalRegions
}
+
+var execRegions []region
+
+func init() {
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" {
+ return
+ }
+
+ if vr.accessType.Execute {
+ execRegions = append(execRegions, vr.region)
+ }
+ })
+}
+
+func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
+ for _, r := range execRegions {
+ physical, length, ok := translateToPhysical(r.virtual)
+ if !ok || length < r.length {
+ panic("impossilbe translation")
+ }
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|r.virtual),
+ r.length,
+ pagetables.MapOpts{AccessType: usermem.Execute},
+ physical)
+ }
+ for start, end := range m.kernel.EntryRegions() {
+ regionLen := end - start
+ physical, length, ok := translateToPhysical(start)
+ if !ok || length < regionLen {
+ panic("impossible translation")
+ }
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|start),
+ regionLen,
+ pagetables.MapOpts{AccessType: usermem.ReadWrite},
+ physical)
+ }
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 9db171af9..2df762991 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -19,6 +19,7 @@ package kvm
import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -48,6 +49,18 @@ const (
poolPCIDs = 8
)
+func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
+ applyPhysicalRegions(func(pr physicalRegion) bool {
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|pr.virtual),
+ pr.length,
+ pagetables.MapOpts{AccessType: usermem.AnyAccess},
+ pr.physical)
+
+ return true // Keep iterating.
+ })
+}
+
// Get all read-only physicalRegions.
func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
var rdonlyRegions []region
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 905712076..537419657 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -79,7 +79,7 @@ func (c *vCPU) initArchState() error {
}
// tcr_el1
- data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS
+ data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS | _TCR_A1
reg.id = _KVM_ARM64_REGS_TCR_EL1
if err := c.setOneRegister(&reg); err != nil {
return err
@@ -103,7 +103,7 @@ func (c *vCPU) initArchState() error {
c.SetTtbr0Kvm(uintptr(data))
// ttbr1_el1
- data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0)
+ data = c.machine.kernel.PageTables.TTBR1_EL1(false, 1)
reg.id = _KVM_ARM64_REGS_TTBR1_EL1
if err := c.setOneRegister(&reg); err != nil {
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index e04165fbf..fc43cc3c0 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -30,7 +30,6 @@ go_library(
"//pkg/safecopy",
"//pkg/seccomp",
"//pkg/sentry/arch",
- "//pkg/sentry/hostcpu",
"//pkg/sentry/memmap",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
diff --git a/pkg/sentry/platform/ptrace/filters.go b/pkg/sentry/platform/ptrace/filters.go
index 1e07cfd0d..b0970e356 100644
--- a/pkg/sentry/platform/ptrace/filters.go
+++ b/pkg/sentry/platform/ptrace/filters.go
@@ -24,10 +24,9 @@ import (
// SyscallFilters returns syscalls made exclusively by the ptrace platform.
func (*PTrace) SyscallFilters() seccomp.SyscallRules {
return seccomp.SyscallRules{
- unix.SYS_GETCPU: {},
- unix.SYS_SCHED_SETAFFINITY: {},
- syscall.SYS_PTRACE: {},
- syscall.SYS_TGKILL: {},
- syscall.SYS_WAIT4: {},
+ unix.SYS_GETCPU: {},
+ syscall.SYS_PTRACE: {},
+ syscall.SYS_TGKILL: {},
+ syscall.SYS_WAIT4: {},
}
}
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index e1d54d8a2..812ab80ef 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -518,11 +518,6 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
}
defer c.interrupt.Disable()
- // Ensure that the CPU set is bound appropriately; this makes the
- // emulation below several times faster, presumably by avoiding
- // interprocessor wakeups and by simplifying the schedule.
- t.bind()
-
// Set registers.
if err := t.setRegs(regs); err != nil {
panic(fmt.Sprintf("ptrace set regs (%+v) failed: %v", regs, err))
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index 84b699f0d..020bbda79 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -201,7 +201,7 @@ func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFActi
seccomp.RuleSet{
Rules: seccomp.SyscallRules{
syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
- {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)},
+ {seccomp.EqualTo(linux.ARCH_SET_CPUID), seccomp.EqualTo(0)},
},
},
Action: linux.SECCOMP_RET_ALLOW,
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index 2ce528601..8548853da 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -80,9 +80,9 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
Rules: seccomp.SyscallRules{
syscall.SYS_CLONE: []seccomp.Rule{
// Allow creation of new subprocesses (used by the master).
- {seccomp.AllowValue(syscall.CLONE_FILES | syscall.SIGKILL)},
+ {seccomp.EqualTo(syscall.CLONE_FILES | syscall.SIGKILL)},
// Allow creation of new threads within a single address space (used by addresss spaces).
- {seccomp.AllowValue(
+ {seccomp.EqualTo(
syscall.CLONE_FILES |
syscall.CLONE_FS |
syscall.CLONE_SIGHAND |
@@ -97,14 +97,14 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
// For the stub prctl dance (all).
syscall.SYS_PRCTL: []seccomp.Rule{
- {seccomp.AllowValue(syscall.PR_SET_PDEATHSIG), seccomp.AllowValue(syscall.SIGKILL)},
+ {seccomp.EqualTo(syscall.PR_SET_PDEATHSIG), seccomp.EqualTo(syscall.SIGKILL)},
},
syscall.SYS_GETPPID: {},
// For the stub to stop itself (all).
syscall.SYS_GETPID: {},
syscall.SYS_KILL: []seccomp.Rule{
- {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SIGSTOP)},
+ {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SIGSTOP)},
},
// Injected to support the address space operations.
@@ -115,7 +115,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
})
}
rules = appendArchSeccompRules(rules, defaultAction)
- instrs, err := seccomp.BuildProgram(rules, defaultAction)
+ instrs, err := seccomp.BuildProgram(rules, defaultAction, defaultAction)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
index 245b20722..533e45497 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
@@ -18,29 +18,12 @@
package ptrace
import (
- "sync/atomic"
"syscall"
"unsafe"
- "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/hostcpu"
- "gvisor.dev/gvisor/pkg/sync"
)
-// maskPool contains reusable CPU masks for setting affinity. Unfortunately,
-// runtime.NumCPU doesn't actually record the number of CPUs on the system, it
-// just records the number of CPUs available in the scheduler affinity set at
-// startup. This may a) change over time and b) gives a number far lower than
-// the maximum indexable CPU. To prevent lots of allocation in the hot path, we
-// use a pool to store large masks that we can reuse during bind.
-var maskPool = sync.Pool{
- New: func() interface{} {
- const maxCPUs = 1024 // Not a hard limit; see below.
- return make([]uintptr, maxCPUs/64)
- },
-}
-
// unmaskAllSignals unmasks all signals on the current thread.
//
//go:nosplit
@@ -49,47 +32,3 @@ func unmaskAllSignals() syscall.Errno {
_, _, errno := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_SETMASK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0)
return errno
}
-
-// setCPU sets the CPU affinity.
-func (t *thread) setCPU(cpu uint32) error {
- mask := maskPool.Get().([]uintptr)
- n := int(cpu / 64)
- v := uintptr(1 << uintptr(cpu%64))
- if n >= len(mask) {
- // See maskPool note above. We've actually exceeded the number
- // of available cores. Grow the mask and return it.
- mask = make([]uintptr, n+1)
- }
- mask[n] |= v
- if _, _, errno := syscall.RawSyscall(
- unix.SYS_SCHED_SETAFFINITY,
- uintptr(t.tid),
- uintptr(len(mask)*8),
- uintptr(unsafe.Pointer(&mask[0]))); errno != 0 {
- return errno
- }
- mask[n] &^= v
- maskPool.Put(mask)
- return nil
-}
-
-// bind attempts to ensure that the thread is on the same CPU as the current
-// thread. This provides no guarantees as it is fundamentally a racy operation:
-// CPU sets may change and we may be rescheduled in the middle of this
-// operation. As a result, no failures are reported.
-//
-// Precondition: the current runtime thread should be locked.
-func (t *thread) bind() {
- currentCPU := hostcpu.GetCPU()
-
- if oldCPU := atomic.SwapUint32(&t.cpu, currentCPU); oldCPU != currentCPU {
- // Set the affinity on the thread and save the CPU for next
- // round; we don't expect CPUs to bounce around too frequently.
- //
- // (It's worth noting that we could move CPUs between this point
- // and when the tracee finishes executing. But that would be
- // roughly the status quo anyways -- we're just maximizing our
- // chances of colocation, not guaranteeing it.)
- t.setCPU(currentCPU)
- }
-}
diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go
index 9c6c2cf5c..f617519fa 100644
--- a/pkg/sentry/platform/ring0/defs_amd64.go
+++ b/pkg/sentry/platform/ring0/defs_amd64.go
@@ -76,15 +76,42 @@ type KernelOpts struct {
type KernelArchState struct {
KernelOpts
+ // cpuEntries is array of kernelEntry for all cpus
+ cpuEntries []kernelEntry
+
// globalIDT is our set of interrupt gates.
- globalIDT idt64
+ globalIDT *idt64
}
-// CPUArchState contains CPU-specific arch state.
-type CPUArchState struct {
+// kernelEntry contains minimal CPU-specific arch state
+// that can be mapped at the upper of the address space.
+// Malicious APP might steal info from it via CPU bugs.
+type kernelEntry struct {
// stack is the stack used for interrupts on this CPU.
stack [256]byte
+ // scratch space for temporary usage.
+ scratch0 uint64
+ scratch1 uint64
+
+ // stackTop is the top of the stack.
+ stackTop uint64
+
+ // cpuSelf is back reference to CPU.
+ cpuSelf *CPU
+
+ // kernelCR3 is the cr3 used for sentry kernel.
+ kernelCR3 uintptr
+
+ // gdt is the CPU's descriptor table.
+ gdt descriptorTable
+
+ // tss is the CPU's task state.
+ tss TaskState64
+}
+
+// CPUArchState contains CPU-specific arch state.
+type CPUArchState struct {
// errorCode is the error code from the last exception.
errorCode uintptr
@@ -97,11 +124,7 @@ type CPUArchState struct {
// exception.
errorType uintptr
- // gdt is the CPU's descriptor table.
- gdt descriptorTable
-
- // tss is the CPU's task state.
- tss TaskState64
+ *kernelEntry
}
// ErrorCode returns the last error code.
diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go
index 0e2ab716c..508236e46 100644
--- a/pkg/sentry/platform/ring0/defs_arm64.go
+++ b/pkg/sentry/platform/ring0/defs_arm64.go
@@ -77,6 +77,9 @@ type CPUArchState struct {
// lazyVFP is the value of cpacr_el1.
lazyVFP uintptr
+
+ // appASID is the asid value of guest application.
+ appASID uintptr
}
// ErrorCode returns the last error code.
diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/sentry/platform/ring0/entry_amd64.go
index 7fa43c2f5..d87b1fd00 100644
--- a/pkg/sentry/platform/ring0/entry_amd64.go
+++ b/pkg/sentry/platform/ring0/entry_amd64.go
@@ -36,12 +36,15 @@ func sysenter()
// This must be called prior to sysret/iret.
func swapgs()
+// jumpToKernel jumps to the kernel version of the current RIP.
+func jumpToKernel()
+
// sysret returns to userspace from a system call.
//
// The return code is the vector that interrupted execution.
//
// See stubs.go for a note regarding the frame size of this function.
-func sysret(*CPU, *arch.Registers) Vector
+func sysret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector
// "iret is the cadillac of CPL switching."
//
@@ -50,7 +53,7 @@ func sysret(*CPU, *arch.Registers) Vector
// iret is nearly identical to sysret, except an iret is used to fully restore
// all user state. This must be called in cases where all registers need to be
// restored.
-func iret(*CPU, *arch.Registers) Vector
+func iret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector
// exception is the generic exception entry.
//
diff --git a/pkg/sentry/platform/ring0/entry_amd64.s b/pkg/sentry/platform/ring0/entry_amd64.s
index 02df38331..82689a194 100644
--- a/pkg/sentry/platform/ring0/entry_amd64.s
+++ b/pkg/sentry/platform/ring0/entry_amd64.s
@@ -63,6 +63,15 @@
MOVQ offset+PTRACE_RSI(reg), SI; \
MOVQ offset+PTRACE_RDI(reg), DI;
+// WRITE_CR3() writes the given CR3 value.
+//
+// The code corresponds to:
+//
+// mov %rax, %cr3
+//
+#define WRITE_CR3() \
+ BYTE $0x0f; BYTE $0x22; BYTE $0xd8;
+
// SWAP_GS swaps the kernel GS (CPU).
#define SWAP_GS() \
BYTE $0x0F; BYTE $0x01; BYTE $0xf8;
@@ -75,15 +84,9 @@
#define SYSRET64() \
BYTE $0x48; BYTE $0x0f; BYTE $0x07;
-// LOAD_KERNEL_ADDRESS loads a kernel address.
-#define LOAD_KERNEL_ADDRESS(from, to) \
- MOVQ from, to; \
- ORQ ·KernelStartAddress(SB), to;
-
// LOAD_KERNEL_STACK loads the kernel stack.
-#define LOAD_KERNEL_STACK(from) \
- LOAD_KERNEL_ADDRESS(CPU_SELF(from), SP); \
- LEAQ CPU_STACK_TOP(SP), SP;
+#define LOAD_KERNEL_STACK(entry) \
+ MOVQ ENTRY_STACK_TOP(entry), SP;
// See kernel.go.
TEXT ·Halt(SB),NOSPLIT,$0
@@ -95,58 +98,93 @@ TEXT ·swapgs(SB),NOSPLIT,$0
SWAP_GS()
RET
+// jumpToKernel changes execution to the kernel address space.
+//
+// This works by changing the return value to the kernel version.
+TEXT ·jumpToKernel(SB),NOSPLIT,$0
+ MOVQ 0(SP), AX
+ ORQ ·KernelStartAddress(SB), AX // Future return value.
+ MOVQ AX, 0(SP)
+ RET
+
// See entry_amd64.go.
TEXT ·sysret(SB),NOSPLIT,$0-24
- // Save original state.
- LOAD_KERNEL_ADDRESS(cpu+0(FP), BX)
- LOAD_KERNEL_ADDRESS(regs+8(FP), AX)
+ CALL ·jumpToKernel(SB)
+ // Save original state and stack. sysenter() or exception()
+ // from APP(gr3) will switch to this stack, set the return
+ // value (vector: 32(SP)) and then do RET, which will also
+ // automatically return to the lower half.
+ MOVQ cpu+0(FP), BX
+ MOVQ regs+8(FP), AX
+ MOVQ userCR3+16(FP), CX
MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX)
MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX)
MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX)
+ // save SP AX userCR3 on the kernel stack.
+ MOVQ CPU_ENTRY(BX), BX
+ LOAD_KERNEL_STACK(BX)
+ PUSHQ PTRACE_RSP(AX)
+ PUSHQ PTRACE_RAX(AX)
+ PUSHQ CX
+
// Restore user register state.
REGISTERS_LOAD(AX, 0)
MOVQ PTRACE_RIP(AX), CX // Needed for SYSRET.
MOVQ PTRACE_FLAGS(AX), R11 // Needed for SYSRET.
- MOVQ PTRACE_RSP(AX), SP // Restore the stack directly.
- MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch).
+
+ // restore userCR3, AX, SP.
+ POPQ AX // Get userCR3.
+ WRITE_CR3() // Switch to userCR3.
+ POPQ AX // Restore AX.
+ POPQ SP // Restore SP.
SYSRET64()
// See entry_amd64.go.
TEXT ·iret(SB),NOSPLIT,$0-24
- // Save original state.
- LOAD_KERNEL_ADDRESS(cpu+0(FP), BX)
- LOAD_KERNEL_ADDRESS(regs+8(FP), AX)
+ CALL ·jumpToKernel(SB)
+ // Save original state and stack. sysenter() or exception()
+ // from APP(gr3) will switch to this stack, set the return
+ // value (vector: 32(SP)) and then do RET, which will also
+ // automatically return to the lower half.
+ MOVQ cpu+0(FP), BX
+ MOVQ regs+8(FP), AX
+ MOVQ userCR3+16(FP), CX
MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX)
MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX)
MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX)
// Build an IRET frame & restore state.
+ MOVQ CPU_ENTRY(BX), BX
LOAD_KERNEL_STACK(BX)
- MOVQ PTRACE_SS(AX), BX; PUSHQ BX
- MOVQ PTRACE_RSP(AX), CX; PUSHQ CX
- MOVQ PTRACE_FLAGS(AX), DX; PUSHQ DX
- MOVQ PTRACE_CS(AX), DI; PUSHQ DI
- MOVQ PTRACE_RIP(AX), SI; PUSHQ SI
- REGISTERS_LOAD(AX, 0) // Restore most registers.
- MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch).
+ PUSHQ PTRACE_SS(AX)
+ PUSHQ PTRACE_RSP(AX)
+ PUSHQ PTRACE_FLAGS(AX)
+ PUSHQ PTRACE_CS(AX)
+ PUSHQ PTRACE_RIP(AX)
+ PUSHQ PTRACE_RAX(AX) // Save AX on kernel stack.
+ PUSHQ CX // Save userCR3 on kernel stack.
+ REGISTERS_LOAD(AX, 0) // Restore most registers.
+ POPQ AX // Get userCR3.
+ WRITE_CR3() // Switch to userCR3.
+ POPQ AX // Restore AX.
IRET()
// See entry_amd64.go.
TEXT ·resume(SB),NOSPLIT,$0
// See iret, above.
- MOVQ CPU_REGISTERS+PTRACE_SS(GS), BX; PUSHQ BX
- MOVQ CPU_REGISTERS+PTRACE_RSP(GS), CX; PUSHQ CX
- MOVQ CPU_REGISTERS+PTRACE_FLAGS(GS), DX; PUSHQ DX
- MOVQ CPU_REGISTERS+PTRACE_CS(GS), DI; PUSHQ DI
- MOVQ CPU_REGISTERS+PTRACE_RIP(GS), SI; PUSHQ SI
- REGISTERS_LOAD(GS, CPU_REGISTERS)
- MOVQ CPU_REGISTERS+PTRACE_RAX(GS), AX
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ PUSHQ CPU_REGISTERS+PTRACE_SS(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_RSP(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_FLAGS(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_CS(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_RIP(AX)
+ REGISTERS_LOAD(AX, CPU_REGISTERS)
+ MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX
IRET()
// See entry_amd64.go.
TEXT ·Start(SB),NOSPLIT,$0
- LOAD_KERNEL_STACK(AX) // Set the stack.
PUSHQ $0x0 // Previous frame pointer.
MOVQ SP, BP // Set frame pointer.
PUSHQ AX // First argument (CPU).
@@ -164,44 +202,54 @@ TEXT ·sysenter(SB),NOSPLIT,$0
user:
SWAP_GS()
- XCHGQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Swap stacks.
- XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for AX (regs).
+ MOVQ AX, ENTRY_SCRATCH0(GS) // Save user AX on scratch.
+ MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX.
+ WRITE_CR3() // Switch to kernel cr3.
+
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs.
REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX.
- MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Load saved AX value.
- MOVQ BX, PTRACE_RAX(AX) // Save everything else.
- MOVQ BX, PTRACE_ORIGRAX(AX)
MOVQ CX, PTRACE_RIP(AX)
MOVQ R11, PTRACE_FLAGS(AX)
- MOVQ CPU_REGISTERS+PTRACE_RSP(GS), BX; MOVQ BX, PTRACE_RSP(AX)
- MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code.
- MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user.
+ MOVQ SP, PTRACE_RSP(AX)
+ MOVQ ENTRY_SCRATCH0(GS), CX // Load saved user AX value.
+ MOVQ CX, PTRACE_RAX(AX) // Save everything else.
+ MOVQ CX, PTRACE_ORIGRAX(AX)
+
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Get stacks.
+ MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code.
+ MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user.
// Return to the kernel, where the frame is:
//
- // vector (sp+24)
+ // vector (sp+32)
+ // userCR3 (sp+24)
// regs (sp+16)
// cpu (sp+8)
// vcpu.Switch (sp+0)
//
- MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer.
- MOVQ $Syscall, 24(SP) // Output vector.
+ MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer.
+ MOVQ $Syscall, 32(SP) // Output vector.
RET
kernel:
// We can't restore the original stack, but we can access the registers
// in the CPU state directly. No need for temporary juggling.
- MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS)
- MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS)
- REGISTERS_SAVE(GS, CPU_REGISTERS)
- MOVQ CX, CPU_REGISTERS+PTRACE_RIP(GS)
- MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(GS)
- MOVQ SP, CPU_REGISTERS+PTRACE_RSP(GS)
- MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code.
- MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel.
+ MOVQ AX, ENTRY_SCRATCH0(GS)
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ REGISTERS_SAVE(AX, CPU_REGISTERS)
+ MOVQ CX, CPU_REGISTERS+PTRACE_RIP(AX)
+ MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(AX)
+ MOVQ SP, CPU_REGISTERS+PTRACE_RSP(AX)
+ MOVQ ENTRY_SCRATCH0(GS), BX
+ MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX)
+ MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX)
+ MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code.
+ MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel.
// Call the syscall trampoline.
LOAD_KERNEL_STACK(GS)
- MOVQ CPU_SELF(GS), AX // Load vCPU.
PUSHQ AX // First argument (vCPU).
CALL ·kernelSyscall(SB) // Call the trampoline.
POPQ AX // Pop vCPU.
@@ -237,9 +285,14 @@ user:
SWAP_GS()
ADDQ $-8, SP // Adjust for flags.
MOVQ $_KERNEL_FLAGS, 0(SP); BYTE $0x9d; // Reset flags (POPFQ).
- XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for user regs.
+ PUSHQ AX // Save user AX on stack.
+ MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX.
+ WRITE_CR3() // Switch to kernel cr3.
+
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs.
REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX.
- MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Restore original AX.
+ POPQ BX // Restore original AX.
MOVQ BX, PTRACE_RAX(AX) // Save it.
MOVQ BX, PTRACE_ORIGRAX(AX)
MOVQ 16(SP), BX; MOVQ BX, PTRACE_RIP(AX)
@@ -249,34 +302,36 @@ user:
MOVQ 48(SP), SI; MOVQ SI, PTRACE_SS(AX)
// Copy out and return.
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
MOVQ 0(SP), BX // Load vector.
MOVQ 8(SP), CX // Load error code.
- MOVQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Original stack (kernel version).
- MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer.
- MOVQ CX, CPU_ERROR_CODE(GS) // Set error code.
- MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user.
- MOVQ BX, 24(SP) // Output vector.
+ MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Original stack (kernel version).
+ MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer.
+ MOVQ CX, CPU_ERROR_CODE(AX) // Set error code.
+ MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user.
+ MOVQ BX, 32(SP) // Output vector.
RET
kernel:
// As per above, we can save directly.
- MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS)
- MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS)
- REGISTERS_SAVE(GS, CPU_REGISTERS)
- MOVQ 16(SP), AX; MOVQ AX, CPU_REGISTERS+PTRACE_RIP(GS)
- MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(GS)
- MOVQ 40(SP), CX; MOVQ CX, CPU_REGISTERS+PTRACE_RSP(GS)
+ PUSHQ AX
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ REGISTERS_SAVE(AX, CPU_REGISTERS)
+ POPQ BX
+ MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX)
+ MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX)
+ MOVQ 16(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RIP(AX)
+ MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(AX)
+ MOVQ 40(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RSP(AX)
// Set the error code and adjust the stack.
- MOVQ 8(SP), AX // Load the error code.
- MOVQ AX, CPU_ERROR_CODE(GS) // Copy out to the CPU.
- MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel.
+ MOVQ 8(SP), BX // Load the error code.
+ MOVQ BX, CPU_ERROR_CODE(AX) // Copy out to the CPU.
+ MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel.
MOVQ 0(SP), BX // BX contains the vector.
- ADDQ $48, SP // Drop the exception frame.
// Call the exception trampoline.
LOAD_KERNEL_STACK(GS)
- MOVQ CPU_SELF(GS), AX // Load vCPU.
PUSHQ BX // Second argument (vector).
PUSHQ AX // First argument (vCPU).
CALL ·kernelException(SB) // Call the trampoline.
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 9d29b7168..5f63cbd45 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -27,7 +27,9 @@
// ERET returns using the ELR and SPSR for the current exception level.
#define ERET() \
- WORD $0xd69f03e0
+ WORD $0xd69f03e0; \
+ DSB $7; \
+ ISB $15;
// RSV_REG is a register that holds el1 information temporarily.
#define RSV_REG R18_PLATFORM
@@ -300,17 +302,23 @@
// SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application.
#define SWITCH_TO_APP_PAGETABLE(from) \
- MOVD CPU_TTBR0_APP(from), RSV_REG; \
- WORD $0xd5182012; \ // MSR R18, TTBR0_EL1
+ MRS TTBR1_EL1, R0; \
+ MOVD CPU_APP_ASID(from), R1; \
+ BFI $48, R1, $16, R0; \
+ MSR R0, TTBR1_EL1; \ // set the ASID in TTBR1_EL1 (since TCR.A1 is set)
ISB $15; \
- DSB $15;
+ MOVD CPU_TTBR0_APP(from), RSV_REG; \
+ MSR RSV_REG, TTBR0_EL1;
// SWITCH_TO_KVM_PAGETABLE sets the kvm pagetable.
#define SWITCH_TO_KVM_PAGETABLE(from) \
- MOVD CPU_TTBR0_KVM(from), RSV_REG; \
- WORD $0xd5182012; \ // MSR R18, TTBR0_EL1
+ MRS TTBR1_EL1, R0; \
+ MOVD $1, R1; \
+ BFI $48, R1, $16, R0; \
+ MSR R0, TTBR1_EL1; \
ISB $15; \
- DSB $15;
+ MOVD CPU_TTBR0_KVM(from), RSV_REG; \
+ MSR RSV_REG, TTBR0_EL1;
#define VFP_ENABLE \
MOVD $FPEN_ENABLE, R0; \
@@ -326,23 +334,20 @@
#define KERNEL_ENTRY_FROM_EL0 \
SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack.
STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \
- WORD $0xd538d092; \ //MRS TPIDR_EL1, R18, step2, switch user pagetable.
- SWITCH_TO_KVM_PAGETABLE(RSV_REG); \
- WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
- MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step3, load app context pointer.
- REGISTERS_SAVE(RSV_REG_APP, 0); \ // step4, save app context.
+ WORD $0xd538d092; \ // MRS TPIDR_EL1, R18
+ MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step2, load app context pointer.
+ REGISTERS_SAVE(RSV_REG_APP, 0); \ // step3, save app context.
MOVD RSV_REG_APP, R20; \
LDP 16*0(RSP), (RSV_REG, RSV_REG_APP); \
ADD $16, RSP, RSP; \
MOVD RSV_REG, PTRACE_R18(R20); \
MOVD RSV_REG_APP, PTRACE_R9(R20); \
- MOVD R20, RSV_REG_APP; \
WORD $0xd5384003; \ // MRS SPSR_EL1, R3
- MOVD R3, PTRACE_PSTATE(RSV_REG_APP); \
+ MOVD R3, PTRACE_PSTATE(R20); \
MRS ELR_EL1, R3; \
- MOVD R3, PTRACE_PC(RSV_REG_APP); \
+ MOVD R3, PTRACE_PC(R20); \
WORD $0xd5384103; \ // MRS SP_EL0, R3
- MOVD R3, PTRACE_SP(RSV_REG_APP);
+ MOVD R3, PTRACE_SP(R20);
// KERNEL_ENTRY_FROM_EL1 is the entry code of the vcpu from el1 to el1.
#define KERNEL_ENTRY_FROM_EL1 \
@@ -357,6 +362,13 @@
MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \
LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack.
+// storeAppASID writes the application's asid value.
+TEXT ·storeAppASID(SB),NOSPLIT,$0-8
+ MOVD asid+0(FP), R1
+ MRS TPIDR_EL1, RSV_REG
+ MOVD R1, CPU_APP_ASID(RSV_REG)
+ RET
+
// Halt halts execution.
TEXT ·Halt(SB),NOSPLIT,$0
// Clear bluepill.
@@ -414,7 +426,7 @@ TEXT ·Current(SB),NOSPLIT,$0-8
MOVD R8, ret+0(FP)
RET
-#define STACK_FRAME_SIZE 16
+#define STACK_FRAME_SIZE 32
// kernelExitToEl0 is the entrypoint for application in guest_el0.
// Prepare the vcpu environment for container application.
@@ -458,15 +470,16 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
SUB $STACK_FRAME_SIZE, RSP, RSP
STP (RSV_REG, RSV_REG_APP), 16*0(RSP)
+ STP (R0, R1), 16*1(RSP)
WORD $0xd538d092 //MRS TPIDR_EL1, R18
SWITCH_TO_APP_PAGETABLE(RSV_REG)
+ LDP 16*1(RSP), (R0, R1)
LDP 16*0(RSP), (RSV_REG, RSV_REG_APP)
ADD $STACK_FRAME_SIZE, RSP, RSP
- ISB $15
ERET()
// kernelExitToEl1 is the entrypoint for sentry in guest_el1.
@@ -482,6 +495,9 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1
MOVD R1, RSP
+ SWITCH_TO_KVM_PAGETABLE(RSV_REG)
+ MRS TPIDR_EL1, RSV_REG
+
REGISTERS_LOAD(RSV_REG, CPU_REGISTERS)
MOVD CPU_REGISTERS+PTRACE_R9(RSV_REG), RSV_REG_APP
diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD
index 549f3d228..9742308d8 100644
--- a/pkg/sentry/platform/ring0/gen_offsets/BUILD
+++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD
@@ -24,7 +24,10 @@ go_binary(
"defs_impl_arm64.go",
"main.go",
],
- visibility = ["//pkg/sentry/platform/ring0:__pkg__"],
+ visibility = [
+ "//pkg/sentry/platform/kvm:__pkg__",
+ "//pkg/sentry/platform/ring0:__pkg__",
+ ],
deps = [
"//pkg/cpuid",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go
index 021693791..264be23d3 100644
--- a/pkg/sentry/platform/ring0/kernel.go
+++ b/pkg/sentry/platform/ring0/kernel.go
@@ -19,8 +19,8 @@ package ring0
// N.B. that constraints on KernelOpts must be satisfied.
//
//go:nosplit
-func (k *Kernel) Init(opts KernelOpts) {
- k.init(opts)
+func (k *Kernel) Init(opts KernelOpts, maxCPUs int) {
+ k.init(opts, maxCPUs)
}
// Halt halts execution.
@@ -49,6 +49,11 @@ func (defaultHooks) KernelException(Vector) {
// kernelSyscall is a trampoline.
//
+// When in amd64, it is called with %rip on the upper half, so it can
+// NOT access to any global data which is not mapped on upper and must
+// call to function pointers or interfaces to switch to the lower half
+// so that callee can access to global data.
+//
// +checkescape:hard,stack
//
//go:nosplit
@@ -58,6 +63,11 @@ func kernelSyscall(c *CPU) {
// kernelException is a trampoline.
//
+// When in amd64, it is called with %rip on the upper half, so it can
+// NOT access to any global data which is not mapped on upper and must
+// call to function pointers or interfaces to switch to the lower half
+// so that callee can access to global data.
+//
// +checkescape:hard,stack
//
//go:nosplit
@@ -68,10 +78,10 @@ func kernelException(c *CPU, vector Vector) {
// Init initializes a new CPU.
//
// Init allows embedding in other objects.
-func (c *CPU) Init(k *Kernel, hooks Hooks) {
- c.self = c // Set self reference.
- c.kernel = k // Set kernel reference.
- c.init() // Perform architectural init.
+func (c *CPU) Init(k *Kernel, cpuID int, hooks Hooks) {
+ c.self = c // Set self reference.
+ c.kernel = k // Set kernel reference.
+ c.init(cpuID) // Perform architectural init.
// Require hooks.
if hooks != nil {
diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go
index d37981dbf..3a9dff4cc 100644
--- a/pkg/sentry/platform/ring0/kernel_amd64.go
+++ b/pkg/sentry/platform/ring0/kernel_amd64.go
@@ -18,13 +18,42 @@ package ring0
import (
"encoding/binary"
+ "reflect"
+
+ "gvisor.dev/gvisor/pkg/usermem"
)
// init initializes architecture-specific state.
-func (k *Kernel) init(opts KernelOpts) {
+func (k *Kernel) init(opts KernelOpts, maxCPUs int) {
// Save the root page tables.
k.PageTables = opts.PageTables
+ entrySize := reflect.TypeOf(kernelEntry{}).Size()
+ var (
+ entries []kernelEntry
+ padding = 1
+ )
+ for {
+ entries = make([]kernelEntry, maxCPUs+padding-1)
+ totalSize := entrySize * uintptr(maxCPUs+padding-1)
+ addr := reflect.ValueOf(&entries[0]).Pointer()
+ if addr&(usermem.PageSize-1) == 0 && totalSize >= usermem.PageSize {
+ // The runtime forces power-of-2 alignment for allocations, and we are therefore
+ // safe once the first address is aligned and the chunk is at least a full page.
+ break
+ }
+ padding = padding << 1
+ }
+ k.cpuEntries = entries
+
+ k.globalIDT = &idt64{}
+ if reflect.TypeOf(idt64{}).Size() != usermem.PageSize {
+ panic("Size of globalIDT should be PageSize")
+ }
+ if reflect.ValueOf(k.globalIDT).Pointer()&(usermem.PageSize-1) != 0 {
+ panic("Allocated globalIDT should be page aligned")
+ }
+
// Setup the IDT, which is uniform.
for v, handler := range handlers {
// Allow Breakpoint and Overflow to be called from all
@@ -39,8 +68,26 @@ func (k *Kernel) init(opts KernelOpts) {
}
}
+func (k *Kernel) EntryRegions() map[uintptr]uintptr {
+ regions := make(map[uintptr]uintptr)
+
+ addr := reflect.ValueOf(&k.cpuEntries[0]).Pointer()
+ size := reflect.TypeOf(kernelEntry{}).Size() * uintptr(len(k.cpuEntries))
+ end, _ := usermem.Addr(addr + size).RoundUp()
+ regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end)
+
+ addr = reflect.ValueOf(k.globalIDT).Pointer()
+ size = reflect.TypeOf(idt64{}).Size()
+ end, _ = usermem.Addr(addr + size).RoundUp()
+ regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end)
+
+ return regions
+}
+
// init initializes architecture-specific state.
-func (c *CPU) init() {
+func (c *CPU) init(cpuID int) {
+ c.kernelEntry = &c.kernel.cpuEntries[cpuID]
+ c.cpuSelf = c
// Null segment.
c.gdt[0].setNull()
@@ -65,6 +112,7 @@ func (c *CPU) init() {
// Set the kernel stack pointer in the TSS (virtual address).
stackAddr := c.StackTop()
+ c.stackTop = stackAddr
c.tss.rsp0Lo = uint32(stackAddr)
c.tss.rsp0Hi = uint32(stackAddr >> 32)
c.tss.ist1Lo = uint32(stackAddr)
@@ -183,7 +231,7 @@ func IsCanonical(addr uint64) bool {
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID)
- kernelCR3 := c.kernel.PageTables.CR3(true, switchOpts.KernelPCID)
+ c.kernelCR3 = uintptr(c.kernel.PageTables.CR3(true, switchOpts.KernelPCID))
// Sanitize registers.
regs := switchOpts.Registers
@@ -197,15 +245,11 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS.
WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point.
- jumpToKernel() // Switch to upper half.
- writeCR3(uintptr(userCR3)) // Change to user address space.
if switchOpts.FullRestore {
- vector = iret(c, regs)
+ vector = iret(c, regs, uintptr(userCR3))
} else {
- vector = sysret(c, regs)
+ vector = sysret(c, regs, uintptr(userCR3))
}
- writeCR3(uintptr(kernelCR3)) // Return to kernel address space.
- jumpToUser() // Return to lower half.
SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point.
WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
return
@@ -219,7 +263,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
//go:nosplit
func start(c *CPU) {
// Save per-cpu & FS segment.
- WriteGS(kernelAddr(c))
+ WriteGS(kernelAddr(c.kernelEntry))
WriteFS(uintptr(c.registers.Fs_base))
// Initialize floating point.
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index d0afa1aaa..0ca98a7c7 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -25,13 +25,13 @@ func HaltAndResume()
func HaltEl1SvcAndResume()
// init initializes architecture-specific state.
-func (k *Kernel) init(opts KernelOpts) {
+func (k *Kernel) init(opts KernelOpts, maxCPUs int) {
// Save the root page tables.
k.PageTables = opts.PageTables
}
// init initializes architecture-specific state.
-func (c *CPU) init() {
+func (c *CPU) init(cpuID int) {
// Set the kernel stack pointer(virtual address).
c.registers.Sp = uint64(c.StackTop())
@@ -53,6 +53,11 @@ func IsCanonical(addr uint64) bool {
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
+ storeAppASID(uintptr(switchOpts.UserASID))
+ if switchOpts.Flush {
+ FlushTlbAll()
+ }
+
regs := switchOpts.Registers
regs.Pstate &= ^uint64(PsrFlagsClear)
diff --git a/pkg/sentry/platform/ring0/lib_amd64.go b/pkg/sentry/platform/ring0/lib_amd64.go
index ca968a036..0ec5c3bc5 100644
--- a/pkg/sentry/platform/ring0/lib_amd64.go
+++ b/pkg/sentry/platform/ring0/lib_amd64.go
@@ -61,21 +61,9 @@ func wrgsbase(addr uintptr)
// wrgsmsr writes to the GS_BASE MSR.
func wrgsmsr(addr uintptr)
-// writeCR3 writes the CR3 value.
-func writeCR3(phys uintptr)
-
-// readCR3 reads the current CR3 value.
-func readCR3() uintptr
-
// readCR2 reads the current CR2 value.
func readCR2() uintptr
-// jumpToKernel jumps to the kernel version of the current RIP.
-func jumpToKernel()
-
-// jumpToUser jumps to the user version of the current RIP.
-func jumpToUser()
-
// fninit initializes the floating point unit.
func fninit()
diff --git a/pkg/sentry/platform/ring0/lib_amd64.s b/pkg/sentry/platform/ring0/lib_amd64.s
index 75d742750..2fe83568a 100644
--- a/pkg/sentry/platform/ring0/lib_amd64.s
+++ b/pkg/sentry/platform/ring0/lib_amd64.s
@@ -127,53 +127,6 @@ TEXT ·wrgsmsr(SB),NOSPLIT,$0-8
BYTE $0x0f; BYTE $0x30; // WRMSR
RET
-// jumpToUser changes execution to the user address.
-//
-// This works by changing the return value to the user version.
-TEXT ·jumpToUser(SB),NOSPLIT,$0
- MOVQ 0(SP), AX
- MOVQ ·KernelStartAddress(SB), BX
- NOTQ BX
- ANDQ BX, SP // Switch the stack.
- ANDQ BX, BP // Switch the frame pointer.
- ANDQ BX, AX // Future return value.
- MOVQ AX, 0(SP)
- RET
-
-// jumpToKernel changes execution to the kernel address space.
-//
-// This works by changing the return value to the kernel version.
-TEXT ·jumpToKernel(SB),NOSPLIT,$0
- MOVQ 0(SP), AX
- MOVQ ·KernelStartAddress(SB), BX
- ORQ BX, SP // Switch the stack.
- ORQ BX, BP // Switch the frame pointer.
- ORQ BX, AX // Future return value.
- MOVQ AX, 0(SP)
- RET
-
-// writeCR3 writes the given CR3 value.
-//
-// The code corresponds to:
-//
-// mov %rax, %cr3
-//
-TEXT ·writeCR3(SB),NOSPLIT,$0-8
- MOVQ cr3+0(FP), AX
- BYTE $0x0f; BYTE $0x22; BYTE $0xd8;
- RET
-
-// readCR3 reads the current CR3 value.
-//
-// The code corresponds to:
-//
-// mov %cr3, %rax
-//
-TEXT ·readCR3(SB),NOSPLIT,$0-8
- BYTE $0x0f; BYTE $0x20; BYTE $0xd8;
- MOVQ AX, ret+0(FP)
- RET
-
// readCR2 reads the current CR2 value.
//
// The code corresponds to:
diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go
index 00e52c8af..2f1abcb0f 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.go
+++ b/pkg/sentry/platform/ring0/lib_arm64.go
@@ -16,6 +16,15 @@
package ring0
+// storeAppASID writes the application's asid value.
+func storeAppASID(asid uintptr)
+
+// LocalFlushTlbAll same as FlushTlbAll, but only applies to the calling CPU.
+func LocalFlushTlbAll()
+
+// FlushTlbAll flush all tlb.
+func FlushTlbAll()
+
// CPACREL1 returns the value of the CPACR_EL1 register.
func CPACREL1() (value uintptr)
diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s
index 86bfbe46f..8aabf7d0e 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.s
+++ b/pkg/sentry/platform/ring0/lib_arm64.s
@@ -15,6 +15,20 @@
#include "funcdata.h"
#include "textflag.h"
+TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0
+ DSB $6 // dsb(nshst)
+ WORD $0xd508871f // __tlbi(vmalle1)
+ DSB $7 // dsb(nsh)
+ ISB $15
+ RET
+
+TEXT ·FlushTlbAll(SB),NOSPLIT,$0
+ DSB $10 // dsb(ishst)
+ WORD $0xd508831f // __tlbi(vmalle1is)
+ DSB $11 // dsb(ish)
+ ISB $15
+ RET
+
TEXT ·GetTLS(SB),NOSPLIT,$0-8
MRS TPIDR_EL0, R1
MOVD R1, ret+0(FP)
diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go
index b8ab120a0..bcb73cb31 100644
--- a/pkg/sentry/platform/ring0/offsets_amd64.go
+++ b/pkg/sentry/platform/ring0/offsets_amd64.go
@@ -30,11 +30,18 @@ func Emit(w io.Writer) {
c := &CPU{}
fmt.Fprintf(w, "\n// CPU offsets.\n")
- fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer())
- fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack)))
fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer())
+
+ e := &kernelEntry{}
+ fmt.Fprintf(w, "\n// CPU entry offsets.\n")
+ fmt.Fprintf(w, "#define ENTRY_SCRATCH0 0x%02x\n", reflect.ValueOf(&e.scratch0).Pointer()-reflect.ValueOf(e).Pointer())
+ fmt.Fprintf(w, "#define ENTRY_SCRATCH1 0x%02x\n", reflect.ValueOf(&e.scratch1).Pointer()-reflect.ValueOf(e).Pointer())
+ fmt.Fprintf(w, "#define ENTRY_STACK_TOP 0x%02x\n", reflect.ValueOf(&e.stackTop).Pointer()-reflect.ValueOf(e).Pointer())
+ fmt.Fprintf(w, "#define ENTRY_CPU_SELF 0x%02x\n", reflect.ValueOf(&e.cpuSelf).Pointer()-reflect.ValueOf(e).Pointer())
+ fmt.Fprintf(w, "#define ENTRY_KERNEL_CR3 0x%02x\n", reflect.ValueOf(&e.kernelCR3).Pointer()-reflect.ValueOf(e).Pointer())
fmt.Fprintf(w, "\n// Bits.\n")
fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF)
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
index f3de962f0..1d86b4bcf 100644
--- a/pkg/sentry/platform/ring0/offsets_arm64.go
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -41,6 +41,7 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define CPU_VECTOR_CODE 0x%02x\n", reflect.ValueOf(&c.vecCode).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_APP_ADDR 0x%02x\n", reflect.ValueOf(&c.appAddr).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_LAZY_VFP 0x%02x\n", reflect.ValueOf(&c.lazyVFP).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_APP_ASID 0x%02x\n", reflect.ValueOf(&c.appASID).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "\n// Bits.\n")
fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet)
diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go
index 9da0ea685..e99da0b35 100644
--- a/pkg/sentry/platform/ring0/x86.go
+++ b/pkg/sentry/platform/ring0/x86.go
@@ -104,7 +104,7 @@ const (
VirtualizationException
SecurityException = 0x1e
SyscallInt80 = 0x80
- _NR_INTERRUPTS = SyscallInt80 + 1
+ _NR_INTERRUPTS = 0x100
)
// System call vectors.
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index c0fd3425b..a3f775d15 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/binary",
"//pkg/context",
+ "//pkg/marshal",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
@@ -20,6 +21,5 @@ go_library(
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/usermem",
- "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index e76e498de..b6ebe29d6 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -21,6 +21,8 @@ go_library(
"//pkg/context",
"//pkg/fdnotifier",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
@@ -37,11 +39,12 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 242e6bf76..7d3c4a01c 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -24,6 +24,8 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -36,8 +38,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const (
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
index 8a1d52ebf..163af329b 100644
--- a/pkg/sentry/socket/hostinet/socket_vfs2.go
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -52,6 +52,7 @@ var _ = socket.SocketVFS2(&socketVFS2{})
func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) {
mnt := t.Kernel().SocketMount()
d := sockfs.NewDentry(t.Credentials(), mnt)
+ defer d.DecRef(t)
s := &socketVFS2{
socketOpsCommon: socketOpsCommon{
@@ -77,6 +78,13 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in
return vfsfd, nil
}
+// Release implements vfs.FileDescriptionImpl.Release.
+func (s *socketVFS2) Release(ctx context.Context) {
+ t := kernel.TaskFromContext(ctx)
+ t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ s.socketOpsCommon.Release(ctx)
+}
+
// Readiness implements waiter.Waitable.Readiness.
func (s *socketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.socketOpsCommon.Readiness(mask)
@@ -97,11 +105,6 @@ func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal
return ioctl(ctx, s.fd, uio, args)
}
-// Allocate implements vfs.FileDescriptionImpl.Allocate.
-func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error {
- return syserror.ENODEV
-}
-
// PRead implements vfs.FileDescriptionImpl.PRead.
func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
return 0, syserror.ESPIPE
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index fda3dcb35..faa61160e 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -30,6 +30,9 @@ import (
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -59,6 +62,8 @@ type Stack struct {
tcpSACKEnabled bool
netDevFile *os.File
netSNMPFile *os.File
+ ipv4Forwarding bool
+ ipv6Forwarding bool
}
// NewStack returns an empty Stack containing no configuration.
@@ -118,6 +123,13 @@ func (s *Stack) Configure() error {
s.netSNMPFile = f
}
+ s.ipv6Forwarding = false
+ if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv6/conf/all/forwarding"); err == nil {
+ s.ipv6Forwarding = strings.TrimSpace(string(ipForwarding)) != "0"
+ } else {
+ log.Warningf("Failed to read if ipv6 forwarding is enabled, setting to false")
+ }
+
return nil
}
@@ -468,3 +480,21 @@ func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
+
+// Forwarding implements inet.Stack.Forwarding.
+func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
+ switch protocol {
+ case ipv4.ProtocolNumber:
+ return s.ipv4Forwarding
+ case ipv6.ProtocolNumber:
+ return s.ipv6Forwarding
+ default:
+ log.Warningf("Forwarding(%v) failed: unsupported protocol", protocol)
+ return false
+ }
+}
+
+// SetForwarding implements inet.Stack.SetForwarding.
+func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error {
+ return syserror.EACCES
+}
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index 0336a32d8..549787955 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -19,6 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -37,7 +39,7 @@ type matchMaker interface {
// name is the matcher name as stored in the xt_entry_match struct.
name() string
- // marshal converts from an stack.Matcher to an ABI struct.
+ // marshal converts from a stack.Matcher to an ABI struct.
marshal(matcher stack.Matcher) []byte
// unmarshal converts from the ABI matcher struct to an
@@ -93,3 +95,71 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf
}
return matchMaker.unmarshal(buf, filter)
}
+
+// targetMaker knows how to (un)marshal a target. Once registered,
+// marshalTarget and unmarshalTarget can be used.
+type targetMaker interface {
+ // id uniquely identifies the target.
+ id() stack.TargetID
+
+ // marshal converts from a stack.Target to an ABI struct.
+ marshal(target stack.Target) []byte
+
+ // unmarshal converts from the ABI matcher struct to a stack.Target.
+ unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error)
+}
+
+// targetMakers maps the TargetID of supported targets to the targetMaker that
+// marshals and unmarshals it. It is immutable after package initialization.
+var targetMakers = map[stack.TargetID]targetMaker{}
+
+func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8) (uint8, bool) {
+ tid := stack.TargetID{
+ Name: name,
+ NetworkProtocol: netProto,
+ Revision: rev,
+ }
+ if _, ok := targetMakers[tid]; !ok {
+ return 0, false
+ }
+
+ // Return the highest supported revision unless rev is higher.
+ for _, other := range targetMakers {
+ otherID := other.id()
+ if name == otherID.Name && netProto == otherID.NetworkProtocol && otherID.Revision > rev {
+ rev = uint8(otherID.Revision)
+ }
+ }
+ return rev, true
+}
+
+// registerTargetMaker should be called by target extensions to register them
+// with the netfilter package.
+func registerTargetMaker(tm targetMaker) {
+ if _, ok := targetMakers[tm.id()]; ok {
+ panic(fmt.Sprintf("multiple targets registered with name %q.", tm.id()))
+ }
+ targetMakers[tm.id()] = tm
+}
+
+func marshalTarget(target stack.Target) []byte {
+ targetMaker, ok := targetMakers[target.ID()]
+ if !ok {
+ panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.ID()))
+ }
+ return targetMaker.marshal(target)
+}
+
+func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (stack.Target, *syserr.Error) {
+ tid := stack.TargetID{
+ Name: target.Name.String(),
+ NetworkProtocol: filter.NetworkProtocol(),
+ Revision: target.Revision,
+ }
+ targetMaker, ok := targetMakers[tid]
+ if !ok {
+ nflog("unsupported target with name %q", target.Name.String())
+ return nil, syserr.ErrInvalidArgument
+ }
+ return targetMaker.unmarshal(buf, filter)
+}
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index e4c55a100..b560fae0d 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -181,18 +181,23 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace,
nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
return nil, syserr.ErrInvalidArgument
}
- target, err := parseTarget(filter, optVal[:targetSize])
- if err != nil {
- nflog("failed to parse target: %v", err)
- return nil, syserr.ErrInvalidArgument
- }
- optVal = optVal[targetSize:]
- table.Rules = append(table.Rules, stack.Rule{
+ rule := stack.Rule{
Filter: filter,
- Target: target,
Matchers: matchers,
- })
+ }
+
+ {
+ target, err := parseTarget(filter, optVal[:targetSize], false /* ipv6 */)
+ if err != nil {
+ nflog("failed to parse target: %v", err)
+ return nil, err
+ }
+ rule.Target = target
+ }
+ optVal = optVal[targetSize:]
+
+ table.Rules = append(table.Rules, rule)
offsets[offset] = int(entryIdx)
offset += uint32(entry.NextOffset)
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 3b2c1becd..4253f7bf4 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -184,18 +184,23 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace,
nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
return nil, syserr.ErrInvalidArgument
}
- target, err := parseTarget(filter, optVal[:targetSize])
- if err != nil {
- nflog("failed to parse target: %v", err)
- return nil, syserr.ErrInvalidArgument
- }
- optVal = optVal[targetSize:]
- table.Rules = append(table.Rules, stack.Rule{
+ rule := stack.Rule{
Filter: filter,
- Target: target,
Matchers: matchers,
- })
+ }
+
+ {
+ target, err := parseTarget(filter, optVal[:targetSize], true /* ipv6 */)
+ if err != nil {
+ nflog("failed to parse target: %v", err)
+ return nil, err
+ }
+ rule.Target = target
+ }
+ optVal = optVal[targetSize:]
+
+ table.Rules = append(table.Rules, rule)
offsets[offset] = int(entryIdx)
offset += uint32(entry.NextOffset)
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 3e1735079..904a12e38 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -195,7 +196,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
// Check the user chains.
for ruleIdx, rule := range table.Rules {
- if _, ok := rule.Target.(stack.UserChainTarget); !ok {
+ if _, ok := rule.Target.(*stack.UserChainTarget); !ok {
continue
}
@@ -216,7 +217,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
// Set each jump to point to the appropriate rule. Right now they hold byte
// offsets.
for ruleIdx, rule := range table.Rules {
- jump, ok := rule.Target.(JumpTarget)
+ jump, ok := rule.Target.(*JumpTarget)
if !ok {
continue
}
@@ -307,7 +308,7 @@ func validUnderflow(rule stack.Rule, ipv6 bool) bool {
return false
}
switch rule.Target.(type) {
- case stack.AcceptTarget, stack.DropTarget:
+ case *stack.AcceptTarget, *stack.DropTarget:
return true
default:
return false
@@ -318,7 +319,7 @@ func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool {
if !validUnderflow(rule, ipv6) {
return false
}
- _, ok := rule.Target.(stack.AcceptTarget)
+ _, ok := rule.Target.(*stack.AcceptTarget)
return ok
}
@@ -337,3 +338,20 @@ func hookFromLinux(hook int) stack.Hook {
}
panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook))
}
+
+// TargetRevision returns a linux.XTGetRevision for a given target. It sets
+// Revision to the highest supported value, unless the provided revision number
+// is larger.
+func TargetRevision(t *kernel.Task, revPtr usermem.Addr, netProto tcpip.NetworkProtocolNumber) (linux.XTGetRevision, *syserr.Error) {
+ // Read in the target name and version.
+ var rev linux.XTGetRevision
+ if _, err := rev.CopyIn(t, revPtr); err != nil {
+ return linux.XTGetRevision{}, syserr.FromError(err)
+ }
+ maxSupported, ok := targetRevision(rev.Name.String(), netProto, rev.Revision)
+ if !ok {
+ return linux.XTGetRevision{}, syserr.ErrProtocolNotSupported
+ }
+ rev.Revision = maxSupported
+ return rev, nil
+}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index 87e41abd8..0e14447fe 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -15,255 +15,357 @@
package netfilter
import (
- "errors"
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
-// errorTargetName is used to mark targets as error targets. Error targets
-// shouldn't be reached - an error has occurred if we fall through to one.
-const errorTargetName = "ERROR"
+func init() {
+ // Standard targets include ACCEPT, DROP, RETURN, and JUMP.
+ registerTargetMaker(&standardTargetMaker{
+ NetworkProtocol: header.IPv4ProtocolNumber,
+ })
+ registerTargetMaker(&standardTargetMaker{
+ NetworkProtocol: header.IPv6ProtocolNumber,
+ })
+
+ // Both user chains and actual errors are represented in iptables by
+ // error targets.
+ registerTargetMaker(&errorTargetMaker{
+ NetworkProtocol: header.IPv4ProtocolNumber,
+ })
+ registerTargetMaker(&errorTargetMaker{
+ NetworkProtocol: header.IPv6ProtocolNumber,
+ })
+
+ registerTargetMaker(&redirectTargetMaker{
+ NetworkProtocol: header.IPv4ProtocolNumber,
+ })
+ registerTargetMaker(&nfNATTargetMaker{
+ NetworkProtocol: header.IPv6ProtocolNumber,
+ })
+}
-// redirectTargetName is used to mark targets as redirect targets. Redirect
-// targets should be reached for only NAT and Mangle tables. These targets will
-// change the destination port/destination IP for packets.
-const redirectTargetName = "REDIRECT"
+type standardTargetMaker struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
-func marshalTarget(target stack.Target) []byte {
+func (sm *standardTargetMaker) id() stack.TargetID {
+ // Standard targets have the empty string as a name and no revisions.
+ return stack.TargetID{
+ NetworkProtocol: sm.NetworkProtocol,
+ }
+}
+func (*standardTargetMaker) marshal(target stack.Target) []byte {
+ // Translate verdicts the same way as the iptables tool.
+ var verdict int32
switch tg := target.(type) {
- case stack.AcceptTarget:
- return marshalStandardTarget(stack.RuleAccept)
- case stack.DropTarget:
- return marshalStandardTarget(stack.RuleDrop)
- case stack.ErrorTarget:
- return marshalErrorTarget(errorTargetName)
- case stack.UserChainTarget:
- return marshalErrorTarget(tg.Name)
- case stack.ReturnTarget:
- return marshalStandardTarget(stack.RuleReturn)
- case stack.RedirectTarget:
- return marshalRedirectTarget(tg)
- case JumpTarget:
- return marshalJumpTarget(tg)
+ case *stack.AcceptTarget:
+ verdict = -linux.NF_ACCEPT - 1
+ case *stack.DropTarget:
+ verdict = -linux.NF_DROP - 1
+ case *stack.ReturnTarget:
+ verdict = linux.NF_RETURN
+ case *JumpTarget:
+ verdict = int32(tg.Offset)
default:
panic(fmt.Errorf("unknown target of type %T", target))
}
-}
-
-func marshalStandardTarget(verdict stack.RuleVerdict) []byte {
- nflog("convert to binary: marshalling standard target")
// The target's name will be the empty string.
- target := linux.XTStandardTarget{
+ xt := linux.XTStandardTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTStandardTarget,
},
- Verdict: translateFromStandardVerdict(verdict),
+ Verdict: verdict,
}
ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
- return binary.Marshal(ret, usermem.ByteOrder, target)
+ return binary.Marshal(ret, usermem.ByteOrder, xt)
+}
+
+func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+ if len(buf) != linux.SizeOfXTStandardTarget {
+ nflog("buf has wrong size for standard target %d", len(buf))
+ return nil, syserr.ErrInvalidArgument
+ }
+ var standardTarget linux.XTStandardTarget
+ buf = buf[:linux.SizeOfXTStandardTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
+
+ if standardTarget.Verdict < 0 {
+ // A Verdict < 0 indicates a non-jump verdict.
+ return translateToStandardTarget(standardTarget.Verdict, filter.NetworkProtocol())
+ }
+ // A verdict >= 0 indicates a jump.
+ return &JumpTarget{
+ Offset: uint32(standardTarget.Verdict),
+ NetworkProtocol: filter.NetworkProtocol(),
+ }, nil
+}
+
+type errorTargetMaker struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+func (em *errorTargetMaker) id() stack.TargetID {
+ // Error targets have no revision.
+ return stack.TargetID{
+ Name: stack.ErrorTargetName,
+ NetworkProtocol: em.NetworkProtocol,
+ }
}
-func marshalErrorTarget(errorName string) []byte {
+func (*errorTargetMaker) marshal(target stack.Target) []byte {
+ var errorName string
+ switch tg := target.(type) {
+ case *stack.ErrorTarget:
+ errorName = stack.ErrorTargetName
+ case *stack.UserChainTarget:
+ errorName = tg.Name
+ default:
+ panic(fmt.Sprintf("errorMakerTarget cannot marshal unknown type %T", target))
+ }
+
// This is an error target named error
- target := linux.XTErrorTarget{
+ xt := linux.XTErrorTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTErrorTarget,
},
}
- copy(target.Name[:], errorName)
- copy(target.Target.Name[:], errorTargetName)
+ copy(xt.Name[:], errorName)
+ copy(xt.Target.Name[:], stack.ErrorTargetName)
ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
- return binary.Marshal(ret, usermem.ByteOrder, target)
+ return binary.Marshal(ret, usermem.ByteOrder, xt)
}
-func marshalRedirectTarget(rt stack.RedirectTarget) []byte {
+func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+ if len(buf) != linux.SizeOfXTErrorTarget {
+ nflog("buf has insufficient size for error target %d", len(buf))
+ return nil, syserr.ErrInvalidArgument
+ }
+ var errorTarget linux.XTErrorTarget
+ buf = buf[:linux.SizeOfXTErrorTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget)
+
+ // Error targets are used in 2 cases:
+ // * An actual error case. These rules have an error
+ // named stack.ErrorTargetName. The last entry of the table
+ // is usually an error case to catch any packets that
+ // somehow fall through every rule.
+ // * To mark the start of a user defined chain. These
+ // rules have an error with the name of the chain.
+ switch name := errorTarget.Name.String(); name {
+ case stack.ErrorTargetName:
+ return &stack.ErrorTarget{NetworkProtocol: filter.NetworkProtocol()}, nil
+ default:
+ // User defined chain.
+ return &stack.UserChainTarget{
+ Name: name,
+ NetworkProtocol: filter.NetworkProtocol(),
+ }, nil
+ }
+}
+
+type redirectTargetMaker struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+func (rm *redirectTargetMaker) id() stack.TargetID {
+ return stack.TargetID{
+ Name: stack.RedirectTargetName,
+ NetworkProtocol: rm.NetworkProtocol,
+ }
+}
+
+func (*redirectTargetMaker) marshal(target stack.Target) []byte {
+ rt := target.(*stack.RedirectTarget)
// This is a redirect target named redirect
- target := linux.XTRedirectTarget{
+ xt := linux.XTRedirectTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTRedirectTarget,
},
}
- copy(target.Target.Name[:], redirectTargetName)
+ copy(xt.Target.Name[:], stack.RedirectTargetName)
ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
- target.NfRange.RangeSize = 1
- if rt.RangeProtoSpecified {
- target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED
+ xt.NfRange.RangeSize = 1
+ xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED
+ xt.NfRange.RangeIPV4.MinPort = htons(rt.Port)
+ xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort
+ return binary.Marshal(ret, usermem.ByteOrder, xt)
+}
+
+func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+ if len(buf) < linux.SizeOfXTRedirectTarget {
+ nflog("redirectTargetMaker: buf has insufficient size for redirect target %d", len(buf))
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
+ nflog("redirectTargetMaker: bad proto %d", p)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var redirectTarget linux.XTRedirectTarget
+ buf = buf[:linux.SizeOfXTRedirectTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
+
+ // Copy linux.XTRedirectTarget to stack.RedirectTarget.
+ target := stack.RedirectTarget{NetworkProtocol: filter.NetworkProtocol()}
+
+ // RangeSize should be 1.
+ nfRange := redirectTarget.NfRange
+ if nfRange.RangeSize != 1 {
+ nflog("redirectTargetMaker: bad rangesize %d", nfRange.RangeSize)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): Check if the flags are valid.
+ // Also check if we need to map ports or IP.
+ // For now, redirect target only supports destination port change.
+ // Port range and IP range are not supported yet.
+ if nfRange.RangeIPV4.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED {
+ nflog("redirectTargetMaker: invalid range flags %d", nfRange.RangeIPV4.Flags)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): Port range is not supported yet.
+ if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
+ nflog("redirectTargetMaker: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
+ return nil, syserr.ErrInvalidArgument
}
- // Convert port from little endian to big endian.
- port := make([]byte, 2)
- binary.LittleEndian.PutUint16(port, rt.MinPort)
- target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port)
- binary.LittleEndian.PutUint16(port, rt.MaxPort)
- target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port)
- return binary.Marshal(ret, usermem.ByteOrder, target)
+ if nfRange.RangeIPV4.MinIP != nfRange.RangeIPV4.MaxIP {
+ nflog("redirectTargetMaker: MinIP != MaxIP (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.Port = ntohs(nfRange.RangeIPV4.MinPort)
+
+ return &target, nil
}
-func marshalJumpTarget(jt JumpTarget) []byte {
- nflog("convert to binary: marshalling jump target")
+type nfNATTarget struct {
+ Target linux.XTEntryTarget
+ Range linux.NFNATRange
+}
- // The target's name will be the empty string.
- target := linux.XTStandardTarget{
+const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange
+
+type nfNATTargetMaker struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+func (rm *nfNATTargetMaker) id() stack.TargetID {
+ return stack.TargetID{
+ Name: stack.RedirectTargetName,
+ NetworkProtocol: rm.NetworkProtocol,
+ }
+}
+
+func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
+ rt := target.(*stack.RedirectTarget)
+ nt := nfNATTarget{
Target: linux.XTEntryTarget{
- TargetSize: linux.SizeOfXTStandardTarget,
+ TargetSize: nfNATMarhsalledSize,
+ },
+ Range: linux.NFNATRange{
+ Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED,
},
- // Verdict is overloaded by the ABI. When positive, it holds
- // the jump offset from the start of the table.
- Verdict: int32(jt.Offset),
}
+ copy(nt.Target.Name[:], stack.RedirectTargetName)
+ copy(nt.Range.MinAddr[:], rt.Addr)
+ copy(nt.Range.MaxAddr[:], rt.Addr)
- ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
- return binary.Marshal(ret, usermem.ByteOrder, target)
+ nt.Range.MinProto = htons(rt.Port)
+ nt.Range.MaxProto = nt.Range.MinProto
+
+ ret := make([]byte, 0, nfNATMarhsalledSize)
+ return binary.Marshal(ret, usermem.ByteOrder, nt)
}
-// translateFromStandardVerdict translates verdicts the same way as the iptables
-// tool.
-func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 {
- switch verdict {
- case stack.RuleAccept:
- return -linux.NF_ACCEPT - 1
- case stack.RuleDrop:
- return -linux.NF_DROP - 1
- case stack.RuleReturn:
- return linux.NF_RETURN
- default:
- // TODO(gvisor.dev/issue/170): Support Jump.
- panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
+func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+ if size := nfNATMarhsalledSize; len(buf) < size {
+ nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size)
+ return nil, syserr.ErrInvalidArgument
}
+
+ if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
+ nflog("nfNATTargetMaker: bad proto %d", p)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var natRange linux.NFNATRange
+ buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize]
+ binary.Unmarshal(buf, usermem.ByteOrder, &natRange)
+
+ // We don't support port or address ranges.
+ if natRange.MinAddr != natRange.MaxAddr {
+ nflog("nfNATTargetMaker: MinAddr and MaxAddr are different")
+ return nil, syserr.ErrInvalidArgument
+ }
+ if natRange.MinProto != natRange.MaxProto {
+ nflog("nfNATTargetMaker: MinProto and MaxProto are different")
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/3549): Check for other flags.
+ // For now, redirect target only supports destination change.
+ if natRange.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED {
+ nflog("nfNATTargetMaker: invalid range flags %d", natRange.Flags)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ target := stack.RedirectTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ Addr: tcpip.Address(natRange.MinAddr[:]),
+ Port: ntohs(natRange.MinProto),
+ }
+
+ return &target, nil
}
// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an stack.Verdict.
-func translateToStandardTarget(val int32) (stack.Target, error) {
+func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) {
// TODO(gvisor.dev/issue/170): Support other verdicts.
switch val {
case -linux.NF_ACCEPT - 1:
- return stack.AcceptTarget{}, nil
+ return &stack.AcceptTarget{NetworkProtocol: netProto}, nil
case -linux.NF_DROP - 1:
- return stack.DropTarget{}, nil
+ return &stack.DropTarget{NetworkProtocol: netProto}, nil
case -linux.NF_QUEUE - 1:
- return nil, errors.New("unsupported iptables verdict QUEUE")
+ nflog("unsupported iptables verdict QUEUE")
+ return nil, syserr.ErrInvalidArgument
case linux.NF_RETURN:
- return stack.ReturnTarget{}, nil
+ return &stack.ReturnTarget{NetworkProtocol: netProto}, nil
default:
- return nil, fmt.Errorf("unknown iptables verdict %d", val)
+ nflog("unknown iptables verdict %d", val)
+ return nil, syserr.ErrInvalidArgument
}
}
// parseTarget parses a target from optVal. optVal should contain only the
// target.
-func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) {
+func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.Target, *syserr.Error) {
nflog("set entries: parsing target of size %d", len(optVal))
if len(optVal) < linux.SizeOfXTEntryTarget {
- return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal))
+ nflog("optVal has insufficient size for entry target %d", len(optVal))
+ return nil, syserr.ErrInvalidArgument
}
var target linux.XTEntryTarget
buf := optVal[:linux.SizeOfXTEntryTarget]
binary.Unmarshal(buf, usermem.ByteOrder, &target)
- switch target.Name.String() {
- case "":
- // Standard target.
- if len(optVal) != linux.SizeOfXTStandardTarget {
- return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal))
- }
- var standardTarget linux.XTStandardTarget
- buf = optVal[:linux.SizeOfXTStandardTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
-
- if standardTarget.Verdict < 0 {
- // A Verdict < 0 indicates a non-jump verdict.
- return translateToStandardTarget(standardTarget.Verdict)
- }
- // A verdict >= 0 indicates a jump.
- return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil
-
- case errorTargetName:
- // Error target.
- if len(optVal) != linux.SizeOfXTErrorTarget {
- return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal))
- }
- var errorTarget linux.XTErrorTarget
- buf = optVal[:linux.SizeOfXTErrorTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget)
-
- // Error targets are used in 2 cases:
- // * An actual error case. These rules have an error
- // named errorTargetName. The last entry of the table
- // is usually an error case to catch any packets that
- // somehow fall through every rule.
- // * To mark the start of a user defined chain. These
- // rules have an error with the name of the chain.
- switch name := errorTarget.Name.String(); name {
- case errorTargetName:
- nflog("set entries: error target")
- return stack.ErrorTarget{}, nil
- default:
- // User defined chain.
- nflog("set entries: user-defined target %q", name)
- return stack.UserChainTarget{Name: name}, nil
- }
-
- case redirectTargetName:
- // Redirect target.
- if len(optVal) < linux.SizeOfXTRedirectTarget {
- return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal))
- }
-
- if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
- return nil, fmt.Errorf("netfilter.SetEntries: bad proto %d", p)
- }
-
- var redirectTarget linux.XTRedirectTarget
- buf = optVal[:linux.SizeOfXTRedirectTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
-
- // Copy linux.XTRedirectTarget to stack.RedirectTarget.
- var target stack.RedirectTarget
- nfRange := redirectTarget.NfRange
-
- // RangeSize should be 1.
- if nfRange.RangeSize != 1 {
- return nil, fmt.Errorf("netfilter.SetEntries: bad rangesize %d", nfRange.RangeSize)
- }
-
- // TODO(gvisor.dev/issue/170): Check if the flags are valid.
- // Also check if we need to map ports or IP.
- // For now, redirect target only supports destination port change.
- // Port range and IP range are not supported yet.
- if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 {
- return nil, fmt.Errorf("netfilter.SetEntries: invalid range flags %d", nfRange.RangeIPV4.Flags)
- }
- target.RangeProtoSpecified = true
-
- target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
- target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:])
-
- // TODO(gvisor.dev/issue/170): Port range is not supported yet.
- if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
- return nil, fmt.Errorf("netfilter.SetEntries: minport != maxport (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
- }
-
- // Convert port from big endian to little endian.
- port := make([]byte, 2)
- binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort)
- target.MinPort = binary.LittleEndian.Uint16(port)
-
- binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort)
- target.MaxPort = binary.LittleEndian.Uint16(port)
- return target, nil
- }
- // Unknown target.
- return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet", target.Name.String())
+ return unmarshalTarget(target, filter, optVal)
}
// JumpTarget implements stack.Target.
@@ -274,9 +376,31 @@ type JumpTarget struct {
// RuleNum is the rule to jump to.
RuleNum int
+
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (jt *JumpTarget) ID() stack.TargetID {
+ return stack.TargetID{
+ NetworkProtocol: jt.NetworkProtocol,
+ }
}
// Action implements stack.Target.Action.
-func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
+func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
return stack.RuleJump, jt.RuleNum
}
+
+func ntohs(port uint16) uint16 {
+ buf := make([]byte, 2)
+ binary.BigEndian.PutUint16(buf, port)
+ return usermem.ByteOrder.Uint16(buf)
+}
+
+func htons(port uint16) uint16 {
+ buf := make([]byte, 2)
+ usermem.ByteOrder.PutUint16(buf, port)
+ return binary.BigEndian.Uint16(buf)
+}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 0bfd6c1f4..844acfede 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -97,17 +97,33 @@ func (*TCPMatcher) Name() string {
// Match implements Matcher.Match.
func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
+ // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
+ // into the stack.Check codepath as matchers are added.
+ switch pkt.NetworkProtocolNumber {
+ case header.IPv4ProtocolNumber:
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ return false, false
+ }
- if netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return false, false
- }
+ // We don't match fragments.
+ if frag := netHeader.FragmentOffset(); frag != 0 {
+ if frag == 1 {
+ return false, true
+ }
+ return false, false
+ }
- // We dont't match fragments.
- if frag := netHeader.FragmentOffset(); frag != 0 {
- if frag == 1 {
- return false, true
+ case header.IPv6ProtocolNumber:
+ // As in Linux, we do not perform an IPv6 fragment check. See
+ // xt_action_param.fragoff in
+ // include/linux/netfilter/x_tables.h.
+ if header.IPv6(pkt.NetworkHeader().View()).TransportProtocol() != header.TCPProtocolNumber {
+ return false, false
}
+
+ default:
+ // We don't know the network protocol.
return false, false
}
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 7ed05461d..63201201c 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -94,19 +94,33 @@ func (*UDPMatcher) Name() string {
// Match implements Matcher.Match.
func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
-
// TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
// into the stack.Check codepath as matchers are added.
- if netHeader.TransportProtocol() != header.UDPProtocolNumber {
- return false, false
- }
+ switch pkt.NetworkProtocolNumber {
+ case header.IPv4ProtocolNumber:
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ if netHeader.TransportProtocol() != header.UDPProtocolNumber {
+ return false, false
+ }
- // We dont't match fragments.
- if frag := netHeader.FragmentOffset(); frag != 0 {
- if frag == 1 {
- return false, true
+ // We don't match fragments.
+ if frag := netHeader.FragmentOffset(); frag != 0 {
+ if frag == 1 {
+ return false, true
+ }
+ return false, false
}
+
+ case header.IPv6ProtocolNumber:
+ // As in Linux, we do not perform an IPv6 fragment check. See
+ // xt_action_param.fragoff in
+ // include/linux/netfilter/x_tables.h.
+ if header.IPv6(pkt.NetworkHeader().View()).TransportProtocol() != header.UDPProtocolNumber {
+ return false, false
+ }
+
+ default:
+ // We don't know the network protocol.
return false, false
}
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 0546801bf..1f926aa91 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -16,6 +16,8 @@ go_library(
"//pkg/abi/linux",
"//pkg/binary",
"//pkg/context",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
@@ -36,8 +38,6 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go
index bb205be0d..e8930f031 100644
--- a/pkg/sentry/socket/netlink/provider_vfs2.go
+++ b/pkg/sentry/socket/netlink/provider_vfs2.go
@@ -52,6 +52,7 @@ func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol
vfsfd := &s.vfsfd
mnt := t.Kernel().SocketMount()
d := sockfs.NewDentry(t.Credentials(), mnt)
+ defer d.DecRef(t)
if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{
DenyPRead: true,
DenyPWrite: true,
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 68a9b9a96..5ddcd4be5 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -38,8 +40,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const sizeOfInt32 int = 4
diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go
index a38d25da9..c83b23242 100644
--- a/pkg/sentry/socket/netlink/socket_vfs2.go
+++ b/pkg/sentry/socket/netlink/socket_vfs2.go
@@ -82,6 +82,13 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV
return fd, nil
}
+// Release implements vfs.FileDescriptionImpl.Release.
+func (s *SocketVFS2) Release(ctx context.Context) {
+ t := kernel.TaskFromContext(ctx)
+ t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ s.socketOpsCommon.Release(ctx)
+}
+
// Readiness implements waiter.Waitable.Readiness.
func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.socketOpsCommon.Readiness(mask)
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 1fb777a6c..fae3b6783 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -22,6 +22,8 @@ go_library(
"//pkg/binary",
"//pkg/context",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/metric",
"//pkg/safemem",
"//pkg/sentry/arch",
@@ -51,8 +53,6 @@ go_library(
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 0bf21f7d8..87e30d742 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -40,6 +40,8 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -62,8 +64,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
func mustCreateMetric(name, description string) *tcpip.StatCounter {
@@ -158,6 +158,9 @@ var Metrics = tcpip.Stats{
OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."),
MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."),
MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."),
+ IPTablesPreroutingDropped: mustCreateMetric("/netstack/ip/iptables/prerouting_dropped", "Total number of IP packets dropped in the Prerouting chain."),
+ IPTablesInputDropped: mustCreateMetric("/netstack/ip/iptables/input_dropped", "Total number of IP packets dropped in the Input chain."),
+ IPTablesOutputDropped: mustCreateMetric("/netstack/ip/iptables/output_dropped", "Total number of IP packets dropped in the Output chain."),
},
TCP: tcpip.TCPStats{
ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."),
@@ -195,7 +198,6 @@ var Metrics = tcpip.Stats{
PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."),
PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."),
ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."),
- InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."),
},
}
@@ -857,7 +859,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
// blockingAccept implements a blocking version of accept(2), that is, if no
// connections are ready to be accept, it will block until one becomes ready.
-func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) {
+func (s *socketOpsCommon) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) {
// Register for notifications.
e, ch := waiter.NewChannelEntry(nil)
s.EventRegister(&e, waiter.EventIn)
@@ -866,7 +868,7 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite
// Try to accept the connection again; if it fails, then wait until we
// get a notification.
for {
- if ep, wq, err := s.Endpoint.Accept(); err != tcpip.ErrWouldBlock {
+ if ep, wq, err := s.Endpoint.Accept(peerAddr); err != tcpip.ErrWouldBlock {
return ep, wq, syserr.TranslateNetstackError(err)
}
@@ -879,15 +881,18 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite
// Accept implements the linux syscall accept(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
- // Issue the accept request to get the new endpoint.
- ep, wq, terr := s.Endpoint.Accept()
+ var peerAddr *tcpip.FullAddress
+ if peerRequested {
+ peerAddr = &tcpip.FullAddress{}
+ }
+ ep, wq, terr := s.Endpoint.Accept(peerAddr)
if terr != nil {
if terr != tcpip.ErrWouldBlock || !blocking {
return 0, nil, 0, syserr.TranslateNetstackError(terr)
}
var err *syserr.Error
- ep, wq, err = s.blockingAccept(t)
+ ep, wq, err = s.blockingAccept(t, peerAddr)
if err != nil {
return 0, nil, 0, err
}
@@ -907,13 +912,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
- if peerRequested {
- // Get address of the peer and write it to peer slice.
- var err *syserr.Error
- addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
- if err != nil {
- return 0, nil, 0, err
- }
+ if peerAddr != nil {
+ addr, addrLen = ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
@@ -1187,7 +1187,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
var v tcpip.LingerOption
var linger linux.Linger
if err := ep.GetSockOpt(&v); err != nil {
- return &linger, nil
+ return nil, syserr.TranslateNetstackError(err)
}
if v.Enabled {
@@ -1512,8 +1512,17 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return &vP, nil
case linux.IP6T_ORIGINAL_DST:
- // TODO(gvisor.dev/issue/170): ip6tables.
- return nil, syserr.ErrInvalidArgument
+ if outLen < int(binary.Size(linux.SockAddrInet6{})) {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.OriginalDestinationOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v))
+ return a.(*linux.SockAddrInet6), nil
case linux.IP6T_SO_GET_INFO:
if outLen < linux.SizeOfIPTGetinfo {
@@ -1555,6 +1564,26 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
return &entries, nil
+ case linux.IP6T_SO_GET_REVISION_TARGET:
+ if outLen < linux.SizeOfXTGetRevision {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Only valid for raw IPv6 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ return nil, syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber)
+ if err != nil {
+ return nil, err
+ }
+ return &ret, nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1718,6 +1747,26 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
}
return &entries, nil
+ case linux.IPT_SO_GET_REVISION_TARGET:
+ if outLen < linux.SizeOfXTGetRevision {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Only valid for raw IPv4 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET || skType != linux.SOCK_RAW {
+ return nil, syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ ret, err := netfilter.TargetRevision(t, outPtr, header.IPv4ProtocolNumber)
+ if err != nil {
+ return nil, err
+ }
+ return &ret, nil
+
default:
emitUnimplementedEventIP(t, name)
}
@@ -1770,10 +1819,16 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int
case linux.SOL_IP:
return setSockOptIP(t, s, ep, name, optVal)
+ case linux.SOL_PACKET:
+ // gVisor doesn't support any SOL_PACKET options just return not
+ // supported. Returning nil here will result in tcpdump thinking AF_PACKET
+ // features are supported and proceed to use them and break.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return syserr.ErrProtocolNotAvailable
+
case linux.SOL_UDP,
linux.SOL_ICMPV6,
- linux.SOL_RAW,
- linux.SOL_PACKET:
+ linux.SOL_RAW:
t.Kernel().EmitUnimplementedEvent(t)
}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index 1f7d17f5f..4c6791fff 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -18,6 +18,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
@@ -29,8 +31,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// SocketVFS2 encapsulates all the state needed to represent a network stack
@@ -79,6 +79,13 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu
return vfsfd, nil
}
+// Release implements vfs.FileDescriptionImpl.Release.
+func (s *SocketVFS2) Release(ctx context.Context) {
+ t := kernel.TaskFromContext(ctx)
+ t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ s.socketOpsCommon.Release(ctx)
+}
+
// Readiness implements waiter.Waitable.Readiness.
func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.socketOpsCommon.Readiness(mask)
@@ -151,14 +158,18 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
// tcpip.Endpoint.
func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
// Issue the accept request to get the new endpoint.
- ep, wq, terr := s.Endpoint.Accept()
+ var peerAddr *tcpip.FullAddress
+ if peerRequested {
+ peerAddr = &tcpip.FullAddress{}
+ }
+ ep, wq, terr := s.Endpoint.Accept(peerAddr)
if terr != nil {
if terr != tcpip.ErrWouldBlock || !blocking {
return 0, nil, 0, syserr.TranslateNetstackError(terr)
}
var err *syserr.Error
- ep, wq, err = s.blockingAccept(t)
+ ep, wq, err = s.blockingAccept(t, peerAddr)
if err != nil {
return 0, nil, 0, err
}
@@ -176,13 +187,9 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addr linux.SockAddr
var addrLen uint32
- if peerRequested {
+ if peerAddr != nil {
// Get address of the peer and write it to peer slice.
- var err *syserr.Error
- addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
- if err != nil {
- return 0, nil, 0, err
- }
+ addr, addrLen = ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index f0fe18684..1028d2a6e 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -155,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
- var rs tcp.ReceiveBufferSizeOption
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &rs)
return inet.TCPBufferSize{
Min: rs.Min,
@@ -166,17 +166,17 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
- rs := tcp.ReceiveBufferSizeOption{
+ rs := tcpip.TCPReceiveBufferSizeRangeOption{
Min: size.Min,
Default: size.Default,
Max: size.Max,
}
- return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, rs)).ToError()
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &rs)).ToError()
}
// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
- var ss tcp.SendBufferSizeOption
+ var ss tcpip.TCPSendBufferSizeRangeOption
err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &ss)
return inet.TCPBufferSize{
Min: ss.Min,
@@ -187,29 +187,30 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
- ss := tcp.SendBufferSizeOption{
+ ss := tcpip.TCPSendBufferSizeRangeOption{
Min: size.Min,
Default: size.Default,
Max: size.Max,
}
- return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, ss)).ToError()
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &ss)).ToError()
}
// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
func (s *Stack) TCPSACKEnabled() (bool, error) {
- var sack tcp.SACKEnabled
+ var sack tcpip.TCPSACKEnabled
err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &sack)
return bool(sack), syserr.TranslateNetstackError(err).ToError()
}
// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
- return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError()
+ opt := tcpip.TCPSACKEnabled(enabled)
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)).ToError()
}
// TCPRecovery implements inet.Stack.TCPRecovery.
func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) {
- var recovery tcp.Recovery
+ var recovery tcpip.TCPRecovery
if err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &recovery); err != nil {
return 0, syserr.TranslateNetstackError(err).ToError()
}
@@ -218,7 +219,8 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) {
// SetTCPRecovery implements inet.Stack.SetTCPRecovery.
func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error {
- return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.Recovery(recovery))).ToError()
+ opt := tcpip.TCPRecovery(recovery)
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)).ToError()
}
// Statistics implements inet.Stack.Statistics.
@@ -410,3 +412,24 @@ func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint {
func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) {
s.Stack.RestoreCleanupEndpoints(es)
}
+
+// Forwarding implements inet.Stack.Forwarding.
+func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
+ switch protocol {
+ case ipv4.ProtocolNumber, ipv6.ProtocolNumber:
+ return s.Stack.Forwarding(protocol)
+ default:
+ panic(fmt.Sprintf("Forwarding(%v) failed: unsupported protocol", protocol))
+ }
+}
+
+// SetForwarding implements inet.Stack.SetForwarding.
+func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error {
+ switch protocol {
+ case ipv4.ProtocolNumber, ipv6.ProtocolNumber:
+ s.Stack.SetForwarding(protocol, enable)
+ default:
+ panic(fmt.Sprintf("SetForwarding(%v) failed: unsupported protocol", protocol))
+ }
+ return nil
+}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 04b259d27..fd31479e5 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -35,7 +36,6 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// ControlMessages represents the union of unix control messages and tcpip
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index cb953e4dc..cc7408698 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -7,10 +7,21 @@ go_template_instance(
name = "socket_refs",
out = "socket_refs.go",
package = "unix",
- prefix = "socketOpsCommon",
+ prefix = "socketOperations",
template = "//pkg/refs_vfs2:refs_template",
types = {
- "T": "socketOpsCommon",
+ "T": "SocketOperations",
+ },
+)
+
+go_template_instance(
+ name = "socket_vfs2_refs",
+ out = "socket_vfs2_refs.go",
+ package = "unix",
+ prefix = "socketVFS2",
+ template = "//pkg/refs_vfs2:refs_template",
+ types = {
+ "T": "SocketVFS2",
},
)
@@ -20,6 +31,7 @@ go_library(
"device.go",
"io.go",
"socket_refs.go",
+ "socket_vfs2_refs.go",
"unix.go",
"unix_vfs2.go",
],
@@ -29,6 +41,7 @@ go_library(
"//pkg/context",
"//pkg/fspath",
"//pkg/log",
+ "//pkg/marshal",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
@@ -49,6 +62,5 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index e3a75b519..aa4f3c04d 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -391,7 +391,7 @@ func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error {
}
// Accept accepts a new connection.
-func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) {
+func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) {
e.Lock()
defer e.Unlock()
@@ -401,6 +401,18 @@ func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) {
select {
case ne := <-e.acceptedChan:
+ if peerAddr != nil {
+ ne.Lock()
+ c := ne.connected
+ ne.Unlock()
+ if c != nil {
+ addr, err := c.GetLocalAddress()
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ *peerAddr = addr
+ }
+ }
return ne, nil
default:
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 4751b2fd8..f8aacca13 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -144,12 +144,12 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi
}
// Listen starts listening on the connection.
-func (e *connectionlessEndpoint) Listen(int) *syserr.Error {
+func (*connectionlessEndpoint) Listen(int) *syserr.Error {
return syserr.ErrNotSupported
}
// Accept accepts a new connection.
-func (e *connectionlessEndpoint) Accept() (Endpoint, *syserr.Error) {
+func (*connectionlessEndpoint) Accept(*tcpip.FullAddress) (Endpoint, *syserr.Error) {
return nil, syserr.ErrNotSupported
}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index cc9d650fb..d6fc03520 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -151,7 +151,10 @@ type Endpoint interface {
// block if no new connections are available.
//
// The returned Queue is the wait queue for the newly created endpoint.
- Accept() (Endpoint, *syserr.Error)
+ //
+ // peerAddr if not nil will be populated with the address of the connected
+ // peer on a successful accept.
+ Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error)
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
@@ -743,6 +746,9 @@ type baseEndpoint struct {
// path is not empty if the endpoint has been bound,
// or may be used if the endpoint is connected.
path string
+
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
}
// EventRegister implements waiter.Waitable.EventRegister.
@@ -838,8 +844,14 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
return n, err
}
-// SetSockOpt sets a socket option. Currently not supported.
-func (e *baseEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error {
+// SetSockOpt sets a socket option.
+func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+ switch v := opt.(type) {
+ case *tcpip.LingerOption:
+ e.Lock()
+ e.linger = *v
+ e.Unlock()
+ }
return nil
}
@@ -942,8 +954,11 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch opt.(type) {
+ switch o := opt.(type) {
case *tcpip.LingerOption:
+ e.Lock()
+ *o = e.linger
+ e.Unlock()
return nil
default:
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 0a7a26495..f80011ce4 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -39,7 +40,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketOperations is a Unix socket. It is similar to a netstack socket,
@@ -55,6 +55,7 @@ type SocketOperations struct {
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ socketOperationsRefs
socketOpsCommon
}
@@ -84,11 +85,27 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
return fs.NewFile(ctx, d, flags, &s)
}
+// DecRef implements RefCounter.DecRef.
+func (s *SocketOperations) DecRef(ctx context.Context) {
+ s.socketOperationsRefs.DecRef(func() {
+ s.ep.Close(ctx)
+ if s.abstractNamespace != nil {
+ s.abstractNamespace.Remove(s.abstractName, s)
+ }
+ })
+}
+
+// Release implemements fs.FileOperations.Release.
+func (s *SocketOperations) Release(ctx context.Context) {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef(ctx)
+}
+
// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
//
// +stateify savable
type socketOpsCommon struct {
- socketOpsCommonRefs
socket.SendReceiveTimeout
ep transport.Endpoint
@@ -101,23 +118,6 @@ type socketOpsCommon struct {
abstractNamespace *kernel.AbstractSocketNamespace
}
-// DecRef implements RefCounter.DecRef.
-func (s *socketOpsCommon) DecRef(ctx context.Context) {
- s.socketOpsCommonRefs.DecRef(func() {
- s.ep.Close(ctx)
- if s.abstractNamespace != nil {
- s.abstractNamespace.Remove(s.abstractName, s)
- }
- })
-}
-
-// Release implemements fs.FileOperations.Release.
-func (s *socketOpsCommon) Release(ctx context.Context) {
- // Release only decrements a reference on s because s may be referenced in
- // the abstract socket namespace.
- s.DecRef(ctx)
-}
-
func (s *socketOpsCommon) isPacket() bool {
switch s.stype {
case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
@@ -205,7 +205,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
// blockingAccept implements a blocking version of accept(2), that is, if no
// connections are ready to be accept, it will block until one becomes ready.
-func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+func (s *SocketOperations) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) {
// Register for notifications.
e, ch := waiter.NewChannelEntry(nil)
s.EventRegister(&e, waiter.EventIn)
@@ -214,7 +214,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *
// Try to accept the connection; if it fails, then wait until we get a
// notification.
for {
- if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock {
return ep, err
}
@@ -227,15 +227,18 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *
// Accept implements the linux syscall accept(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
- // Issue the accept request to get the new endpoint.
- ep, err := s.ep.Accept()
+ var peerAddr *tcpip.FullAddress
+ if peerRequested {
+ peerAddr = &tcpip.FullAddress{}
+ }
+ ep, err := s.ep.Accept(peerAddr)
if err != nil {
if err != syserr.ErrWouldBlock || !blocking {
return 0, nil, 0, err
}
var err *syserr.Error
- ep, err = s.blockingAccept(t)
+ ep, err = s.blockingAccept(t, peerAddr)
if err != nil {
return 0, nil, 0, err
}
@@ -252,13 +255,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
- if peerRequested {
- // Get address of the peer.
- var err *syserr.Error
- addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
- if err != nil {
- return 0, nil, 0, err
- }
+ if peerAddr != nil {
+ addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 65a285b8f..3345124cc 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
@@ -32,17 +33,19 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketVFS2 implements socket.SocketVFS2 (and by extension,
// vfs.FileDescriptionImpl) for Unix sockets.
+//
+// +stateify savable
type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
vfs.LockFD
+ socketVFS2Refs
socketOpsCommon
}
@@ -53,6 +56,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{})
func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
mnt := t.Kernel().SocketMount()
d := sockfs.NewDentry(t.Credentials(), mnt)
+ defer d.DecRef(t)
fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{})
if err != nil {
@@ -88,6 +92,25 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
return vfsfd, nil
}
+// DecRef implements RefCounter.DecRef.
+func (s *SocketVFS2) DecRef(ctx context.Context) {
+ s.socketVFS2Refs.DecRef(func() {
+ t := kernel.TaskFromContext(ctx)
+ t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ s.ep.Close(ctx)
+ if s.abstractNamespace != nil {
+ s.abstractNamespace.Remove(s.abstractName, s)
+ }
+ })
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (s *SocketVFS2) Release(ctx context.Context) {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef(ctx)
+}
+
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
@@ -96,7 +119,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
// blockingAccept implements a blocking version of accept(2), that is, if no
// connections are ready to be accept, it will block until one becomes ready.
-func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+func (s *SocketVFS2) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) {
// Register for notifications.
e, ch := waiter.NewChannelEntry(nil)
s.socketOpsCommon.EventRegister(&e, waiter.EventIn)
@@ -105,7 +128,7 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr
// Try to accept the connection; if it fails, then wait until we get a
// notification.
for {
- if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock {
return ep, err
}
@@ -118,15 +141,18 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr
// Accept implements the linux syscall accept(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
- // Issue the accept request to get the new endpoint.
- ep, err := s.ep.Accept()
+ var peerAddr *tcpip.FullAddress
+ if peerRequested {
+ peerAddr = &tcpip.FullAddress{}
+ }
+ ep, err := s.ep.Accept(peerAddr)
if err != nil {
if err != syserr.ErrWouldBlock || !blocking {
return 0, nil, 0, err
}
var err *syserr.Error
- ep, err = s.blockingAccept(t)
+ ep, err = s.blockingAccept(t, peerAddr)
if err != nil {
return 0, nil, 0, err
}
@@ -144,13 +170,8 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addr linux.SockAddr
var addrLen uint32
- if peerRequested {
- // Get address of the peer.
- var err *syserr.Error
- addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
- if err != nil {
- return 0, nil, 0, err
- }
+ if peerAddr != nil {
+ addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go
index a06c9b8ab..245d2c5cf 100644
--- a/pkg/sentry/state/state.go
+++ b/pkg/sentry/state/state.go
@@ -61,8 +61,10 @@ func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error {
log.Infof("Sandbox save started, pausing all tasks.")
k.Pause()
k.ReceiveTaskStates()
- defer k.Unpause()
- defer log.Infof("Tasks resumed after save.")
+ defer func() {
+ k.Unpause()
+ log.Infof("Tasks resumed after save.")
+ }()
w.Stop()
defer w.Start()
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 88d5db9fc..a920180d3 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -28,6 +28,7 @@ go_library(
"//pkg/binary",
"//pkg/bits",
"//pkg/eventchannel",
+ "//pkg/marshal/primitive",
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/kernel",
diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go
index 5d51a7792..ae3b998c8 100644
--- a/pkg/sentry/strace/epoll.go
+++ b/pkg/sentry/strace/epoll.go
@@ -26,7 +26,7 @@ import (
func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string {
var e linux.EpollEvent
- if _, err := t.CopyIn(eventAddr, &e); err != nil {
+ if _, err := e.CopyIn(t, eventAddr); err != nil {
return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err)
}
var sb strings.Builder
@@ -41,7 +41,7 @@ func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes ui
addr := eventsAddr
for i := uint64(0); i < numEvents; i++ {
var e linux.EpollEvent
- if _, err := t.CopyIn(addr, &e); err != nil {
+ if _, err := e.CopyIn(t, addr); err != nil {
fmt.Fprintf(&sb, "{error reading event at %#x: %v}", addr, err)
continue
}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index 08e97e6c4..cc5f70cd4 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
@@ -166,7 +167,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
}
buf := make([]byte, length)
- if _, err := t.CopyIn(addr, &buf); err != nil {
+ if _, err := t.CopyInBytes(addr, buf); err != nil {
return fmt.Sprintf("%#x (error decoding control: %v)", addr, err)
}
@@ -302,7 +303,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
func msghdr(t *kernel.Task, addr usermem.Addr, printContent bool, maxBytes uint64) string {
var msg slinux.MessageHeader64
- if err := slinux.CopyInMessageHeader64(t, addr, &msg); err != nil {
+ if _, err := msg.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding msghdr: %v)", addr, err)
}
s := fmt.Sprintf(
@@ -380,9 +381,9 @@ func postSockAddr(t *kernel.Task, addr usermem.Addr, lengthPtr usermem.Addr) str
func copySockLen(t *kernel.Task, addr usermem.Addr) (uint32, error) {
// socklen_t is 32-bits.
- var l uint32
- _, err := t.CopyIn(addr, &l)
- return l, err
+ var l primitive.Uint32
+ _, err := l.CopyIn(t, addr)
+ return uint32(l), err
}
func sockLenPointer(t *kernel.Task, addr usermem.Addr) string {
@@ -436,22 +437,22 @@ func getSockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, o
func sockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen uint64, maximumBlobSize uint) string {
switch optLen {
case 1:
- var v uint8
- _, err := t.CopyIn(optVal, &v)
+ var v primitive.Uint8
+ _, err := v.CopyIn(t, optVal)
if err != nil {
return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
}
return fmt.Sprintf("%#x {value=%v}", optVal, v)
case 2:
- var v uint16
- _, err := t.CopyIn(optVal, &v)
+ var v primitive.Uint16
+ _, err := v.CopyIn(t, optVal)
if err != nil {
return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
}
return fmt.Sprintf("%#x {value=%v}", optVal, v)
case 4:
- var v uint32
- _, err := t.CopyIn(optVal, &v)
+ var v primitive.Uint32
+ _, err := v.CopyIn(t, optVal)
if err != nil {
return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
}
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 87b239730..396744597 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -17,17 +17,16 @@
package strace
import (
- "encoding/binary"
"fmt"
"strconv"
"strings"
- "syscall"
"time"
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/eventchannel"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -91,7 +90,7 @@ func iovecs(t *kernel.Task, addr usermem.Addr, iovcnt int, printContent bool, ma
}
b := make([]byte, size)
- amt, err := t.CopyIn(ar.Start, b)
+ amt, err := t.CopyInBytes(ar.Start, b)
if err != nil {
iovs[i] = fmt.Sprintf("{base=%#x, len=%d, %q..., error decoding string: %v}", ar.Start, ar.Length(), b[:amt], err)
continue
@@ -118,7 +117,7 @@ func dump(t *kernel.Task, addr usermem.Addr, size uint, maximumBlobSize uint) st
}
b := make([]byte, size)
- amt, err := t.CopyIn(addr, b)
+ amt, err := t.CopyInBytes(addr, b)
if err != nil {
return fmt.Sprintf("%#x (error decoding string: %s)", addr, err)
}
@@ -199,7 +198,7 @@ func fdVFS2(t *kernel.Task, fd int32) string {
func fdpair(t *kernel.Task, addr usermem.Addr) string {
var fds [2]int32
- _, err := t.CopyIn(addr, &fds)
+ _, err := primitive.CopyInt32SliceIn(t, addr, fds[:])
if err != nil {
return fmt.Sprintf("%#x (error decoding fds: %s)", addr, err)
}
@@ -209,7 +208,7 @@ func fdpair(t *kernel.Task, addr usermem.Addr) string {
func uname(t *kernel.Task, addr usermem.Addr) string {
var u linux.UtsName
- if _, err := t.CopyIn(addr, &u); err != nil {
+ if _, err := u.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding utsname: %s)", addr, err)
}
@@ -222,7 +221,7 @@ func utimensTimespec(t *kernel.Task, addr usermem.Addr) string {
}
var tim linux.Timespec
- if _, err := t.CopyIn(addr, &tim); err != nil {
+ if _, err := tim.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding timespec: %s)", addr, err)
}
@@ -244,7 +243,7 @@ func timespec(t *kernel.Task, addr usermem.Addr) string {
}
var tim linux.Timespec
- if _, err := t.CopyIn(addr, &tim); err != nil {
+ if _, err := tim.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding timespec: %s)", addr, err)
}
return fmt.Sprintf("%#x {sec=%v nsec=%v}", addr, tim.Sec, tim.Nsec)
@@ -256,7 +255,7 @@ func timeval(t *kernel.Task, addr usermem.Addr) string {
}
var tim linux.Timeval
- if _, err := t.CopyIn(addr, &tim); err != nil {
+ if _, err := tim.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding timeval: %s)", addr, err)
}
@@ -268,8 +267,8 @@ func utimbuf(t *kernel.Task, addr usermem.Addr) string {
return "null"
}
- var utim syscall.Utimbuf
- if _, err := t.CopyIn(addr, &utim); err != nil {
+ var utim linux.Utime
+ if _, err := utim.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding utimbuf: %s)", addr, err)
}
@@ -282,7 +281,7 @@ func stat(t *kernel.Task, addr usermem.Addr) string {
}
var stat linux.Stat
- if _, err := t.CopyIn(addr, &stat); err != nil {
+ if _, err := stat.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding stat: %s)", addr, err)
}
return fmt.Sprintf("%#x {dev=%d, ino=%d, mode=%s, nlink=%d, uid=%d, gid=%d, rdev=%d, size=%d, blksize=%d, blocks=%d, atime=%s, mtime=%s, ctime=%s}", addr, stat.Dev, stat.Ino, linux.FileMode(stat.Mode), stat.Nlink, stat.UID, stat.GID, stat.Rdev, stat.Size, stat.Blksize, stat.Blocks, time.Unix(stat.ATime.Sec, stat.ATime.Nsec), time.Unix(stat.MTime.Sec, stat.MTime.Nsec), time.Unix(stat.CTime.Sec, stat.CTime.Nsec))
@@ -294,7 +293,7 @@ func itimerval(t *kernel.Task, addr usermem.Addr) string {
}
interval := timeval(t, addr)
- value := timeval(t, addr+usermem.Addr(binary.Size(linux.Timeval{})))
+ value := timeval(t, addr+usermem.Addr((*linux.Timeval)(nil).SizeBytes()))
return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value)
}
@@ -304,7 +303,7 @@ func itimerspec(t *kernel.Task, addr usermem.Addr) string {
}
interval := timespec(t, addr)
- value := timespec(t, addr+usermem.Addr(binary.Size(linux.Timespec{})))
+ value := timespec(t, addr+usermem.Addr((*linux.Timespec)(nil).SizeBytes()))
return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value)
}
@@ -330,7 +329,7 @@ func rusage(t *kernel.Task, addr usermem.Addr) string {
}
var ru linux.Rusage
- if _, err := t.CopyIn(addr, &ru); err != nil {
+ if _, err := ru.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding rusage: %s)", addr, err)
}
return fmt.Sprintf("%#x %+v", addr, ru)
@@ -342,7 +341,7 @@ func capHeader(t *kernel.Task, addr usermem.Addr) string {
}
var hdr linux.CapUserHeader
- if _, err := t.CopyIn(addr, &hdr); err != nil {
+ if _, err := hdr.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding header: %s)", addr, err)
}
@@ -367,7 +366,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string {
}
var hdr linux.CapUserHeader
- if _, err := t.CopyIn(hdrAddr, &hdr); err != nil {
+ if _, err := hdr.CopyIn(t, hdrAddr); err != nil {
return fmt.Sprintf("%#x (error decoding header: %v)", dataAddr, err)
}
@@ -376,7 +375,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string {
switch hdr.Version {
case linux.LINUX_CAPABILITY_VERSION_1:
var data linux.CapUserData
- if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ if _, err := data.CopyIn(t, dataAddr); err != nil {
return fmt.Sprintf("%#x (error decoding data: %v)", dataAddr, err)
}
p = uint64(data.Permitted)
@@ -384,7 +383,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string {
e = uint64(data.Effective)
case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3:
var data [2]linux.CapUserData
- if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ if _, err := linux.CopyCapUserDataSliceIn(t, dataAddr, data[:]); err != nil {
return fmt.Sprintf("%#x (error decoding data: %v)", dataAddr, err)
}
p = uint64(data[0].Permitted) | (uint64(data[1].Permitted) << 32)
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 4a9b04fd0..75752b2e6 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -56,6 +56,7 @@ go_library(
"sys_xattr.go",
"timespec.go",
],
+ marshal = True,
visibility = ["//:sandbox"],
deps = [
"//pkg/abi",
@@ -64,6 +65,8 @@ go_library(
"//pkg/bpf",
"//pkg/context",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/metric",
"//pkg/rand",
"//pkg/safemem",
@@ -99,7 +102,5 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index da6bd85e1..5f26697d2 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -138,7 +138,7 @@ var AMD64 = &kernel.SyscallTable{
83: syscalls.Supported("mkdir", Mkdir),
84: syscalls.Supported("rmdir", Rmdir),
85: syscalls.Supported("creat", Creat),
- 86: syscalls.Supported("link", Link),
+ 86: syscalls.PartiallySupported("link", Link, "Limited support with Gofer. Link count and linked files may get out of sync because gVisor is not aware of external hardlinks.", nil),
87: syscalls.Supported("unlink", Unlink),
88: syscalls.Supported("symlink", Symlink),
89: syscalls.Supported("readlink", Readlink),
@@ -317,7 +317,7 @@ var AMD64 = &kernel.SyscallTable{
262: syscalls.Supported("fstatat", Fstatat),
263: syscalls.Supported("unlinkat", Unlinkat),
264: syscalls.Supported("renameat", Renameat),
- 265: syscalls.Supported("linkat", Linkat),
+ 265: syscalls.PartiallySupported("linkat", Linkat, "See link(2).", nil),
266: syscalls.Supported("symlinkat", Symlinkat),
267: syscalls.Supported("readlinkat", Readlinkat),
268: syscalls.Supported("fchmodat", Fchmodat),
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
index e9d64dec5..0bf313a13 100644
--- a/pkg/sentry/syscalls/linux/sys_aio.go
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -17,6 +17,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -36,7 +37,7 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
//
// The context pointer _must_ be zero initially.
var idIn uint64
- if _, err := t.CopyIn(idAddr, &idIn); err != nil {
+ if _, err := primitive.CopyUint64In(t, idAddr, &idIn); err != nil {
return 0, nil, err
}
if idIn != 0 {
@@ -49,7 +50,7 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
// Copy out the new ID.
- if _, err := t.CopyOut(idAddr, &id); err != nil {
+ if _, err := primitive.CopyUint64Out(t, idAddr, id); err != nil {
t.MemoryManager().DestroyAIOContext(t, id)
return 0, nil, err
}
@@ -142,7 +143,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
ev := v.(*linux.IOEvent)
// Copy out the result.
- if _, err := t.CopyOut(eventsAddr, ev); err != nil {
+ if _, err := ev.CopyOut(t, eventsAddr); err != nil {
if count > 0 {
return uintptr(count), nil, nil
}
@@ -338,21 +339,27 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
for i := int32(0); i < nrEvents; i++ {
- // Copy in the address.
- cbAddrNative := t.Arch().Native(0)
- if _, err := t.CopyIn(addr, cbAddrNative); err != nil {
- if i > 0 {
- // Some successful.
- return uintptr(i), nil, nil
+ // Copy in the callback address.
+ var cbAddr usermem.Addr
+ switch t.Arch().Width() {
+ case 8:
+ var cbAddrP primitive.Uint64
+ if _, err := cbAddrP.CopyIn(t, addr); err != nil {
+ if i > 0 {
+ // Some successful.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
}
- // Nothing done.
- return 0, nil, err
+ cbAddr = usermem.Addr(cbAddrP)
+ default:
+ return 0, nil, syserror.ENOSYS
}
// Copy in this callback.
var cb linux.IOCallback
- cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative))
- if _, err := t.CopyIn(cbAddr, &cb); err != nil {
+ if _, err := cb.CopyIn(t, cbAddr); err != nil {
if i > 0 {
// Some have been successful.
diff --git a/pkg/sentry/syscalls/linux/sys_capability.go b/pkg/sentry/syscalls/linux/sys_capability.go
index adf5ea5f2..d3b85e11b 100644
--- a/pkg/sentry/syscalls/linux/sys_capability.go
+++ b/pkg/sentry/syscalls/linux/sys_capability.go
@@ -45,7 +45,7 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
dataAddr := args[1].Pointer()
var hdr linux.CapUserHeader
- if _, err := t.CopyIn(hdrAddr, &hdr); err != nil {
+ if _, err := hdr.CopyIn(t, hdrAddr); err != nil {
return 0, nil, err
}
// hdr.Pid doesn't need to be valid if this capget() is a "version probe"
@@ -65,7 +65,7 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
Permitted: uint32(p),
Inheritable: uint32(i),
}
- _, err = t.CopyOut(dataAddr, &data)
+ _, err = data.CopyOut(t, dataAddr)
return 0, nil, err
case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3:
@@ -88,12 +88,12 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
Inheritable: uint32(i >> 32),
},
}
- _, err = t.CopyOut(dataAddr, &data)
+ _, err = linux.CopyCapUserDataSliceOut(t, dataAddr, data[:])
return 0, nil, err
default:
hdr.Version = linux.HighestCapabilityVersion
- if _, err := t.CopyOut(hdrAddr, &hdr); err != nil {
+ if _, err := hdr.CopyOut(t, hdrAddr); err != nil {
return 0, nil, err
}
if dataAddr != 0 {
@@ -109,7 +109,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
dataAddr := args[1].Pointer()
var hdr linux.CapUserHeader
- if _, err := t.CopyIn(hdrAddr, &hdr); err != nil {
+ if _, err := hdr.CopyIn(t, hdrAddr); err != nil {
return 0, nil, err
}
switch hdr.Version {
@@ -118,7 +118,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.EPERM
}
var data linux.CapUserData
- if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ if _, err := data.CopyIn(t, dataAddr); err != nil {
return 0, nil, err
}
p := auth.CapabilitySet(data.Permitted) & auth.AllCapabilities
@@ -131,7 +131,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.EPERM
}
var data [2]linux.CapUserData
- if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ if _, err := linux.CopyCapUserDataSliceIn(t, dataAddr, data[:]); err != nil {
return 0, nil, err
}
p := (auth.CapabilitySet(data[0].Permitted) | (auth.CapabilitySet(data[1].Permitted) << 32)) & auth.AllCapabilities
@@ -141,7 +141,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
default:
hdr.Version = linux.HighestCapabilityVersion
- if _, err := t.CopyOut(hdrAddr, &hdr); err != nil {
+ if _, err := hdr.CopyOut(t, hdrAddr); err != nil {
return 0, nil, err
}
return 0, nil, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 256422689..98331eb3c 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
@@ -601,19 +602,19 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Shared flags between file and socket.
switch request {
case linux.FIONCLEX:
- t.FDTable().SetFlags(fd, kernel.FDFlags{
+ t.FDTable().SetFlags(t, fd, kernel.FDFlags{
CloseOnExec: false,
})
return 0, nil, nil
case linux.FIOCLEX:
- t.FDTable().SetFlags(fd, kernel.FDFlags{
+ t.FDTable().SetFlags(t, fd, kernel.FDFlags{
CloseOnExec: true,
})
return 0, nil, nil
case linux.FIONBIO:
var set int32
- if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil {
return 0, nil, err
}
flags := file.Flags()
@@ -627,7 +628,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.FIOASYNC:
var set int32
- if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil {
return 0, nil, err
}
flags := file.Flags()
@@ -641,15 +642,14 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.FIOSETOWN, linux.SIOCSPGRP:
var set int32
- if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil {
return 0, nil, err
}
fSetOwn(t, file, set)
return 0, nil, nil
case linux.FIOGETOWN, linux.SIOCGPGRP:
- who := fGetOwn(t, file)
- _, err := t.CopyOut(args[2].Pointer(), &who)
+ _, err := primitive.CopyInt32Out(t, args[2].Pointer(), fGetOwn(t, file))
return 0, nil, err
default:
@@ -694,7 +694,7 @@ func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
// Top it off with a terminator.
- _, err = t.CopyOut(addr+usermem.Addr(bytes), []byte("\x00"))
+ _, err = t.CopyOutBytes(addr+usermem.Addr(bytes), []byte("\x00"))
return uintptr(bytes + 1), nil, err
}
@@ -787,7 +787,7 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Note that Remove provides a reference on the file that we may use to
// flush. It is still active until we drop the final reference below
// (and other reference-holding operations complete).
- file, _ := t.FDTable().Remove(fd)
+ file, _ := t.FDTable().Remove(t, fd)
if file == nil {
return 0, nil, syserror.EBADF
}
@@ -941,7 +941,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return uintptr(flags.ToLinuxFDFlags()), nil, nil
case linux.F_SETFD:
flags := args[2].Uint()
- err := t.FDTable().SetFlags(fd, kernel.FDFlags{
+ err := t.FDTable().SetFlags(t, fd, kernel.FDFlags{
CloseOnExec: flags&linux.FD_CLOEXEC != 0,
})
return 0, nil, err
@@ -962,7 +962,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Copy in the lock request.
flockAddr := args[2].Pointer()
var flock linux.Flock
- if _, err := t.CopyIn(flockAddr, &flock); err != nil {
+ if _, err := flock.CopyIn(t, flockAddr); err != nil {
return 0, nil, err
}
@@ -1052,12 +1052,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.F_GETOWN_EX:
addr := args[2].Pointer()
owner := fGetOwnEx(t, file)
- _, err := t.CopyOut(addr, &owner)
+ _, err := owner.CopyOut(t, addr)
return 0, nil, err
case linux.F_SETOWN_EX:
addr := args[2].Pointer()
var owner linux.FOwnerEx
- _, err := t.CopyIn(addr, &owner)
+ _, err := owner.CopyIn(t, addr)
if err != nil {
return 0, nil, err
}
@@ -1154,6 +1154,10 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, nil
}
+// LINT.ThenChange(vfs2/fd.go)
+
+// LINT.IfChange
+
func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error {
path, _, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1918,7 +1922,7 @@ func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
ts := defaultSetToSystemTimeSpec()
if timesAddr != 0 {
var times linux.Utime
- if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ if _, err := times.CopyIn(t, timesAddr); err != nil {
return 0, nil, err
}
ts = fs.TimeSpec{
@@ -1938,7 +1942,7 @@ func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
ts := defaultSetToSystemTimeSpec()
if timesAddr != 0 {
var times [2]linux.Timeval
- if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil {
return 0, nil, err
}
ts = fs.TimeSpec{
@@ -1966,7 +1970,7 @@ func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
ts := defaultSetToSystemTimeSpec()
if timesAddr != 0 {
var times [2]linux.Timespec
- if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ if _, err := linux.CopyTimespecSliceIn(t, timesAddr, times[:]); err != nil {
return 0, nil, err
}
if !timespecIsValid(times[0]) || !timespecIsValid(times[1]) {
@@ -2000,7 +2004,7 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
ts := defaultSetToSystemTimeSpec()
if timesAddr != 0 {
var times [2]linux.Timeval
- if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil {
return 0, nil, err
}
if times[0].Usec >= 1e6 || times[0].Usec < 0 ||
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
index 12b2fa690..f39ce0639 100644
--- a/pkg/sentry/syscalls/linux/sys_futex.go
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -306,8 +306,8 @@ func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
// Despite the syscall using the name 'pid' for this variable, it is
// very much a tid.
tid := args[0].Int()
- head := args[1].Pointer()
- size := args[2].Pointer()
+ headAddr := args[1].Pointer()
+ sizeAddr := args[2].Pointer()
if tid < 0 {
return 0, nil, syserror.EINVAL
@@ -321,12 +321,16 @@ func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
}
// Copy out head pointer.
- if _, err := t.CopyOut(head, uint64(ot.GetRobustList())); err != nil {
+ head := t.Arch().Native(uintptr(ot.GetRobustList()))
+ if _, err := head.CopyOut(t, headAddr); err != nil {
return 0, nil, err
}
- // Copy out size, which is a constant.
- if _, err := t.CopyOut(size, uint64(linux.SizeOfRobustListHead)); err != nil {
+ // Copy out size, which is a constant. Note that while size isn't
+ // an address, it is defined as the arch-dependent size_t, so it
+ // needs to be converted to a native-sized int.
+ size := t.Arch().Native(uintptr(linux.SizeOfRobustListHead))
+ if _, err := size.CopyOut(t, sizeAddr); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
index 59004cefe..b25f7d881 100644
--- a/pkg/sentry/syscalls/linux/sys_getdents.go
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -19,7 +19,6 @@ import (
"io"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -93,19 +92,23 @@ func getdents(t *kernel.Task, fd int32, addr usermem.Addr, size int, f func(*dir
}
}
-// oldDirentHdr is a fixed sized header matching the fixed size
-// fields found in the old linux dirent struct.
+// oldDirentHdr is a fixed sized header matching the fixed size fields found in
+// the old linux dirent struct.
+//
+// +marshal
type oldDirentHdr struct {
Ino uint64
Off uint64
- Reclen uint16
+ Reclen uint16 `marshal:"unaligned"` // Struct ends mid-word.
}
-// direntHdr is a fixed sized header matching the fixed size
-// fields found in the new linux dirent struct.
+// direntHdr is a fixed sized header matching the fixed size fields found in the
+// new linux dirent struct.
+//
+// +marshal
type direntHdr struct {
OldHdr oldDirentHdr
- Typ uint8
+ Typ uint8 `marshal:"unaligned"` // Struct ends mid-word.
}
// dirent contains the data pointed to by a new linux dirent struct.
@@ -134,20 +137,20 @@ func newDirent(width uint, name string, attr fs.DentAttr, offset uint64) *dirent
// the old linux dirent format.
func smallestDirent(a arch.Context) uint {
d := dirent{}
- return uint(binary.Size(d.Hdr.OldHdr)) + a.Width() + 1
+ return uint(d.Hdr.OldHdr.SizeBytes()) + a.Width() + 1
}
// smallestDirent64 returns the size of the smallest possible dirent using
// the new linux dirent format.
func smallestDirent64(a arch.Context) uint {
d := dirent{}
- return uint(binary.Size(d.Hdr)) + a.Width()
+ return uint(d.Hdr.SizeBytes()) + a.Width()
}
// padRec pads the name field until the rec length is a multiple of the width,
// which must be a power of 2. It returns the padded rec length.
func (d *dirent) padRec(width int) uint16 {
- a := int(binary.Size(d.Hdr)) + len(d.Name)
+ a := d.Hdr.SizeBytes() + len(d.Name)
r := (a + width) &^ (width - 1)
padding := r - a
d.Name = append(d.Name, make([]byte, padding)...)
@@ -157,7 +160,7 @@ func (d *dirent) padRec(width int) uint16 {
// Serialize64 serializes a Dirent struct to a byte slice, keeping the new
// linux dirent format. Returns the number of bytes serialized or an error.
func (d *dirent) Serialize64(w io.Writer) (int, error) {
- n1, err := w.Write(binary.Marshal(nil, usermem.ByteOrder, d.Hdr))
+ n1, err := d.Hdr.WriteTo(w)
if err != nil {
return 0, err
}
@@ -165,14 +168,14 @@ func (d *dirent) Serialize64(w io.Writer) (int, error) {
if err != nil {
return 0, err
}
- return n1 + n2, nil
+ return int(n1) + n2, nil
}
// Serialize serializes a Dirent struct to a byte slice, using the old linux
// dirent format.
// Returns the number of bytes serialized or an error.
func (d *dirent) Serialize(w io.Writer) (int, error) {
- n1, err := w.Write(binary.Marshal(nil, usermem.ByteOrder, d.Hdr.OldHdr))
+ n1, err := d.Hdr.OldHdr.WriteTo(w)
if err != nil {
return 0, err
}
@@ -184,7 +187,7 @@ func (d *dirent) Serialize(w io.Writer) (int, error) {
if err != nil {
return 0, err
}
- return n1 + n2 + n3, nil
+ return int(n1) + n2 + n3, nil
}
// direntSerializer implements fs.InodeOperationsInfoSerializer, serializing dirents to an
diff --git a/pkg/sentry/syscalls/linux/sys_identity.go b/pkg/sentry/syscalls/linux/sys_identity.go
index 715ac45e6..a29d307e5 100644
--- a/pkg/sentry/syscalls/linux/sys_identity.go
+++ b/pkg/sentry/syscalls/linux/sys_identity.go
@@ -49,13 +49,13 @@ func Getresuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
ruid := c.RealKUID.In(c.UserNamespace).OrOverflow()
euid := c.EffectiveKUID.In(c.UserNamespace).OrOverflow()
suid := c.SavedKUID.In(c.UserNamespace).OrOverflow()
- if _, err := t.CopyOut(ruidAddr, ruid); err != nil {
+ if _, err := ruid.CopyOut(t, ruidAddr); err != nil {
return 0, nil, err
}
- if _, err := t.CopyOut(euidAddr, euid); err != nil {
+ if _, err := euid.CopyOut(t, euidAddr); err != nil {
return 0, nil, err
}
- if _, err := t.CopyOut(suidAddr, suid); err != nil {
+ if _, err := suid.CopyOut(t, suidAddr); err != nil {
return 0, nil, err
}
return 0, nil, nil
@@ -84,13 +84,13 @@ func Getresgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
rgid := c.RealKGID.In(c.UserNamespace).OrOverflow()
egid := c.EffectiveKGID.In(c.UserNamespace).OrOverflow()
sgid := c.SavedKGID.In(c.UserNamespace).OrOverflow()
- if _, err := t.CopyOut(rgidAddr, rgid); err != nil {
+ if _, err := rgid.CopyOut(t, rgidAddr); err != nil {
return 0, nil, err
}
- if _, err := t.CopyOut(egidAddr, egid); err != nil {
+ if _, err := egid.CopyOut(t, egidAddr); err != nil {
return 0, nil, err
}
- if _, err := t.CopyOut(sgidAddr, sgid); err != nil {
+ if _, err := sgid.CopyOut(t, sgidAddr); err != nil {
return 0, nil, err
}
return 0, nil, nil
@@ -157,7 +157,7 @@ func Getgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
for i, kgid := range kgids {
gids[i] = kgid.In(t.UserNamespace()).OrOverflow()
}
- if _, err := t.CopyOut(args[1].Pointer(), gids); err != nil {
+ if _, err := auth.CopyGIDSliceOut(t, args[1].Pointer(), gids); err != nil {
return 0, nil, err
}
return uintptr(len(gids)), nil, nil
@@ -173,7 +173,7 @@ func Setgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, t.SetExtraGIDs(nil)
}
gids := make([]auth.GID, size)
- if _, err := t.CopyIn(args[1].Pointer(), &gids); err != nil {
+ if _, err := auth.CopyGIDSliceIn(t, args[1].Pointer(), gids); err != nil {
return 0, nil, err
}
return 0, nil, t.SetExtraGIDs(gids)
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
index d0109baa4..cd8dfdfa4 100644
--- a/pkg/sentry/syscalls/linux/sys_mmap.go
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -100,6 +100,15 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if err := file.ConfigureMMap(t, &opts); err != nil {
return 0, nil, err
}
+ } else if shared {
+ // Back shared anonymous mappings with a special mappable.
+ opts.Offset = 0
+ m, err := mm.NewSharedAnonMappable(opts.Length, t.Kernel())
+ if err != nil {
+ return 0, nil, err
+ }
+ opts.MappingIdentity = m // transfers ownership of m to opts
+ opts.Mappable = m
}
rv, err := t.MemoryManager().MMap(t, opts)
@@ -239,7 +248,7 @@ func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
return 0, nil, syserror.ENOMEM
}
resident := bytes.Repeat([]byte{1}, int(mapped/usermem.PageSize))
- _, err := t.CopyOut(vec, resident)
+ _, err := t.CopyOutBytes(vec, resident)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go
index 3149e4aad..849a47476 100644
--- a/pkg/sentry/syscalls/linux/sys_pipe.go
+++ b/pkg/sentry/syscalls/linux/sys_pipe.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -46,9 +47,9 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
return 0, err
}
- if _, err := t.CopyOut(addr, fds); err != nil {
+ if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil {
for _, fd := range fds {
- if file, _ := t.FDTable().Remove(fd); file != nil {
+ if file, _ := t.FDTable().Remove(t, fd); file != nil {
file.DecRef(t)
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go
index 789e2ed5b..254f4c9f9 100644
--- a/pkg/sentry/syscalls/linux/sys_poll.go
+++ b/pkg/sentry/syscalls/linux/sys_poll.go
@@ -162,7 +162,7 @@ func CopyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD
pfd := make([]linux.PollFD, nfds)
if nfds > 0 {
- if _, err := t.CopyIn(addr, &pfd); err != nil {
+ if _, err := linux.CopyPollFDSliceIn(t, addr, pfd); err != nil {
return nil, err
}
}
@@ -189,7 +189,7 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration)
// The poll entries are copied out regardless of whether
// any are set or not. This aligns with the Linux behavior.
if nfds > 0 && err == nil {
- if _, err := t.CopyOut(addr, pfd); err != nil {
+ if _, err := linux.CopyPollFDSliceOut(t, addr, pfd); err != nil {
return remainingTimeout, 0, err
}
}
@@ -202,7 +202,7 @@ func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialBy
set := make([]byte, nBytes)
if addr != 0 {
- if _, err := t.CopyIn(addr, &set); err != nil {
+ if _, err := t.CopyInBytes(addr, set); err != nil {
return nil, err
}
// If we only use part of the last byte, mask out the extraneous bits.
@@ -329,19 +329,19 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add
// Copy updated vectors back.
if readFDs != 0 {
- if _, err := t.CopyOut(readFDs, r); err != nil {
+ if _, err := t.CopyOutBytes(readFDs, r); err != nil {
return 0, err
}
}
if writeFDs != 0 {
- if _, err := t.CopyOut(writeFDs, w); err != nil {
+ if _, err := t.CopyOutBytes(writeFDs, w); err != nil {
return 0, err
}
}
if exceptFDs != 0 {
- if _, err := t.CopyOut(exceptFDs, e); err != nil {
+ if _, err := t.CopyOutBytes(exceptFDs, e); err != nil {
return 0, err
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go
index 64a725296..a892d2c62 100644
--- a/pkg/sentry/syscalls/linux/sys_prctl.go
+++ b/pkg/sentry/syscalls/linux/sys_prctl.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
@@ -43,7 +44,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, nil
case linux.PR_GET_PDEATHSIG:
- _, err := t.CopyOut(args[1].Pointer(), int32(t.ParentDeathSignal()))
+ _, err := primitive.CopyInt32Out(t, args[1].Pointer(), int32(t.ParentDeathSignal()))
return 0, nil, err
case linux.PR_GET_DUMPABLE:
@@ -110,7 +111,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
buf[len] = 0
len++
}
- _, err := t.CopyOut(addr, buf[:len])
+ _, err := t.CopyOutBytes(addr, buf[:len])
if err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go
index d5d5b6959..309c183a3 100644
--- a/pkg/sentry/syscalls/linux/sys_rlimit.go
+++ b/pkg/sentry/syscalls/linux/sys_rlimit.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/limits"
@@ -26,17 +27,13 @@ import (
// rlimit describes an implementation of 'struct rlimit', which may vary from
// system-to-system.
type rlimit interface {
+ marshal.Marshallable
+
// toLimit converts an rlimit to a limits.Limit.
toLimit() *limits.Limit
// fromLimit converts a limits.Limit to an rlimit.
fromLimit(lim limits.Limit)
-
- // copyIn copies an rlimit from the untrusted app to the kernel.
- copyIn(t *kernel.Task, addr usermem.Addr) error
-
- // copyOut copies an rlimit from the kernel to the untrusted app.
- copyOut(t *kernel.Task, addr usermem.Addr) error
}
// newRlimit returns the appropriate rlimit type for 'struct rlimit' on this system.
@@ -50,6 +47,7 @@ func newRlimit(t *kernel.Task) (rlimit, error) {
}
}
+// +marshal
type rlimit64 struct {
Cur uint64
Max uint64
@@ -70,12 +68,12 @@ func (r *rlimit64) fromLimit(lim limits.Limit) {
}
func (r *rlimit64) copyIn(t *kernel.Task, addr usermem.Addr) error {
- _, err := t.CopyIn(addr, r)
+ _, err := r.CopyIn(t, addr)
return err
}
func (r *rlimit64) copyOut(t *kernel.Task, addr usermem.Addr) error {
- _, err := t.CopyOut(addr, *r)
+ _, err := r.CopyOut(t, addr)
return err
}
@@ -140,7 +138,8 @@ func Getrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, err
}
rlim.fromLimit(lim)
- return 0, nil, rlim.copyOut(t, addr)
+ _, err = rlim.CopyOut(t, addr)
+ return 0, nil, err
}
// Setrlimit implements linux syscall setrlimit(2).
@@ -155,7 +154,7 @@ func Setrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if err != nil {
return 0, nil, err
}
- if err := rlim.copyIn(t, addr); err != nil {
+ if _, err := rlim.CopyIn(t, addr); err != nil {
return 0, nil, syserror.EFAULT
}
_, err = prlimit64(t, resource, rlim.toLimit())
diff --git a/pkg/sentry/syscalls/linux/sys_rusage.go b/pkg/sentry/syscalls/linux/sys_rusage.go
index 1674c7445..ac5c98a54 100644
--- a/pkg/sentry/syscalls/linux/sys_rusage.go
+++ b/pkg/sentry/syscalls/linux/sys_rusage.go
@@ -80,7 +80,7 @@ func Getrusage(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
ru := getrusage(t, which)
- _, err := t.CopyOut(addr, &ru)
+ _, err := ru.CopyOut(t, addr)
return 0, nil, err
}
@@ -104,7 +104,7 @@ func Times(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
CUTime: linux.ClockTFromDuration(cs2.UserTime),
CSTime: linux.ClockTFromDuration(cs2.SysTime),
}
- if _, err := t.CopyOut(addr, &r); err != nil {
+ if _, err := r.CopyOut(t, addr); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_sched.go b/pkg/sentry/syscalls/linux/sys_sched.go
index 99f6993f5..bfcf44b6f 100644
--- a/pkg/sentry/syscalls/linux/sys_sched.go
+++ b/pkg/sentry/syscalls/linux/sys_sched.go
@@ -27,8 +27,10 @@ const (
)
// SchedParam replicates struct sched_param in sched.h.
+//
+// +marshal
type SchedParam struct {
- schedPriority int64
+ schedPriority int32
}
// SchedGetparam implements linux syscall sched_getparam(2).
@@ -45,7 +47,7 @@ func SchedGetparam(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
return 0, nil, syserror.ESRCH
}
r := SchedParam{schedPriority: onlyPriority}
- if _, err := t.CopyOut(param, r); err != nil {
+ if _, err := r.CopyOut(t, param); err != nil {
return 0, nil, err
}
@@ -79,7 +81,7 @@ func SchedSetscheduler(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ke
return 0, nil, syserror.ESRCH
}
var r SchedParam
- if _, err := t.CopyIn(param, &r); err != nil {
+ if _, err := r.CopyIn(t, param); err != nil {
return 0, nil, syserror.EINVAL
}
if r.schedPriority != onlyPriority {
diff --git a/pkg/sentry/syscalls/linux/sys_seccomp.go b/pkg/sentry/syscalls/linux/sys_seccomp.go
index 5b7a66f4d..4fdb4463c 100644
--- a/pkg/sentry/syscalls/linux/sys_seccomp.go
+++ b/pkg/sentry/syscalls/linux/sys_seccomp.go
@@ -24,6 +24,8 @@ import (
)
// userSockFprog is equivalent to Linux's struct sock_fprog on amd64.
+//
+// +marshal
type userSockFprog struct {
// Len is the length of the filter in BPF instructions.
Len uint16
@@ -33,7 +35,7 @@ type userSockFprog struct {
// Filter is a user pointer to the struct sock_filter array that makes up
// the filter program. Filter is a uint64 rather than a usermem.Addr
// because usermem.Addr is actually uintptr, which is not a fixed-size
- // type, and encoding/binary.Read objects to this.
+ // type.
Filter uint64
}
@@ -54,11 +56,11 @@ func seccomp(t *kernel.Task, mode, flags uint64, addr usermem.Addr) error {
}
var fprog userSockFprog
- if _, err := t.CopyIn(addr, &fprog); err != nil {
+ if _, err := fprog.CopyIn(t, addr); err != nil {
return err
}
filter := make([]linux.BPFInstruction, int(fprog.Len))
- if _, err := t.CopyIn(usermem.Addr(fprog.Filter), &filter); err != nil {
+ if _, err := linux.CopyBPFInstructionSliceIn(t, usermem.Addr(fprog.Filter), filter); err != nil {
return err
}
compiledFilter, err := bpf.Compile(filter)
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index 5f54f2456..47dadb800 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -18,6 +18,7 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -66,7 +67,7 @@ func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
ops := make([]linux.Sembuf, nsops)
- if _, err := t.CopyIn(sembufAddr, ops); err != nil {
+ if _, err := linux.CopySembufSliceIn(t, sembufAddr, ops); err != nil {
return 0, nil, err
}
@@ -116,8 +117,8 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
case linux.IPC_SET:
arg := args[3].Pointer()
- s := linux.SemidDS{}
- if _, err := t.CopyIn(arg, &s); err != nil {
+ var s linux.SemidDS
+ if _, err := s.CopyIn(t, arg); err != nil {
return 0, nil, err
}
@@ -188,7 +189,7 @@ func setValAll(t *kernel.Task, id int32, array usermem.Addr) error {
return syserror.EINVAL
}
vals := make([]uint16, set.Size())
- if _, err := t.CopyIn(array, vals); err != nil {
+ if _, err := primitive.CopyUint16SliceIn(t, array, vals); err != nil {
return err
}
creds := auth.CredentialsFromContext(t)
@@ -217,7 +218,7 @@ func getValAll(t *kernel.Task, id int32, array usermem.Addr) error {
if err != nil {
return err
}
- _, err = t.CopyOut(array, vals)
+ _, err = primitive.CopyUint16SliceOut(t, array, vals)
return err
}
diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go
index f0ae8fa8e..584064143 100644
--- a/pkg/sentry/syscalls/linux/sys_shm.go
+++ b/pkg/sentry/syscalls/linux/sys_shm.go
@@ -112,18 +112,18 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
stat, err := segment.IPCStat(t)
if err == nil {
- _, err = t.CopyOut(buf, stat)
+ _, err = stat.CopyOut(t, buf)
}
return 0, nil, err
case linux.IPC_INFO:
params := r.IPCInfo()
- _, err := t.CopyOut(buf, params)
+ _, err := params.CopyOut(t, buf)
return 0, nil, err
case linux.SHM_INFO:
info := r.ShmInfo()
- _, err := t.CopyOut(buf, info)
+ _, err := info.CopyOut(t, buf)
return 0, nil, err
}
@@ -137,11 +137,10 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
switch cmd {
case linux.IPC_SET:
var ds linux.ShmidDS
- _, err = t.CopyIn(buf, &ds)
- if err != nil {
+ if _, err = ds.CopyIn(t, buf); err != nil {
return 0, nil, err
}
- err = segment.Set(t, &ds)
+ err := segment.Set(t, &ds)
return 0, nil, err
case linux.IPC_RMID:
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 38f573c14..9feaca0da 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -19,6 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -29,8 +31,6 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// LINT.IfChange
@@ -67,10 +67,10 @@ const flagsOffset = 48
const sizeOfInt32 = 4
// messageHeader64Len is the length of a MessageHeader64 struct.
-var messageHeader64Len = uint64(binary.Size(MessageHeader64{}))
+var messageHeader64Len = uint64((*MessageHeader64)(nil).SizeBytes())
// multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct.
-var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{}))
+var multipleMessageHeader64Len = uint64((*multipleMessageHeader64)(nil).SizeBytes())
// baseRecvFlags are the flags that are accepted across recvmsg(2),
// recvmmsg(2), and recvfrom(2).
@@ -78,6 +78,8 @@ const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT |
// MessageHeader64 is the 64-bit representation of the msghdr struct used in
// the recvmsg and sendmsg syscalls.
+//
+// +marshal
type MessageHeader64 struct {
// Name is the optional pointer to a network address buffer.
Name uint64
@@ -106,30 +108,14 @@ type MessageHeader64 struct {
// multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in
// the recvmmsg and sendmmsg syscalls.
+//
+// +marshal
type multipleMessageHeader64 struct {
msgHdr MessageHeader64
msgLen uint32
_ int32
}
-// CopyInMessageHeader64 copies a message header from user to kernel memory.
-func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error {
- b := t.CopyScratchBuffer(52)
- if _, err := t.CopyInBytes(addr, b); err != nil {
- return err
- }
-
- msg.Name = usermem.ByteOrder.Uint64(b[0:])
- msg.NameLen = usermem.ByteOrder.Uint32(b[8:])
- msg.Iov = usermem.ByteOrder.Uint64(b[16:])
- msg.IovLen = usermem.ByteOrder.Uint64(b[24:])
- msg.Control = usermem.ByteOrder.Uint64(b[32:])
- msg.ControlLen = usermem.ByteOrder.Uint64(b[40:])
- msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:]))
-
- return nil
-}
-
// CaptureAddress allocates memory for and copies a socket address structure
// from the untrusted address space range.
func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) {
@@ -148,10 +134,10 @@ func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte,
// writeAddress writes a sockaddr structure and its length to an output buffer
// in the unstrusted address space range. If the address is bigger than the
// buffer, it is truncated.
-func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
+func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
// Get the buffer length.
var bufLen uint32
- if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil {
+ if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil {
return err
}
@@ -160,7 +146,7 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user
}
// Write the length unconditionally.
- if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil {
+ if _, err := primitive.CopyUint32Out(t, addrLenPtr, addrLen); err != nil {
return err
}
@@ -173,7 +159,8 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user
}
// Copy as much of the address as will fit in the buffer.
- encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr)
+ encodedAddr := t.CopyScratchBuffer(addr.SizeBytes())
+ addr.MarshalUnsafe(encodedAddr)
if bufLen > uint32(len(encodedAddr)) {
bufLen = uint32(len(encodedAddr))
}
@@ -247,9 +234,9 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
// Copy the file descriptors out.
- if _, err := t.CopyOut(socks, fds); err != nil {
+ if _, err := primitive.CopyInt32SliceOut(t, socks, fds); err != nil {
for _, fd := range fds {
- if file, _ := t.FDTable().Remove(fd); file != nil {
+ if file, _ := t.FDTable().Remove(t, fd); file != nil {
file.DecRef(t)
}
}
@@ -456,8 +443,8 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
// Read the length. Reject negative values.
- optLen := int32(0)
- if _, err := t.CopyIn(optLenAddr, &optLen); err != nil {
+ var optLen int32
+ if _, err := primitive.CopyInt32In(t, optLenAddr, &optLen); err != nil {
return 0, nil, err
}
if optLen < 0 {
@@ -471,7 +458,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
vLen := int32(binary.Size(v))
- if _, err := t.CopyOut(optLenAddr, vLen); err != nil {
+ if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil {
return 0, nil, err
}
@@ -733,7 +720,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if !ok {
return 0, nil, syserror.EFAULT
}
- if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil {
break
}
count++
@@ -748,7 +735,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
// Capture the message header and io vectors.
var msg MessageHeader64
- if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ if _, err := msg.CopyIn(t, msgPtr); err != nil {
return 0, err
}
@@ -780,7 +767,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
if int(msg.Flags) != mflags {
// Copy out the flags to the caller.
- if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil {
return 0, err
}
}
@@ -817,17 +804,17 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
}
// Copy the control data to the caller.
- if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
+ if _, err := primitive.CopyUint64Out(t, msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
return 0, err
}
if len(controlData) > 0 {
- if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil {
+ if _, err := t.CopyOutBytes(usermem.Addr(msg.Control), controlData); err != nil {
return 0, err
}
}
// Copy out the flags to the caller.
- if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil {
return 0, err
}
@@ -996,7 +983,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if !ok {
return 0, nil, syserror.EFAULT
}
- if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil {
break
}
count++
@@ -1011,7 +998,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr usermem.Addr, flags int32) (uintptr, error) {
// Capture the message header.
var msg MessageHeader64
- if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ if _, err := msg.CopyIn(t, msgPtr); err != nil {
return 0, err
}
@@ -1022,7 +1009,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme
return 0, syserror.ENOBUFS
}
controlData = make([]byte, msg.ControlLen)
- if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil {
+ if _, err := t.CopyInBytes(usermem.Addr(msg.Control), controlData); err != nil {
return 0, err
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index c69941feb..46616c961 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -141,7 +142,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Copy in the offset.
var offset int64
- if _, err := t.CopyIn(offsetAddr, &offset); err != nil {
+ if _, err := primitive.CopyInt64In(t, offsetAddr, &offset); err != nil {
return 0, nil, err
}
@@ -149,11 +150,11 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{
Length: count,
SrcOffset: true,
- SrcStart: offset,
+ SrcStart: int64(offset),
}, outFile.Flags().NonBlocking)
// Copy out the new offset.
- if _, err := t.CopyOut(offsetAddr, n+offset); err != nil {
+ if _, err := primitive.CopyInt64Out(t, offsetAddr, offset+n); err != nil {
return 0, nil, err
}
} else {
@@ -228,7 +229,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
var offset int64
- if _, err := t.CopyIn(outOffset, &offset); err != nil {
+ if _, err := primitive.CopyInt64In(t, outOffset, &offset); err != nil {
return 0, nil, err
}
@@ -246,7 +247,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
var offset int64
- if _, err := t.CopyIn(inOffset, &offset); err != nil {
+ if _, err := primitive.CopyInt64In(t, inOffset, &offset); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
index a5826f2dd..cda29a8b5 100644
--- a/pkg/sentry/syscalls/linux/sys_stat.go
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -221,7 +221,7 @@ func statx(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr, statxAddr
DevMajor: uint32(devMajor),
DevMinor: devMinor,
}
- _, err := t.CopyOut(statxAddr, &s)
+ _, err := s.CopyOut(t, statxAddr)
return err
}
@@ -283,7 +283,7 @@ func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error {
FragmentSize: d.Inode.StableAttr.BlockSize,
// Leave other fields 0 like simple_statfs does.
}
- _, err = t.CopyOut(addr, &statfs)
+ _, err = statfs.CopyOut(t, addr)
return err
}
diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go
index 297de052a..674d341b6 100644
--- a/pkg/sentry/syscalls/linux/sys_sysinfo.go
+++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go
@@ -43,6 +43,6 @@ func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
FreeRAM: memFree,
Unit: 1,
}
- _, err := t.CopyOut(addr, si)
+ _, err := si.CopyOut(t, addr)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 101096038..39ca9ea97 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -19,6 +19,7 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
@@ -311,13 +312,13 @@ func wait4(t *kernel.Task, pid int, statusAddr usermem.Addr, options int, rusage
return 0, err
}
if statusAddr != 0 {
- if _, err := t.CopyOut(statusAddr, wr.Status); err != nil {
+ if _, err := primitive.CopyUint32Out(t, statusAddr, wr.Status); err != nil {
return 0, err
}
}
if rusageAddr != 0 {
ru := getrusage(wr.Task, linux.RUSAGE_BOTH)
- if _, err := t.CopyOut(rusageAddr, &ru); err != nil {
+ if _, err := ru.CopyOut(t, rusageAddr); err != nil {
return 0, err
}
}
@@ -395,14 +396,14 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// as well.
if infop != 0 {
var si arch.SignalInfo
- _, err = t.CopyOut(infop, &si)
+ _, err = si.CopyOut(t, infop)
}
}
return 0, nil, err
}
if rusageAddr != 0 {
ru := getrusage(wr.Task, linux.RUSAGE_BOTH)
- if _, err := t.CopyOut(rusageAddr, &ru); err != nil {
+ if _, err := ru.CopyOut(t, rusageAddr); err != nil {
return 0, nil, err
}
}
@@ -441,7 +442,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
default:
t.Warningf("waitid got incomprehensible wait status %d", s)
}
- _, err = t.CopyOut(infop, &si)
+ _, err = si.CopyOut(t, infop)
return 0, nil, err
}
@@ -558,9 +559,7 @@ func Getcpu(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// third argument to this system call is nowadays unused.
if cpu != 0 {
- buf := t.CopyScratchBuffer(4)
- usermem.ByteOrder.PutUint32(buf, uint32(t.CPU()))
- if _, err := t.CopyOutBytes(cpu, buf); err != nil {
+ if _, err := primitive.CopyInt32Out(t, cpu, t.CPU()); err != nil {
return 0, nil, err
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go
index a2a24a027..c5054d2f1 100644
--- a/pkg/sentry/syscalls/linux/sys_time.go
+++ b/pkg/sentry/syscalls/linux/sys_time.go
@@ -19,6 +19,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -168,7 +169,7 @@ func Time(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(r), nil, nil
}
- if _, err := t.CopyOut(addr, r); err != nil {
+ if _, err := r.CopyOut(t, addr); err != nil {
return 0, nil, err
}
return uintptr(r), nil, nil
@@ -334,8 +335,8 @@ func Gettimeofday(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
// Ask the time package for the timezone.
_, offset := time.Now().Zone()
// This int32 array mimics linux's struct timezone.
- timezone := [2]int32{-int32(offset) / 60, 0}
- _, err := t.CopyOut(tz, timezone)
+ timezone := []int32{-int32(offset) / 60, 0}
+ _, err := primitive.CopyInt32SliceOut(t, tz, timezone)
return 0, nil, err
}
return 0, nil, nil
diff --git a/pkg/sentry/syscalls/linux/sys_timer.go b/pkg/sentry/syscalls/linux/sys_timer.go
index a4c400f87..45eef4feb 100644
--- a/pkg/sentry/syscalls/linux/sys_timer.go
+++ b/pkg/sentry/syscalls/linux/sys_timer.go
@@ -21,81 +21,63 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
const nsecPerSec = int64(time.Second)
-// copyItimerValIn copies an ItimerVal from the untrusted app range to the
-// kernel. The ItimerVal may be either 32 or 64 bits.
-// A NULL address is allowed because because Linux allows
-// setitimer(which, NULL, &old_value) which disables the timer.
-// There is a KERN_WARN message saying this misfeature will be removed.
-// However, that hasn't happened as of 3.19, so we continue to support it.
-func copyItimerValIn(t *kernel.Task, addr usermem.Addr) (linux.ItimerVal, error) {
- if addr == usermem.Addr(0) {
- return linux.ItimerVal{}, nil
- }
-
- switch t.Arch().Width() {
- case 8:
- // Native size, just copy directly.
- var itv linux.ItimerVal
- if _, err := t.CopyIn(addr, &itv); err != nil {
- return linux.ItimerVal{}, err
- }
-
- return itv, nil
- default:
- return linux.ItimerVal{}, syserror.ENOSYS
- }
-}
-
-// copyItimerValOut copies an ItimerVal to the untrusted app range.
-// The ItimerVal may be either 32 or 64 bits.
-// A NULL address is allowed, in which case no copy takes place
-func copyItimerValOut(t *kernel.Task, addr usermem.Addr, itv *linux.ItimerVal) error {
- if addr == usermem.Addr(0) {
- return nil
- }
-
- switch t.Arch().Width() {
- case 8:
- // Native size, just copy directly.
- _, err := t.CopyOut(addr, itv)
- return err
- default:
- return syserror.ENOSYS
- }
-}
-
// Getitimer implements linux syscall getitimer(2).
func Getitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ if t.Arch().Width() != 8 {
+ // Definition of linux.ItimerVal assumes 64-bit architecture.
+ return 0, nil, syserror.ENOSYS
+ }
+
timerID := args[0].Int()
- val := args[1].Pointer()
+ addr := args[1].Pointer()
olditv, err := t.Getitimer(timerID)
if err != nil {
return 0, nil, err
}
- return 0, nil, copyItimerValOut(t, val, &olditv)
+ // A NULL address is allowed, in which case no copy out takes place.
+ if addr == 0 {
+ return 0, nil, nil
+ }
+ _, err = olditv.CopyOut(t, addr)
+ return 0, nil, err
}
// Setitimer implements linux syscall setitimer(2).
func Setitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- timerID := args[0].Int()
- newVal := args[1].Pointer()
- oldVal := args[2].Pointer()
+ if t.Arch().Width() != 8 {
+ // Definition of linux.ItimerVal assumes 64-bit architecture.
+ return 0, nil, syserror.ENOSYS
+ }
- newitv, err := copyItimerValIn(t, newVal)
- if err != nil {
- return 0, nil, err
+ timerID := args[0].Int()
+ newAddr := args[1].Pointer()
+ oldAddr := args[2].Pointer()
+
+ var newitv linux.ItimerVal
+ // A NULL address is allowed because because Linux allows
+ // setitimer(which, NULL, &old_value) which disables the timer. There is a
+ // KERN_WARN message saying this misfeature will be removed. However, that
+ // hasn't happened as of 3.19, so we continue to support it.
+ if newAddr != 0 {
+ if _, err := newitv.CopyIn(t, newAddr); err != nil {
+ return 0, nil, err
+ }
}
olditv, err := t.Setitimer(timerID, newitv)
if err != nil {
return 0, nil, err
}
- return 0, nil, copyItimerValOut(t, oldVal, &olditv)
+ // A NULL address is allowed, in which case no copy out takes place.
+ if oldAddr == 0 {
+ return 0, nil, nil
+ }
+ _, err = olditv.CopyOut(t, oldAddr)
+ return 0, nil, err
}
// Alarm implements linux syscall alarm(2).
@@ -131,7 +113,7 @@ func TimerCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
var sev *linux.Sigevent
if sevp != 0 {
sev = &linux.Sigevent{}
- if _, err = t.CopyIn(sevp, sev); err != nil {
+ if _, err = sev.CopyIn(t, sevp); err != nil {
return 0, nil, err
}
}
@@ -141,7 +123,7 @@ func TimerCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, err
}
- if _, err := t.CopyOut(timerIDp, &id); err != nil {
+ if _, err := id.CopyOut(t, timerIDp); err != nil {
t.IntervalTimerDelete(id)
return 0, nil, err
}
@@ -157,7 +139,7 @@ func TimerSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
oldValAddr := args[3].Pointer()
var newVal linux.Itimerspec
- if _, err := t.CopyIn(newValAddr, &newVal); err != nil {
+ if _, err := newVal.CopyIn(t, newValAddr); err != nil {
return 0, nil, err
}
oldVal, err := t.IntervalTimerSettime(timerID, newVal, flags&linux.TIMER_ABSTIME != 0)
@@ -165,9 +147,8 @@ func TimerSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
return 0, nil, err
}
if oldValAddr != 0 {
- if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil {
- return 0, nil, err
- }
+ _, err = oldVal.CopyOut(t, oldValAddr)
+ return 0, nil, err
}
return 0, nil, nil
}
@@ -181,7 +162,7 @@ func TimerGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
if err != nil {
return 0, nil, err
}
- _, err = t.CopyOut(curValAddr, &curVal)
+ _, err = curVal.CopyOut(t, curValAddr)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_timerfd.go b/pkg/sentry/syscalls/linux/sys_timerfd.go
index 34b03e4ee..cadd9d348 100644
--- a/pkg/sentry/syscalls/linux/sys_timerfd.go
+++ b/pkg/sentry/syscalls/linux/sys_timerfd.go
@@ -81,7 +81,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
}
var newVal linux.Itimerspec
- if _, err := t.CopyIn(newValAddr, &newVal); err != nil {
+ if _, err := newVal.CopyIn(t, newValAddr); err != nil {
return 0, nil, err
}
newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tf.Clock())
@@ -91,7 +91,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
tm, oldS := tf.SetTime(newS)
if oldValAddr != 0 {
oldVal := ktime.ItimerspecFromSetting(tm, oldS)
- if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil {
+ if _, err := oldVal.CopyOut(t, oldValAddr); err != nil {
return 0, nil, err
}
}
@@ -116,6 +116,6 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
tm, s := tf.GetTime()
curVal := ktime.ItimerspecFromSetting(tm, s)
- _, err := t.CopyOut(curValAddr, &curVal)
+ _, err := curVal.CopyOut(t, curValAddr)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_tls_amd64.go b/pkg/sentry/syscalls/linux/sys_tls_amd64.go
index b3eb96a1c..6ddd30d5c 100644
--- a/pkg/sentry/syscalls/linux/sys_tls_amd64.go
+++ b/pkg/sentry/syscalls/linux/sys_tls_amd64.go
@@ -18,6 +18,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
@@ -30,17 +31,19 @@ func ArchPrctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
case linux.ARCH_GET_FS:
addr := args[1].Pointer()
fsbase := t.Arch().TLS()
- _, err := t.CopyOut(addr, uint64(fsbase))
- if err != nil {
- return 0, nil, err
+ switch t.Arch().Width() {
+ case 8:
+ if _, err := primitive.CopyUint64Out(t, addr, uint64(fsbase)); err != nil {
+ return 0, nil, err
+ }
+ default:
+ return 0, nil, syserror.ENOSYS
}
-
case linux.ARCH_SET_FS:
fsbase := args[1].Uint64()
if !t.Arch().SetTLS(uintptr(fsbase)) {
return 0, nil, syserror.EPERM
}
-
case linux.ARCH_GET_GS, linux.ARCH_SET_GS:
t.Kernel().EmitUnimplementedEvent(t)
fallthrough
diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go
index e9d702e8e..66c5974f5 100644
--- a/pkg/sentry/syscalls/linux/sys_utsname.go
+++ b/pkg/sentry/syscalls/linux/sys_utsname.go
@@ -46,7 +46,7 @@ func Uname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Copy out the result.
va := args[0].Pointer()
- _, err := t.CopyOut(va, u)
+ _, err := u.CopyOut(t, va)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 64696b438..9ee766552 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -44,6 +44,9 @@ go_library(
"//pkg/context",
"//pkg/fspath",
"//pkg/gohacks",
+ "//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fsbridge",
@@ -72,7 +75,5 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go
index 42559bf69..6d0a38330 100644
--- a/pkg/sentry/syscalls/linux/vfs2/aio.go
+++ b/pkg/sentry/syscalls/linux/vfs2/aio.go
@@ -17,6 +17,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -38,21 +39,27 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
for i := int32(0); i < nrEvents; i++ {
- // Copy in the address.
- cbAddrNative := t.Arch().Native(0)
- if _, err := t.CopyIn(addr, cbAddrNative); err != nil {
- if i > 0 {
- // Some successful.
- return uintptr(i), nil, nil
+ // Copy in the callback address.
+ var cbAddr usermem.Addr
+ switch t.Arch().Width() {
+ case 8:
+ var cbAddrP primitive.Uint64
+ if _, err := cbAddrP.CopyIn(t, addr); err != nil {
+ if i > 0 {
+ // Some successful.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
}
- // Nothing done.
- return 0, nil, err
+ cbAddr = usermem.Addr(cbAddrP)
+ default:
+ return 0, nil, syserror.ENOSYS
}
// Copy in this callback.
var cb linux.IOCallback
- cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative))
- if _, err := t.CopyIn(cbAddr, &cb); err != nil {
+ if _, err := cb.CopyIn(t, cbAddr); err != nil {
if i > 0 {
// Some have been successful.
return uintptr(i), nil, nil
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
index c62f03509..d0cbb77eb 100644
--- a/pkg/sentry/syscalls/linux/vfs2/epoll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -24,7 +24,6 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -141,50 +140,26 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, syserror.EINVAL
}
- // Use a fixed-size buffer in a loop, instead of make([]linux.EpollEvent,
- // maxEvents), so that the buffer can be allocated on the stack.
+ // Allocate space for a few events on the stack for the common case in
+ // which we don't have too many events.
var (
- events [16]linux.EpollEvent
- total int
+ eventsArr [16]linux.EpollEvent
ch chan struct{}
haveDeadline bool
deadline ktime.Time
)
for {
- batchEvents := len(events)
- if batchEvents > maxEvents {
- batchEvents = maxEvents
- }
- n := ep.ReadEvents(events[:batchEvents])
- maxEvents -= n
- if n != 0 {
- // Copy what we read out.
- copiedBytes, err := linux.CopyEpollEventSliceOut(t, eventsAddr, events[:n])
+ events := ep.ReadEvents(eventsArr[:0], maxEvents)
+ if len(events) != 0 {
+ copiedBytes, err := linux.CopyEpollEventSliceOut(t, eventsAddr, events)
copiedEvents := copiedBytes / sizeofEpollEvent // rounded down
- eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent)
- total += copiedEvents
- if err != nil {
- if total != 0 {
- return uintptr(total), nil, nil
- }
- return 0, nil, err
- }
- // If we've filled the application's event buffer, we're done.
- if maxEvents == 0 {
- return uintptr(total), nil, nil
- }
- // Loop if we read a full batch, under the expectation that there
- // may be more events to read.
- if n == batchEvents {
- continue
+ if copiedEvents != 0 {
+ return uintptr(copiedEvents), nil, nil
}
+ return 0, nil, err
}
- // We get here if n != batchEvents. If we read any number of events
- // (just now, or in a previous iteration of this loop), or if timeout
- // is 0 (such that epoll_wait should be non-blocking), return the
- // events we've read so far to the application.
- if total != 0 || timeout == 0 {
- return uintptr(total), nil, nil
+ if timeout == 0 {
+ return 0, nil, nil
}
// In the first iteration of this loop, register with the epoll
// instance for readability events, but then immediately continue the
@@ -207,8 +182,6 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if err == syserror.ETIMEDOUT {
err = nil
}
- // total must be 0 since otherwise we would have returned
- // above.
return 0, nil, err
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index 4856554fe..d8b8d9783 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -34,7 +34,7 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Note that Remove provides a reference on the file that we may use to
// flush. It is still active until we drop the final reference below
// (and other reference-holding operations complete).
- _, file := t.FDTable().Remove(fd)
+ _, file := t.FDTable().Remove(t, fd)
if file == nil {
return 0, nil, syserror.EBADF
}
@@ -137,7 +137,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return uintptr(flags.ToLinuxFDFlags()), nil, nil
case linux.F_SETFD:
flags := args[2].Uint()
- err := t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{
+ err := t.FDTable().SetFlagsVFS2(t, fd, kernel.FDFlags{
CloseOnExec: flags&linux.FD_CLOEXEC != 0,
})
return 0, nil, err
@@ -181,11 +181,11 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if !hasOwner {
return 0, nil, nil
}
- _, err := t.CopyOut(args[2].Pointer(), &owner)
+ _, err := owner.CopyOut(t, args[2].Pointer())
return 0, nil, err
case linux.F_SETOWN_EX:
var owner linux.FOwnerEx
- _, err := t.CopyIn(args[2].Pointer(), &owner)
+ _, err := owner.CopyIn(t, args[2].Pointer())
if err != nil {
return 0, nil, err
}
@@ -286,7 +286,7 @@ func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescrip
// Copy in the lock request.
flockAddr := args[2].Pointer()
var flock linux.Flock
- if _, err := t.CopyIn(flockAddr, &flock); err != nil {
+ if _, err := flock.CopyIn(t, flockAddr); err != nil {
return err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
index 38778a388..2806c3f6f 100644
--- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -16,6 +16,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
@@ -34,20 +35,20 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Handle ioctls that apply to all FDs.
switch args[1].Int() {
case linux.FIONCLEX:
- t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{
+ t.FDTable().SetFlagsVFS2(t, fd, kernel.FDFlags{
CloseOnExec: false,
})
return 0, nil, nil
case linux.FIOCLEX:
- t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{
+ t.FDTable().SetFlagsVFS2(t, fd, kernel.FDFlags{
CloseOnExec: true,
})
return 0, nil, nil
case linux.FIONBIO:
var set int32
- if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil {
return 0, nil, err
}
flags := file.StatusFlags()
@@ -60,7 +61,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.FIOASYNC:
var set int32
- if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil {
return 0, nil, err
}
flags := file.StatusFlags()
@@ -82,12 +83,12 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
who = owner.PID
}
}
- _, err := t.CopyOut(args[2].Pointer(), &who)
+ _, err := primitive.CopyInt32Out(t, args[2].Pointer(), who)
return 0, nil, err
case linux.FIOSETOWN, linux.SIOCSPGRP:
var who int32
- if _, err := t.CopyIn(args[2].Pointer(), &who); err != nil {
+ if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &who); err != nil {
return 0, nil, err
}
ownerType := int32(linux.F_OWNER_PID)
diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go
index dc05c2994..9d9dbf775 100644
--- a/pkg/sentry/syscalls/linux/vfs2/mmap.go
+++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go
@@ -17,6 +17,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/syserror"
@@ -85,6 +86,17 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if err := file.ConfigureMMap(t, &opts); err != nil {
return 0, nil, err
}
+ } else if shared {
+ // Back shared anonymous mappings with an anonymous tmpfs file.
+ opts.Offset = 0
+ file, err := tmpfs.NewZeroFile(t, t.Credentials(), t.Kernel().ShmMount(), opts.Length)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef(t)
+ if err := file.ConfigureMMap(t, &opts); err != nil {
+ return 0, nil, err
+ }
}
rv, err := t.MemoryManager().MMap(t, opts)
diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go
index 9b4848d9e..ee38fdca0 100644
--- a/pkg/sentry/syscalls/linux/vfs2/pipe.go
+++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go
@@ -16,6 +16,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -51,9 +52,9 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error {
if err != nil {
return err
}
- if _, err := t.CopyOut(addr, fds); err != nil {
+ if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil {
for _, fd := range fds {
- if _, file := t.FDTable().Remove(fd); file != nil {
+ if _, file := t.FDTable().Remove(t, fd); file != nil {
file.DecRef(t)
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
index 79ad64039..c22e4ce54 100644
--- a/pkg/sentry/syscalls/linux/vfs2/poll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -165,7 +165,7 @@ func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD
pfd := make([]linux.PollFD, nfds)
if nfds > 0 {
- if _, err := t.CopyIn(addr, &pfd); err != nil {
+ if _, err := linux.CopyPollFDSliceIn(t, addr, pfd); err != nil {
return nil, err
}
}
@@ -192,7 +192,7 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration)
// The poll entries are copied out regardless of whether
// any are set or not. This aligns with the Linux behavior.
if nfds > 0 && err == nil {
- if _, err := t.CopyOut(addr, pfd); err != nil {
+ if _, err := linux.CopyPollFDSliceOut(t, addr, pfd); err != nil {
return remainingTimeout, 0, err
}
}
@@ -205,7 +205,7 @@ func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialBy
set := make([]byte, nBytes)
if addr != 0 {
- if _, err := t.CopyIn(addr, &set); err != nil {
+ if _, err := t.CopyInBytes(addr, set); err != nil {
return nil, err
}
// If we only use part of the last byte, mask out the extraneous bits.
@@ -332,19 +332,19 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add
// Copy updated vectors back.
if readFDs != 0 {
- if _, err := t.CopyOut(readFDs, r); err != nil {
+ if _, err := t.CopyOutBytes(readFDs, r); err != nil {
return 0, err
}
}
if writeFDs != 0 {
- if _, err := t.CopyOut(writeFDs, w); err != nil {
+ if _, err := t.CopyOutBytes(writeFDs, w); err != nil {
return 0, err
}
}
if exceptFDs != 0 {
- if _, err := t.CopyOut(exceptFDs, e); err != nil {
+ if _, err := t.CopyOutBytes(exceptFDs, e); err != nil {
return 0, err
}
}
@@ -497,6 +497,12 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return n, nil, err
}
+// +marshal
+type sigSetWithSize struct {
+ sigsetAddr uint64
+ sizeofSigset uint64
+}
+
// Pselect implements linux syscall pselect(2).
func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
nfds := int(args[0].Int()) // select(2) uses an int.
@@ -538,12 +544,6 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
return n, nil, err
}
-// +marshal
-type sigSetWithSize struct {
- sigsetAddr uint64
- sizeofSigset uint64
-}
-
// copyTimespecInToDuration copies a Timespec from the untrusted app range,
// validates it and converts it to a Duration.
//
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index 5e6eb13ba..1ee37e5a8 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -346,7 +346,7 @@ func populateSetStatOptionsForUtimes(t *kernel.Task, timesAddr usermem.Addr, opt
return nil
}
var times [2]linux.Timeval
- if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil {
return err
}
if times[0].Usec < 0 || times[0].Usec > 999999 || times[1].Usec < 0 || times[1].Usec > 999999 {
@@ -410,7 +410,7 @@ func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, op
return nil
}
var times [2]linux.Timespec
- if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ if _, err := linux.CopyTimespecSliceIn(t, timesAddr, times[:]); err != nil {
return err
}
if times[0].Nsec != linux.UTIME_OMIT {
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index a5032657a..bfae6b7e9 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -19,6 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -30,8 +32,6 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// minListenBacklog is the minimum reasonable backlog for listening sockets.
@@ -66,10 +66,10 @@ const flagsOffset = 48
const sizeOfInt32 = 4
// messageHeader64Len is the length of a MessageHeader64 struct.
-var messageHeader64Len = uint64(binary.Size(MessageHeader64{}))
+var messageHeader64Len = uint64((*MessageHeader64)(nil).SizeBytes())
// multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct.
-var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{}))
+var multipleMessageHeader64Len = uint64((*multipleMessageHeader64)(nil).SizeBytes())
// baseRecvFlags are the flags that are accepted across recvmsg(2),
// recvmmsg(2), and recvfrom(2).
@@ -77,6 +77,8 @@ const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT |
// MessageHeader64 is the 64-bit representation of the msghdr struct used in
// the recvmsg and sendmsg syscalls.
+//
+// +marshal
type MessageHeader64 struct {
// Name is the optional pointer to a network address buffer.
Name uint64
@@ -105,30 +107,14 @@ type MessageHeader64 struct {
// multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in
// the recvmmsg and sendmmsg syscalls.
+//
+// +marshal
type multipleMessageHeader64 struct {
msgHdr MessageHeader64
msgLen uint32
_ int32
}
-// CopyInMessageHeader64 copies a message header from user to kernel memory.
-func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error {
- b := t.CopyScratchBuffer(52)
- if _, err := t.CopyInBytes(addr, b); err != nil {
- return err
- }
-
- msg.Name = usermem.ByteOrder.Uint64(b[0:])
- msg.NameLen = usermem.ByteOrder.Uint32(b[8:])
- msg.Iov = usermem.ByteOrder.Uint64(b[16:])
- msg.IovLen = usermem.ByteOrder.Uint64(b[24:])
- msg.Control = usermem.ByteOrder.Uint64(b[32:])
- msg.ControlLen = usermem.ByteOrder.Uint64(b[40:])
- msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:]))
-
- return nil
-}
-
// CaptureAddress allocates memory for and copies a socket address structure
// from the untrusted address space range.
func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) {
@@ -147,10 +133,10 @@ func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte,
// writeAddress writes a sockaddr structure and its length to an output buffer
// in the unstrusted address space range. If the address is bigger than the
// buffer, it is truncated.
-func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
+func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
// Get the buffer length.
var bufLen uint32
- if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil {
+ if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil {
return err
}
@@ -159,7 +145,7 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user
}
// Write the length unconditionally.
- if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil {
+ if _, err := primitive.CopyUint32Out(t, addrLenPtr, addrLen); err != nil {
return err
}
@@ -172,7 +158,8 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user
}
// Copy as much of the address as will fit in the buffer.
- encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr)
+ encodedAddr := t.CopyScratchBuffer(addr.SizeBytes())
+ addr.MarshalUnsafe(encodedAddr)
if bufLen > uint32(len(encodedAddr)) {
bufLen = uint32(len(encodedAddr))
}
@@ -250,9 +237,9 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, err
}
- if _, err := t.CopyOut(addr, fds); err != nil {
+ if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil {
for _, fd := range fds {
- if _, file := t.FDTable().Remove(fd); file != nil {
+ if _, file := t.FDTable().Remove(t, fd); file != nil {
file.DecRef(t)
}
}
@@ -459,8 +446,8 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
// Read the length. Reject negative values.
- optLen := int32(0)
- if _, err := t.CopyIn(optLenAddr, &optLen); err != nil {
+ var optLen int32
+ if _, err := primitive.CopyInt32In(t, optLenAddr, &optLen); err != nil {
return 0, nil, err
}
if optLen < 0 {
@@ -474,7 +461,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
vLen := int32(binary.Size(v))
- if _, err := t.CopyOut(optLenAddr, vLen); err != nil {
+ if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil {
return 0, nil, err
}
@@ -736,7 +723,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if !ok {
return 0, nil, syserror.EFAULT
}
- if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil {
break
}
count++
@@ -751,7 +738,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
// Capture the message header and io vectors.
var msg MessageHeader64
- if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ if _, err := msg.CopyIn(t, msgPtr); err != nil {
return 0, err
}
@@ -783,7 +770,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla
if int(msg.Flags) != mflags {
// Copy out the flags to the caller.
- if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil {
return 0, err
}
}
@@ -820,17 +807,17 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla
}
// Copy the control data to the caller.
- if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
+ if _, err := primitive.CopyUint64Out(t, msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
return 0, err
}
if len(controlData) > 0 {
- if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil {
+ if _, err := t.CopyOutBytes(usermem.Addr(msg.Control), controlData); err != nil {
return 0, err
}
}
// Copy out the flags to the caller.
- if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil {
return 0, err
}
@@ -999,7 +986,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if !ok {
return 0, nil, syserror.EFAULT
}
- if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil {
break
}
count++
@@ -1014,7 +1001,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr usermem.Addr, flags int32) (uintptr, error) {
// Capture the message header.
var msg MessageHeader64
- if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ if _, err := msg.CopyIn(t, msgPtr); err != nil {
return 0, err
}
@@ -1025,7 +1012,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio
return 0, syserror.ENOBUFS
}
controlData = make([]byte, msg.ControlLen)
- if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil {
+ if _, err := t.CopyInBytes(usermem.Addr(msg.Control), controlData); err != nil {
return 0, err
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
index 68ce94778..bf5c1171f 100644
--- a/pkg/sentry/syscalls/linux/vfs2/splice.go
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -18,9 +18,12 @@ import (
"io"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -88,7 +91,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if inFile.Options().DenyPRead {
return 0, nil, syserror.EINVAL
}
- if _, err := t.CopyIn(inOffsetPtr, &inOffset); err != nil {
+ if _, err := primitive.CopyInt64In(t, inOffsetPtr, &inOffset); err != nil {
return 0, nil, err
}
if inOffset < 0 {
@@ -103,7 +106,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if outFile.Options().DenyPWrite {
return 0, nil, syserror.EINVAL
}
- if _, err := t.CopyIn(outOffsetPtr, &outOffset); err != nil {
+ if _, err := primitive.CopyInt64In(t, outOffsetPtr, &outOffset); err != nil {
return 0, nil, err
}
if outOffset < 0 {
@@ -144,11 +147,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
panic("at least one end of splice must be a pipe")
}
- if n == 0 && err == io.EOF {
- // We reached the end of the file. Eat the error and exit the loop.
- err = nil
- break
- }
if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
break
}
@@ -159,25 +157,26 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Copy updated offsets out.
if inOffsetPtr != 0 {
- if _, err := t.CopyOut(inOffsetPtr, &inOffset); err != nil {
+ if _, err := primitive.CopyInt64Out(t, inOffsetPtr, inOffset); err != nil {
return 0, nil, err
}
}
if outOffsetPtr != 0 {
- if _, err := t.CopyOut(outOffsetPtr, &outOffset); err != nil {
+ if _, err := primitive.CopyInt64Out(t, outOffsetPtr, outOffset); err != nil {
return 0, nil, err
}
}
- if n == 0 {
- return 0, nil, err
+ if n != 0 {
+ // On Linux, inotify behavior is not very consistent with splice(2). We try
+ // our best to emulate Linux for very basic calls to splice, where for some
+ // reason, events are generated for output files, but not input files.
+ outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
}
- // On Linux, inotify behavior is not very consistent with splice(2). We try
- // our best to emulate Linux for very basic calls to splice, where for some
- // reason, events are generated for output files, but not input files.
- outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
- return uintptr(n), nil, nil
+ // We can only pass a single file to handleIOError, so pick inFile arbitrarily.
+ // This is used only for debugging purposes.
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "splice", outFile)
}
// Tee implements Linux syscall tee(2).
@@ -249,11 +248,20 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
break
}
}
- if n == 0 {
- return 0, nil, err
+
+ if n != 0 {
+ outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
+
+ // If a partial write is completed, the error is dropped. Log it here.
+ if err != nil && err != io.EOF && err != syserror.ErrWouldBlock {
+ log.Debugf("tee completed a partial write with error: %v", err)
+ err = nil
+ }
}
- outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
- return uintptr(n), nil, nil
+
+ // We can only pass a single file to handleIOError, so pick inFile arbitrarily.
+ // This is used only for debugging purposes.
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "tee", inFile)
}
// Sendfile implements linux system call sendfile(2).
@@ -302,9 +310,12 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if inFile.Options().DenyPRead {
return 0, nil, syserror.ESPIPE
}
- if _, err := t.CopyIn(offsetAddr, &offset); err != nil {
+ var offsetP primitive.Int64
+ if _, err := offsetP.CopyIn(t, offsetAddr); err != nil {
return 0, nil, err
}
+ offset = int64(offsetP)
+
if offset < 0 {
return 0, nil, syserror.EINVAL
}
@@ -343,11 +354,6 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
for n < count {
var spliceN int64
spliceN, err = outPipeFD.SpliceFromNonPipe(t, inFile, offset, count)
- if spliceN == 0 && err == io.EOF {
- // We reached the end of the file. Eat the error and exit the loop.
- err = nil
- break
- }
if offset != -1 {
offset += spliceN
}
@@ -370,13 +376,6 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
} else {
readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
}
- if readN == 0 && err != nil {
- if err == io.EOF {
- // We reached the end of the file. Eat the error before exiting the loop.
- err = nil
- }
- break
- }
n += readN
// Write all of the bytes that we read. This may need
@@ -390,16 +389,21 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
err = dw.waitForOut(t)
}
if err != nil {
- // We didn't complete the write. Only
- // report the bytes that were actually
- // written, and rewind the offset.
+ // We didn't complete the write. Only report the bytes that were actually
+ // written, and rewind offsets as needed.
notWritten := int64(len(wbuf))
n -= notWritten
- if offset != -1 {
- // TODO(gvisor.dev/issue/3779): The inFile offset will be incorrect if we
- // roll back, because it has already been advanced by the full amount.
- // Merely seeking on inFile does not work, because there may be concurrent
- // file operations.
+ if offset == -1 {
+ // We modified the offset of the input file itself during the read
+ // operation. Rewind it.
+ if _, seekErr := inFile.Seek(t, -notWritten, linux.SEEK_CUR); seekErr != nil {
+ // Log the error but don't return it, since the write has already
+ // completed successfully.
+ log.Warningf("failed to roll back input file offset: %v", seekErr)
+ }
+ } else {
+ // The sendfile call was provided an offset parameter that should be
+ // adjusted to reflect the number of bytes sent. Rewind it.
offset -= notWritten
}
break
@@ -416,18 +420,26 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if offsetAddr != 0 {
// Copy out the new offset.
- if _, err := t.CopyOut(offsetAddr, offset); err != nil {
+ offsetP := primitive.Uint64(offset)
+ if _, err := offsetP.CopyOut(t, offsetAddr); err != nil {
return 0, nil, err
}
}
- if n == 0 {
- return 0, nil, err
+ if n != 0 {
+ inFile.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
+ outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
+
+ if err != nil && err != io.EOF && err != syserror.ErrWouldBlock {
+ // If a partial write is completed, the error is dropped. Log it here.
+ log.Debugf("sendfile completed a partial write with error: %v", err)
+ err = nil
+ }
}
- inFile.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
- outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
- return uintptr(n), nil, nil
+ // We can only pass a single file to handleIOError, so pick inFile arbitrarily.
+ // This is used only for debugging purposes.
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "sendfile", inFile)
}
// dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not
diff --git a/pkg/sentry/syscalls/linux/vfs2/timerfd.go b/pkg/sentry/syscalls/linux/vfs2/timerfd.go
index 7a26890ef..250870c03 100644
--- a/pkg/sentry/syscalls/linux/vfs2/timerfd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/timerfd.go
@@ -87,7 +87,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
}
var newVal linux.Itimerspec
- if _, err := t.CopyIn(newValAddr, &newVal); err != nil {
+ if _, err := newVal.CopyIn(t, newValAddr); err != nil {
return 0, nil, err
}
newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tfd.Clock())
@@ -97,7 +97,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
tm, oldS := tfd.SetTime(newS)
if oldValAddr != 0 {
oldVal := ktime.ItimerspecFromSetting(tm, oldS)
- if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil {
+ if _, err := oldVal.CopyOut(t, oldValAddr); err != nil {
return 0, nil, err
}
}
@@ -122,6 +122,6 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
tm, s := tfd.GetTime()
curVal := ktime.ItimerspecFromSetting(tm, s)
- _, err := t.CopyOut(curValAddr, &curVal)
+ _, err := curVal.CopyOut(t, curValAddr)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
index c576d9475..0df3bd449 100644
--- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -93,16 +93,16 @@ func Override() {
s.Table[165] = syscalls.Supported("mount", Mount)
s.Table[166] = syscalls.Supported("umount2", Umount2)
s.Table[187] = syscalls.Supported("readahead", Readahead)
- s.Table[188] = syscalls.Supported("setxattr", Setxattr)
+ s.Table[188] = syscalls.Supported("setxattr", SetXattr)
s.Table[189] = syscalls.Supported("lsetxattr", Lsetxattr)
s.Table[190] = syscalls.Supported("fsetxattr", Fsetxattr)
- s.Table[191] = syscalls.Supported("getxattr", Getxattr)
+ s.Table[191] = syscalls.Supported("getxattr", GetXattr)
s.Table[192] = syscalls.Supported("lgetxattr", Lgetxattr)
s.Table[193] = syscalls.Supported("fgetxattr", Fgetxattr)
- s.Table[194] = syscalls.Supported("listxattr", Listxattr)
+ s.Table[194] = syscalls.Supported("listxattr", ListXattr)
s.Table[195] = syscalls.Supported("llistxattr", Llistxattr)
s.Table[196] = syscalls.Supported("flistxattr", Flistxattr)
- s.Table[197] = syscalls.Supported("removexattr", Removexattr)
+ s.Table[197] = syscalls.Supported("removexattr", RemoveXattr)
s.Table[198] = syscalls.Supported("lremovexattr", Lremovexattr)
s.Table[199] = syscalls.Supported("fremovexattr", Fremovexattr)
s.Table[209] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"})
@@ -163,16 +163,16 @@ func Override() {
// Override ARM64.
s = linux.ARM64
- s.Table[5] = syscalls.Supported("setxattr", Setxattr)
+ s.Table[5] = syscalls.Supported("setxattr", SetXattr)
s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr)
s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr)
- s.Table[8] = syscalls.Supported("getxattr", Getxattr)
+ s.Table[8] = syscalls.Supported("getxattr", GetXattr)
s.Table[9] = syscalls.Supported("lgetxattr", Lgetxattr)
s.Table[10] = syscalls.Supported("fgetxattr", Fgetxattr)
- s.Table[11] = syscalls.Supported("listxattr", Listxattr)
+ s.Table[11] = syscalls.Supported("listxattr", ListXattr)
s.Table[12] = syscalls.Supported("llistxattr", Llistxattr)
s.Table[13] = syscalls.Supported("flistxattr", Flistxattr)
- s.Table[14] = syscalls.Supported("removexattr", Removexattr)
+ s.Table[14] = syscalls.Supported("removexattr", RemoveXattr)
s.Table[15] = syscalls.Supported("lremovexattr", Lremovexattr)
s.Table[16] = syscalls.Supported("fremovexattr", Fremovexattr)
s.Table[17] = syscalls.Supported("getcwd", Getcwd)
diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go
index ef99246ed..e05723ef9 100644
--- a/pkg/sentry/syscalls/linux/vfs2/xattr.go
+++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go
@@ -26,8 +26,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// Listxattr implements Linux syscall listxattr(2).
-func Listxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+// ListXattr implements Linux syscall listxattr(2).
+func ListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return listxattr(t, args, followFinalSymlink)
}
@@ -51,7 +51,7 @@ func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSyml
}
defer tpop.Release(t)
- names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop, uint64(size))
+ names, err := t.Kernel().VFS().ListXattrAt(t, t.Credentials(), &tpop.pop, uint64(size))
if err != nil {
return 0, nil, err
}
@@ -74,7 +74,7 @@ func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
defer file.DecRef(t)
- names, err := file.Listxattr(t, uint64(size))
+ names, err := file.ListXattr(t, uint64(size))
if err != nil {
return 0, nil, err
}
@@ -85,8 +85,8 @@ func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return uintptr(n), nil, nil
}
-// Getxattr implements Linux syscall getxattr(2).
-func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+// GetXattr implements Linux syscall getxattr(2).
+func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return getxattr(t, args, followFinalSymlink)
}
@@ -116,7 +116,7 @@ func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli
return 0, nil, err
}
- value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.GetxattrOptions{
+ value, err := t.Kernel().VFS().GetXattrAt(t, t.Credentials(), &tpop.pop, &vfs.GetXattrOptions{
Name: name,
Size: uint64(size),
})
@@ -148,7 +148,7 @@ func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, err
}
- value, err := file.Getxattr(t, &vfs.GetxattrOptions{Name: name, Size: uint64(size)})
+ value, err := file.GetXattr(t, &vfs.GetXattrOptions{Name: name, Size: uint64(size)})
if err != nil {
return 0, nil, err
}
@@ -159,8 +159,8 @@ func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return uintptr(n), nil, nil
}
-// Setxattr implements Linux syscall setxattr(2).
-func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+// SetXattr implements Linux syscall setxattr(2).
+func SetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return 0, nil, setxattr(t, args, followFinalSymlink)
}
@@ -199,7 +199,7 @@ func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli
return err
}
- return t.Kernel().VFS().SetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetxattrOptions{
+ return t.Kernel().VFS().SetXattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetXattrOptions{
Name: name,
Value: value,
Flags: uint32(flags),
@@ -233,15 +233,15 @@ func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, err
}
- return 0, nil, file.Setxattr(t, &vfs.SetxattrOptions{
+ return 0, nil, file.SetXattr(t, &vfs.SetXattrOptions{
Name: name,
Value: value,
Flags: uint32(flags),
})
}
-// Removexattr implements Linux syscall removexattr(2).
-func Removexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+// RemoveXattr implements Linux syscall removexattr(2).
+func RemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return 0, nil, removexattr(t, args, followFinalSymlink)
}
@@ -269,7 +269,7 @@ func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSy
return err
}
- return t.Kernel().VFS().RemovexattrAt(t, t.Credentials(), &tpop.pop, name)
+ return t.Kernel().VFS().RemoveXattrAt(t, t.Credentials(), &tpop.pop, name)
}
// Fremovexattr implements Linux syscall fremovexattr(2).
@@ -288,7 +288,7 @@ func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
return 0, nil, err
}
- return 0, nil, file.Removexattr(t, name)
+ return 0, nil, file.RemoveXattr(t, name)
}
func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index 5a0e3e6b5..bdfd3ca8f 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -52,6 +52,8 @@ const (
)
// anonFilesystemType implements FilesystemType.
+//
+// +stateify savable
type anonFilesystemType struct{}
// GetFilesystem implements FilesystemType.GetFilesystem.
@@ -69,12 +71,15 @@ func (anonFilesystemType) Name() string {
//
// Since all Dentries in anonFilesystem are non-directories, all FilesystemImpl
// methods that would require an anonDentry to be a directory return ENOTDIR.
+//
+// +stateify savable
type anonFilesystem struct {
vfsfs Filesystem
devMinor uint32
}
+// +stateify savable
type anonDentry struct {
vfsd Dentry
@@ -245,32 +250,32 @@ func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath
return nil, syserror.ECONNREFUSED
}
-// ListxattrAt implements FilesystemImpl.ListxattrAt.
-func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) {
+// ListXattrAt implements FilesystemImpl.ListXattrAt.
+func (fs *anonFilesystem) ListXattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) {
if !rp.Done() {
return nil, syserror.ENOTDIR
}
return nil, nil
}
-// GetxattrAt implements FilesystemImpl.GetxattrAt.
-func (fs *anonFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) {
+// GetXattrAt implements FilesystemImpl.GetXattrAt.
+func (fs *anonFilesystem) GetXattrAt(ctx context.Context, rp *ResolvingPath, opts GetXattrOptions) (string, error) {
if !rp.Done() {
return "", syserror.ENOTDIR
}
return "", syserror.ENOTSUP
}
-// SetxattrAt implements FilesystemImpl.SetxattrAt.
-func (fs *anonFilesystem) SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error {
+// SetXattrAt implements FilesystemImpl.SetXattrAt.
+func (fs *anonFilesystem) SetXattrAt(ctx context.Context, rp *ResolvingPath, opts SetXattrOptions) error {
if !rp.Done() {
return syserror.ENOTDIR
}
return syserror.EPERM
}
-// RemovexattrAt implements FilesystemImpl.RemovexattrAt.
-func (fs *anonFilesystem) RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error {
+// RemoveXattrAt implements FilesystemImpl.RemoveXattrAt.
+func (fs *anonFilesystem) RemoveXattrAt(ctx context.Context, rp *ResolvingPath, name string) error {
if !rp.Done() {
return syserror.ENOTDIR
}
diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go
index a69a5b2f1..320ab7ce1 100644
--- a/pkg/sentry/vfs/dentry.go
+++ b/pkg/sentry/vfs/dentry.go
@@ -89,6 +89,8 @@ func (d *Dentry) Impl() DentryImpl {
// DentryImpl contains implementation details for a Dentry. Implementations of
// DentryImpl should contain their associated Dentry by value as their first
// field.
+//
+// +stateify savable
type DentryImpl interface {
// IncRef increments the Dentry's reference count. A Dentry with a non-zero
// reference count must remain coherent with the state of the filesystem.
diff --git a/pkg/sentry/vfs/device.go b/pkg/sentry/vfs/device.go
index 1e9dffc8f..dde2ad79b 100644
--- a/pkg/sentry/vfs/device.go
+++ b/pkg/sentry/vfs/device.go
@@ -22,6 +22,8 @@ import (
)
// DeviceKind indicates whether a device is a block or character device.
+//
+// +stateify savable
type DeviceKind uint32
const (
@@ -44,6 +46,7 @@ func (kind DeviceKind) String() string {
}
}
+// +stateify savable
type devTuple struct {
kind DeviceKind
major uint32
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 1b5af9f73..8f36c3e3b 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -27,6 +27,8 @@ import (
var epollCycleMu sync.Mutex
// EpollInstance represents an epoll instance, as described by epoll(7).
+//
+// +stateify savable
type EpollInstance struct {
vfsfd FileDescription
FileDescriptionDefaultImpl
@@ -38,11 +40,11 @@ type EpollInstance struct {
// interest is the set of file descriptors that are registered with the
// EpollInstance for monitoring. interest is protected by interestMu.
- interestMu sync.Mutex
+ interestMu sync.Mutex `state:"nosave"`
interest map[epollInterestKey]*epollInterest
// mu protects fields in registered epollInterests.
- mu sync.Mutex
+ mu sync.Mutex `state:"nosave"`
// ready is the set of file descriptors that may be "ready" for I/O. Note
// that this must be an ordered list, not a map: "If more than maxevents
@@ -55,6 +57,7 @@ type EpollInstance struct {
ready epollInterestList
}
+// +stateify savable
type epollInterestKey struct {
// file is the registered FileDescription. No reference is held on file;
// instead, when the last reference is dropped, FileDescription.DecRef()
@@ -67,6 +70,8 @@ type epollInterestKey struct {
}
// epollInterest represents an EpollInstance's interest in a file descriptor.
+//
+// +stateify savable
type epollInterest struct {
// epoll is the owning EpollInstance. epoll is immutable.
epoll *EpollInstance
@@ -331,11 +336,9 @@ func (ep *EpollInstance) removeLocked(epi *epollInterest) {
ep.mu.Unlock()
}
-// ReadEvents reads up to len(events) ready events into events and returns the
-// number of events read.
-//
-// Preconditions: len(events) != 0.
-func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent) int {
+// ReadEvents appends up to maxReady events to events and returns the updated
+// slice of events.
+func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent, maxEvents int) []linux.EpollEvent {
i := 0
// Hot path: avoid defer.
ep.mu.Lock()
@@ -368,16 +371,16 @@ func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent) int {
requeue.PushBack(epi)
}
// Report ievents.
- events[i] = linux.EpollEvent{
+ events = append(events, linux.EpollEvent{
Events: ievents.ToLinux(),
Data: epi.userData,
- }
+ })
i++
- if i == len(events) {
+ if i == maxEvents {
break
}
}
ep.ready.PushBackList(&requeue)
ep.mu.Unlock()
- return i
+ return events
}
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 22a54fa48..1eba0270f 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -37,11 +37,13 @@ import (
// FileDescription methods require that a reference is held.
//
// FileDescription is analogous to Linux's struct file.
+//
+// +stateify savable
type FileDescription struct {
FileDescriptionRefs
// flagsMu protects statusFlags and asyncHandler below.
- flagsMu sync.Mutex
+ flagsMu sync.Mutex `state:"nosave"`
// statusFlags contains status flags, "initialized by open(2) and possibly
// modified by fcntl()" - fcntl(2). statusFlags can be read using atomic
@@ -56,7 +58,7 @@ type FileDescription struct {
// epolls is the set of epollInterests registered for this FileDescription.
// epolls is protected by epollMu.
- epollMu sync.Mutex
+ epollMu sync.Mutex `state:"nosave"`
epolls map[*epollInterest]struct{}
// vd is the filesystem location at which this FileDescription was opened.
@@ -88,6 +90,8 @@ type FileDescription struct {
}
// FileDescriptionOptions contains options to FileDescription.Init().
+//
+// +stateify savable
type FileDescriptionOptions struct {
// If AllowDirectIO is true, allow O_DIRECT to be set on the file.
AllowDirectIO bool
@@ -101,7 +105,7 @@ type FileDescriptionOptions struct {
// If UseDentryMetadata is true, calls to FileDescription methods that
// interact with file and filesystem metadata (Stat, SetStat, StatFS,
- // Listxattr, Getxattr, Setxattr, Removexattr) are implemented by calling
+ // ListXattr, GetXattr, SetXattr, RemoveXattr) are implemented by calling
// the corresponding FilesystemImpl methods instead of the corresponding
// FileDescriptionImpl methods.
//
@@ -326,6 +330,9 @@ type FileDescriptionImpl interface {
// Allocate grows the file to offset + length bytes.
// Only mode == 0 is supported currently.
//
+ // Allocate should return EISDIR on directories, ESPIPE on pipes, and ENODEV on
+ // other files where it is not supported.
+ //
// Preconditions: The FileDescription was opened for writing.
Allocate(ctx context.Context, mode, offset, length uint64) error
@@ -420,19 +427,19 @@ type FileDescriptionImpl interface {
// Ioctl implements the ioctl(2) syscall.
Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error)
- // Listxattr returns all extended attribute names for the file.
- Listxattr(ctx context.Context, size uint64) ([]string, error)
+ // ListXattr returns all extended attribute names for the file.
+ ListXattr(ctx context.Context, size uint64) ([]string, error)
- // Getxattr returns the value associated with the given extended attribute
+ // GetXattr returns the value associated with the given extended attribute
// for the file.
- Getxattr(ctx context.Context, opts GetxattrOptions) (string, error)
+ GetXattr(ctx context.Context, opts GetXattrOptions) (string, error)
- // Setxattr changes the value associated with the given extended attribute
+ // SetXattr changes the value associated with the given extended attribute
// for the file.
- Setxattr(ctx context.Context, opts SetxattrOptions) error
+ SetXattr(ctx context.Context, opts SetXattrOptions) error
- // Removexattr removes the given extended attribute from the file.
- Removexattr(ctx context.Context, name string) error
+ // RemoveXattr removes the given extended attribute from the file.
+ RemoveXattr(ctx context.Context, name string) error
// LockBSD tries to acquire a BSD-style advisory file lock.
LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error
@@ -448,6 +455,8 @@ type FileDescriptionImpl interface {
}
// Dirent holds the information contained in struct linux_dirent64.
+//
+// +stateify savable
type Dirent struct {
// Name is the filename.
Name string
@@ -635,25 +644,25 @@ func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
return fd.impl.Ioctl(ctx, uio, args)
}
-// Listxattr returns all extended attribute names for the file represented by
+// ListXattr returns all extended attribute names for the file represented by
// fd.
//
// If the size of the list (including a NUL terminating byte after every entry)
// would exceed size, ERANGE may be returned. Note that implementations
// are free to ignore size entirely and return without error). In all cases,
// if size is 0, the list should be returned without error, regardless of size.
-func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
+func (fd *FileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) {
if fd.opts.UseDentryMetadata {
vfsObj := fd.vd.mount.vfs
rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
Root: fd.vd,
Start: fd.vd,
})
- names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp, size)
+ names, err := fd.vd.mount.fs.impl.ListXattrAt(ctx, rp, size)
vfsObj.putResolvingPath(ctx, rp)
return names, err
}
- names, err := fd.impl.Listxattr(ctx, size)
+ names, err := fd.impl.ListXattr(ctx, size)
if err == syserror.ENOTSUP {
// Linux doesn't actually return ENOTSUP in this case; instead,
// fs/xattr.c:vfs_listxattr() falls back to allowing the security
@@ -664,57 +673,57 @@ func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string
return names, err
}
-// Getxattr returns the value associated with the given extended attribute for
+// GetXattr returns the value associated with the given extended attribute for
// the file represented by fd.
//
// If the size of the return value exceeds opts.Size, ERANGE may be returned
// (note that implementations are free to ignore opts.Size entirely and return
// without error). In all cases, if opts.Size is 0, the value should be
// returned without error, regardless of size.
-func (fd *FileDescription) Getxattr(ctx context.Context, opts *GetxattrOptions) (string, error) {
+func (fd *FileDescription) GetXattr(ctx context.Context, opts *GetXattrOptions) (string, error) {
if fd.opts.UseDentryMetadata {
vfsObj := fd.vd.mount.vfs
rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
Root: fd.vd,
Start: fd.vd,
})
- val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, *opts)
+ val, err := fd.vd.mount.fs.impl.GetXattrAt(ctx, rp, *opts)
vfsObj.putResolvingPath(ctx, rp)
return val, err
}
- return fd.impl.Getxattr(ctx, *opts)
+ return fd.impl.GetXattr(ctx, *opts)
}
-// Setxattr changes the value associated with the given extended attribute for
+// SetXattr changes the value associated with the given extended attribute for
// the file represented by fd.
-func (fd *FileDescription) Setxattr(ctx context.Context, opts *SetxattrOptions) error {
+func (fd *FileDescription) SetXattr(ctx context.Context, opts *SetXattrOptions) error {
if fd.opts.UseDentryMetadata {
vfsObj := fd.vd.mount.vfs
rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
Root: fd.vd,
Start: fd.vd,
})
- err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, *opts)
+ err := fd.vd.mount.fs.impl.SetXattrAt(ctx, rp, *opts)
vfsObj.putResolvingPath(ctx, rp)
return err
}
- return fd.impl.Setxattr(ctx, *opts)
+ return fd.impl.SetXattr(ctx, *opts)
}
-// Removexattr removes the given extended attribute from the file represented
+// RemoveXattr removes the given extended attribute from the file represented
// by fd.
-func (fd *FileDescription) Removexattr(ctx context.Context, name string) error {
+func (fd *FileDescription) RemoveXattr(ctx context.Context, name string) error {
if fd.opts.UseDentryMetadata {
vfsObj := fd.vd.mount.vfs
rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
Root: fd.vd,
Start: fd.vd,
})
- err := fd.vd.mount.fs.impl.RemovexattrAt(ctx, rp, name)
+ err := fd.vd.mount.fs.impl.RemoveXattrAt(ctx, rp, name)
vfsObj.putResolvingPath(ctx, rp)
return err
}
- return fd.impl.Removexattr(ctx, name)
+ return fd.impl.RemoveXattr(ctx, name)
}
// SyncFS instructs the filesystem containing fd to execute the semantics of
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index 6b8b4ad49..48ca9de44 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -42,6 +42,8 @@ import (
// FileDescriptionDefaultImpl may be embedded by implementations of
// FileDescriptionImpl to obtain implementations of many FileDescriptionImpl
// methods with default behavior analogous to Linux's.
+//
+// +stateify savable
type FileDescriptionDefaultImpl struct{}
// OnClose implements FileDescriptionImpl.OnClose analogously to
@@ -57,7 +59,11 @@ func (FileDescriptionDefaultImpl) StatFS(ctx context.Context) (linux.Statfs, err
}
// Allocate implements FileDescriptionImpl.Allocate analogously to
-// fallocate called on regular file, directory or FIFO in Linux.
+// fallocate called on an invalid type of file in Linux.
+//
+// Note that directories can rely on this implementation even though they
+// should technically return EISDIR. Allocate should never be called for a
+// directory, because it requires a writable fd.
func (FileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error {
return syserror.ENODEV
}
@@ -134,34 +140,36 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg
return 0, syserror.ENOTTY
}
-// Listxattr implements FileDescriptionImpl.Listxattr analogously to
+// ListXattr implements FileDescriptionImpl.ListXattr analogously to
// inode_operations::listxattr == NULL in Linux.
-func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context, size uint64) ([]string, error) {
- // This isn't exactly accurate; see FileDescription.Listxattr.
+func (FileDescriptionDefaultImpl) ListXattr(ctx context.Context, size uint64) ([]string, error) {
+ // This isn't exactly accurate; see FileDescription.ListXattr.
return nil, syserror.ENOTSUP
}
-// Getxattr implements FileDescriptionImpl.Getxattr analogously to
+// GetXattr implements FileDescriptionImpl.GetXattr analogously to
// inode::i_opflags & IOP_XATTR == 0 in Linux.
-func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) {
+func (FileDescriptionDefaultImpl) GetXattr(ctx context.Context, opts GetXattrOptions) (string, error) {
return "", syserror.ENOTSUP
}
-// Setxattr implements FileDescriptionImpl.Setxattr analogously to
+// SetXattr implements FileDescriptionImpl.SetXattr analogously to
// inode::i_opflags & IOP_XATTR == 0 in Linux.
-func (FileDescriptionDefaultImpl) Setxattr(ctx context.Context, opts SetxattrOptions) error {
+func (FileDescriptionDefaultImpl) SetXattr(ctx context.Context, opts SetXattrOptions) error {
return syserror.ENOTSUP
}
-// Removexattr implements FileDescriptionImpl.Removexattr analogously to
+// RemoveXattr implements FileDescriptionImpl.RemoveXattr analogously to
// inode::i_opflags & IOP_XATTR == 0 in Linux.
-func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) error {
+func (FileDescriptionDefaultImpl) RemoveXattr(ctx context.Context, name string) error {
return syserror.ENOTSUP
}
// DirectoryFileDescriptionDefaultImpl may be embedded by implementations of
// FileDescriptionImpl that always represent directories to obtain
// implementations of non-directory I/O methods that return EISDIR.
+//
+// +stateify savable
type DirectoryFileDescriptionDefaultImpl struct{}
// Allocate implements DirectoryFileDescriptionDefaultImpl.Allocate.
@@ -192,6 +200,8 @@ func (DirectoryFileDescriptionDefaultImpl) Write(ctx context.Context, src userme
// DentryMetadataFileDescriptionImpl may be embedded by implementations of
// FileDescriptionImpl for which FileDescriptionOptions.UseDentryMetadata is
// true to obtain implementations of Stat and SetStat that panic.
+//
+// +stateify savable
type DentryMetadataFileDescriptionImpl struct{}
// Stat implements FileDescriptionImpl.Stat.
@@ -206,12 +216,16 @@ func (DentryMetadataFileDescriptionImpl) SetStat(ctx context.Context, opts SetSt
// DynamicBytesSource represents a data source for a
// DynamicBytesFileDescriptionImpl.
+//
+// +stateify savable
type DynamicBytesSource interface {
// Generate writes the file's contents to buf.
Generate(ctx context.Context, buf *bytes.Buffer) error
}
// StaticData implements DynamicBytesSource over a static string.
+//
+// +stateify savable
type StaticData struct {
Data string
}
@@ -238,14 +252,24 @@ type WritableDynamicBytesSource interface {
//
// DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first
// use.
+//
+// +stateify savable
type DynamicBytesFileDescriptionImpl struct {
data DynamicBytesSource // immutable
- mu sync.Mutex // protects the following fields
- buf bytes.Buffer
+ mu sync.Mutex `state:"nosave"` // protects the following fields
+ buf bytes.Buffer `state:".([]byte)"`
off int64
lastRead int64 // offset at which the last Read, PRead, or Seek ended
}
+func (fd *DynamicBytesFileDescriptionImpl) saveBuf() []byte {
+ return fd.buf.Bytes()
+}
+
+func (fd *DynamicBytesFileDescriptionImpl) loadBuf(p []byte) {
+ fd.buf.Write(p)
+}
+
// SetDataSource must be called exactly once on fd before first use.
func (fd *DynamicBytesFileDescriptionImpl) SetDataSource(data DynamicBytesSource) {
fd.data = data
@@ -378,6 +402,8 @@ func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.M
// LockFD may be used by most implementations of FileDescriptionImpl.Lock*
// functions. Caller must call Init().
+//
+// +stateify savable
type LockFD struct {
locks *FileLocks
}
@@ -405,6 +431,8 @@ func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
// NoLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface
// returning ENOLCK.
+//
+// +stateify savable
type NoLockFD struct{}
// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 46851f638..c93d94634 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -416,26 +416,26 @@ type FilesystemImpl interface {
// ResolvingPath.Resolve*(), then !rp.Done().
UnlinkAt(ctx context.Context, rp *ResolvingPath) error
- // ListxattrAt returns all extended attribute names for the file at rp.
+ // ListXattrAt returns all extended attribute names for the file at rp.
//
// Errors:
//
// - If extended attributes are not supported by the filesystem,
- // ListxattrAt returns ENOTSUP.
+ // ListXattrAt returns ENOTSUP.
//
// - If the size of the list (including a NUL terminating byte after every
// entry) would exceed size, ERANGE may be returned. Note that
// implementations are free to ignore size entirely and return without
// error). In all cases, if size is 0, the list should be returned without
// error, regardless of size.
- ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error)
+ ListXattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error)
- // GetxattrAt returns the value associated with the given extended
+ // GetXattrAt returns the value associated with the given extended
// attribute for the file at rp.
//
// Errors:
//
- // - If extended attributes are not supported by the filesystem, GetxattrAt
+ // - If extended attributes are not supported by the filesystem, GetXattrAt
// returns ENOTSUP.
//
// - If an extended attribute named opts.Name does not exist, ENODATA is
@@ -445,30 +445,30 @@ type FilesystemImpl interface {
// returned (note that implementations are free to ignore opts.Size entirely
// and return without error). In all cases, if opts.Size is 0, the value
// should be returned without error, regardless of size.
- GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error)
+ GetXattrAt(ctx context.Context, rp *ResolvingPath, opts GetXattrOptions) (string, error)
- // SetxattrAt changes the value associated with the given extended
+ // SetXattrAt changes the value associated with the given extended
// attribute for the file at rp.
//
// Errors:
//
- // - If extended attributes are not supported by the filesystem, SetxattrAt
+ // - If extended attributes are not supported by the filesystem, SetXattrAt
// returns ENOTSUP.
//
// - If XATTR_CREATE is set in opts.Flag and opts.Name already exists,
// EEXIST is returned. If XATTR_REPLACE is set and opts.Name does not exist,
// ENODATA is returned.
- SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error
+ SetXattrAt(ctx context.Context, rp *ResolvingPath, opts SetXattrOptions) error
- // RemovexattrAt removes the given extended attribute from the file at rp.
+ // RemoveXattrAt removes the given extended attribute from the file at rp.
//
// Errors:
//
// - If extended attributes are not supported by the filesystem,
- // RemovexattrAt returns ENOTSUP.
+ // RemoveXattrAt returns ENOTSUP.
//
// - If name does not exist, ENODATA is returned.
- RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error
+ RemoveXattrAt(ctx context.Context, rp *ResolvingPath, name string) error
// BoundEndpointAt returns the Unix socket endpoint bound at the path rp.
//
@@ -506,6 +506,8 @@ type FilesystemImpl interface {
// PrependPathAtVFSRootError is returned by implementations of
// FilesystemImpl.PrependPath() when they encounter the contextual VFS root.
+//
+// +stateify savable
type PrependPathAtVFSRootError struct{}
// Error implements error.Error.
@@ -516,6 +518,8 @@ func (PrependPathAtVFSRootError) Error() string {
// PrependPathAtNonMountRootError is returned by implementations of
// FilesystemImpl.PrependPath() when they encounter an independent ancestor
// Dentry that is not the Mount root.
+//
+// +stateify savable
type PrependPathAtNonMountRootError struct{}
// Error implements error.Error.
@@ -526,6 +530,8 @@ func (PrependPathAtNonMountRootError) Error() string {
// PrependPathSyntheticError is returned by implementations of
// FilesystemImpl.PrependPath() for which prepended names do not represent real
// paths.
+//
+// +stateify savable
type PrependPathSyntheticError struct{}
// Error implements error.Error.
diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go
index f2298f7f6..bc19db1d5 100644
--- a/pkg/sentry/vfs/filesystem_type.go
+++ b/pkg/sentry/vfs/filesystem_type.go
@@ -55,10 +55,13 @@ type registeredFilesystemType struct {
// RegisterFilesystemTypeOptions contains options to
// VirtualFilesystem.RegisterFilesystem().
+//
+// +stateify savable
type RegisterFilesystemTypeOptions struct {
- // If AllowUserMount is true, allow calls to VirtualFilesystem.MountAt()
- // for which MountOptions.InternalMount == false to use this filesystem
- // type.
+ // AllowUserMount determines whether users are allowed to mount a file system
+ // of this type, i.e. through mount(2). If AllowUserMount is true, allow calls
+ // to VirtualFilesystem.MountAt() for which MountOptions.InternalMount == false
+ // to use this filesystem type.
AllowUserMount bool
// If AllowUserList is true, make this filesystem type visible in
diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go
index 8882fa84a..2d27d9d35 100644
--- a/pkg/sentry/vfs/genericfstree/genericfstree.go
+++ b/pkg/sentry/vfs/genericfstree/genericfstree.go
@@ -27,6 +27,8 @@ import (
)
// Dentry is a required type parameter that is a struct with the given fields.
+//
+// +stateify savable
type Dentry struct {
// vfsd is the embedded vfs.Dentry corresponding to this vfs.DentryImpl.
vfsd vfs.Dentry
diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go
index aff220a61..3f0b8f45b 100644
--- a/pkg/sentry/vfs/inotify.go
+++ b/pkg/sentry/vfs/inotify.go
@@ -37,6 +37,8 @@ const inotifyEventBaseSize = 16
//
// The way events are labelled appears somewhat arbitrary, but they must match
// Linux so that IN_EXCL_UNLINK behaves as it does in Linux.
+//
+// +stateify savable
type EventType uint8
// PathEvent and InodeEvent correspond to FSNOTIFY_EVENT_PATH and
diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go
index 42666eebf..55783d4eb 100644
--- a/pkg/sentry/vfs/lock.go
+++ b/pkg/sentry/vfs/lock.go
@@ -33,6 +33,8 @@ import (
// Note that in Linux these two types of locks are _not_ cooperative, because
// race and deadlock conditions make merging them prohibitive. We do the same
// and keep them oblivious to each other.
+//
+// +stateify savable
type FileLocks struct {
// bsd is a set of BSD-style advisory file wide locks, see flock(2).
bsd fslock.Locks
diff --git a/pkg/sentry/vfs/memxattr/xattr.go b/pkg/sentry/vfs/memxattr/xattr.go
index cc1e7d764..638b5d830 100644
--- a/pkg/sentry/vfs/memxattr/xattr.go
+++ b/pkg/sentry/vfs/memxattr/xattr.go
@@ -33,8 +33,8 @@ type SimpleExtendedAttributes struct {
xattrs map[string]string
}
-// Getxattr returns the value at 'name'.
-func (x *SimpleExtendedAttributes) Getxattr(opts *vfs.GetxattrOptions) (string, error) {
+// GetXattr returns the value at 'name'.
+func (x *SimpleExtendedAttributes) GetXattr(opts *vfs.GetXattrOptions) (string, error) {
x.mu.RLock()
value, ok := x.xattrs[opts.Name]
x.mu.RUnlock()
@@ -49,8 +49,8 @@ func (x *SimpleExtendedAttributes) Getxattr(opts *vfs.GetxattrOptions) (string,
return value, nil
}
-// Setxattr sets 'value' at 'name'.
-func (x *SimpleExtendedAttributes) Setxattr(opts *vfs.SetxattrOptions) error {
+// SetXattr sets 'value' at 'name'.
+func (x *SimpleExtendedAttributes) SetXattr(opts *vfs.SetXattrOptions) error {
x.mu.Lock()
defer x.mu.Unlock()
if x.xattrs == nil {
@@ -72,8 +72,8 @@ func (x *SimpleExtendedAttributes) Setxattr(opts *vfs.SetxattrOptions) error {
return nil
}
-// Listxattr returns all names in xattrs.
-func (x *SimpleExtendedAttributes) Listxattr(size uint64) ([]string, error) {
+// ListXattr returns all names in xattrs.
+func (x *SimpleExtendedAttributes) ListXattr(size uint64) ([]string, error) {
// Keep track of the size of the buffer needed in listxattr(2) for the list.
listSize := 0
x.mu.RLock()
@@ -90,8 +90,8 @@ func (x *SimpleExtendedAttributes) Listxattr(size uint64) ([]string, error) {
return names, nil
}
-// Removexattr removes the xattr at 'name'.
-func (x *SimpleExtendedAttributes) Removexattr(name string) error {
+// RemoveXattr removes the xattr at 'name'.
+func (x *SimpleExtendedAttributes) RemoveXattr(name string) error {
x.mu.Lock()
defer x.mu.Unlock()
if _, ok := x.xattrs[name]; !ok {
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index db5fb3bb1..dfc3ae6c0 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -18,14 +18,12 @@ import (
"bytes"
"fmt"
"math"
- "path"
"sort"
"strings"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -67,7 +65,7 @@ type Mount struct {
//
// Invariant: key.parent != nil iff key.point != nil. key.point belongs to
// key.parent.fs.
- key mountKey
+ key mountKey `state:".(VirtualDentry)"`
// ns is the namespace in which this Mount was mounted. ns is protected by
// VirtualFilesystem.mountMu.
@@ -154,13 +152,13 @@ type MountNamespace struct {
// NewMountNamespace returns a new mount namespace with a root filesystem
// configured by the given arguments. A reference is taken on the returned
// MountNamespace.
-func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) {
+func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *MountOptions) (*MountNamespace, error) {
rft := vfs.getFilesystemType(fsTypeName)
if rft == nil {
ctx.Warningf("Unknown filesystem type: %s", fsTypeName)
return nil, syserror.ENODEV
}
- fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, *opts)
+ fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions)
if err != nil {
return nil, err
}
@@ -169,7 +167,7 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth
mountpoints: make(map[*Dentry]uint32),
}
mntns.EnableLeakCheck()
- mntns.root = newMount(vfs, fs, root, mntns, &MountOptions{})
+ mntns.root = newMount(vfs, fs, root, mntns, opts)
return mntns, nil
}
@@ -347,6 +345,7 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti
return nil
}
+// +stateify savable
type umountRecursiveOptions struct {
// If eager is true, ensure that future calls to Mount.tryIncMountedRef()
// on umounted mounts fail.
@@ -416,7 +415,7 @@ func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns
}
}
mnt.IncRef() // dropped by callers of umountRecursiveLocked
- mnt.storeKey(vd)
+ mnt.setKey(vd)
if vd.mount.children == nil {
vd.mount.children = make(map[*Mount]struct{})
}
@@ -441,13 +440,13 @@ func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns
// * vfs.mounts.seq must be in a writer critical section.
// * mnt.parent() != nil.
func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry {
- vd := mnt.loadKey()
+ vd := mnt.getKey()
if checkInvariants {
if vd.mount != nil {
panic("VFS.disconnectLocked called on disconnected mount")
}
}
- mnt.storeKey(VirtualDentry{})
+ mnt.loadKey(VirtualDentry{})
delete(vd.mount.children, mnt)
atomic.AddUint32(&vd.dentry.mounts, math.MaxUint32) // -1
mnt.ns.mountpoints[vd.dentry]--
@@ -740,11 +739,23 @@ func (mntns *MountNamespace) Root() VirtualDentry {
//
// Preconditions: taskRootDir.Ok().
func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDir VirtualDentry, buf *bytes.Buffer) {
- vfs.mountMu.Lock()
- defer vfs.mountMu.Unlock()
rootMnt := taskRootDir.mount
+
+ vfs.mountMu.Lock()
mounts := rootMnt.submountsLocked()
+ // Take a reference on mounts since we need to drop vfs.mountMu before
+ // calling vfs.PathnameReachable() (=> FilesystemImpl.PrependPath()).
+ for _, mnt := range mounts {
+ mnt.IncRef()
+ }
+ vfs.mountMu.Unlock()
+ defer func() {
+ for _, mnt := range mounts {
+ mnt.DecRef(ctx)
+ }
+ }()
sort.Slice(mounts, func(i, j int) bool { return mounts[i].ID < mounts[j].ID })
+
for _, mnt := range mounts {
// Get the path to this mount relative to task root.
mntRootVD := VirtualDentry{
@@ -755,7 +766,7 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi
if err != nil {
// For some reason we didn't get a path. Log a warning
// and run with empty path.
- ctx.Warningf("Error getting pathname for mount root %+v: %v", mnt.root, err)
+ ctx.Warningf("VFS.GenerateProcMounts: error getting pathname for mount root %+v: %v", mnt.root, err)
path = ""
}
if path == "" {
@@ -789,11 +800,25 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi
//
// Preconditions: taskRootDir.Ok().
func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRootDir VirtualDentry, buf *bytes.Buffer) {
- vfs.mountMu.Lock()
- defer vfs.mountMu.Unlock()
rootMnt := taskRootDir.mount
+
+ vfs.mountMu.Lock()
mounts := rootMnt.submountsLocked()
+ // Take a reference on mounts since we need to drop vfs.mountMu before
+ // calling vfs.PathnameReachable() (=> FilesystemImpl.PrependPath()) or
+ // vfs.StatAt() (=> FilesystemImpl.StatAt()).
+ for _, mnt := range mounts {
+ mnt.IncRef()
+ }
+ vfs.mountMu.Unlock()
+ defer func() {
+ for _, mnt := range mounts {
+ mnt.DecRef(ctx)
+ }
+ }()
sort.Slice(mounts, func(i, j int) bool { return mounts[i].ID < mounts[j].ID })
+
+ creds := auth.CredentialsFromContext(ctx)
for _, mnt := range mounts {
// Get the path to this mount relative to task root.
mntRootVD := VirtualDentry{
@@ -804,7 +829,7 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo
if err != nil {
// For some reason we didn't get a path. Log a warning
// and run with empty path.
- ctx.Warningf("Error getting pathname for mount root %+v: %v", mnt.root, err)
+ ctx.Warningf("VFS.GenerateProcMountInfo: error getting pathname for mount root %+v: %v", mnt.root, err)
path = ""
}
if path == "" {
@@ -817,9 +842,10 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo
Root: mntRootVD,
Start: mntRootVD,
}
- statx, err := vfs.StatAt(ctx, auth.NewAnonymousCredentials(), pop, &StatOptions{})
+ statx, err := vfs.StatAt(ctx, creds, pop, &StatOptions{})
if err != nil {
// Well that's not good. Ignore this mount.
+ ctx.Warningf("VFS.GenerateProcMountInfo: failed to stat mount root %+v: %v", mnt.root, err)
break
}
@@ -831,6 +857,9 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo
fmt.Fprintf(buf, "%d ", mnt.ID)
// (2) Parent ID (or this ID if there is no parent).
+ // Note that even if the call to mnt.parent() races with Mount
+ // destruction (which is possible since we're not holding vfs.mountMu),
+ // its Mount.ID will still be valid.
pID := mnt.ID
if p := mnt.parent(); p != nil {
pID = p.ID
@@ -879,30 +908,6 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo
}
}
-// MakeSyntheticMountpoint creates parent directories of target if they do not
-// exist and attempts to create a directory for the mountpoint. If a
-// non-directory file already exists there then we allow it.
-func (vfs *VirtualFilesystem) MakeSyntheticMountpoint(ctx context.Context, target string, root VirtualDentry, creds *auth.Credentials) error {
- mkdirOpts := &MkdirOptions{Mode: 0777, ForSyntheticMountpoint: true}
-
- // Make sure the parent directory of target exists.
- if err := vfs.MkdirAllAt(ctx, path.Dir(target), root, creds, mkdirOpts); err != nil {
- return fmt.Errorf("failed to create parent directory of mountpoint %q: %w", target, err)
- }
-
- // Attempt to mkdir the final component. If a file (of any type) exists
- // then we let allow mounting on top of that because we do not require the
- // target to be an existing directory, unlike Linux mount(2).
- if err := vfs.MkdirAt(ctx, creds, &PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(target),
- }, mkdirOpts); err != nil && err != syserror.EEXIST {
- return fmt.Errorf("failed to create mountpoint %q: %w", target, err)
- }
- return nil
-}
-
// manglePath replaces ' ', '\t', '\n', and '\\' with their octal equivalents.
// See Linux fs/seq_file.c:mangle_path.
func manglePath(p string) string {
diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go
index 3335e4057..cb8c56bd3 100644
--- a/pkg/sentry/vfs/mount_test.go
+++ b/pkg/sentry/vfs/mount_test.go
@@ -38,7 +38,7 @@ func TestMountTableInsertLookup(t *testing.T) {
mt.Init()
mount := &Mount{}
- mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}})
+ mount.setKey(VirtualDentry{&Mount{}, &Dentry{}})
mt.Insert(mount)
if m := mt.Lookup(mount.parent(), mount.point()); m != mount {
@@ -79,7 +79,7 @@ const enableComparativeBenchmarks = false
func newBenchMount() *Mount {
mount := &Mount{}
- mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}})
+ mount.loadKey(VirtualDentry{&Mount{}, &Dentry{}})
return mount
}
@@ -94,7 +94,7 @@ func BenchmarkMountTableParallelLookup(b *testing.B) {
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
mt.Insert(mount)
- keys = append(keys, mount.loadKey())
+ keys = append(keys, mount.saveKey())
}
var ready sync.WaitGroup
@@ -146,7 +146,7 @@ func BenchmarkMountMapParallelLookup(b *testing.B) {
keys := make([]VirtualDentry, 0, numMounts)
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
- key := mount.loadKey()
+ key := mount.saveKey()
ms[key] = mount
keys = append(keys, key)
}
@@ -201,7 +201,7 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) {
keys := make([]VirtualDentry, 0, numMounts)
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
- key := mount.loadKey()
+ key := mount.getKey()
ms.Store(key, mount)
keys = append(keys, key)
}
@@ -283,7 +283,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) {
ms := make(map[VirtualDentry]*Mount)
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
- ms[mount.loadKey()] = mount
+ ms[mount.getKey()] = mount
}
negkeys := make([]VirtualDentry, 0, numMounts)
for i := 0; i < numMounts; i++ {
@@ -318,7 +318,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) {
var ms sync.Map
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
- ms.Store(mount.loadKey(), mount)
+ ms.Store(mount.saveKey(), mount)
}
negkeys := make([]VirtualDentry, 0, numMounts)
for i := 0; i < numMounts; i++ {
@@ -372,7 +372,7 @@ func BenchmarkMountMapInsert(b *testing.B) {
b.ResetTimer()
for i := range mounts {
mount := mounts[i]
- ms[mount.loadKey()] = mount
+ ms[mount.saveKey()] = mount
}
}
@@ -392,7 +392,7 @@ func BenchmarkMountSyncMapInsert(b *testing.B) {
b.ResetTimer()
for i := range mounts {
mount := mounts[i]
- ms.Store(mount.loadKey(), mount)
+ ms.Store(mount.saveKey(), mount)
}
}
@@ -425,13 +425,13 @@ func BenchmarkMountMapRemove(b *testing.B) {
ms := make(map[VirtualDentry]*Mount)
for i := range mounts {
mount := mounts[i]
- ms[mount.loadKey()] = mount
+ ms[mount.saveKey()] = mount
}
b.ResetTimer()
for i := range mounts {
mount := mounts[i]
- delete(ms, mount.loadKey())
+ delete(ms, mount.saveKey())
}
}
@@ -447,12 +447,12 @@ func BenchmarkMountSyncMapRemove(b *testing.B) {
var ms sync.Map
for i := range mounts {
mount := mounts[i]
- ms.Store(mount.loadKey(), mount)
+ ms.Store(mount.saveKey(), mount)
}
b.ResetTimer()
for i := range mounts {
mount := mounts[i]
- ms.Delete(mount.loadKey())
+ ms.Delete(mount.saveKey())
}
}
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index da2a2e9c4..b7d122d22 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -34,6 +34,8 @@ import (
// structurally identical to VirtualDentry, but stores its fields as
// unsafe.Pointer since mutators synchronize with VFS path traversal using
// seqcounts.
+//
+// This is explicitly not savable.
type mountKey struct {
parent unsafe.Pointer // *Mount
point unsafe.Pointer // *Dentry
@@ -47,19 +49,23 @@ func (mnt *Mount) point() *Dentry {
return (*Dentry)(atomic.LoadPointer(&mnt.key.point))
}
-func (mnt *Mount) loadKey() VirtualDentry {
+func (mnt *Mount) getKey() VirtualDentry {
return VirtualDentry{
mount: mnt.parent(),
dentry: mnt.point(),
}
}
+func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() }
+
// Invariant: mnt.key.parent == nil. vd.Ok().
-func (mnt *Mount) storeKey(vd VirtualDentry) {
+func (mnt *Mount) setKey(vd VirtualDentry) {
atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount))
atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry))
}
+func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) }
+
// mountTable maps (mount parent, mount point) pairs to mounts. It supports
// efficient concurrent lookup, even in the presence of concurrent mutators
// (provided mutation is sufficiently uncommon).
@@ -92,6 +98,7 @@ type mountTable struct {
// length and cap in separate uint32s) for ~free.
size uint64
+ // FIXME(gvisor.dev/issue/1663): Slots need to be saved.
slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init
}
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index dfc8573fd..bc79e5ecc 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -21,6 +21,8 @@ import (
// GetDentryOptions contains options to VirtualFilesystem.GetDentryAt() and
// FilesystemImpl.GetDentryAt().
+//
+// +stateify savable
type GetDentryOptions struct {
// If CheckSearchable is true, FilesystemImpl.GetDentryAt() must check that
// the returned Dentry is a directory for which creds has search
@@ -30,6 +32,8 @@ type GetDentryOptions struct {
// MkdirOptions contains options to VirtualFilesystem.MkdirAt() and
// FilesystemImpl.MkdirAt().
+//
+// +stateify savable
type MkdirOptions struct {
// Mode is the file mode bits for the created directory.
Mode linux.FileMode
@@ -56,6 +60,8 @@ type MkdirOptions struct {
// MknodOptions contains options to VirtualFilesystem.MknodAt() and
// FilesystemImpl.MknodAt().
+//
+// +stateify savable
type MknodOptions struct {
// Mode is the file type and mode bits for the created file.
Mode linux.FileMode
@@ -72,6 +78,8 @@ type MknodOptions struct {
// MountFlags contains flags as specified for mount(2), e.g. MS_NOEXEC.
// MS_RDONLY is not part of MountFlags because it's tracked in Mount.writers.
+//
+// +stateify savable
type MountFlags struct {
// NoExec is equivalent to MS_NOEXEC.
NoExec bool
@@ -93,6 +101,8 @@ type MountFlags struct {
}
// MountOptions contains options to VirtualFilesystem.MountAt().
+//
+// +stateify savable
type MountOptions struct {
// Flags contains flags as specified for mount(2), e.g. MS_NOEXEC.
Flags MountFlags
@@ -103,13 +113,17 @@ type MountOptions struct {
// GetFilesystemOptions contains options to FilesystemType.GetFilesystem().
GetFilesystemOptions GetFilesystemOptions
- // If InternalMount is true, allow the use of filesystem types for which
- // RegisterFilesystemTypeOptions.AllowUserMount == false.
+ // InternalMount indicates whether the mount operation is coming from the
+ // application, i.e. through mount(2). If InternalMount is true, allow the use
+ // of filesystem types for which RegisterFilesystemTypeOptions.AllowUserMount
+ // == false.
InternalMount bool
}
// OpenOptions contains options to VirtualFilesystem.OpenAt() and
// FilesystemImpl.OpenAt().
+//
+// +stateify savable
type OpenOptions struct {
// Flags contains access mode and flags as specified for open(2).
//
@@ -135,6 +149,8 @@ type OpenOptions struct {
// ReadOptions contains options to FileDescription.PRead(),
// FileDescriptionImpl.PRead(), FileDescription.Read(), and
// FileDescriptionImpl.Read().
+//
+// +stateify savable
type ReadOptions struct {
// Flags contains flags as specified for preadv2(2).
Flags uint32
@@ -142,6 +158,8 @@ type ReadOptions struct {
// RenameOptions contains options to VirtualFilesystem.RenameAt() and
// FilesystemImpl.RenameAt().
+//
+// +stateify savable
type RenameOptions struct {
// Flags contains flags as specified for renameat2(2).
Flags uint32
@@ -153,6 +171,8 @@ type RenameOptions struct {
// SetStatOptions contains options to VirtualFilesystem.SetStatAt(),
// FilesystemImpl.SetStatAt(), FileDescription.SetStat(), and
// FileDescriptionImpl.SetStat().
+//
+// +stateify savable
type SetStatOptions struct {
// Stat is the metadata that should be set. Only fields indicated by
// Stat.Mask should be set.
@@ -174,6 +194,8 @@ type SetStatOptions struct {
// BoundEndpointOptions contains options to VirtualFilesystem.BoundEndpointAt()
// and FilesystemImpl.BoundEndpointAt().
+//
+// +stateify savable
type BoundEndpointOptions struct {
// Addr is the path of the file whose socket endpoint is being retrieved.
// It is generally irrelevant: most endpoints are stored at a dentry that
@@ -190,10 +212,12 @@ type BoundEndpointOptions struct {
Addr string
}
-// GetxattrOptions contains options to VirtualFilesystem.GetxattrAt(),
-// FilesystemImpl.GetxattrAt(), FileDescription.Getxattr(), and
-// FileDescriptionImpl.Getxattr().
-type GetxattrOptions struct {
+// GetXattrOptions contains options to VirtualFilesystem.GetXattrAt(),
+// FilesystemImpl.GetXattrAt(), FileDescription.GetXattr(), and
+// FileDescriptionImpl.GetXattr().
+//
+// +stateify savable
+type GetXattrOptions struct {
// Name is the name of the extended attribute to retrieve.
Name string
@@ -204,10 +228,12 @@ type GetxattrOptions struct {
Size uint64
}
-// SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(),
-// FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and
-// FileDescriptionImpl.Setxattr().
-type SetxattrOptions struct {
+// SetXattrOptions contains options to VirtualFilesystem.SetXattrAt(),
+// FilesystemImpl.SetXattrAt(), FileDescription.SetXattr(), and
+// FileDescriptionImpl.SetXattr().
+//
+// +stateify savable
+type SetXattrOptions struct {
// Name is the name of the extended attribute being mutated.
Name string
@@ -221,6 +247,8 @@ type SetxattrOptions struct {
// StatOptions contains options to VirtualFilesystem.StatAt(),
// FilesystemImpl.StatAt(), FileDescription.Stat(), and
// FileDescriptionImpl.Stat().
+//
+// +stateify savable
type StatOptions struct {
// Mask is the set of fields in the returned Statx that the FilesystemImpl
// or FileDescriptionImpl should provide. Bits are as in linux.Statx.Mask.
@@ -238,6 +266,8 @@ type StatOptions struct {
}
// UmountOptions contains options to VirtualFilesystem.UmountAt().
+//
+// +stateify savable
type UmountOptions struct {
// Flags contains flags as specified for umount2(2).
Flags uint32
@@ -246,6 +276,8 @@ type UmountOptions struct {
// WriteOptions contains options to FileDescription.PWrite(),
// FileDescriptionImpl.PWrite(), FileDescription.Write(), and
// FileDescriptionImpl.Write().
+//
+// +stateify savable
type WriteOptions struct {
// Flags contains flags as specified for pwritev2(2).
Flags uint32
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index 014b928ed..d48520d58 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -16,6 +16,7 @@ package vfs
import (
"math"
+ "strings"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -25,6 +26,8 @@ import (
)
// AccessTypes is a bitmask of Unix file permissions.
+//
+// +stateify savable
type AccessTypes uint16
// Bits in AccessTypes.
@@ -284,3 +287,40 @@ func CheckLimit(ctx context.Context, offset, size int64) (int64, error) {
}
return size, nil
}
+
+// CheckXattrPermissions checks permissions for extended attribute access.
+// This is analogous to fs/xattr.c:xattr_permission(). Some key differences:
+// * Does not check for read-only filesystem property.
+// * Does not check inode immutability or append only mode. In both cases EPERM
+// must be returned by filesystem implementations.
+// * Does not do inode permission checks. Filesystem implementations should
+// handle inode permission checks as they may differ across implementations.
+func CheckXattrPermissions(creds *auth.Credentials, ats AccessTypes, mode linux.FileMode, kuid auth.KUID, name string) error {
+ switch {
+ case strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX):
+ // The trusted.* namespace can only be accessed by privileged
+ // users.
+ if creds.HasCapability(linux.CAP_SYS_ADMIN) {
+ return nil
+ }
+ if ats.MayWrite() {
+ return syserror.EPERM
+ }
+ return syserror.ENODATA
+ case strings.HasPrefix(name, linux.XATTR_USER_PREFIX):
+ // In the user.* namespace, only regular files and directories can have
+ // extended attributes. For sticky directories, only the owner and
+ // privileged users can write attributes.
+ filetype := mode.FileType()
+ if filetype != linux.ModeRegular && filetype != linux.ModeDirectory {
+ if ats.MayWrite() {
+ return syserror.EPERM
+ }
+ return syserror.ENODATA
+ }
+ if filetype == linux.ModeDirectory && mode&linux.ModeSticky != 0 && ats.MayWrite() && !CanActAsOwner(creds, kuid) {
+ return syserror.EPERM
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index 3304372d9..e4fd55012 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -35,6 +35,8 @@ import (
// FilesystemImpl methods.
//
// ResolvingPath is loosely analogous to Linux's struct nameidata.
+//
+// +stateify savable
type ResolvingPath struct {
vfs *VirtualFilesystem
root VirtualDentry // refs borrowed from PathOperation
@@ -88,6 +90,7 @@ func init() {
// so error "constants" are really mutable vars, necessitating somewhat
// expensive interface object comparisons.
+// +stateify savable
type resolveMountRootOrJumpError struct{}
// Error implements error.Error.
@@ -95,6 +98,7 @@ func (resolveMountRootOrJumpError) Error() string {
return "resolving mount root or jump"
}
+// +stateify savable
type resolveMountPointError struct{}
// Error implements error.Error.
@@ -102,6 +106,7 @@ func (resolveMountPointError) Error() string {
return "resolving mount point"
}
+// +stateify savable
type resolveAbsSymlinkError struct{}
// Error implements error.Error.
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index ec27562d6..5bd756ea5 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -163,6 +163,8 @@ func (vfs *VirtualFilesystem) Init(ctx context.Context) error {
// PathOperation is passed to VFS methods by pointer to reduce memory copying:
// it's somewhat large and should never escape. (Options structs are passed by
// pointer to VFS and FileDescription methods for the same reason.)
+//
+// +stateify savable
type PathOperation struct {
// Root is the VFS root. References on Root are borrowed from the provider
// of the PathOperation.
@@ -297,6 +299,8 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential
// MkdirAt creates a directory at the given path.
func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error {
if !pop.Path.Begin.Ok() {
+ // pop.Path should not be empty in operations that create/delete files.
+ // This is consistent with mkdirat(dirfd, "", mode).
if pop.Path.Absolute {
return syserror.EEXIST
}
@@ -333,6 +337,8 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
// error from the syserror package.
func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error {
if !pop.Path.Begin.Ok() {
+ // pop.Path should not be empty in operations that create/delete files.
+ // This is consistent with mknodat(dirfd, "", mode, dev).
if pop.Path.Absolute {
return syserror.EEXIST
}
@@ -518,6 +524,8 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
// RmdirAt removes the directory at the given path.
func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error {
if !pop.Path.Begin.Ok() {
+ // pop.Path should not be empty in operations that create/delete files.
+ // This is consistent with unlinkat(dirfd, "", AT_REMOVEDIR).
if pop.Path.Absolute {
return syserror.EBUSY
}
@@ -599,6 +607,8 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti
// SymlinkAt creates a symbolic link at the given path with the given target.
func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, target string) error {
if !pop.Path.Begin.Ok() {
+ // pop.Path should not be empty in operations that create/delete files.
+ // This is consistent with symlinkat(oldpath, newdirfd, "").
if pop.Path.Absolute {
return syserror.EEXIST
}
@@ -631,6 +641,8 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent
// UnlinkAt deletes the non-directory file at the given path.
func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error {
if !pop.Path.Begin.Ok() {
+ // pop.Path should not be empty in operations that create/delete files.
+ // This is consistent with unlinkat(dirfd, "", 0).
if pop.Path.Absolute {
return syserror.EBUSY
}
@@ -662,12 +674,6 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti
// BoundEndpointAt gets the bound endpoint at the given path, if one exists.
func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *BoundEndpointOptions) (transport.BoundEndpoint, error) {
- if !pop.Path.Begin.Ok() {
- if pop.Path.Absolute {
- return nil, syserror.ECONNREFUSED
- }
- return nil, syserror.ENOENT
- }
rp := vfs.getResolvingPath(creds, pop)
for {
bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts)
@@ -687,12 +693,12 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C
}
}
-// ListxattrAt returns all extended attribute names for the file at the given
+// ListXattrAt returns all extended attribute names for the file at the given
// path.
-func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) {
+func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) {
rp := vfs.getResolvingPath(creds, pop)
for {
- names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp, size)
+ names, err := rp.mount.fs.impl.ListXattrAt(ctx, rp, size)
if err == nil {
vfs.putResolvingPath(ctx, rp)
return names, nil
@@ -712,12 +718,12 @@ func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Crede
}
}
-// GetxattrAt returns the value associated with the given extended attribute
+// GetXattrAt returns the value associated with the given extended attribute
// for the file at the given path.
-func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetxattrOptions) (string, error) {
+func (vfs *VirtualFilesystem) GetXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetXattrOptions) (string, error) {
rp := vfs.getResolvingPath(creds, pop)
for {
- val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, *opts)
+ val, err := rp.mount.fs.impl.GetXattrAt(ctx, rp, *opts)
if err == nil {
vfs.putResolvingPath(ctx, rp)
return val, nil
@@ -729,12 +735,12 @@ func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Creden
}
}
-// SetxattrAt changes the value associated with the given extended attribute
+// SetXattrAt changes the value associated with the given extended attribute
// for the file at the given path.
-func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetxattrOptions) error {
+func (vfs *VirtualFilesystem) SetXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetXattrOptions) error {
rp := vfs.getResolvingPath(creds, pop)
for {
- err := rp.mount.fs.impl.SetxattrAt(ctx, rp, *opts)
+ err := rp.mount.fs.impl.SetXattrAt(ctx, rp, *opts)
if err == nil {
vfs.putResolvingPath(ctx, rp)
return nil
@@ -746,11 +752,11 @@ func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Creden
}
}
-// RemovexattrAt removes the given extended attribute from the file at rp.
-func (vfs *VirtualFilesystem) RemovexattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error {
+// RemoveXattrAt removes the given extended attribute from the file at rp.
+func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error {
rp := vfs.getResolvingPath(creds, pop)
for {
- err := rp.mount.fs.impl.RemovexattrAt(ctx, rp, name)
+ err := rp.mount.fs.impl.RemoveXattrAt(ctx, rp, name)
if err == nil {
vfs.putResolvingPath(ctx, rp)
return nil
@@ -815,6 +821,30 @@ func (vfs *VirtualFilesystem) MkdirAllAt(ctx context.Context, currentPath string
return nil
}
+// MakeSyntheticMountpoint creates parent directories of target if they do not
+// exist and attempts to create a directory for the mountpoint. If a
+// non-directory file already exists there then we allow it.
+func (vfs *VirtualFilesystem) MakeSyntheticMountpoint(ctx context.Context, target string, root VirtualDentry, creds *auth.Credentials) error {
+ mkdirOpts := &MkdirOptions{Mode: 0777, ForSyntheticMountpoint: true}
+
+ // Make sure the parent directory of target exists.
+ if err := vfs.MkdirAllAt(ctx, path.Dir(target), root, creds, mkdirOpts); err != nil {
+ return fmt.Errorf("failed to create parent directory of mountpoint %q: %w", target, err)
+ }
+
+ // Attempt to mkdir the final component. If a file (of any type) exists
+ // then we let allow mounting on top of that because we do not require the
+ // target to be an existing directory, unlike Linux mount(2).
+ if err := vfs.MkdirAt(ctx, creds, &PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(target),
+ }, mkdirOpts); err != nil && err != syserror.EEXIST {
+ return fmt.Errorf("failed to create mountpoint %q: %w", target, err)
+ }
+ return nil
+}
+
// A VirtualDentry represents a node in a VFS tree, by combining a Dentry
// (which represents a node in a Filesystem's tree) and a Mount (which
// represents the Filesystem's position in a VFS mount tree).
diff --git a/pkg/state/types.go b/pkg/state/types.go
index 215ef80f8..84aed8732 100644
--- a/pkg/state/types.go
+++ b/pkg/state/types.go
@@ -107,6 +107,14 @@ func lookupNameFields(typ reflect.Type) (string, []string, bool) {
}
return name, nil, true
}
+ // Sanity check the type.
+ if raceEnabled {
+ if _, ok := reverseTypeDatabase[typ]; !ok {
+ // The type was not registered? Must be an embedded
+ // structure or something else.
+ return "", nil, false
+ }
+ }
// Extract the name from the object.
name := t.StateTypeName()
fields := t.StateFields()
@@ -313,6 +321,9 @@ var primitiveTypeDatabase = func() map[string]reflect.Type {
// globalTypeDatabase is used for dispatching interfaces on decode.
var globalTypeDatabase = map[string]reflect.Type{}
+// reverseTypeDatabase is a reverse mapping.
+var reverseTypeDatabase = map[reflect.Type]string{}
+
// Register registers a type.
//
// This must be called on init and only done once.
@@ -358,4 +369,7 @@ func Register(t Type) {
Failf("conflicting name for %T: matches interfaceType", t)
}
globalTypeDatabase[name] = typ
+ if raceEnabled {
+ reverseTypeDatabase[typ] = name
+ }
}
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index 4d47207f7..68535c3b1 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -38,6 +38,7 @@ go_library(
"race_unsafe.go",
"rwmutex_unsafe.go",
"seqcount.go",
+ "spin_unsafe.go",
"sync.go",
],
marshal = False,
diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/seqatomic_unsafe.go
index eda6fb131..2184cb5ab 100644
--- a/pkg/sync/seqatomic_unsafe.go
+++ b/pkg/sync/seqatomic_unsafe.go
@@ -25,41 +25,35 @@ import (
type Value struct{}
// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race
-// with any writer critical sections in sc.
-func SeqAtomicLoad(sc *sync.SeqCount, ptr *Value) Value {
- // This function doesn't use SeqAtomicTryLoad because doing so is
- // measurably, significantly (~20%) slower; Go is awful at inlining.
- var val Value
+// with any writer critical sections in seq.
+//
+//go:nosplit
+func SeqAtomicLoad(seq *sync.SeqCount, ptr *Value) Value {
for {
- epoch := sc.BeginRead()
- if sync.RaceEnabled {
- // runtime.RaceDisable() doesn't actually stop the race detector,
- // so it can't help us here. Instead, call runtime.memmove
- // directly, which is not instrumented by the race detector.
- sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
- } else {
- // This is ~40% faster for short reads than going through memmove.
- val = *ptr
- }
- if sc.ReadOk(epoch) {
- break
+ if val, ok := SeqAtomicTryLoad(seq, seq.BeginRead(), ptr); ok {
+ return val
}
}
- return val
}
// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section
-// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read
-// would race with a writer critical section, SeqAtomicTryLoad returns
+// in seq initiated by a call to seq.BeginRead() that returned epoch. If the
+// read would race with a writer critical section, SeqAtomicTryLoad returns
// (unspecified, false).
-func SeqAtomicTryLoad(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (Value, bool) {
- var val Value
+//
+//go:nosplit
+func SeqAtomicTryLoad(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (val Value, ok bool) {
if sync.RaceEnabled {
+ // runtime.RaceDisable() doesn't actually stop the race detector, so it
+ // can't help us here. Instead, call runtime.memmove directly, which is
+ // not instrumented by the race detector.
sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
} else {
+ // This is ~40% faster for short reads than going through memmove.
val = *ptr
}
- return val, sc.ReadOk(epoch)
+ ok = seq.ReadOk(epoch)
+ return
}
func init() {
diff --git a/pkg/sync/seqcount.go b/pkg/sync/seqcount.go
index a1e895352..2c5d3df99 100644
--- a/pkg/sync/seqcount.go
+++ b/pkg/sync/seqcount.go
@@ -8,7 +8,6 @@ package sync
import (
"fmt"
"reflect"
- "runtime"
"sync/atomic"
)
@@ -43,9 +42,7 @@ type SeqCount struct {
}
// SeqCountEpoch tracks writer critical sections in a SeqCount.
-type SeqCountEpoch struct {
- val uint32
-}
+type SeqCountEpoch uint32
// We assume that:
//
@@ -83,12 +80,25 @@ type SeqCountEpoch struct {
// using this pattern. Most users of SeqCount will need to use the
// SeqAtomicLoad function template in seqatomic.go.
func (s *SeqCount) BeginRead() SeqCountEpoch {
- epoch := atomic.LoadUint32(&s.epoch)
- for epoch&1 != 0 {
- runtime.Gosched()
- epoch = atomic.LoadUint32(&s.epoch)
+ if epoch := atomic.LoadUint32(&s.epoch); epoch&1 == 0 {
+ return SeqCountEpoch(epoch)
+ }
+ return s.beginReadSlow()
+}
+
+func (s *SeqCount) beginReadSlow() SeqCountEpoch {
+ i := 0
+ for {
+ if canSpin(i) {
+ i++
+ doSpin()
+ } else {
+ goyield()
+ }
+ if epoch := atomic.LoadUint32(&s.epoch); epoch&1 == 0 {
+ return SeqCountEpoch(epoch)
+ }
}
- return SeqCountEpoch{epoch}
}
// ReadOk returns true if the reader critical section initiated by a previous
@@ -99,7 +109,7 @@ func (s *SeqCount) BeginRead() SeqCountEpoch {
// Reader critical sections do not need to be explicitly terminated; the last
// call to ReadOk is implicitly the end of the reader critical section.
func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool {
- return atomic.LoadUint32(&s.epoch) == epoch.val
+ return atomic.LoadUint32(&s.epoch) == uint32(epoch)
}
// BeginWrite indicates the beginning of a writer critical section.
diff --git a/pkg/sync/spin_unsafe.go b/pkg/sync/spin_unsafe.go
new file mode 100644
index 000000000..cafb2d065
--- /dev/null
+++ b/pkg/sync/spin_unsafe.go
@@ -0,0 +1,24 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.13
+// +build !go1.17
+
+// Check go:linkname function signatures when updating Go version.
+
+package sync
+
+import (
+ _ "unsafe" // for go:linkname
+)
+
+//go:linkname canSpin sync.runtime_canSpin
+func canSpin(i int) bool
+
+//go:linkname doSpin sync.runtime_doSpin
+func doSpin()
+
+//go:linkname goyield runtime.goyield
+func goyield()
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index fe9f50169..f516c8e46 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -33,6 +33,7 @@ var (
EBADFD = error(syscall.EBADFD)
EBUSY = error(syscall.EBUSY)
ECHILD = error(syscall.ECHILD)
+ ECONNABORTED = error(syscall.ECONNABORTED)
ECONNREFUSED = error(syscall.ECONNREFUSED)
ECONNRESET = error(syscall.ECONNRESET)
EDEADLK = error(syscall.EDEADLK)
diff --git a/pkg/syserror/syserror_test.go b/pkg/syserror/syserror_test.go
index 29719752e..7036467c4 100644
--- a/pkg/syserror/syserror_test.go
+++ b/pkg/syserror/syserror_test.go
@@ -24,27 +24,20 @@ import (
var globalError error
-func returnErrnoAsError() error {
- return syscall.EINVAL
-}
-
-func returnError() error {
- return syserror.EINVAL
-}
-
-func BenchmarkReturnErrnoAsError(b *testing.B) {
+func BenchmarkAssignErrno(b *testing.B) {
for i := b.N; i > 0; i-- {
- returnErrnoAsError()
+ globalError = syscall.EINVAL
}
}
-func BenchmarkReturnError(b *testing.B) {
+func BenchmarkAssignError(b *testing.B) {
for i := b.N; i > 0; i-- {
- returnError()
+ globalError = syserror.EINVAL
}
}
func BenchmarkCompareErrno(b *testing.B) {
+ globalError = syscall.EAGAIN
j := 0
for i := b.N; i > 0; i-- {
if globalError == syscall.EINVAL {
@@ -54,6 +47,7 @@ func BenchmarkCompareErrno(b *testing.B) {
}
func BenchmarkCompareError(b *testing.B) {
+ globalError = syserror.EAGAIN
j := 0
for i := b.N; i > 0; i-- {
if globalError == syserror.EINVAL {
@@ -63,6 +57,7 @@ func BenchmarkCompareError(b *testing.B) {
}
func BenchmarkSwitchErrno(b *testing.B) {
+ globalError = syscall.EPERM
j := 0
for i := b.N; i > 0; i-- {
switch globalError {
@@ -77,6 +72,7 @@ func BenchmarkSwitchErrno(b *testing.B) {
}
func BenchmarkSwitchError(b *testing.B) {
+ globalError = syserror.EPERM
j := 0
for i := b.N; i > 0; i-- {
switch globalError {
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 68a954a10..4f551cd92 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -245,7 +245,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {
// Accept implements net.Conn.Accept.
func (l *TCPListener) Accept() (net.Conn, error) {
- n, wq, err := l.ep.Accept()
+ n, wq, err := l.ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Create wait queue entry that notifies a channel.
@@ -254,7 +254,7 @@ func (l *TCPListener) Accept() (net.Conn, error) {
defer l.wq.EventUnregister(&waitEntry)
for {
- n, wq, err = l.ep.Accept()
+ n, wq, err = l.ep.Accept(nil)
if err != tcpip.ErrWouldBlock {
break
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index c975ad9cf..12b061def 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -61,8 +61,8 @@ func TestTimeouts(t *testing.T) {
func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
// Create the stack and add a NIC.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
})
if err := s.CreateNIC(NICID, loopback.New()); err != nil {
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
index 563bc78ea..c326fab54 100644
--- a/pkg/tcpip/buffer/BUILD
+++ b/pkg/tcpip/buffer/BUILD
@@ -14,6 +14,8 @@ go_library(
go_test(
name = "buffer_test",
size = "small",
- srcs = ["view_test.go"],
+ srcs = [
+ "view_test.go",
+ ],
library = ":buffer",
)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index b769094dc..71b2f1bda 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -118,18 +118,62 @@ func TTL(ttl uint8) NetworkChecker {
v = ip.HopLimit()
}
if v != ttl {
- t.Fatalf("Bad TTL, got %v, want %v", v, ttl)
+ t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
+ }
+ }
+}
+
+// IPFullLength creates a checker for the full IP packet length. The
+// expected size is checked against both the Total Length in the
+// header and the number of bytes received.
+func IPFullLength(packetLength uint16) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ var v uint16
+ var l uint16
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ v = ip.TotalLength()
+ l = uint16(len(ip))
+ case header.IPv6:
+ v = ip.PayloadLength() + header.IPv6FixedHeaderSize
+ l = uint16(len(ip))
+ default:
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip)
+ }
+ if l != packetLength {
+ t.Errorf("bad packet length, got = %d, want = %d", l, packetLength)
+ }
+ if v != packetLength {
+ t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength)
+ }
+ }
+}
+
+// IPv4HeaderLength creates a checker that checks the IPv4 Header length.
+func IPv4HeaderLength(headerLength int) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ if hl := ip.HeaderLength(); hl != uint8(headerLength) {
+ t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength)
+ }
+ default:
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip)
}
}
}
// PayloadLen creates a checker that checks the payload length.
-func PayloadLen(plen int) NetworkChecker {
+func PayloadLen(payloadLength int) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
- if l := len(h[0].Payload()); l != plen {
- t.Errorf("Bad payload length, got %v, want %v", l, plen)
+ if l := len(h[0].Payload()); l != payloadLength {
+ t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength)
}
}
}
@@ -143,7 +187,7 @@ func FragmentOffset(offset uint16) NetworkChecker {
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.FragmentOffset(); v != offset {
- t.Errorf("Bad fragment offset, got %v, want %v", v, offset)
+ t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset)
}
}
}
@@ -158,7 +202,7 @@ func FragmentFlags(flags uint8) NetworkChecker {
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.Flags(); v != flags {
- t.Errorf("Bad fragment offset, got %v, want %v", v, flags)
+ t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags)
}
}
}
@@ -208,7 +252,7 @@ func TOS(tos uint8, label uint32) NetworkChecker {
t.Helper()
if v, l := h[0].TOS(); v != tos || l != label {
- t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
+ t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label)
}
}
}
@@ -234,7 +278,7 @@ func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
t.Helper()
if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
}
ipv6Frag := header.IPv6Fragment(h[0].Payload())
@@ -261,7 +305,7 @@ func TCP(checkers ...TransportChecker) NetworkChecker {
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber)
}
// Verify the checksum.
@@ -297,7 +341,7 @@ func UDP(checkers ...TransportChecker) NetworkChecker {
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
}
udp := header.UDP(last.Payload())
@@ -316,7 +360,7 @@ func SrcPort(port uint16) TransportChecker {
t.Helper()
if p := h.SourcePort(); p != port {
- t.Errorf("Bad source port, got %v, want %v", p, port)
+ t.Errorf("Bad source port, got = %d, want = %d", p, port)
}
}
}
@@ -327,7 +371,7 @@ func DstPort(port uint16) TransportChecker {
t.Helper()
if p := h.DestinationPort(); p != port {
- t.Errorf("Bad destination port, got %v, want %v", p, port)
+ t.Errorf("Bad destination port, got = %d, want = %d", p, port)
}
}
}
@@ -339,7 +383,7 @@ func NoChecksum(noChecksum bool) TransportChecker {
udp, ok := h.(header.UDP)
if !ok {
- return
+ t.Fatalf("UDP header not found in h: %T", h)
}
if b := udp.Checksum() == 0; b != noChecksum {
@@ -348,50 +392,84 @@ func NoChecksum(noChecksum bool) TransportChecker {
}
}
-// SeqNum creates a checker that checks the sequence number.
-func SeqNum(seq uint32) TransportChecker {
+// TCPSeqNum creates a checker that checks the sequence number.
+func TCPSeqNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in h: %T", h)
}
if s := tcp.SequenceNumber(); s != seq {
- t.Errorf("Bad sequence number, got %v, want %v", s, seq)
+ t.Errorf("Bad sequence number, got = %d, want = %d", s, seq)
}
}
}
-// AckNum creates a checker that checks the ack number.
-func AckNum(seq uint32) TransportChecker {
+// TCPAckNum creates a checker that checks the ack number.
+func TCPAckNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in h: %T", h)
}
if s := tcp.AckNumber(); s != seq {
- t.Errorf("Bad ack number, got %v, want %v", s, seq)
+ t.Errorf("Bad ack number, got = %d, want = %d", s, seq)
}
}
}
-// Window creates a checker that checks the tcp window.
-func Window(window uint16) TransportChecker {
+// TCPWindow creates a checker that checks the tcp window.
+func TCPWindow(window uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in hdr : %T", h)
}
if w := tcp.WindowSize(); w != window {
- t.Errorf("Bad window, got 0x%x, want 0x%x", w, window)
+ t.Errorf("Bad window, got %d, want %d", w, window)
+ }
+ }
+}
+
+// TCPWindowGreaterThanEq creates a checker that checks that the TCP window
+// is greater than or equal to the provided value.
+func TCPWindowGreaterThanEq(window uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ t.Fatalf("TCP header not found in h: %T", h)
+ }
+
+ if w := tcp.WindowSize(); w < window {
+ t.Errorf("Bad window, got %d, want > %d", w, window)
+ }
+ }
+}
+
+// TCPWindowLessThanEq creates a checker that checks that the tcp window
+// is less than or equal to the provided value.
+func TCPWindowLessThanEq(window uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ t.Fatalf("TCP header not found in h: %T", h)
+ }
+
+ if w := tcp.WindowSize(); w > window {
+ t.Errorf("Bad window, got %d, want < %d", w, window)
}
}
}
@@ -403,7 +481,7 @@ func TCPFlags(flags uint8) TransportChecker {
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in h: %T", h)
}
if f := tcp.Flags(); f != flags {
@@ -420,7 +498,7 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker {
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in h: %T", h)
}
if f := tcp.Flags(); (f & mask) != (flags & mask) {
@@ -458,7 +536,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
case header.TCPOptionMSS:
v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
if wantOpts.MSS != v {
- t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
+ t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS)
}
foundMSS = true
i += 4
@@ -468,7 +546,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
}
v := int(opts[i+2])
if v != wantOpts.WS {
- t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS)
+ t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS)
}
foundWS = true
i += 3
@@ -517,7 +595,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
t.Error("TS option specified but the timestamp value is zero")
}
if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
- t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
+ t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr)
}
if wantOpts.SACKPermitted && !foundSACKPermitted {
t.Errorf("SACKPermitted option not found. Options: %x", opts)
@@ -555,7 +633,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
}
if opts[i+1] != 10 {
- t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
+ t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1])
}
tsVal = binary.BigEndian.Uint32(opts[i+2:])
tsEcr = binary.BigEndian.Uint32(opts[i+6:])
@@ -575,13 +653,13 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
}
if wantTS != foundTS {
- t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
+ t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS)
}
if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
- t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
+ t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal)
}
if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
- t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
+ t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr)
}
}
}
@@ -645,7 +723,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
}
if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
- t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
+ t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks)
}
}
}
@@ -690,10 +768,10 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
icmpv4, ok := h.(header.ICMPv4)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
}
}
}
@@ -705,10 +783,76 @@ func ICMPv4Code(want header.ICMPv4Code) TransportChecker {
icmpv4, ok := h.(header.ICMPv4)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident.
+func ICMPv4Ident(want uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Ident(); got != want {
+ t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence.
+func ICMPv4Seq(want uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Sequence(); got != want {
+ t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum.
+// This assumes that the payload exactly makes up the rest of the slice.
+func ICMPv4Checksum() TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ heldChecksum := icmpv4.Checksum()
+ icmpv4.SetChecksum(0)
+ newChecksum := ^header.Checksum(icmpv4, 0)
+ icmpv4.SetChecksum(heldChecksum)
+ if heldChecksum != newChecksum {
+ t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum)
+ }
+ }
+}
+
+// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet.
+func ICMPv4Payload(want []byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ payload := icmpv4.Payload()
+ if diff := cmp.Diff(payload, want); diff != "" {
+ t.Errorf("got ICMP payload mismatch (-want +got):\n%s", diff)
}
}
}
@@ -748,10 +892,10 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
icmpv6, ok := h.(header.ICMPv6)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
}
}
}
@@ -763,10 +907,10 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
icmpv6, ok := h.(header.ICMPv6)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
}
}
}
diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD
new file mode 100644
index 000000000..114d43df3
--- /dev/null
+++ b/pkg/tcpip/faketime/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "faketime",
+ srcs = ["faketime.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "@com_github_dpjacques_clockwork//:go_default_library",
+ ],
+)
+
+go_test(
+ name = "faketime_test",
+ size = "small",
+ srcs = [
+ "faketime_test.go",
+ ],
+ deps = [
+ "//pkg/tcpip/faketime",
+ ],
+)
diff --git a/pkg/tcpip/stack/fake_time_test.go b/pkg/tcpip/faketime/faketime.go
index 92c8cb534..f7a4fbde1 100644
--- a/pkg/tcpip/stack/fake_time_test.go
+++ b/pkg/tcpip/faketime/faketime.go
@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package stack
+// Package faketime provides a fake clock that implements tcpip.Clock interface.
+package faketime
import (
"container/heap"
@@ -23,7 +24,29 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-type fakeClock struct {
+// NullClock implements a clock that never advances.
+type NullClock struct{}
+
+var _ tcpip.Clock = (*NullClock)(nil)
+
+// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
+func (*NullClock) NowNanoseconds() int64 {
+ return 0
+}
+
+// NowMonotonic implements tcpip.Clock.NowMonotonic.
+func (*NullClock) NowMonotonic() int64 {
+ return 0
+}
+
+// AfterFunc implements tcpip.Clock.AfterFunc.
+func (*NullClock) AfterFunc(time.Duration, func()) tcpip.Timer {
+ return nil
+}
+
+// ManualClock implements tcpip.Clock and only advances manually with Advance
+// method.
+type ManualClock struct {
clock clockwork.FakeClock
// mu protects the fields below.
@@ -39,34 +62,35 @@ type fakeClock struct {
waitGroups map[time.Time]*sync.WaitGroup
}
-func newFakeClock() *fakeClock {
- return &fakeClock{
+// NewManualClock creates a new ManualClock instance.
+func NewManualClock() *ManualClock {
+ return &ManualClock{
clock: clockwork.NewFakeClock(),
times: &timeHeap{},
waitGroups: make(map[time.Time]*sync.WaitGroup),
}
}
-var _ tcpip.Clock = (*fakeClock)(nil)
+var _ tcpip.Clock = (*ManualClock)(nil)
// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (fc *fakeClock) NowNanoseconds() int64 {
- return fc.clock.Now().UnixNano()
+func (mc *ManualClock) NowNanoseconds() int64 {
+ return mc.clock.Now().UnixNano()
}
// NowMonotonic implements tcpip.Clock.NowMonotonic.
-func (fc *fakeClock) NowMonotonic() int64 {
- return fc.NowNanoseconds()
+func (mc *ManualClock) NowMonotonic() int64 {
+ return mc.NowNanoseconds()
}
// AfterFunc implements tcpip.Clock.AfterFunc.
-func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer {
- until := fc.clock.Now().Add(d)
- wg := fc.addWait(until)
- return &fakeTimer{
- clock: fc,
+func (mc *ManualClock) AfterFunc(d time.Duration, f func()) tcpip.Timer {
+ until := mc.clock.Now().Add(d)
+ wg := mc.addWait(until)
+ return &manualTimer{
+ clock: mc,
until: until,
- timer: fc.clock.AfterFunc(d, func() {
+ timer: mc.clock.AfterFunc(d, func() {
defer wg.Done()
f()
}),
@@ -75,110 +99,113 @@ func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer {
// addWait adds an additional wait to the WaitGroup for parallel execution of
// all work scheduled for t. Returns a reference to the WaitGroup modified.
-func (fc *fakeClock) addWait(t time.Time) *sync.WaitGroup {
- fc.mu.RLock()
- wg, ok := fc.waitGroups[t]
- fc.mu.RUnlock()
+func (mc *ManualClock) addWait(t time.Time) *sync.WaitGroup {
+ mc.mu.RLock()
+ wg, ok := mc.waitGroups[t]
+ mc.mu.RUnlock()
if ok {
wg.Add(1)
return wg
}
- fc.mu.Lock()
- heap.Push(fc.times, t)
- fc.mu.Unlock()
+ mc.mu.Lock()
+ heap.Push(mc.times, t)
+ mc.mu.Unlock()
wg = &sync.WaitGroup{}
wg.Add(1)
- fc.mu.Lock()
- fc.waitGroups[t] = wg
- fc.mu.Unlock()
+ mc.mu.Lock()
+ mc.waitGroups[t] = wg
+ mc.mu.Unlock()
return wg
}
// removeWait removes a wait from the WaitGroup for parallel execution of all
// work scheduled for t.
-func (fc *fakeClock) removeWait(t time.Time) {
- fc.mu.RLock()
- defer fc.mu.RUnlock()
+func (mc *ManualClock) removeWait(t time.Time) {
+ mc.mu.RLock()
+ defer mc.mu.RUnlock()
- wg := fc.waitGroups[t]
+ wg := mc.waitGroups[t]
wg.Done()
}
-// advance executes all work that have been scheduled to execute within d from
-// the current fake time. Blocks until all work has completed execution.
-func (fc *fakeClock) advance(d time.Duration) {
+// Advance executes all work that have been scheduled to execute within d from
+// the current time. Blocks until all work has completed execution.
+func (mc *ManualClock) Advance(d time.Duration) {
// Block until all the work is done
- until := fc.clock.Now().Add(d)
+ until := mc.clock.Now().Add(d)
for {
- fc.mu.Lock()
- if fc.times.Len() == 0 {
- fc.mu.Unlock()
- return
+ mc.mu.Lock()
+ if mc.times.Len() == 0 {
+ mc.mu.Unlock()
+ break
}
- t := heap.Pop(fc.times).(time.Time)
+ t := heap.Pop(mc.times).(time.Time)
if t.After(until) {
// No work to do
- heap.Push(fc.times, t)
- fc.mu.Unlock()
- return
+ heap.Push(mc.times, t)
+ mc.mu.Unlock()
+ break
}
- fc.mu.Unlock()
+ mc.mu.Unlock()
- diff := t.Sub(fc.clock.Now())
- fc.clock.Advance(diff)
+ diff := t.Sub(mc.clock.Now())
+ mc.clock.Advance(diff)
- fc.mu.RLock()
- wg := fc.waitGroups[t]
- fc.mu.RUnlock()
+ mc.mu.RLock()
+ wg := mc.waitGroups[t]
+ mc.mu.RUnlock()
wg.Wait()
- fc.mu.Lock()
- delete(fc.waitGroups, t)
- fc.mu.Unlock()
+ mc.mu.Lock()
+ delete(mc.waitGroups, t)
+ mc.mu.Unlock()
+ }
+ if now := mc.clock.Now(); until.After(now) {
+ mc.clock.Advance(until.Sub(now))
}
}
-type fakeTimer struct {
- clock *fakeClock
+type manualTimer struct {
+ clock *ManualClock
timer clockwork.Timer
mu sync.RWMutex
until time.Time
}
-var _ tcpip.Timer = (*fakeTimer)(nil)
+var _ tcpip.Timer = (*manualTimer)(nil)
// Reset implements tcpip.Timer.Reset.
-func (ft *fakeTimer) Reset(d time.Duration) {
- if !ft.timer.Reset(d) {
+func (t *manualTimer) Reset(d time.Duration) {
+ if !t.timer.Reset(d) {
return
}
- ft.mu.Lock()
- defer ft.mu.Unlock()
+ t.mu.Lock()
+ defer t.mu.Unlock()
- ft.clock.removeWait(ft.until)
- ft.until = ft.clock.clock.Now().Add(d)
- ft.clock.addWait(ft.until)
+ t.clock.removeWait(t.until)
+ t.until = t.clock.clock.Now().Add(d)
+ t.clock.addWait(t.until)
}
// Stop implements tcpip.Timer.Stop.
-func (ft *fakeTimer) Stop() bool {
- if !ft.timer.Stop() {
+func (t *manualTimer) Stop() bool {
+ if !t.timer.Stop() {
return false
}
- ft.mu.RLock()
- defer ft.mu.RUnlock()
+ t.mu.RLock()
+ defer t.mu.RUnlock()
- ft.clock.removeWait(ft.until)
+ t.clock.removeWait(t.until)
return true
}
diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go
new file mode 100644
index 000000000..c2704df2c
--- /dev/null
+++ b/pkg/tcpip/faketime/faketime_test.go
@@ -0,0 +1,95 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package faketime_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+)
+
+func TestManualClockAdvance(t *testing.T) {
+ const timeout = time.Millisecond
+ clock := faketime.NewManualClock()
+ start := clock.NowMonotonic()
+ clock.Advance(timeout)
+ if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want {
+ t.Errorf("got = %d, want = %d", got, want)
+ }
+}
+
+func TestManualClockAfterFunc(t *testing.T) {
+ const (
+ timeout1 = time.Millisecond // timeout for counter1
+ timeout2 = 2 * time.Millisecond // timeout for counter2
+ )
+ tests := []struct {
+ name string
+ advance time.Duration
+ wantCounter1 int
+ wantCounter2 int
+ }{
+ {
+ name: "before timeout1",
+ advance: timeout1 - 1,
+ wantCounter1: 0,
+ wantCounter2: 0,
+ },
+ {
+ name: "timeout1",
+ advance: timeout1,
+ wantCounter1: 1,
+ wantCounter2: 0,
+ },
+ {
+ name: "timeout2",
+ advance: timeout2,
+ wantCounter1: 1,
+ wantCounter2: 1,
+ },
+ {
+ name: "after timeout2",
+ advance: timeout2 + 1,
+ wantCounter1: 1,
+ wantCounter2: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ counter1 := 0
+ counter2 := 0
+ clock.AfterFunc(timeout1, func() {
+ counter1++
+ })
+ clock.AfterFunc(timeout2, func() {
+ counter2++
+ })
+ start := clock.NowMonotonic()
+ clock.Advance(test.advance)
+ if got, want := counter1, test.wantCounter1; got != want {
+ t.Errorf("got counter1 = %d, want = %d", got, want)
+ }
+ if got, want := counter2, test.wantCounter2; got != want {
+ t.Errorf("got counter2 = %d, want = %d", got, want)
+ }
+ if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want {
+ t.Errorf("got elapsed = %d, want = %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index be03fb086..504408878 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -31,6 +31,27 @@ const (
// ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
ICMPv4MinimumSize = 8
+ // ICMPv4MinimumErrorPayloadSize Is the smallest number of bytes of an
+ // errant packet's transport layer that an ICMP error type packet should
+ // attempt to send as per RFC 792 (see each type) and RFC 1122
+ // section 3.2.2 which states:
+ // Every ICMP error message includes the Internet header and at
+ // least the first 8 data octets of the datagram that triggered
+ // the error; more than 8 octets MAY be sent; this header and data
+ // MUST be unchanged from the received datagram.
+ //
+ // RFC 792 shows:
+ // 0 1 2 3
+ // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Type | Code | Checksum |
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | unused |
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Internet Header + 64 bits of Original Data Datagram |
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ ICMPv4MinimumErrorPayloadSize = 8
+
// ICMPv4ProtocolNumber is the ICMP transport protocol number.
ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
@@ -39,15 +60,19 @@ const (
icmpv4ChecksumOffset = 2
// icmpv4MTUOffset is the offset of the MTU field
- // in a ICMPv4FragmentationNeeded message.
+ // in an ICMPv4FragmentationNeeded message.
icmpv4MTUOffset = 6
// icmpv4IdentOffset is the offset of the ident field
- // in a ICMPv4EchoRequest/Reply message.
+ // in an ICMPv4EchoRequest/Reply message.
icmpv4IdentOffset = 4
+ // icmpv4PointerOffset is the offset of the pointer field
+ // in an ICMPv4ParamProblem message.
+ icmpv4PointerOffset = 4
+
// icmpv4SequenceOffset is the offset of the sequence field
- // in a ICMPv4EchoRequest/Reply message.
+ // in an ICMPv4EchoRequest/Reply message.
icmpv4SequenceOffset = 6
)
@@ -72,15 +97,23 @@ const (
ICMPv4InfoReply ICMPv4Type = 16
)
+// ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792.
+const (
+ ICMPv4TTLExceeded ICMPv4Code = 0
+)
+
// ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792.
const (
- ICMPv4TTLExceeded ICMPv4Code = 0
+ ICMPv4NetUnreachable ICMPv4Code = 0
ICMPv4HostUnreachable ICMPv4Code = 1
ICMPv4ProtoUnreachable ICMPv4Code = 2
ICMPv4PortUnreachable ICMPv4Code = 3
ICMPv4FragmentationNeeded ICMPv4Code = 4
)
+// ICMPv4UnusedCode is a code to use in ICMP messages where no code is needed.
+const ICMPv4UnusedCode ICMPv4Code = 0
+
// Type is the ICMP type field.
func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) }
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 20b01d8f4..6be31beeb 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -54,9 +54,17 @@ const (
// address.
ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize
- // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
+ // ICMPv6EchoMinimumSize is the minimum size of a valid echo packet.
ICMPv6EchoMinimumSize = 8
+ // ICMPv6ErrorHeaderSize is the size of an ICMP error packet header,
+ // as per RFC 4443, Apendix A, item 4 and the errata.
+ // ... all ICMP error messages shall have exactly
+ // 32 bits of type-specific data, so that receivers can reliably find
+ // the embedded invoking packet even when they don't recognize the
+ // ICMP message Type.
+ ICMPv6ErrorHeaderSize = 8
+
// ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
// destination unreachable packet.
ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize
@@ -69,6 +77,10 @@ const (
// in an ICMPv6 message.
icmpv6ChecksumOffset = 2
+ // icmpv6PointerOffset is the offset of the pointer
+ // in an ICMPv6 Parameter problem message.
+ icmpv6PointerOffset = 4
+
// icmpv6MTUOffset is the offset of the MTU field in an ICMPv6
// PacketTooBig message.
icmpv6MTUOffset = 4
@@ -89,9 +101,10 @@ const (
NDPHopLimit = 255
)
-// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
+// ICMPv6Type is the ICMP type field described in RFC 4443.
type ICMPv6Type byte
+// Values for use in the Type field of ICMPv6 packet from RFC 4433.
const (
ICMPv6DstUnreachable ICMPv6Type = 1
ICMPv6PacketTooBig ICMPv6Type = 2
@@ -109,7 +122,18 @@ const (
ICMPv6RedirectMsg ICMPv6Type = 137
)
-// ICMPv6Code is the ICMP code field described in RFC 4443.
+// IsErrorType returns true if the receiver is an ICMP error type.
+func (typ ICMPv6Type) IsErrorType() bool {
+ // Per RFC 4443 section 2.1:
+ // ICMPv6 messages are grouped into two classes: error messages and
+ // informational messages. Error messages are identified as such by a
+ // zero in the high-order bit of their message Type field values. Thus,
+ // error messages have message types from 0 to 127; informational
+ // messages have message types from 128 to 255.
+ return typ&0x80 == 0
+}
+
+// ICMPv6Code is the ICMP Code field described in RFC 4443.
type ICMPv6Code byte
// ICMP codes used with Destination Unreachable (Type 1). As per RFC 4443
@@ -132,9 +156,14 @@ const (
// ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4.
const (
+ // ICMPv6ErroneousHeader indicates an erroneous header field was encountered.
ICMPv6ErroneousHeader ICMPv6Code = 0
- ICMPv6UnknownHeader ICMPv6Code = 1
- ICMPv6UnknownOption ICMPv6Code = 2
+
+ // ICMPv6UnknownHeader indicates an unrecognized Next Header type encountered.
+ ICMPv6UnknownHeader ICMPv6Code = 1
+
+ // ICMPv6UnknownOption indicates an unrecognized IPv6 option was encountered.
+ ICMPv6UnknownOption ICMPv6Code = 2
)
// ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use
@@ -153,6 +182,16 @@ func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) }
+// TypeSpecific returns the type specific data field.
+func (b ICMPv6) TypeSpecific() uint32 {
+ return binary.BigEndian.Uint32(b[icmpv6PointerOffset:])
+}
+
+// SetTypeSpecific sets the type specific data field.
+func (b ICMPv6) SetTypeSpecific(val uint32) {
+ binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val)
+}
+
// Checksum is the ICMP checksum field.
func (b ICMPv6) Checksum() uint16 {
return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:])
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 680eafd16..b07d9991d 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -80,7 +80,8 @@ type IPv4Fields struct {
type IPv4 []byte
const (
- // IPv4MinimumSize is the minimum size of a valid IPv4 packet.
+ // IPv4MinimumSize is the minimum size of a valid IPv4 packet;
+ // i.e. a packet header with no options.
IPv4MinimumSize = 20
// IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
@@ -88,6 +89,16 @@ const (
// units, the header cannot exceed 15*4 = 60 bytes.
IPv4MaximumHeaderSize = 60
+ // IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload.
+ //
+ // Linux limits this to 65,515 octets (the max IP datagram size - the IPv4
+ // header size). But RFC 791 section 3.2 discusses the design of the IPv4
+ // fragment "allows 2**13 = 8192 fragments of 8 octets each for a total of
+ // 65,536 octets. Note that this is consistent with the the datagram total
+ // length field (of course, the header is counted in the total length and not
+ // in the fragments)."
+ IPv4MaximumPayloadSize = 65536
+
// MinIPFragmentPayloadSize is the minimum number of payload bytes that
// the first fragment must carry when an IPv4 packet is fragmented.
MinIPFragmentPayloadSize = 8
@@ -317,7 +328,7 @@ func IsV4MulticastAddress(addr tcpip.Address) bool {
}
// IsV4LoopbackAddress determines if the provided address is an IPv4 loopback
-// address (belongs to 127.0.0.1/8 subnet).
+// address (belongs to 127.0.0.0/8 subnet). See RFC 1122 section 3.2.1.3.
func IsV4LoopbackAddress(addr tcpip.Address) bool {
if len(addr) != IPv4AddressSize {
return false
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index ea3823898..ef454b313 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -34,6 +34,9 @@ const (
hopLimit = 7
v6SrcAddr = 8
v6DstAddr = v6SrcAddr + IPv6AddressSize
+
+ // IPv6FixedHeaderSize is the size of the fixed header.
+ IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize
)
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
@@ -69,11 +72,15 @@ type IPv6 []byte
const (
// IPv6MinimumSize is the minimum size of a valid IPv6 packet.
- IPv6MinimumSize = 40
+ IPv6MinimumSize = IPv6FixedHeaderSize
// IPv6AddressSize is the size, in bytes, of an IPv6 address.
IPv6AddressSize = 16
+ // IPv6MaximumPayloadSize is the maximum size of a valid IPv6 payload per
+ // RFC 8200 Section 4.5.
+ IPv6MaximumPayloadSize = 65535
+
// IPv6ProtocolNumber is IPv6's network protocol number.
IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
index 3499d8399..583c2c5d3 100644
--- a/pkg/tcpip/header/ipv6_extension_headers.go
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -149,6 +149,19 @@ func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator {
// obtained before modification is no longer used.
type IPv6OptionsExtHdrOptionsIterator struct {
reader bytes.Reader
+
+ // optionOffset is the number of bytes from the first byte of the
+ // options field to the beginning of the current option.
+ optionOffset uint32
+
+ // nextOptionOffset is the offset of the next option.
+ nextOptionOffset uint32
+}
+
+// OptionOffset returns the number of bytes parsed while processing the
+// option field of the current Extension Header.
+func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 {
+ return i.optionOffset
}
// IPv6OptionUnknownAction is the action that must be taken if the processing
@@ -226,6 +239,7 @@ func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {}
// the options data, or an error occured.
func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) {
for {
+ i.optionOffset = i.nextOptionOffset
temp, err := i.reader.ReadByte()
if err != nil {
// If we can't read the first byte of a new option, then we know the
@@ -238,6 +252,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
// know the option does not have Length and Data fields. End processing of
// the Pad1 option and continue processing the buffer as a new option.
if id == ipv6Pad1ExtHdrOptionIdentifier {
+ i.nextOptionOffset = i.optionOffset + 1
continue
}
@@ -254,41 +269,40 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF)
}
- // Special-case the variable length padding option to avoid a copy.
- if id == ipv6PadNExtHdrOptionIdentifier {
- // Do we have enough bytes in the reader for the PadN option?
- if n := i.reader.Len(); n < int(length) {
- // Reset the reader to effectively consume the remaining buffer.
- i.reader.Reset(nil)
-
- // We return the same error as if we failed to read a non-padding option
- // so consumers of this iterator don't need to differentiate between
- // padding and non-padding options.
- return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
- }
+ // Do we have enough bytes in the reader for the next option?
+ if n := i.reader.Len(); n < int(length) {
+ // Reset the reader to effectively consume the remaining buffer.
+ i.reader.Reset(nil)
+
+ // We return the same error as if we failed to read a non-padding option
+ // so consumers of this iterator don't need to differentiate between
+ // padding and non-padding options.
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
+ }
+
+ i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */
+ switch id {
+ case ipv6PadNExtHdrOptionIdentifier:
+ // Special-case the variable length padding option to avoid a copy.
if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil {
panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err))
}
-
- // End processing of the PadN option and continue processing the buffer as
- // a new option.
continue
- }
-
- bytes := make([]byte, length)
- if n, err := io.ReadFull(&i.reader, bytes); err != nil {
- // io.ReadFull may return io.EOF if i.reader has been exhausted. We use
- // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the
- // Length field found in the option.
- if err == io.EOF {
- err = io.ErrUnexpectedEOF
+ default:
+ bytes := make([]byte, length)
+ if n, err := io.ReadFull(&i.reader, bytes); err != nil {
+ // io.ReadFull may return io.EOF if i.reader has been exhausted. We use
+ // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the
+ // Length field found in the option.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
}
-
- return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
+ return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
}
-
- return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
}
}
@@ -382,6 +396,29 @@ type IPv6PayloadIterator struct {
// Indicates to the iterator that it should return the remaining payload as a
// raw payload on the next call to Next.
forceRaw bool
+
+ // headerOffset is the offset of the beginning of the current extension
+ // header starting from the beginning of the fixed header.
+ headerOffset uint32
+
+ // parseOffset is the byte offset into the current extension header of the
+ // field we are currently examining. It can be added to the header offset
+ // if the absolute offset within the packet is required.
+ parseOffset uint32
+
+ // nextOffset is the offset of the next header.
+ nextOffset uint32
+}
+
+// HeaderOffset returns the offset to the start of the extension
+// header most recently processed.
+func (i IPv6PayloadIterator) HeaderOffset() uint32 {
+ return i.headerOffset
+}
+
+// ParseOffset returns the number of bytes successfully parsed.
+func (i IPv6PayloadIterator) ParseOffset() uint32 {
+ return i.headerOffset + i.parseOffset
}
// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing
@@ -397,7 +434,8 @@ func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, pa
nextHdrIdentifier: nextHdrIdentifier,
payload: payload.Clone(nil),
// We need a buffer of size 1 for calls to bufio.Reader.ReadByte.
- reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1),
+ reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1),
+ nextOffset: IPv6FixedHeaderSize,
}
}
@@ -434,6 +472,8 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
// Next is unable to return anything because the iterator has reached the end of
// the payload, or an error occured.
func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
+ i.headerOffset = i.nextOffset
+ i.parseOffset = 0
// We could be forced to return i as a raw header when the previous header was
// a fragment extension header as the data following the fragment extension
// header may not be complete.
@@ -461,7 +501,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
return IPv6RoutingExtHdr(bytes), false, nil
case IPv6FragmentExtHdrIdentifier:
var data [6]byte
- // We ignore the returned bytes becauase we know the fragment extension
+ // We ignore the returned bytes because we know the fragment extension
// header specific data will fit in data.
nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:])
if err != nil {
@@ -519,10 +559,12 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
if err != nil {
return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
}
+ i.parseOffset++
var length uint8
length, err = i.reader.ReadByte()
i.payload.TrimFront(1)
+
if err != nil {
if fragmentHdr {
return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
@@ -534,6 +576,17 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
length = 0
}
+ // Make parseOffset point to the first byte of the Extension Header
+ // specific data.
+ i.parseOffset++
+
+ // length is in 8 byte chunks but doesn't include the first one.
+ // See RFC 8200 for each header type, sections 4.3-4.6 and the requirement
+ // in section 4.8 for new extension headers at the top of page 24.
+ // [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet
+ // units, not including the first 8 octets.
+ i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit)
+
bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded
if bytes == nil {
bytes = make([]byte, bytesLen)
diff --git a/pkg/tcpip/header/parse/BUILD b/pkg/tcpip/header/parse/BUILD
new file mode 100644
index 000000000..2adee9288
--- /dev/null
+++ b/pkg/tcpip/header/parse/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "parse",
+ srcs = ["parse.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go
new file mode 100644
index 000000000..5ca75c834
--- /dev/null
+++ b/pkg/tcpip/header/parse/parse.go
@@ -0,0 +1,168 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package parse provides utilities to parse packets.
+package parse
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// ARP populates pkt's network header with an ARP header found in
+// pkt.Data.
+//
+// Returns true if the header was successfully parsed.
+func ARP(pkt *stack.PacketBuffer) bool {
+ _, ok := pkt.NetworkHeader().Consume(header.ARPSize)
+ if ok {
+ pkt.NetworkProtocolNumber = header.ARPProtocolNumber
+ }
+ return ok
+}
+
+// IPv4 parses an IPv4 packet found in pkt.Data and populates pkt's network
+// header with the IPv4 header.
+//
+// Returns true if the header was successfully parsed.
+func IPv4(pkt *stack.PacketBuffer) bool {
+ hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return false
+ }
+ ipHdr := header.IPv4(hdr)
+
+ // Header may have options, determine the true header length.
+ headerLen := int(ipHdr.HeaderLength())
+ if headerLen < header.IPv4MinimumSize {
+ // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in
+ // order for the packet to be valid. Figure out if we want to reject this
+ // case.
+ headerLen = header.IPv4MinimumSize
+ }
+ hdr, ok = pkt.NetworkHeader().Consume(headerLen)
+ if !ok {
+ return false
+ }
+ ipHdr = header.IPv4(hdr)
+
+ pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
+ return true
+}
+
+// IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network
+// header with the IPv6 header.
+func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) {
+ hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return 0, 0, 0, false, false
+ }
+ ipHdr := header.IPv6(hdr)
+
+ // dataClone consists of:
+ // - Any IPv6 header bytes after the first 40 (i.e. extensions).
+ // - The transport header, if present.
+ // - Any other payload data.
+ views := [8]buffer.View{}
+ dataClone := pkt.Data.Clone(views[:])
+ dataClone.TrimFront(header.IPv6MinimumSize)
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)
+
+ // Iterate over the IPv6 extensions to find their length.
+ var nextHdr tcpip.TransportProtocolNumber
+ var extensionsSize int
+
+traverseExtensions:
+ for {
+ extHdr, done, err := it.Next()
+ if err != nil {
+ break
+ }
+
+ // If we exhaust the extension list, the entire packet is the IPv6 header
+ // and (possibly) extensions.
+ if done {
+ extensionsSize = dataClone.Size()
+ break
+ }
+
+ switch extHdr := extHdr.(type) {
+ case header.IPv6FragmentExtHdr:
+ if fragID == 0 && fragOffset == 0 && !fragMore {
+ fragID = extHdr.ID()
+ fragOffset = extHdr.FragmentOffset()
+ fragMore = extHdr.More()
+ }
+
+ case header.IPv6RawPayloadHeader:
+ // We've found the payload after any extensions.
+ extensionsSize = dataClone.Size() - extHdr.Buf.Size()
+ nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
+ break traverseExtensions
+
+ default:
+ // Any other extension is a no-op, keep looping until we find the payload.
+ }
+ }
+
+ // Put the IPv6 header with extensions in pkt.NetworkHeader().
+ hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize)
+ if !ok {
+ panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
+ }
+ ipHdr = header.IPv6(hdr)
+ pkt.Data.CapLength(int(ipHdr.PayloadLength()))
+ pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
+
+ return nextHdr, fragID, fragOffset, fragMore, true
+}
+
+// UDP parses a UDP packet found in pkt.Data and populates pkt's transport
+// header with the UDP header.
+//
+// Returns true if the header was successfully parsed.
+func UDP(pkt *stack.PacketBuffer) bool {
+ _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize)
+ pkt.TransportProtocolNumber = header.UDPProtocolNumber
+ return ok
+}
+
+// TCP parses a TCP packet found in pkt.Data and populates pkt's transport
+// header with the TCP header.
+//
+// Returns true if the header was successfully parsed.
+func TCP(pkt *stack.PacketBuffer) bool {
+ // TCP header is variable length, peek at it first.
+ hdrLen := header.TCPMinimumSize
+ hdr, ok := pkt.Data.PullUp(hdrLen)
+ if !ok {
+ return false
+ }
+
+ // If the header has options, pull those up as well.
+ if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() {
+ // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of
+ // packets.
+ hdrLen = offset
+ }
+
+ _, ok = pkt.TransportHeader().Consume(hdrLen)
+ pkt.TransportProtocolNumber = header.TCPProtocolNumber
+ return ok
+}
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index 9339d637f..98bdd29db 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -16,6 +16,7 @@ package header
import (
"encoding/binary"
+ "math"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -55,6 +56,10 @@ const (
// UDPMinimumSize is the minimum size of a valid UDP packet.
UDPMinimumSize = 8
+ // UDPMaximumSize is the maximum size of a valid UDP packet. The length field
+ // in the UDP header is 16 bits as per RFC 768.
+ UDPMaximumSize = math.MaxUint16
+
// UDPProtocolNumber is UDP's transport protocol number.
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
)
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 14b527bc2..6c410c5a6 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -18,3 +18,14 @@ go_library(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "rawfile_test",
+ srcs = [
+ "errors_test.go",
+ ],
+ library = "rawfile",
+ deps = [
+ "//pkg/tcpip",
+ ],
+)
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
index a0a873c84..604868fd8 100644
--- a/pkg/tcpip/link/rawfile/errors.go
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -31,10 +31,12 @@ var translations [maxErrno]*tcpip.Error
// *tcpip.Error.
//
// Valid, but unrecognized errnos will be translated to
-// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos.
+// tcpip.ErrInvalidEndpointState (EINVAL).
func TranslateErrno(e syscall.Errno) *tcpip.Error {
- if err := translations[e]; err != nil {
- return err
+ if e > 0 && e < syscall.Errno(len(translations)) {
+ if err := translations[e]; err != nil {
+ return err
+ }
}
return tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go
new file mode 100644
index 000000000..e4cdc66bd
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/errors_test.go
@@ -0,0 +1,53 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package rawfile
+
+import (
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestTranslateErrno(t *testing.T) {
+ for _, test := range []struct {
+ errno syscall.Errno
+ translated *tcpip.Error
+ }{
+ {
+ errno: syscall.Errno(0),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.Errno(maxErrno),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.Errno(514),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.EEXIST,
+ translated: tcpip.ErrDuplicateAddress,
+ },
+ } {
+ got := TranslateErrno(test.errno)
+ if got != test.translated {
+ t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated)
+ }
+ }
+}
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
index 7cbc305e7..4aac12a8c 100644
--- a/pkg/tcpip/link/sniffer/BUILD
+++ b/pkg/tcpip/link/sniffer/BUILD
@@ -14,6 +14,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/link/nested",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 4fb127978..560477926 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/link/nested"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -195,49 +196,52 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
var transProto uint8
src := tcpip.Address("unknown")
dst := tcpip.Address("unknown")
- id := 0
- size := uint16(0)
+ var size uint16
+ var id uint32
var fragmentOffset uint16
var moreFragments bool
- // Examine the packet using a new VV. Backing storage must not be written.
- vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
-
+ // Clone the packet buffer to not modify the original.
+ //
+ // We don't clone the original packet buffer so that the new packet buffer
+ // does not have any of its headers set.
+ pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())})
switch protocol {
case header.IPv4ProtocolNumber:
- hdr, ok := vv.PullUp(header.IPv4MinimumSize)
- if !ok {
+ if ok := parse.IPv4(pkt); !ok {
return
}
- ipv4 := header.IPv4(hdr)
+
+ ipv4 := header.IPv4(pkt.NetworkHeader().View())
fragmentOffset = ipv4.FragmentOffset()
moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments
src = ipv4.SourceAddress()
dst = ipv4.DestinationAddress()
transProto = ipv4.Protocol()
size = ipv4.TotalLength() - uint16(ipv4.HeaderLength())
- vv.TrimFront(int(ipv4.HeaderLength()))
- id = int(ipv4.ID())
+ id = uint32(ipv4.ID())
case header.IPv6ProtocolNumber:
- hdr, ok := vv.PullUp(header.IPv6MinimumSize)
+ proto, fragID, fragOffset, fragMore, ok := parse.IPv6(pkt)
if !ok {
return
}
- ipv6 := header.IPv6(hdr)
+
+ ipv6 := header.IPv6(pkt.NetworkHeader().View())
src = ipv6.SourceAddress()
dst = ipv6.DestinationAddress()
- transProto = ipv6.NextHeader()
+ transProto = uint8(proto)
size = ipv6.PayloadLength()
- vv.TrimFront(header.IPv6MinimumSize)
+ id = fragID
+ moreFragments = fragMore
+ fragmentOffset = fragOffset
case header.ARPProtocolNumber:
- hdr, ok := vv.PullUp(header.ARPSize)
- if !ok {
+ if parse.ARP(pkt) {
return
}
- vv.TrimFront(header.ARPSize)
- arp := header.ARP(hdr)
+
+ arp := header.ARP(pkt.NetworkHeader().View())
log.Infof(
"%s arp %s (%s) -> %s (%s) valid:%t",
prefix,
@@ -259,7 +263,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
switch tcpip.TransportProtocolNumber(transProto) {
case header.ICMPv4ProtocolNumber:
transName = "icmp"
- hdr, ok := vv.PullUp(header.ICMPv4MinimumSize)
+ hdr, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
if !ok {
break
}
@@ -296,7 +300,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.ICMPv6ProtocolNumber:
transName = "icmp"
- hdr, ok := vv.PullUp(header.ICMPv6MinimumSize)
+ hdr, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
if !ok {
break
}
@@ -331,11 +335,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.UDPProtocolNumber:
transName = "udp"
- hdr, ok := vv.PullUp(header.UDPMinimumSize)
- if !ok {
+ if ok := parse.UDP(pkt); !ok {
break
}
- udp := header.UDP(hdr)
+
+ udp := header.UDP(pkt.TransportHeader().View())
if fragmentOffset == 0 {
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
@@ -345,19 +349,19 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.TCPProtocolNumber:
transName = "tcp"
- hdr, ok := vv.PullUp(header.TCPMinimumSize)
- if !ok {
+ if ok := parse.TCP(pkt); !ok {
break
}
- tcp := header.TCP(hdr)
+
+ tcp := header.TCP(pkt.TransportHeader().View())
if fragmentOffset == 0 {
offset := int(tcp.DataOffset())
if offset < header.TCPMinimumSize {
details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset)
break
}
- if offset > vv.Size() && !moreFragments {
- details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, vv.Size())
+ if size := pkt.Data.Size() + len(tcp); offset > size && !moreFragments {
+ details += fmt.Sprintf("invalid packet: tcp data offset %d larger than tcp packet length %d", offset, size)
break
}
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 46083925c..59710352b 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -9,6 +9,7 @@ go_test(
"ip_test.go",
],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
@@ -17,6 +18,7 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
],
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index 82c073e32..b40dde96b 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 7aaee08c4..b47a7be51 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -15,20 +15,15 @@
// Package arp implements the ARP network protocol. It is used to resolve
// IPv4 addresses into link-local MAC addresses, and advertises IPv4
// addresses of its stack with the local network.
-//
-// To use it in the networking stack, pass arp.NewProtocol() as one of the
-// network protocols when calling stack.New. Then add an "arp" address to every
-// NIC on the stack that should respond to ARP requests. That is:
-//
-// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
-// // handle err
-// }
package arp
import (
+ "sync/atomic"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -40,15 +35,57 @@ const (
ProtocolAddress = tcpip.Address("arp")
)
-// endpoint implements stack.NetworkEndpoint.
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- protocol *protocol
- nicID tcpip.NICID
+ stack.AddressableEndpointState
+
+ protocol *protocol
+
+ // enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ nic stack.NetworkInterface
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
nud stack.NUDHandler
}
+func (e *endpoint) Enable() *tcpip.Error {
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ e.setEnabled(true)
+ return nil
+}
+
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
+}
+
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
+}
+
+// setEnabled sets the enabled status for the endpoint.
+func (e *endpoint) setEnabled(v bool) {
+ if v {
+ atomic.StoreUint32(&e.enabled, 1)
+ } else {
+ atomic.StoreUint32(&e.enabled, 0)
+ }
+}
+
+func (e *endpoint) Disable() {
+ e.setEnabled(false)
+}
+
// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
func (e *endpoint) DefaultTTL() uint8 {
return 0
@@ -59,19 +96,13 @@ func (e *endpoint) MTU() uint32 {
return lmtu - uint32(e.MaxHeaderLength())
}
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
-}
-
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
func (e *endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.ARPSize
}
-func (e *endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.AddressableEndpointState.Cleanup()
+}
func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
@@ -92,6 +123,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
}
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ if !e.isEnabled() {
+ return
+ }
+
h := header.ARP(pkt.NetworkHeader().View())
if !h.IsValid() {
return
@@ -102,15 +137,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
localAddr := tcpip.Address(h.ProtocolAddressTarget())
if e.nud == nil {
- if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
} else {
- if r.Stack().CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
@@ -136,7 +171,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
return
}
@@ -168,14 +203,16 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
- return &endpoint{
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &endpoint{
protocol: p,
- nicID: nicID,
- linkEP: sender,
+ nic: nic,
+ linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
nud: nud,
}
+ e.AddressableEndpointState.Init(e)
+ return e
}
// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
@@ -234,14 +271,14 @@ func (*protocol) Wait() {}
// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
- _, ok = pkt.NetworkHeader().Consume(header.ARPSize)
- if !ok {
- return 0, false, false
- }
- return 0, false, true
+ return 0, false, parse.ARP(pkt)
}
// NewProtocol returns an ARP network protocol.
-func NewProtocol() stack.NetworkProtocol {
+//
+// Note, to make sure that the ARP endpoint receives ARP packets, the "arp"
+// address must be added to every NIC that should respond to ARP requests. See
+// ProtocolAddress for more details.
+func NewProtocol(*stack.Stack) stack.NetworkProtocol {
return &protocol{}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 9c9a859e3..626af975a 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -176,8 +176,8 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext {
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
NUDConfigs: c,
NUDDisp: &d,
UseNeighborCache: useNeighborCache,
@@ -442,7 +442,7 @@ func TestLinkAddressRequest(t *testing.T) {
}
for _, test := range tests {
- p := arp.NewProtocol()
+ p := arp.NewProtocol(nil)
linkRes, ok := p.(stack.LinkAddressResolver)
if !ok {
t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 96c5f42f8..e247f06a4 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -43,5 +43,6 @@ go_test(
library = ":fragmentation",
deps = [
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/faketime",
],
)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 6a4843f92..e1909fab0 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -81,6 +81,8 @@ type Fragmentation struct {
size int
timeout time.Duration
blockSize uint16
+ clock tcpip.Clock
+ releaseJob *tcpip.Job
}
// NewFragmentation creates a new Fragmentation.
@@ -97,7 +99,7 @@ type Fragmentation struct {
// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
-func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
+func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation {
if lowMemoryLimit >= highMemoryLimit {
lowMemoryLimit = highMemoryLimit
}
@@ -110,13 +112,17 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
blockSize = minBlockSize
}
- return &Fragmentation{
+ f := &Fragmentation{
reassemblers: make(map[FragmentID]*reassembler),
highLimit: highMemoryLimit,
lowLimit: lowMemoryLimit,
timeout: reassemblingTimeout,
blockSize: blockSize,
+ clock: clock,
}
+ f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked)
+
+ return f
}
// Process processes an incoming fragment belonging to an ID and returns a
@@ -155,15 +161,17 @@ func (f *Fragmentation) Process(
f.mu.Lock()
r, ok := f.reassemblers[id]
- if ok && r.tooOld(f.timeout) {
- // This is very likely to be an id-collision or someone performing a slow-rate attack.
- f.release(r)
- ok = false
- }
if !ok {
- r = newReassembler(id)
+ r = newReassembler(id, f.clock)
f.reassemblers[id] = r
+ wasEmpty := f.rList.Empty()
f.rList.PushFront(r)
+ if wasEmpty {
+ // If we have just pushed a first reassembler into an empty list, we
+ // should kickstart the release job. The release job will keep
+ // rescheduling itself until the list becomes empty.
+ f.releaseReassemblersLocked()
+ }
}
f.mu.Unlock()
@@ -211,3 +219,27 @@ func (f *Fragmentation) release(r *reassembler) {
f.size = 0
}
}
+
+// releaseReassemblersLocked releases already-expired reassemblers, then
+// schedules the job to call back itself for the remaining reassemblers if
+// any. This function must be called with f.mu locked.
+func (f *Fragmentation) releaseReassemblersLocked() {
+ now := f.clock.NowMonotonic()
+ for {
+ // The reassembler at the end of the list is the oldest.
+ r := f.rList.Back()
+ if r == nil {
+ // The list is empty.
+ break
+ }
+ elapsed := time.Duration(now-r.creationTime) * time.Nanosecond
+ if f.timeout > elapsed {
+ // If the oldest reassembler has not expired, schedule the release
+ // job so that this function is called back when it has expired.
+ f.releaseJob.Schedule(f.timeout - elapsed)
+ break
+ }
+ // If the oldest reassembler has already expired, release it.
+ f.release(r)
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 416604659..189b223c5 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -21,6 +21,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
)
// vv is a helper to build VectorisedView from different strings.
@@ -95,7 +96,7 @@ var processTestCases = []struct {
func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout, &faketime.NullClock{})
firstFragmentProto := c.in[0].proto
for i, in := range c.in {
vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv)
@@ -131,25 +132,126 @@ func TestFragmentationProcess(t *testing.T) {
}
func TestReassemblingTimeout(t *testing.T) {
- timeout := time.Millisecond
- f := NewFragmentation(minBlockSize, 1024, 512, timeout)
- // Send first fragment with id = 0, first = 0, last = 0, and more = true.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
- // Sleep more than the timeout.
- time.Sleep(2 * timeout)
- // Send another fragment that completes a packet.
- // However, no packet should be reassembled because the fragment arrived after the timeout.
- _, _, done, err := f.Process(FragmentID{}, 1, 1, false, 0xFF, vv(1, "1"))
- if err != nil {
- t.Fatalf("f.Process(0, 1, 1, false, 0xFF, vv(1, \"1\")) failed: %v", err)
+ const (
+ reassemblyTimeout = time.Millisecond
+ protocol = 0xff
+ )
+
+ type fragment struct {
+ first uint16
+ last uint16
+ more bool
+ data string
}
- if done {
- t.Errorf("Fragmentation does not respect the reassembling timeout.")
+
+ type event struct {
+ // name is a nickname of this event.
+ name string
+
+ // clockAdvance is a duration to advance the clock. The clock advances
+ // before a fragment specified in the fragment field is processed.
+ clockAdvance time.Duration
+
+ // fragment is a fragment to process. This can be nil if there is no
+ // fragment to process.
+ fragment *fragment
+
+ // expectDone is true if the fragmentation instance should report the
+ // reassembly is done after the fragment is processd.
+ expectDone bool
+
+ // sizeAfterEvent is the expected size of the fragmentation instance after
+ // the event.
+ sizeAfterEvent int
+ }
+
+ half1 := &fragment{first: 0, last: 0, more: true, data: "0"}
+ half2 := &fragment{first: 1, last: 1, more: false, data: "1"}
+
+ tests := []struct {
+ name string
+ events []event
+ }{
+ {
+ name: "half1 and half2 are reassembled successfully",
+ events: []event{
+ {
+ name: "half1",
+ fragment: half1,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2",
+ fragment: half2,
+ expectDone: true,
+ sizeAfterEvent: 0,
+ },
+ },
+ },
+ {
+ name: "half1 timeout, half2 timeout",
+ events: []event{
+ {
+ name: "half1",
+ fragment: half1,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half1 just before reassembly timeout",
+ clockAdvance: reassemblyTimeout - 1,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half1 reassembly timeout",
+ clockAdvance: 1,
+ sizeAfterEvent: 0,
+ },
+ {
+ name: "half2",
+ fragment: half2,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2 just before reassembly timeout",
+ clockAdvance: reassemblyTimeout - 1,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2 reassembly timeout",
+ clockAdvance: 1,
+ sizeAfterEvent: 0,
+ },
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock)
+ for _, event := range test.events {
+ clock.Advance(event.clockAdvance)
+ if frag := event.fragment; frag != nil {
+ _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data))
+ if err != nil {
+ t.Fatalf("%s: f.Process failed: %s", event.name, err)
+ }
+ if done != event.expectDone {
+ t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone)
+ }
+ }
+ if got, want := f.size, event.sizeAfterEvent; got != want {
+ t.Errorf("%s: got f.size = %d, want = %d", event.name, got, want)
+ }
+ }
+ })
}
}
func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"))
// Send first fragment with id = 1.
@@ -173,7 +275,7 @@ func TestMemoryLimits(t *testing.T) {
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
// Send the same packet again.
@@ -268,7 +370,7 @@ func TestErrors(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout)
+ f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout, &faketime.NullClock{})
_, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data))
if !errors.Is(err, test.err) {
t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index f044867dc..9bb051a30 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -18,9 +18,9 @@ import (
"container/heap"
"fmt"
"math"
- "time"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -40,15 +40,15 @@ type reassembler struct {
deleted int
heap fragHeap
done bool
- creationTime time.Time
+ creationTime int64
}
-func newReassembler(id FragmentID) *reassembler {
+func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
r := &reassembler{
id: id,
holes: make([]hole, 0, 16),
heap: make(fragHeap, 0, 8),
- creationTime: time.Now(),
+ creationTime: clock.NowMonotonic(),
}
r.holes = append(r.holes, hole{
first: 0,
@@ -116,10 +116,6 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buf
return res, r.proto, true, consumed, nil
}
-func (r *reassembler) tooOld(timeout time.Duration) bool {
- return time.Now().Sub(r.creationTime) > timeout
-}
-
func (r *reassembler) checkDoneOrMark() bool {
r.mu.Lock()
prev := r.done
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index dff7c9dcb..a0a04a027 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -18,6 +18,8 @@ import (
"math"
"reflect"
"testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
)
type updateHolesInput struct {
@@ -94,7 +96,7 @@ var holesTestCases = []struct {
func TestUpdateHoles(t *testing.T) {
for _, c := range holesTestCases {
- r := newReassembler(FragmentID{})
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
for _, i := range c.in {
r.updateHoles(i.first, i.last, i.more)
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index e45dd17f8..6861cfdaf 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -17,6 +17,7 @@ package ip_test
import (
"testing"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -25,26 +26,35 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
- localIpv4Addr = "\x0a\x00\x00\x01"
- localIpv4PrefixLen = 24
- remoteIpv4Addr = "\x0a\x00\x00\x02"
- ipv4SubnetAddr = "\x0a\x00\x00\x00"
- ipv4SubnetMask = "\xff\xff\xff\x00"
- ipv4Gateway = "\x0a\x00\x00\x03"
- localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- localIpv6PrefixLen = 120
- remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
- ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
- nicID = 1
+ localIPv4Addr = "\x0a\x00\x00\x01"
+ remoteIPv4Addr = "\x0a\x00\x00\x02"
+ ipv4SubnetAddr = "\x0a\x00\x00\x00"
+ ipv4SubnetMask = "\xff\xff\xff\x00"
+ ipv4Gateway = "\x0a\x00\x00\x03"
+ localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
+ ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ nicID = 1
)
+var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: localIPv4Addr,
+ PrefixLen: 24,
+}
+
+var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: localIPv6Addr,
+ PrefixLen: 120,
+}
+
// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
// The former is used to pretend that it's a link endpoint so that we can
// inspect packets written by the network endpoints. The latter is used to
@@ -98,9 +108,10 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) {
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition {
t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
t.dataCalls++
+ return stack.TransportPacketHandled
}
// DeliverTransportControlPacket is called by network endpoints after parsing
@@ -194,8 +205,8 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
s.CreateNIC(nicID, loopback.New())
s.AddAddress(nicID, ipv4.ProtocolNumber, local)
@@ -210,8 +221,8 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
s.CreateNIC(nicID, loopback.New())
s.AddAddress(nicID, ipv6.ProtocolNumber, local)
@@ -224,33 +235,294 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
}
-func buildDummyStack(t *testing.T) *stack.Stack {
+func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpoint) {
t.Helper()
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
e := channel.New(0, 1280, "")
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, localIpv4Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, localIpv4Addr, err)
+ v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
+ if err := s.AddProtocolAddress(nicID, v4Addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, localIpv6Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, localIpv6Addr, err)
+ v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
+ if err := s.AddProtocolAddress(nicID, v6Addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err)
}
+ return s, e
+}
+
+func buildDummyStack(t *testing.T) *stack.Stack {
+ t.Helper()
+
+ s, _ := buildDummyStackWithLinkEndpoint(t)
return s
}
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ tester testObject
+
+ mu struct {
+ sync.RWMutex
+ disabled bool
+ }
+}
+
+func (*testInterface) ID() tcpip.NICID {
+ return nicID
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (t *testInterface) Enabled() bool {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ return !t.mu.disabled
+}
+
+func (t *testInterface) setEnabled(v bool) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.mu.disabled = !v
+}
+
+func (t *testInterface) LinkEndpoint() stack.LinkEndpoint {
+ return &t.tester
+}
+
+func TestSourceAddressValidation(t *testing.T) {
+ rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv4Addr,
+ })
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv6Addr,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ tests := []struct {
+ name string
+ srcAddress tcpip.Address
+ rxICMP func(*channel.Endpoint, tcpip.Address)
+ valid bool
+ }{
+ {
+ name: "IPv4 valid",
+ srcAddress: "\x01\x02\x03\x04",
+ rxICMP: rxIPv4ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv6 valid",
+ srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10",
+ rxICMP: rxIPv6ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv4 unspecified",
+ srcAddress: header.IPv4Any,
+ rxICMP: rxIPv4ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv6 unspecified",
+ srcAddress: header.IPv4Any,
+ rxICMP: rxIPv6ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv4 multicast",
+ srcAddress: "\xe0\x00\x00\x01",
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv6 multicast",
+ srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ rxICMP: rxIPv6ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv4 broadcast",
+ srcAddress: header.IPv4Broadcast,
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv4 subnet broadcast",
+ srcAddress: func() tcpip.Address {
+ subnet := localIPv4AddrWithPrefix.Subnet()
+ return subnet.Broadcast()
+ }(),
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, e := buildDummyStackWithLinkEndpoint(t)
+ test.rxICMP(e, test.srcAddress)
+
+ var wantValid uint64
+ if test.valid {
+ wantValid = 1
+ }
+
+ if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want {
+ t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
+ }
+ if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid {
+ t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid)
+ }
+ })
+ }
+}
+
+func TestEnableWhenNICDisabled(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ }{
+ {
+ name: "IPv4",
+ protocolFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ },
+ {
+ name: "IPv6",
+ protocolFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var nic testInterface
+ nic.setEnabled(false)
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory},
+ })
+ p := s.NetworkProtocolInstance(test.protoNum)
+
+ // We pass nil for all parameters except the NetworkInterface and Stack
+ // since Enable only depends on these.
+ ep := p.NewEndpoint(&nic, nil, nil, nil)
+
+ // The endpoint should initially be disabled, regardless the NIC's enabled
+ // status.
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+ nic.setEnabled(true)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Attempting to enable the endpoint while the NIC is disabled should
+ // fail.
+ nic.setEnabled(false)
+ if err := ep.Enable(); err != tcpip.ErrNotPermitted {
+ t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted)
+ }
+ // ep should consider the NIC's enabled status when determining its own
+ // enabled status so we "enable" the NIC to read just the endpoint's
+ // enabled status.
+ nic.setEnabled(true)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Enabling the interface after the NIC has been enabled should succeed.
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+ if !ep.Enabled() {
+ t.Fatal("got ep.Enabled() = false, want = true")
+ }
+
+ // ep should consider the NIC's enabled status when determining its own
+ // enabled status.
+ nic.setEnabled(false)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Disabling the endpoint when the NIC is enabled should make the endpoint
+ // disabled.
+ nic.setEnabled(true)
+ ep.Disable()
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+ })
+ }
+}
+
func TestIPv4Send(t *testing.T) {
- o := testObject{t: t, v4: true}
- proto := ipv4.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, nil, nil, &o, buildDummyStack(t))
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, nil)
defer ep.Close()
// Allocate and initialize the payload view.
@@ -266,12 +538,12 @@ func TestIPv4Send(t *testing.T) {
})
// Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv4Addr
- o.dstAddr = remoteIpv4Addr
- o.contents = payload
+ nic.tester.protocol = 123
+ nic.tester.srcAddr = localIPv4Addr
+ nic.tester.dstAddr = remoteIPv4Addr
+ nic.tester.contents = payload
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -285,11 +557,21 @@ func TestIPv4Send(t *testing.T) {
}
func TestIPv4Receive(t *testing.T) {
- o := testObject{t: t, v4: true}
- proto := ipv4.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
totalLen := header.IPv4MinimumSize + 30
view := buffer.NewView(totalLen)
ip := header.IPv4(view)
@@ -298,8 +580,8 @@ func TestIPv4Receive(t *testing.T) {
TotalLength: uint16(totalLen),
TTL: 20,
Protocol: 10,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
// Make payload be non-zero.
@@ -308,12 +590,12 @@ func TestIPv4Receive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[header.IPv4MinimumSize:totalLen]
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIPv4Addr
+ nic.tester.dstAddr = localIPv4Addr
+ nic.tester.contents = view[header.IPv4MinimumSize:totalLen]
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -324,8 +606,8 @@ func TestIPv4Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.tester.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
}
}
@@ -349,17 +631,26 @@ func TestIPv4ReceiveControl(t *testing.T) {
{"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8},
}
- r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb")
+ r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb")
if err != nil {
t.Fatal(err)
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
- proto := ipv4.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
view := buffer.NewView(dataOffset + 8)
@@ -371,7 +662,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
TTL: 20,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: "\x0a\x00\x00\xbb",
- DstAddr: localIpv4Addr,
+ DstAddr: localIPv4Addr,
})
// Create the ICMP header.
@@ -389,8 +680,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
TTL: 20,
Protocol: 10,
FragmentOffset: c.fragmentOffset,
- SrcAddr: localIpv4Addr,
- DstAddr: remoteIpv4Addr,
+ SrcAddr: localIPv4Addr,
+ DstAddr: remoteIPv4Addr,
})
// Make payload be non-zero.
@@ -400,27 +691,37 @@ func TestIPv4ReceiveControl(t *testing.T) {
// Give packet to IPv4 endpoint, dispatcher will validate that
// it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIPv4Addr
+ nic.tester.dstAddr = localIPv4Addr
+ nic.tester.contents = view[dataOffset:]
+ nic.tester.typ = c.expectedTyp
+ nic.tester.extra = c.expectedExtra
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ if want := c.expectedCount; nic.tester.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want)
}
})
}
}
func TestIPv4FragmentationReceive(t *testing.T) {
- o := testObject{t: t, v4: true}
- proto := ipv4.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
totalLen := header.IPv4MinimumSize + 24
frag1 := buffer.NewView(totalLen)
@@ -432,8 +733,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
Protocol: 10,
FragmentOffset: 0,
Flags: header.IPv4FlagMoreFragments,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
@@ -448,8 +749,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
TTL: 20,
Protocol: 10,
FragmentOffset: 24,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
@@ -457,12 +758,12 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIPv4Addr
+ nic.tester.dstAddr = localIPv4Addr
+ nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -475,8 +776,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
+ if nic.tester.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls)
}
// Send second segment.
@@ -487,17 +788,26 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.tester.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
}
}
func TestIPv6Send(t *testing.T) {
- o := testObject{t: t}
- proto := ipv6.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t))
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, nil)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
// Allocate and initialize the payload view.
payload := buffer.NewView(100)
for i := 0; i < len(payload); i++ {
@@ -511,12 +821,12 @@ func TestIPv6Send(t *testing.T) {
})
// Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv6Addr
- o.dstAddr = remoteIpv6Addr
- o.contents = payload
+ nic.tester.protocol = 123
+ nic.tester.srcAddr = localIPv6Addr
+ nic.tester.dstAddr = remoteIPv6Addr
+ nic.tester.contents = payload
- r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -530,11 +840,20 @@ func TestIPv6Send(t *testing.T) {
}
func TestIPv6Receive(t *testing.T) {
- o := testObject{t: t}
- proto := ipv6.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
totalLen := header.IPv6MinimumSize + 30
view := buffer.NewView(totalLen)
ip := header.IPv6(view)
@@ -542,8 +861,8 @@ func TestIPv6Receive(t *testing.T) {
PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
NextHeader: 10,
HopLimit: 20,
- SrcAddr: remoteIpv6Addr,
- DstAddr: localIpv6Addr,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: localIPv6Addr,
})
// Make payload be non-zero.
@@ -552,12 +871,12 @@ func TestIPv6Receive(t *testing.T) {
}
// Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[header.IPv6MinimumSize:totalLen]
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIPv6Addr
+ nic.tester.dstAddr = localIPv6Addr
+ nic.tester.contents = view[header.IPv6MinimumSize:totalLen]
- r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -569,8 +888,8 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.tester.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
}
}
@@ -601,7 +920,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
{"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
}
r, err := buildIPv6Route(
- localIpv6Addr,
+ localIPv6Addr,
"\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
)
if err != nil {
@@ -609,11 +928,20 @@ func TestIPv6ReceiveControl(t *testing.T) {
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
- proto := ipv6.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
if c.fragmentOffset != nil {
dataOffset += header.IPv6FragmentHeaderSize
@@ -627,7 +955,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: 20,
SrcAddr: outerSrcAddr,
- DstAddr: localIpv6Addr,
+ DstAddr: localIPv6Addr,
})
// Create the ICMP header.
@@ -643,8 +971,8 @@ func TestIPv6ReceiveControl(t *testing.T) {
PayloadLength: 100,
NextHeader: 10,
HopLimit: 20,
- SrcAddr: localIpv6Addr,
- DstAddr: remoteIpv6Addr,
+ SrcAddr: localIPv6Addr,
+ DstAddr: remoteIPv6Addr,
})
// Build the fragmentation header if needed.
@@ -666,19 +994,19 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Give packet to IPv6 endpoint, dispatcher will validate that
// it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIPv6Addr
+ nic.tester.dstAddr = localIPv6Addr
+ nic.tester.contents = view[dataOffset:]
+ nic.tester.typ = c.expectedTyp
+ nic.tester.extra = c.expectedExtra
// Set ICMPv6 checksum.
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ if want := c.expectedCount; nic.tester.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want)
}
})
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index d142b4ffa..ee2c23e91 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -10,9 +10,11 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
@@ -26,11 +28,14 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index b5659a36b..a8985ff5d 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,6 +15,9 @@
package ipv4
import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -39,7 +42,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match an address we own.
src := hdr.SourceAddress()
- if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -105,11 +108,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
// source address MUST be one of its own IP addresses (but not a broadcast
// or multicast address).
localAddr := r.LocalAddress
- if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) {
+ if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) {
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
@@ -131,7 +134,10 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: dataVV,
})
-
+ // TODO(gvisor.dev/issue/3810): When adding protocol numbers into the header
+ // information we will have to change this code to handle the ICMP header
+ // no longer being in the data buffer.
+ replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
// Send out the reply packet.
sent := stats.ICMP.V4PacketsSent
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
@@ -193,3 +199,177 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
received.Invalid.Increment()
}
}
+
+// ======= ICMP Error packet generation =========
+
+// icmpReason is a marker interface for IPv4 specific ICMP errors.
+type icmpReason interface {
+ isICMPReason()
+}
+
+// icmpReasonPortUnreachable is an error where the transport protocol has no
+// listener and no alternative means to inform the sender.
+type icmpReasonPortUnreachable struct{}
+
+func (*icmpReasonPortUnreachable) isICMPReason() {}
+
+// icmpReasonProtoUnreachable is an error where the transport protocol is
+// not supported.
+type icmpReasonProtoUnreachable struct{}
+
+func (*icmpReasonProtoUnreachable) isICMPReason() {}
+
+// returnError takes an error descriptor and generates the appropriate ICMP
+// error packet for IPv4 and sends it back to the remote device that sent
+// the problematic packet. It incorporates as much of that packet as
+// possible as well as any error metadata as is available. returnError
+// expects pkt to hold a valid IPv4 packet as per the wire format.
+func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ sent := r.Stats().ICMP.V4PacketsSent
+ if !r.Stack().AllowICMPMessage() {
+ sent.RateLimited.Increment()
+ return nil
+ }
+
+ // We check we are responding only when we are allowed to.
+ // See RFC 1812 section 4.3.2.7 (shown below).
+ //
+ // =========
+ // 4.3.2.7 When Not to Send ICMP Errors
+ //
+ // An ICMP error message MUST NOT be sent as the result of receiving:
+ //
+ // o An ICMP error message, or
+ //
+ // o A packet which fails the IP header validation tests described in
+ // Section [5.2.2] (except where that section specifically permits
+ // the sending of an ICMP error message), or
+ //
+ // o A packet destined to an IP broadcast or IP multicast address, or
+ //
+ // o A packet sent as a Link Layer broadcast or multicast, or
+ //
+ // o Any fragment of a datagram other then the first fragment (i.e., a
+ // packet for which the fragment offset in the IP header is nonzero).
+ //
+ // TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in
+ // response to a non-initial fragment, but it currently can not happen.
+
+ if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any {
+ return nil
+ }
+
+ networkHeader := pkt.NetworkHeader().View()
+ transportHeader := pkt.TransportHeader().View()
+
+ // Don't respond to icmp error packets.
+ if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) {
+ // TODO(gvisor.dev/issue/3810):
+ // Unfortunately the current stack pretty much always has ICMPv4 headers
+ // in the Data section of the packet but there is no guarantee that is the
+ // case. If this is the case grab the header to make it like all other
+ // packet types. When this is cleaned up the Consume should be removed.
+ if transportHeader.IsEmpty() {
+ var ok bool
+ transportHeader, ok = pkt.TransportHeader().Consume(header.ICMPv4MinimumSize)
+ if !ok {
+ return nil
+ }
+ } else if transportHeader.Size() < header.ICMPv4MinimumSize {
+ return nil
+ }
+ // We need to decide to explicitly name the packets we can respond to or
+ // the ones we can not respond to. The decision is somewhat arbitrary and
+ // if problems arise this could be reversed. It was judged less of a breach
+ // of protocol to not respond to unknown non-error packets than to respond
+ // to unknown error packets so we take the first approach.
+ switch header.ICMPv4(transportHeader).Type() {
+ case
+ header.ICMPv4EchoReply,
+ header.ICMPv4Echo,
+ header.ICMPv4Timestamp,
+ header.ICMPv4TimestampReply,
+ header.ICMPv4InfoRequest,
+ header.ICMPv4InfoReply:
+ default:
+ // Assume any type we don't know about may be an error type.
+ return nil
+ }
+ }
+
+ // Now work out how much of the triggering packet we should return.
+ // As per RFC 1812 Section 4.3.2.3
+ //
+ // ICMP datagram SHOULD contain as much of the original
+ // datagram as possible without the length of the ICMP
+ // datagram exceeding 576 bytes.
+ //
+ // NOTE: The above RFC referenced is different from the original
+ // recommendation in RFC 1122 and RFC 792 where it mentioned that at
+ // least 8 bytes of the payload must be included. Today linux and other
+ // systems implement the RFC 1812 definition and not the original
+ // requirement. We treat 8 bytes as the minimum but will try send more.
+ mtu := int(r.MTU())
+ if mtu > header.IPv4MinimumProcessableDatagramSize {
+ mtu = header.IPv4MinimumProcessableDatagramSize
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ available := int(mtu) - headerLen
+
+ if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize {
+ return nil
+ }
+
+ payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+
+ // The buffers used by pkt may be used elsewhere in the system.
+ // For example, an AF_RAW or AF_PACKET socket may use what the transport
+ // protocol considers an unreachable destination. Thus we deep copy pkt to
+ // prevent multiple ownership and SR errors. The new copy is a vectorized
+ // view with the entire incoming IP packet reassembled and truncated as
+ // required. This is now the payload of the new ICMP packet and no longer
+ // considered a packet in its own right.
+ newHeader := append(buffer.View(nil), networkHeader...)
+ newHeader = append(newHeader, transportHeader...)
+ payload := newHeader.ToVectorisedView()
+ payload.AppendView(pkt.Data.ToView())
+ payload.CapLength(payloadLen)
+
+ icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: headerLen,
+ Data: payload,
+ })
+
+ icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
+
+ icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
+ switch reason.(type) {
+ case *icmpReasonPortUnreachable:
+ icmpHdr.SetCode(header.ICMPv4PortUnreachable)
+ case *icmpReasonProtoUnreachable:
+ icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
+ counter := sent.DstUnreachable
+
+ if err := r.WritePacket(
+ nil, /* gso */
+ stack.NetworkHeaderParams{
+ Protocol: header.ICMPv4ProtocolNumber,
+ TTL: r.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ },
+ icmpPkt,
+ ); err != nil {
+ sent.Dropped.Increment()
+ return err
+ }
+ counter.Increment()
+ return nil
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index fa4ae2012..1f6e14c3f 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -12,20 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package ipv4 contains the implementation of the ipv4 network protocol. To use
-// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv4.NewProtocol() as one of the network
-// protocols when calling stack.New(). Then endpoints can be created by passing
-// ipv4.ProtocolNumber as the network protocol number when calling
-// Stack.NewEndpoint().
+// Package ipv4 contains the implementation of the ipv4 network protocol.
package ipv4
import (
+ "fmt"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -50,22 +48,115 @@ const (
fragmentblockSize = 8
)
+var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix()
+
+var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- nicID tcpip.NICID
+ nic stack.NetworkInterface
linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
protocol *protocol
- stack *stack.Stack
+
+ // enabled is set to 1 when the enpoint is enabled and 0 when it is
+ // disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ mu struct {
+ sync.RWMutex
+
+ addressableEndpointState stack.AddressableEndpointState
+ }
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
- return &endpoint{
- nicID: nicID,
- linkEP: linkEP,
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &endpoint{
+ nic: nic,
+ linkEP: nic.LinkEndpoint(),
dispatcher: dispatcher,
protocol: p,
- stack: st,
+ }
+ e.mu.addressableEndpointState.Init(e)
+ return e
+}
+
+// Enable implements stack.NetworkEndpoint.
+func (e *endpoint) Enable() *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // If the NIC is not enabled, the endpoint can't do anything meaningful so
+ // don't enable the endpoint.
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ // If the endpoint is already enabled, there is nothing for it to do.
+ if !e.setEnabled(true) {
+ return nil
+ }
+
+ // Create an endpoint to receive broadcast packets on this interface.
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ if err != nil {
+ return err
+ }
+ // We have no need for the address endpoint.
+ ep.DecRef()
+
+ // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
+ // multicast group. Note, the IANA calls the all-hosts multicast group the
+ // all-systems multicast group.
+ _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems)
+ return err
+}
+
+// Enabled implements stack.NetworkEndpoint.
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
+}
+
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
+}
+
+// setEnabled sets the enabled status for the endpoint.
+//
+// Returns true if the enabled status was updated.
+func (e *endpoint) setEnabled(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&e.enabled, 1) == 0
+ }
+ return atomic.SwapUint32(&e.enabled, 0) == 1
+}
+
+// Disable implements stack.NetworkEndpoint.
+func (e *endpoint) Disable() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.disableLocked()
+}
+
+func (e *endpoint) disableLocked() {
+ if !e.setEnabled(false) {
+ return
+ }
+
+ // The endpoint may have already left the multicast group.
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
+ }
+
+ // The address may have already been removed.
+ if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
}
}
@@ -80,16 +171,6 @@ func (e *endpoint) MTU() uint32 {
return calculateMTU(e.linkEP.MTU())
}
-// Capabilities implements stack.NetworkEndpoint.Capabilities.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
-}
-
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
@@ -197,6 +278,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
// Send out the fragment.
if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(n - i))
return err
}
r.Stats().IP.PacketsSent.Increment()
@@ -231,21 +313,24 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesOutputDropped.Increment()
return nil
}
- // If the packet is manipulated as per NAT Ouput rules, handle packet
- // based on destination address and do not send the packet to link layer.
- // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than
- // only NATted packets, but removing this check short circuits broadcasts
- // before they are sent out to other hosts.
+ // If the packet is manipulated as per NAT Output rules, handle packet
+ // based on destination address and do not send the packet to link
+ // layer.
+ //
+ // TODO(gvisor.dev/issue/170): We should do this for every
+ // packet, rather than only NATted packets, but removing this check
+ // short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
@@ -265,6 +350,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt)
}
if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
r.Stats().IP.PacketsSent.Increment()
@@ -285,20 +371,24 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
pkt = pkt.Next()
}
- nicName := e.stack.FindNICNameFromID(e.NICID())
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ }
return n, err
}
+ r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
- // Slow Path as we are dropping some packets in the batch degrade to
+ // Slow path as we are dropping some packets in the batch degrade to
// emitting one packet at a time.
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
@@ -307,7 +397,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -318,12 +408,16 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, err
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
+ // Dropped packets aren't errors, so include them in
+ // the return value.
+ return n + len(dropped), err
}
n++
}
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, nil
+ // Dropped packets aren't errors, so include them in the return value.
+ return n + len(dropped), nil
}
// WriteHeaderIncludedPacket writes a packet already containing a network
@@ -373,25 +467,42 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return nil
}
+ if err := e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
r.Stats().IP.PacketsSent.Increment()
-
- return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ return nil
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ if !e.isEnabled() {
+ return
+ }
+
h := header.IPv4(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
+ // As per RFC 1122 section 3.2.1.3:
+ // When a host sends any datagram, the IP source address MUST
+ // be one of its own IP addresses (but not a broadcast or
+ // multicast address).
+ if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) {
+ r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesInputDropped.Increment()
return
}
@@ -404,11 +515,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
return
}
// The packet is a fragment, let's try to reassemble it.
- last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1
- // Drop the packet if the fragmentOffset is incorrect. i.e the
- // combination of fragmentOffset and pkt.Data.size() causes a
- // wrap around resulting in last being less than the offset.
- if last < h.FragmentOffset() {
+ start := h.FragmentOffset()
+ // Drop the fragment if the size of the reassembled payload would exceed the
+ // maximum payload size.
+ //
+ // Note that this addition doesn't overflow even on 32bit architecture
+ // because pkt.Data.Size() should not exceed 65535 (the max IP datagram
+ // size). Otherwise the packet would've been rejected as invalid before
+ // reaching here.
+ if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
return
@@ -425,8 +540,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ID: uint32(h.ID()),
Protocol: proto,
},
- h.FragmentOffset(),
- last,
+ start,
+ start+uint16(pkt.Data.Size())-1,
h.More(),
proto,
pkt.Data,
@@ -440,27 +555,165 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
return
}
}
+
+ r.Stats().IP.PacketsDelivered.Increment()
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
+ // TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport
+ // headers, the setting of the transport number here should be
+ // unnecessary and removed.
+ pkt.TransportProtocolNumber = p
e.handleICMP(r, pkt)
return
}
- r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, pkt)
+
+ switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
+ case stack.TransportPacketHandled:
+ case stack.TransportPacketDestinationPortUnreachable:
+ // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
+ // Unreachable messages with code:
+ // 3 (Port Unreachable), when the designated transport protocol
+ // (e.g., UDP) is unable to demultiplex the datagram but has no
+ // protocol mechanism to inform the sender.
+ _ = returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ case stack.TransportPacketProtocolUnreachable:
+ // As per RFC: 1122 Section 3.2.2.1
+ // A host SHOULD generate Destination Unreachable messages with code:
+ // 2 (Protocol Unreachable), when the designated transport protocol
+ // is not supported
+ _ = returnError(r, &icmpReasonProtoUnreachable{}, pkt)
+ default:
+ panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
+ }
}
// Close cleans up resources associated with the endpoint.
-func (e *endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.disableLocked()
+ e.mu.addressableEndpointState.Cleanup()
+}
+
+// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+}
+
+// RemovePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.RemovePermanentAddress(addr)
+}
+
+// MainAddress implements stack.AddressableEndpoint.
+func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.MainAddress()
+}
+
+// AcquireAssignedAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ loopback := e.nic.IsLoopback()
+ addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool {
+ subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ // IPv4 has a notion of a subnet broadcast address and considers the
+ // loopback interface bound to an address's whole subnet (on linux).
+ return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr))
+ })
+ if addressEndpoint != nil {
+ return addressEndpoint
+ }
+
+ if !allowTemp {
+ return nil
+ }
+
+ addr := localAddr.WithPrefix()
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB)
+ if err != nil {
+ // AddAddress only returns an error if the address is already assigned,
+ // but we just checked above if the address exists so we expect no error.
+ panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err))
+ }
+ return addressEndpoint
+}
+
+// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
+}
+
+// PrimaryAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PrimaryAddresses()
+}
+
+// PermanentAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PermanentAddresses()
+}
+
+// JoinGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ if !header.IsV4MulticastAddress(addr) {
+ return false, tcpip.ErrBadAddress
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.JoinGroup(addr)
+}
+
+// LeaveGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.LeaveGroup(addr)
+}
+
+// IsInGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.IsInGroup(addr)
+}
+
+var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
+var _ stack.NetworkProtocol = (*protocol)(nil)
type protocol struct {
- ids []uint32
- hashIV uint32
+ stack *stack.Stack
// defaultTTL is the current default TTL for the protocol. Only the
- // uint8 portion of it is meaningful and it must be accessed
- // atomically.
+ // uint8 portion of it is meaningful.
+ //
+ // Must be accessed using atomic operations.
defaultTTL uint32
+ // forwarding is set to 1 when the protocol has forwarding enabled and 0
+ // when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
+ ids []uint32
+ hashIV uint32
+
fragmentation *fragmentation.Fragmentation
}
@@ -523,37 +776,28 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// Parse implements stack.TransportProtocol.Parse.
+// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
- hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
+ if ok := parse.IPv4(pkt); !ok {
return 0, false, false
}
- ipHdr := header.IPv4(hdr)
- // Header may have options, determine the true header length.
- headerLen := int(ipHdr.HeaderLength())
- if headerLen < header.IPv4MinimumSize {
- // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in
- // order for the packet to be valid. Figure out if we want to reject this
- // case.
- headerLen = header.IPv4MinimumSize
- }
- hdr, ok = pkt.NetworkHeader().Consume(headerLen)
- if !ok {
- return 0, false, false
- }
- ipHdr = header.IPv4(hdr)
+ ipHdr := header.IPv4(pkt.NetworkHeader().View())
+ return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
+}
- // If this is a fragment, don't bother parsing the transport header.
- parseTransportHeader := true
- if ipHdr.More() || ipHdr.FragmentOffset() != 0 {
- parseTransportHeader = false
- }
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) Forwarding() bool {
+ return uint8(atomic.LoadUint32(&p.forwarding)) == 1
+}
- pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
- pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
- return ipHdr.TransportProtocol(), parseTransportHeader, true
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) SetForwarding(v bool) {
+ if v {
+ atomic.StoreUint32(&p.forwarding, 1)
+ } else {
+ atomic.StoreUint32(&p.forwarding, 0)
+ }
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
@@ -577,7 +821,7 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV ui
}
// NewProtocol returns an IPv4 network protocol.
-func NewProtocol() stack.NetworkProtocol {
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
@@ -588,9 +832,10 @@ func NewProtocol() stack.NetworkProtocol {
hashIV := r[buckets]
return &protocol{
+ stack: s,
ids: ids,
hashIV: hashIV,
defaultTTL: DefaultTTL,
- fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()),
}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 197e3bc51..33cd5a3eb 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -17,18 +17,21 @@ package ipv4_test
import (
"bytes"
"encoding/hex"
- "fmt"
- "math/rand"
+ "math"
+ "net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -36,8 +39,8 @@ import (
func TestExcludeBroadcast(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
const defaultMTU = 65536
@@ -92,29 +95,244 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
-// makeRandPkt generates a randomize packet. hdrLength indicates how much
-// data should already be in the header before WritePacket. extraLength
-// indicates how much extra space should be in the header. The payload is made
-// from many Views of the sizes listed in viewSizes.
-func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer {
- var views []buffer.View
- totalLength := 0
- for _, s := range viewSizes {
- newView := buffer.NewView(s)
- rand.Read(newView)
- views = append(views, newView)
- totalLength += s
+// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and
+// checks the response.
+func TestIPv4Sanity(t *testing.T) {
+ const (
+ defaultMTU = header.IPv6MinimumMTU
+ ttl = 255
+ nicID = 1
+ randomSequence = 123
+ randomIdent = 42
+ )
+ var (
+ ipv4Addr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
+ PrefixLen: 24,
+ }
+ remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4())
+ )
+
+ tests := []struct {
+ name string
+ headerLength uint8
+ minTotalLength uint16
+ transportProtocol uint8
+ TTL uint8
+ shouldFail bool
+ expectICMP bool
+ ICMPType header.ICMPv4Type
+ ICMPCode header.ICMPv4Code
+ }{
+ {
+ name: "valid",
+ headerLength: header.IPv4MinimumSize,
+ minTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: false,
+ },
+ // The TTL tests check that we are not rejecting an incoming packet
+ // with a zero or one TTL, which has been a point of confusion in the
+ // past as RFC 791 says: "If this field contains the value zero, then the
+ // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies
+ // for the case of the destination host, stating as follows.
+ //
+ // A host MUST NOT send a datagram with a Time-to-Live (TTL)
+ // value of zero.
+ //
+ // A host MUST NOT discard a datagram just because it was
+ // received with TTL less than 2.
+ {
+ name: "zero TTL",
+ headerLength: header.IPv4MinimumSize,
+ minTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 0,
+ shouldFail: false,
+ },
+ {
+ name: "one TTL",
+ headerLength: header.IPv4MinimumSize,
+ minTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 1,
+ shouldFail: false,
+ },
+ {
+ name: "bad header length",
+ headerLength: header.IPv4MinimumSize - 1,
+ minTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (0)",
+ headerLength: header.IPv4MinimumSize,
+ minTotalLength: 0,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (ip - 1)",
+ headerLength: header.IPv4MinimumSize,
+ minTotalLength: uint16(header.IPv4MinimumSize - 1),
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (ip + icmp - 1)",
+ headerLength: header.IPv4MinimumSize,
+ minTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1),
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad protocol",
+ headerLength: header.IPv4MinimumSize,
+ minTotalLength: defaultMTU,
+ transportProtocol: 99,
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: true,
+ ICMPType: header.ICMPv4DstUnreachable,
+ ICMPCode: header.ICMPv4ProtoUnreachable,
+ },
}
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: hdrLength + extraLength,
- Data: buffer.NewVectorisedView(totalLength, views),
- })
- pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
- if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil {
- panic(fmt.Sprintf("rand.Read: %s", err))
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ }
+
+ // Default routes for IPv4 so ICMP can find a route to the remote
+ // node when attempting to send the ICMP Echo Reply.
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ ipHeaderLength := header.IPv4MinimumSize
+ totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
+ hdr := buffer.NewPrependable(int(totalLen))
+ icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+
+ // Specify ident/seq to make sure we get the same in the response.
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv4Echo)
+ icmp.SetCode(header.ICMPv4UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(^header.Checksum(icmp, 0))
+ ip := header.IPv4(hdr.Prepend(ipHeaderLength))
+ if test.minTotalLength < totalLen {
+ totalLen = test.minTotalLength
+ }
+ ip.Encode(&header.IPv4Fields{
+ IHL: test.headerLength,
+ TotalLength: totalLen,
+ Protocol: test.transportProtocol,
+ TTL: test.TTL,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: ipv4Addr.Address,
+ })
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+ reply, ok := e.Read()
+ if !ok {
+ if test.shouldFail {
+ if test.expectICMP {
+ t.Fatal("expected ICMP error response missing")
+ }
+ return // Expected silent failure.
+ }
+ t.Fatal("expected ICMP echo reply missing")
+ }
+
+ // Check the route that brought the packet to us.
+ if reply.Route.LocalAddress != ipv4Addr.Address {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address)
+ }
+ if reply.Route.RemoteAddress != remoteIPv4Addr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr)
+ }
+
+ // Make sure it's all in one buffer.
+ vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views())
+ replyIPHeader := header.IPv4(vv.ToView())
+
+ // At this stage we only know it's an IP header so verify that much.
+ checker.IPv4(t, replyIPHeader,
+ checker.SrcAddr(ipv4Addr.Address),
+ checker.DstAddr(remoteIPv4Addr),
+ )
+
+ // All expected responses are ICMP packets.
+ if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want {
+ t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want)
+ }
+ replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
+
+ // Sanity check the response.
+ switch replyICMPHeader.Type() {
+ case header.ICMPv4DstUnreachable:
+ checker.IPv4(t, replyIPHeader,
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
+ checker.IPv4HeaderLength(ipHeaderLength),
+ checker.ICMPv4(
+ checker.ICMPv4Code(test.ICMPCode),
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
+ if !test.shouldFail || !test.expectICMP {
+ t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d",
+ header.ICMPv4DstUnreachable, replyICMPHeader.Code())
+ }
+ return
+ case header.ICMPv4EchoReply:
+ checker.IPv4(t, replyIPHeader,
+ checker.IPv4HeaderLength(ipHeaderLength),
+ checker.IPFullLength(uint16(requestPkt.Size())),
+ checker.ICMPv4(
+ checker.ICMPv4Code(test.ICMPCode),
+ checker.ICMPv4Seq(randomSequence),
+ checker.ICMPv4Ident(randomIdent),
+ checker.ICMPv4Checksum(),
+ ),
+ )
+ if test.shouldFail {
+ t.Fatalf("unexpected Echo Reply packet\n")
+ }
+ default:
+ t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d",
+ replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable)
+ }
+ })
}
- return pkt
}
// comparePayloads compared the contents of all the packets against the contents
@@ -186,101 +404,19 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
}
}
-type errorChannel struct {
- *channel.Endpoint
- Ch chan *stack.PacketBuffer
- packetCollectorErrors []*tcpip.Error
-}
-
-// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
-// will return successive errors from packetCollectorErrors until the list is
-// empty and then return nil each time.
-func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
- return &errorChannel{
- Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan *stack.PacketBuffer, size),
- packetCollectorErrors: packetCollectorErrors,
- }
-}
-
-// Drain removes all outbound packets from the channel and counts them.
-func (e *errorChannel) Drain() int {
- c := 0
- for {
- select {
- case <-e.Ch:
- c++
- default:
- return c
- }
- }
-}
-
-// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- select {
- case e.Ch <- pkt:
- default:
- }
-
- nextError := (*tcpip.Error)(nil)
- if len(e.packetCollectorErrors) > 0 {
- nextError = e.packetCollectorErrors[0]
- e.packetCollectorErrors = e.packetCollectorErrors[1:]
- }
- return nextError
-}
-
-type context struct {
- stack.Route
- linkEP *errorChannel
-}
-
-func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
- // Make the packet and write it.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- })
- ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
- s.CreateNIC(1, ep)
- const (
- src = "\x10\x00\x00\x01"
- dst = "\x10\x00\x00\x02"
- )
- s.AddAddress(1, ipv4.ProtocolNumber, src)
- {
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }})
- }
- r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("s.FindRoute got %v, want %v", err, nil)
- }
- return context{
- Route: r,
- linkEP: ep,
- }
-}
-
func TestFragmentation(t *testing.T) {
var manyPayloadViewsSizes [1000]int
for i := range manyPayloadViewsSizes {
manyPayloadViewsSizes[i] = 7
}
fragTests := []struct {
- description string
- mtu uint32
- gso *stack.GSO
- hdrLength int
- extraLength int
- payloadViewsSizes []int
- expectedFrags int
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transportHeaderLength int
+ extraHeaderReserveLength int
+ payloadViewsSizes []int
+ expectedFrags int
}{
{"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
{"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
@@ -295,36 +431,29 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
source := pkt.Clone()
- c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
}, pkt)
if err != nil {
- t.Errorf("err got %v, want %v", err, nil)
+ t.Errorf("got err = %s, want = nil", err)
}
- var results []*stack.PacketBuffer
- L:
- for {
- select {
- case pi := <-c.linkEP.Ch:
- results = append(results, pi)
- default:
- break L
- }
+ if got := len(ep.WrittenPackets); got != ft.expectedFrags {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags)
}
-
- if got, want := len(results), ft.expectedFrags; got != want {
- t.Errorf("len(result) got %d, want %d", got, want)
+ if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
+ t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
- if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want {
- t.Errorf("no errors yet len(result) got %d, want %d", got, want)
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
}
- compareFragments(t, results, source, ft.mtu)
+ compareFragments(t, ep.WrittenPackets, source, ft.mtu)
})
}
}
@@ -335,152 +464,376 @@ func TestFragmentationErrors(t *testing.T) {
fragTests := []struct {
description string
mtu uint32
- hdrLength int
+ transportHeaderLength int
payloadViewsSizes []int
- packetCollectorErrors []*tcpip.Error
+ err *tcpip.Error
+ allowPackets int
+ fragmentCount int
}{
- {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
- {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
- {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
- {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+ {
+ description: "NoFrag",
+ mtu: 2000,
+ transportHeaderLength: 0,
+ payloadViewsSizes: []int{1000},
+ err: tcpip.ErrAborted,
+ allowPackets: 0,
+ fragmentCount: 1,
+ },
+ {
+ description: "ErrorOnFirstFrag",
+ mtu: 500,
+ transportHeaderLength: 0,
+ payloadViewsSizes: []int{1000},
+ err: tcpip.ErrAborted,
+ allowPackets: 0,
+ fragmentCount: 3,
+ },
+ {
+ description: "ErrorOnSecondFrag",
+ mtu: 500,
+ transportHeaderLength: 0,
+ payloadViewsSizes: []int{1000},
+ err: tcpip.ErrAborted,
+ allowPackets: 1,
+ fragmentCount: 3,
+ },
+ {
+ description: "ErrorOnFirstFragMTUSmallerThanHeader",
+ mtu: 500,
+ transportHeaderLength: 1000,
+ payloadViewsSizes: []int{500},
+ err: tcpip.ErrAborted,
+ allowPackets: 0,
+ fragmentCount: 4,
+ },
}
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
- c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets)
+ r := buildRoute(t, ep)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+ err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
}, pkt)
- for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
- if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
- t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
- }
+ if err != ft.err {
+ t.Errorf("got WritePacket() = %s, want = %s", err, ft.err)
}
- // We only need to check that last error because all the ones before are
- // nil.
- if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
- t.Errorf("err got %v, want %v", got, want)
+ if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
- if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
- t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
+ if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want)
}
})
}
}
func TestInvalidFragments(t *testing.T) {
+ const (
+ nicID = 1
+ linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ addr1 = "\x0a\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x02"
+ tos = 0
+ ident = 1
+ ttl = 48
+ protocol = 6
+ )
+
+ payloadGen := func(payloadLen int) []byte {
+ payload := make([]byte, payloadLen)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = 0x30
+ }
+ return payload
+ }
+
+ type fragmentData struct {
+ ipv4fields header.IPv4Fields
+ payload []byte
+ autoChecksum bool // if true, the Checksum field will be overwritten.
+ }
+
// These packets have both IHL and TotalLength set to 0.
- testCases := []struct {
+ tests := []struct {
name string
- packets [][]byte
+ fragments []fragmentData
wantMalformedIPPackets uint64
wantMalformedFragments uint64
}{
{
- "ihl_totallen_zero_valid_frag_offset",
- [][]byte{
- {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- },
- 1,
- 0,
- },
- {
- "ihl_totallen_zero_invalid_frag_offset",
- [][]byte{
- {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "IHL and TotalLength zero, FragmentOffset non-zero",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: 0,
+ TOS: tos,
+ TotalLength: 0,
+ ID: ident,
+ Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments,
+ FragmentOffset: 59776,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(12),
+ autoChecksum: true,
+ },
},
- 1,
- 0,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 0,
},
{
- // Total Length of 37(20 bytes IP header + 17 bytes of
- // payload)
- // Frag Offset of 0x1ffe = 8190*8 = 65520
- // Leading to the fragment end to be past 65535.
- "ihl_totallen_valid_invalid_frag_offset_1",
- [][]byte{
- {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "IHL and TotalLength zero, FragmentOffset zero",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: 0,
+ TOS: tos,
+ TotalLength: 0,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(12),
+ autoChecksum: true,
+ },
},
- 1,
- 1,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 0,
},
- // The following 3 tests were found by running a fuzzer and were
- // triggering a panic in the IPv4 reassembler code.
{
- "ihl_less_than_ipv4_minimum_size_1",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ // Payload 17 octets and Fragment offset 65520
+ // Leading to the fragment end to be past 65536.
+ name: "fragment ends past 65536",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 17,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 65520,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(17),
+ autoChecksum: true,
+ },
},
- 2,
- 0,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
},
{
- "ihl_less_than_ipv4_minimum_size_2",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ // Payload 16 octets and fragment offset 65520
+ // Leading to the fragment end to be exactly 65536.
+ name: "fragment ends exactly at 65536",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 65520,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(16),
+ autoChecksum: true,
+ },
},
- 2,
- 0,
+ wantMalformedIPPackets: 0,
+ wantMalformedFragments: 0,
},
{
- "ihl_less_than_ipv4_minimum_size_3",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "IHL less than IPv4 minimum size",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 12,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 28,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 1944,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 12,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize - 12,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
},
- 2,
- 0,
+ wantMalformedIPPackets: 2,
+ wantMalformedFragments: 0,
},
{
- "fragment_with_short_total_len_extra_payload",
- [][]byte{
- {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "fragment with short TotalLength and extra payload",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize + 4,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 28,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 28816,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize + 4,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 4,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
},
- 1,
- 1,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
},
{
- "multiple_fragments_with_more_fragments_set_to_false",
- [][]byte{
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ name: "multiple fragments with More Fragments flag set to false",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 128,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 8,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
},
- 1,
- 1,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
},
}
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- const nicID tcpip.NICID = 42
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{
- ipv4.NewProtocol(),
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocol,
},
})
+ e := channel.New(0, 1500, linkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ }
+
+ for _, f := range test.fragments {
+ pktSize := header.IPv4MinimumSize + len(f.payload)
+ hdr := buffer.NewPrependable(pktSize)
- var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30})
- var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31})
- ep := channel.New(10, 1500, linkAddr)
- s.CreateNIC(nicID, sniffer.New(ep))
+ ip := header.IPv4(hdr.Prepend(pktSize))
+ ip.Encode(&f.ipv4fields)
+ copy(ip[header.IPv4MinimumSize:], f.payload)
+
+ if f.autoChecksum {
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ }
- for _, pkt := range tc.packets {
- ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
+ vv := hdr.View().ToVectorisedView()
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
}))
}
- if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want {
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want)
}
- if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want {
+ if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want)
}
})
@@ -534,6 +887,9 @@ func TestReceiveFragments(t *testing.T) {
// the fragment block size of 8 (RFC 791 section 3.1 page 14).
ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2)
udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:]
+ // Used to test the max reassembled payload length (65,535 octets).
+ ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2)
+ udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:]
type fragmentData struct {
srcAddr tcpip.Address
@@ -827,14 +1183,36 @@ func TestReceiveFragments(t *testing.T) {
},
expectedPayloads: nil,
},
+ {
+ name: "Two fragments reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload4Addr1ToAddr2[:65512],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 65512,
+ payload: ipv4Payload4Addr1ToAddr2[65512:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
+ },
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Setup a stack and endpoint.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00"))
if err := s.CreateNIC(nicID, e); err != nil {
@@ -906,3 +1284,189 @@ func TestReceiveFragments(t *testing.T) {
})
}
}
+
+func TestWriteStats(t *testing.T) {
+ const nPackets = 3
+ tests := []struct {
+ name string
+ setup func(*testing.T, *stack.Stack)
+ allowPackets int
+ expectSent int
+ expectDropped int
+ expectWritten int
+ }{
+ {
+ name: "Accept all",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets,
+ expectDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Accept all with error",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: nPackets - 1,
+ expectSent: nPackets - 1,
+ expectDropped: 0,
+ expectWritten: nPackets - 1,
+ }, {
+ name: "Drop all",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %s", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule that matches only 1
+ // of the 3 packets.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %s", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectDropped: 1,
+ expectWritten: nPackets,
+ },
+ }
+
+ // Parameterize the tests to run with both WritePacket and WritePackets.
+ writers := []struct {
+ name string
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ }{
+ {
+ name: "WritePacket",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ nWritten := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
+ return nWritten, err
+ }
+ nWritten++
+ }
+ return nWritten, nil
+ },
+ }, {
+ name: "WritePackets",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
+ },
+ },
+ }
+
+ for _, writer := range writers {
+ t.Run(writer.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ rt := buildRoute(t, ep)
+
+ var pkts stack.PacketBufferList
+ for i := 0; i < nPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
+ Data: buffer.NewView(0).ToVectorisedView(),
+ })
+ pkt.TransportHeader().Push(header.UDPMinimumSize)
+ pkts.PushBack(pkt)
+ }
+
+ test.setup(t, rt.Stack())
+
+ nWritten, _ := writer.writePackets(&rt, pkts)
+
+ if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
+ t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ }
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
+ t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ }
+ if nWritten != test.expectWritten {
+ t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ }
+ })
+ }
+ })
+ }
+}
+
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC(1, _) failed: %s", err)
+ }
+ const (
+ src = "\x10\x00\x00\x01"
+ dst = "\x10\x00\x00\x02"
+ )
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
+ t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
+ if err != nil {
+ t.Fatalf("NewSubnet(_, _) failed: %v", err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err)
+ }
+ return rt
+}
+
+// limitedMatcher is an iptables matcher that matches after a certain number of
+// packets are checked against it.
+type limitedMatcher struct {
+ limit int
+}
+
+// Name implements Matcher.Name.
+func (*limitedMatcher) Name() string {
+ return "limitedMatcher"
+}
+
+// Match implements Matcher.Match.
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+ if lm.limit == 0 {
+ return true, false
+ }
+ lm.limit--
+ return false, false
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index bcc64994e..97adbcbd4 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -5,14 +5,18 @@ package(licenses = ["notice"])
go_library(
name = "ipv6",
srcs = [
+ "dhcpv6configurationfromndpra_string.go",
"icmp.go",
"ipv6.go",
+ "ndp.go",
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/stack",
],
@@ -34,6 +38,7 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go
index d199ded6a..09ba133b1 100644
--- a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go
+++ b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go
@@ -14,7 +14,7 @@
// Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT.
-package stack
+package ipv6
import "strconv"
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 2b83c421e..8e9def6b8 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -15,6 +15,8 @@
package ipv6
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -39,7 +41,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Drop packet if it doesn't have the basic IPv6 header or if the
// original source address doesn't match an address we own.
src := hdr.SourceAddress()
- if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -207,14 +209,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- s := r.Stack()
- if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
- // We will only get an error if the NIC is unrecognized, which should not
- // happen. For now, drop this packet.
- //
- // TODO(b/141002840): Handle this better?
- return
- } else if isTentative {
+ if e.hasTentativeAddr(targetAddr) {
// If the target address is tentative and the source of the packet is a
// unicast (specified) address, then the source of the packet is
// attempting to perform address resolution on the target. In this case,
@@ -227,7 +222,20 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// stack know so it can handle such a scenario and do nothing further with
// the NS.
if r.RemoteAddress == header.IPv6Any {
- s.DupTentativeAddrDetected(e.nicID, targetAddr)
+ // We would get an error if the address no longer exists or the address
+ // is no longer tentative (DAD resolved between the call to
+ // hasTentativeAddr and this point). Both of these are valid scenarios:
+ // 1) An address may be removed at any time.
+ // 2) As per RFC 4862 section 5.4, DAD is not a perfect:
+ // "Note that the method for detecting duplicates
+ // is not completely reliable, and it is possible that duplicate
+ // addresses will still exist"
+ //
+ // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
+ // address is detected for an assigned address.
+ if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
+ }
}
// Do not handle neighbor solicitations targeted to an address that is
@@ -240,7 +248,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// section 5.4.3.
// Is the NS targeting us?
- if s.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
+ if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
return
}
@@ -275,7 +283,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
} else if e.nud != nil {
e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr)
}
// ICMPv6 Neighbor Solicit messages are always sent to
@@ -318,6 +326,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()),
})
packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet.SetType(header.ICMPv6NeighborAdvert)
na := header.NDPNeighborAdvert(packet.NDPPayload())
na.SetSolicitedFlag(solicited)
@@ -352,20 +361,26 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// NDP datagrams are very small and ToView() will not incur allocations.
na := header.NDPNeighborAdvert(payload.ToView())
targetAddr := na.TargetAddress()
- s := r.Stack()
-
- if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
- // We will only get an error if the NIC is unrecognized, which should not
- // happen. For now short-circuit this packet.
- //
- // TODO(b/141002840): Handle this better?
- return
- } else if isTentative {
+ if e.hasTentativeAddr(targetAddr) {
// We just got an NA from a node that owns an address we are performing
// DAD on, implying the address is not unique. In this case we let the
// stack know so it can handle such a scenario and do nothing furthur with
// the NDP NA.
- s.DupTentativeAddrDetected(e.nicID, targetAddr)
+ //
+ // We would get an error if the address no longer exists or the address
+ // is no longer tentative (DAD resolved between the call to
+ // hasTentativeAddr and this point). Both of these are valid scenarios:
+ // 1) An address may be removed at any time.
+ // 2) As per RFC 4862 section 5.4, DAD is not a perfect:
+ // "Note that the method for detecting duplicates
+ // is not completely reliable, and it is possible that duplicate
+ // addresses will still exist"
+ //
+ // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
+ // address is detected for an assigned address.
+ if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
+ }
return
}
@@ -395,7 +410,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// address cache with the link address for the target of the message.
if len(targetLinkAddr) != 0 {
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr)
return
}
@@ -423,7 +438,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
@@ -438,6 +453,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
Data: pkt.Data,
})
packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize))
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
copy(packet, icmpHdr)
packet.SetType(header.ICMPv6EchoReply)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
@@ -477,7 +493,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
stack := r.Stack()
// Is the networking stack operating as a router?
- if !stack.Forwarding() {
+ if !stack.Forwarding(ProtocolNumber) {
// ... No, silently drop the packet.
received.RouterOnlyPacketsDroppedByHost.Increment()
return
@@ -566,9 +582,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
}
- // Tell the NIC to handle the RA.
- stack := r.Stack()
- stack.HandleNDPRA(e.nicID, routerAddr, ra)
+ e.mu.Lock()
+ e.mu.ndp.handleRA(routerAddr, ra)
+ e.mu.Unlock()
case header.ICMPv6RedirectMsg:
// TODO(gvisor.dev/issue/2285): Call `e.nud.HandleProbe` after validating
@@ -637,6 +653,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd
ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize,
})
icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
icmpHdr.SetType(header.ICMPv6NeighborSolicit)
copy(icmpHdr[icmpV6OptOffset-len(addr):], addr)
icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr
@@ -665,3 +682,159 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
}
return tcpip.LinkAddress([]byte(nil)), false
}
+
+// ======= ICMP Error packet generation =========
+
+// icmpReason is a marker interface for IPv6 specific ICMP errors.
+type icmpReason interface {
+ isICMPReason()
+}
+
+// icmpReasonParameterProblem is an error during processing of extension headers
+// or the fixed header defined in RFC 4443 section 3.4.
+type icmpReasonParameterProblem struct {
+ code header.ICMPv6Code
+
+ // respondToMulticast indicates that we are sending a packet that falls under
+ // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2:
+ //
+ // (e.3) A packet destined to an IPv6 multicast address. (There are
+ // two exceptions to this rule: (1) the Packet Too Big Message
+ // (Section 3.2) to allow Path MTU discovery to work for IPv6
+ // multicast, and (2) the Parameter Problem Message, Code 2
+ // (Section 3.4) reporting an unrecognized IPv6 option (see
+ // Section 4.2 of [IPv6]) that has the Option Type highest-
+ // order two bits set to 10).
+ respondToMulticast bool
+
+ // pointer is defined in the RFC 4443 setion 3.4 which reads:
+ //
+ // Pointer Identifies the octet offset within the invoking packet
+ // where the error was detected.
+ //
+ // The pointer will point beyond the end of the ICMPv6
+ // packet if the field in error is beyond what can fit
+ // in the maximum size of an ICMPv6 error message.
+ pointer uint32
+}
+
+func (*icmpReasonParameterProblem) isICMPReason() {}
+
+// icmpReasonPortUnreachable is an error where the transport protocol has no
+// listener and no alternative means to inform the sender.
+type icmpReasonPortUnreachable struct{}
+
+func (*icmpReasonPortUnreachable) isICMPReason() {}
+
+// returnError takes an error descriptor and generates the appropriate ICMP
+// error packet for IPv6 and sends it.
+func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ stats := r.Stats().ICMP
+ sent := stats.V6PacketsSent
+ if !r.Stack().AllowICMPMessage() {
+ sent.RateLimited.Increment()
+ return nil
+ }
+
+ // Only send ICMP error if the address is not a multicast v6
+ // address and the source is not the unspecified address.
+ //
+ // There are exceptions to this rule.
+ // See: point e.3) RFC 4443 section-2.4
+ //
+ // (e) An ICMPv6 error message MUST NOT be originated as a result of
+ // receiving the following:
+ //
+ // (e.1) An ICMPv6 error message.
+ //
+ // (e.2) An ICMPv6 redirect message [IPv6-DISC].
+ //
+ // (e.3) A packet destined to an IPv6 multicast address. (There are
+ // two exceptions to this rule: (1) the Packet Too Big Message
+ // (Section 3.2) to allow Path MTU discovery to work for IPv6
+ // multicast, and (2) the Parameter Problem Message, Code 2
+ // (Section 3.4) reporting an unrecognized IPv6 option (see
+ // Section 4.2 of [IPv6]) that has the Option Type highest-
+ // order two bits set to 10).
+ //
+ var allowResponseToMulticast bool
+ if reason, ok := reason.(*icmpReasonParameterProblem); ok {
+ allowResponseToMulticast = reason.respondToMulticast
+ }
+
+ if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any {
+ return nil
+ }
+
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+
+ if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber {
+ // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored.
+ // Unfortunately at this time ICMP Packets do not have a transport
+ // header separated out. It is in the Data part so we need to
+ // separate it out now. We will just pretend it is a minimal length
+ // ICMP packet as we don't really care if any later bits of a
+ // larger ICMP packet are in the header view or in the Data view.
+ transport, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize)
+ if !ok {
+ return nil
+ }
+ typ := header.ICMPv6(transport).Type()
+ if typ.IsErrorType() || typ == header.ICMPv6RedirectMsg {
+ return nil
+ }
+ }
+
+ // As per RFC 4443 section 2.4
+ //
+ // (c) Every ICMPv6 error message (type < 128) MUST include
+ // as much of the IPv6 offending (invoking) packet (the
+ // packet that caused the error) as possible without making
+ // the error message packet exceed the minimum IPv6 MTU
+ // [IPv6].
+ mtu := int(r.MTU())
+ if mtu > header.IPv6MinimumMTU {
+ mtu = header.IPv6MinimumMTU
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
+ available := int(mtu) - headerLen
+ if available < header.IPv6MinimumSize {
+ return nil
+ }
+ payloadLen := network.Size() + transport.Size() + pkt.Data.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+ payload := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ payload.CapLength(payloadLen)
+
+ newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: headerLen,
+ Data: payload,
+ })
+ newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
+
+ icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
+ var counter *tcpip.StatCounter
+ switch reason := reason.(type) {
+ case *icmpReasonParameterProblem:
+ icmpHdr.SetType(header.ICMPv6ParamProblem)
+ icmpHdr.SetCode(reason.code)
+ icmpHdr.SetTypeSpecific(reason.pointer)
+ counter = sent.ParamProblem
+ case *icmpReasonPortUnreachable:
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv6PortUnreachable)
+ counter = sent.DstUnreachable
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data))
+ err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt)
+ if err != nil {
+ sent.Dropped.Increment()
+ return err
+ }
+ counter.Increment()
+ return nil
+}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 8112ed051..31370c1d4 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -75,7 +75,8 @@ type stubDispatcher struct {
stack.TransportDispatcher
}
-func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
+func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition {
+ return stack.TransportPacketHandled
}
type stubLinkAddressCache struct {
@@ -102,6 +103,30 @@ func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.Lin
func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) {
}
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct{}
+
+func (*testInterface) ID() tcpip.NICID {
+ return 0
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (*testInterface) Enabled() bool {
+ return true
+}
+
+func (*testInterface) LinkEndpoint() stack.LinkEndpoint {
+ return nil
+}
+
func TestICMPCounts(t *testing.T) {
tests := []struct {
name string
@@ -120,8 +145,8 @@ func TestICMPCounts(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
UseNeighborCache: test.useNeighborCache,
})
{
@@ -149,9 +174,13 @@ func TestICMPCounts(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
@@ -258,8 +287,8 @@ func TestICMPCounts(t *testing.T) {
func TestICMPCountsWithNeighborCache(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
UseNeighborCache: true,
})
{
@@ -287,9 +316,13 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(0, nil, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{})
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
@@ -423,12 +456,12 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab
func newTestContext(t *testing.T) *testContext {
c := &testContext{
s0: stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
}),
s1: stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
}),
}
@@ -723,12 +756,12 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
UseNeighborCache: test.useNeighborCache,
})
if isRouter {
// Enabling forwarding makes the stack act as a router.
- s.SetForwarding(true)
+ s.SetForwarding(ProtocolNumber, true)
}
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(_, _) = %s", err)
@@ -919,7 +952,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
t.Run(typ.name, func(t *testing.T) {
e := channel.New(10, 1280, linkAddr0)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(_, _) = %s", err)
@@ -1097,7 +1130,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
t.Run(typ.name, func(t *testing.T) {
e := channel.New(10, 1280, linkAddr0)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -1203,7 +1236,10 @@ func TestLinkAddressRequest(t *testing.T) {
}
for _, test := range tests {
- p := NewProtocol()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ p := s.NetworkProtocolInstance(ProtocolNumber)
linkRes, ok := p.(stack.LinkAddressResolver)
if !ok {
t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index af3cd91c6..c8a3e0b34 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,21 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package ipv6 contains the implementation of the ipv6 network protocol. To use
-// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv6.NewProtocol() as one of the network
-// protocols when calling stack.New(). Then endpoints can be created by passing
-// ipv6.ProtocolNumber as the network protocol number when calling
-// Stack.NewEndpoint().
+// Package ipv6 contains the implementation of the ipv6 network protocol.
package ipv6
import (
"fmt"
+ "sort"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -44,14 +42,302 @@ const (
DefaultTTL = 64
)
+var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+var _ stack.NDPEndpoint = (*endpoint)(nil)
+var _ NDPEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- nicID tcpip.NICID
+ nic stack.NetworkInterface
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
nud stack.NUDHandler
dispatcher stack.TransportDispatcher
protocol *protocol
stack *stack.Stack
+
+ // enabled is set to 1 when the endpoint is enabled and 0 when it is
+ // disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ mu struct {
+ sync.RWMutex
+
+ addressableEndpointState stack.AddressableEndpointState
+ ndp ndpState
+ }
+}
+
+// NICNameFromID is a function that returns a stable name for the specified NIC,
+// even if different NIC IDs are used to refer to the same NIC in different
+// program runs. It is used when generating opaque interface identifiers (IIDs).
+// If the NIC was created with a name, it is passed to NICNameFromID.
+//
+// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
+// generated for the same prefix on differnt NICs.
+type NICNameFromID func(tcpip.NICID, string) string
+
+// OpaqueInterfaceIdentifierOptions holds the options related to the generation
+// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
+type OpaqueInterfaceIdentifierOptions struct {
+ // NICNameFromID is a function that returns a stable name for a specified NIC,
+ // even if the NIC ID changes over time.
+ //
+ // Must be specified to generate the opaque IID.
+ NICNameFromID NICNameFromID
+
+ // SecretKey is a pseudo-random number used as the secret key when generating
+ // opaque IIDs as defined by RFC 7217. The key SHOULD be at least
+ // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness
+ // requirements for security as outlined by RFC 4086. SecretKey MUST NOT
+ // change between program runs, unless explicitly changed.
+ //
+ // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey
+ // MUST NOT be modified after Stack is created.
+ //
+ // May be nil, but a nil value is highly discouraged to maintain
+ // some level of randomness between nodes.
+ SecretKey []byte
+}
+
+// InvalidateDefaultRouter implements stack.NDPEndpoint.
+func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.mu.ndp.invalidateDefaultRouter(rtr)
+}
+
+// SetNDPConfigurations implements NDPEndpoint.
+func (e *endpoint) SetNDPConfigurations(c NDPConfigurations) {
+ c.validate()
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.mu.ndp.configs = c
+}
+
+// hasTentativeAddr returns true if addr is tentative on e.
+func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool {
+ e.mu.RLock()
+ addressEndpoint := e.getAddressRLocked(addr)
+ e.mu.RUnlock()
+ return addressEndpoint != nil && addressEndpoint.GetKind() == stack.PermanentTentative
+}
+
+// dupTentativeAddrDetected attempts to inform e that a tentative addr is a
+// duplicate on a link.
+//
+// dupTentativeAddrDetected removes the tentative address if it exists. If the
+// address was generated via SLAAC, an attempt is made to generate a new
+// address.
+func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil {
+ return tcpip.ErrBadAddress
+ }
+
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
+ // attempt will be made to generate a new address for it.
+ if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ return err
+ }
+
+ prefix := addressEndpoint.AddressWithPrefix().Subnet()
+
+ switch t := addressEndpoint.ConfigType(); t {
+ case stack.AddressConfigStatic:
+ case stack.AddressConfigSlaac:
+ e.mu.ndp.regenerateSLAACAddr(prefix)
+ case stack.AddressConfigSlaacTemp:
+ // Do not reset the generation attempts counter for the prefix as the
+ // temporary address is being regenerated in response to a DAD conflict.
+ e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
+ default:
+ panic(fmt.Sprintf("unrecognized address config type = %d", t))
+ }
+
+ return nil
+}
+
+// transitionForwarding transitions the endpoint's forwarding status to
+// forwarding.
+//
+// Must only be called when the forwarding status changes.
+func (e *endpoint) transitionForwarding(forwarding bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if !e.Enabled() {
+ return
+ }
+
+ if forwarding {
+ // When transitioning into an IPv6 router, host-only state (NDP discovered
+ // routers, discovered on-link prefixes, and auto-generated addresses) is
+ // cleaned up/invalidated and NDP router solicitations are stopped.
+ e.mu.ndp.stopSolicitingRouters()
+ e.mu.ndp.cleanupState(true /* hostOnly */)
+ } else {
+ // When transitioning into an IPv6 host, NDP router solicitations are
+ // started.
+ e.mu.ndp.startSolicitingRouters()
+ }
+}
+
+// Enable implements stack.NetworkEndpoint.
+func (e *endpoint) Enable() *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // If the NIC is not enabled, the endpoint can't do anything meaningful so
+ // don't enable the endpoint.
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ // If the endpoint is already enabled, there is nothing for it to do.
+ if !e.setEnabled(true) {
+ return nil
+ }
+
+ // Join the IPv6 All-Nodes Multicast group if the stack is configured to
+ // use IPv6. This is required to ensure that this node properly receives
+ // and responds to the various NDP messages that are destined to the
+ // all-nodes multicast address. An example is the Neighbor Advertisement
+ // when we perform Duplicate Address Detection, or Router Advertisement
+ // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
+ // section 4.2 for more information.
+ //
+ // Also auto-generate an IPv6 link-local address based on the endpoint's
+ // link address if it is configured to do so. Note, each interface is
+ // required to have IPv6 link-local unicast address, as per RFC 4291
+ // section 2.1.
+
+ // Join the All-Nodes multicast group before starting DAD as responses to DAD
+ // (NDP NS) messages may be sent to the All-Nodes multicast group if the
+ // source address of the NDP NS is the unspecified address, as per RFC 4861
+ // section 7.2.4.
+ if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil {
+ return err
+ }
+
+ // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
+ // state.
+ //
+ // Addresses may have aleady completed DAD but in the time since the endpoint
+ // was last enabled, other devices may have acquired the same addresses.
+ var err *tcpip.Error
+ e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ addr := addressEndpoint.AddressWithPrefix().Address
+ if !header.IsV6UnicastAddress(addr) {
+ return true
+ }
+
+ switch addressEndpoint.GetKind() {
+ case stack.Permanent:
+ addressEndpoint.SetKind(stack.PermanentTentative)
+ fallthrough
+ case stack.PermanentTentative:
+ err = e.mu.ndp.startDuplicateAddressDetection(addr, addressEndpoint)
+ return err == nil
+ default:
+ return true
+ }
+ })
+ if err != nil {
+ return err
+ }
+
+ // Do not auto-generate an IPv6 link-local address for loopback devices.
+ if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() {
+ // The valid and preferred lifetime is infinite for the auto-generated
+ // link-local address.
+ e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
+ }
+
+ // If we are operating as a router, then do not solicit routers since we
+ // won't process the RAs anyway.
+ //
+ // Routers do not process Router Advertisements (RA) the same way a host
+ // does. That is, routers do not learn from RAs (e.g. on-link prefixes
+ // and default routers). Therefore, soliciting RAs from other routers on
+ // a link is unnecessary for routers.
+ if !e.protocol.Forwarding() {
+ e.mu.ndp.startSolicitingRouters()
+ }
+
+ return nil
+}
+
+// Enabled implements stack.NetworkEndpoint.
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
+}
+
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
+}
+
+// setEnabled sets the enabled status for the endpoint.
+//
+// Returns true if the enabled status was updated.
+func (e *endpoint) setEnabled(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&e.enabled, 1) == 0
+ }
+ return atomic.SwapUint32(&e.enabled, 0) == 1
+}
+
+// Disable implements stack.NetworkEndpoint.
+func (e *endpoint) Disable() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.disableLocked()
+}
+
+func (e *endpoint) disableLocked() {
+ if !e.setEnabled(false) {
+ return
+ }
+
+ e.mu.ndp.stopSolicitingRouters()
+ e.mu.ndp.cleanupState(false /* hostOnly */)
+ e.stopDADForPermanentAddressesLocked()
+
+ // The endpoint may have already left the multicast group.
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
+ }
+}
+
+// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) stopDADForPermanentAddressesLocked() {
+ // Stop DAD for all the tentative unicast addresses.
+ e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ return true
+ }
+
+ addr := addressEndpoint.AddressWithPrefix().Address
+ if header.IsV6UnicastAddress(addr) {
+ e.mu.ndp.stopDuplicateAddressDetection(addr)
+ }
+
+ return true
+ })
}
// DefaultTTL is the default hop limit for this endpoint.
@@ -65,16 +351,6 @@ func (e *endpoint) MTU() uint32 {
return calculateMTU(e.linkEP.MTU())
}
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
-}
-
-// Capabilities implements stack.NetworkEndpoint.Capabilities.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
@@ -107,6 +383,32 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
+ if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ // iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesOutputDropped.Increment()
+ return nil
+ }
+
+ // If the packet is manipulated as per NAT Output rules, handle packet
+ // based on destination address and do not send the packet to link
+ // layer.
+ //
+ // TODO(gvisor.dev/issue/170): We should do this for every
+ // packet, rather than only NATted packets, but removing this check
+ // short circuits broadcasts before they are sent out to other hosts.
+ if pkt.NatDone {
+ netHeader := header.IPv6(pkt.NetworkHeader().View())
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ ep.HandlePacket(&route, pkt)
+ return nil
+ }
+ }
+
if r.Loop&stack.PacketLoop != 0 {
loopedR := r.MakeLoopedRoute()
@@ -121,8 +423,12 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
r.Stats().IP.PacketsSent.Increment()
- return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt)
+ return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
@@ -138,9 +444,54 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
e.addIPHeader(r, pb, params)
}
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
+ dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ if len(dropped) == 0 && len(natPkts) == 0 {
+ // Fast path: If no packets are to be dropped then we can just invoke the
+ // faster WritePackets API directly.
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ }
+ return n, err
+ }
+ r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+
+ // Slow path as we are dropping some packets in the batch degrade to
+ // emitting one packet at a time.
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if _, ok := dropped[pkt]; ok {
+ continue
+ }
+ if _, ok := natPkts[pkt]; ok {
+ netHeader := header.IPv6(pkt.NetworkHeader().View())
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ src := netHeader.SourceAddress()
+ dst := netHeader.DestinationAddress()
+ route := r.ReverseRoute(src, dst)
+ ep.HandlePacket(&route, pkt)
+ n++
+ continue
+ }
+ }
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
+ // Dropped packets aren't errors, so include them in
+ // the return value.
+ return n + len(dropped), err
+ }
+ n++
+ }
+
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, err
+ // Dropped packets aren't errors, so include them in the return value.
+ return n + len(dropped), nil
}
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
@@ -153,12 +504,24 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuff
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ if !e.isEnabled() {
+ return
+ }
+
h := header.IPv6(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
+ // As per RFC 4291 section 2.7:
+ // Multicast addresses must not be used as source addresses in IPv6
+ // packets or appear in any Routing header.
+ if header.IsV6MulticastAddress(r.RemoteAddress) {
+ r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
@@ -169,7 +532,19 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
hasFragmentHeader := false
- for firstHeader := true; ; firstHeader = false {
+ // iptables filtering. All packets that reach here are intended for
+ // this machine and need not be forwarded.
+ ipt := e.protocol.stack.IPTables()
+ if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ // iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesInputDropped.Increment()
+ return
+ }
+
+ for {
+ // Keep track of the start of the previous header so we can report the
+ // special case of a Hop by Hop at a location other than at the start.
+ previousHeaderStart := it.HeaderOffset()
extHdr, done, err := it.Next()
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -183,11 +558,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6HopByHopOptionsExtHdr:
// As per RFC 8200 section 4.1, the Hop By Hop extension header is
// restricted to appear immediately after an IPv6 fixed header.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1
- // (unrecognized next header) error in response to an extension header's
- // Next Header field with the Hop By Hop extension header identifier.
- if !firstHeader {
+ if previousHeaderStart != 0 {
+ _ = returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: previousHeaderStart,
+ }, pkt)
return
}
@@ -209,13 +584,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
return
- case header.IPv6OptionUnknownActionDiscardSendICMP:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
- return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ return
+ }
+ fallthrough
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // This case satisfies a requirement of RFC 8200 section 4.2
+ // which states that an unknown option starting with bits [10] should:
+ //
+ // discard the packet and, regardless of whether or not the
+ // packet's Destination Address was a multicast address, send an
+ // ICMP Parameter Problem, Code 2, message to the packet's
+ // Source Address, pointing to the unrecognized Option Type.
+ //
+ _ = returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownOption,
+ pointer: it.ParseOffset() + optsIt.OptionOffset(),
+ respondToMulticast: true,
+ }, pkt)
return
default:
panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
@@ -226,16 +613,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// As per RFC 8200 section 4.4, if a node encounters a routing header with
// an unrecognized routing type value, with a non-zero Segments Left
// value, the node must discard the packet and send an ICMP Parameter
- // Problem, Code 0. If the Segments Left is 0, the node must ignore the
- // Routing extension header and process the next header in the packet.
+ // Problem, Code 0 to the packet's Source Address, pointing to the
+ // unrecognized Routing Type.
+ //
+ // If the Segments Left is 0, the node must ignore the Routing extension
+ // header and process the next header in the packet.
//
// Note, the stack does not yet handle any type of routing extension
// header, so we just make sure Segments Left is zero before processing
// the next extension header.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for
- // unrecognized routing types with a non-zero Segments Left value.
if extHdr.SegmentsLeft() != 0 {
+ _ = returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
return
}
@@ -268,7 +659,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
it, done, err := it.Next()
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
return
}
if done {
@@ -311,12 +702,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// The packet is a fragment, let's try to reassemble it.
start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
- last := start + uint16(fragmentPayloadLen) - 1
- // Drop the packet if the fragmentOffset is incorrect. i.e the
- // combination of fragmentOffset and pkt.Data.size() causes a
- // wrap around resulting in last being less than the offset.
- if last < start {
+ // Drop the fragment if the size of the reassembled payload would exceed
+ // the maximum payload size.
+ if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
return
@@ -333,7 +722,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ID: extHdr.ID(),
},
start,
- last,
+ start+uint16(fragmentPayloadLen)-1,
extHdr.More(),
uint8(rawPayload.Identifier),
rawPayload.Buf,
@@ -372,13 +761,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
return
- case header.IPv6OptionUnknownActionDiscardSendICMP:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
- return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ return
+ }
+ fallthrough
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // This case satisfies a requirement of RFC 8200 section 4.2
+ // which states that an unknown option starting with bits [10] should:
+ //
+ // discard the packet and, regardless of whether or not the
+ // packet's Destination Address was a multicast address, send an
+ // ICMP Parameter Problem, Code 2, message to the packet's
+ // Source Address, pointing to the unrecognized Option Type.
+ //
+ _ = returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownOption,
+ pointer: it.ParseOffset() + optsIt.OptionOffset(),
+ respondToMulticast: true,
+ }, pkt)
return
default:
panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
@@ -397,21 +798,55 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size())
pkt.Data = extHdr.Buf
+ r.Stats().IP.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
+ pkt.TransportProtocolNumber = p
e.handleICMP(r, pkt, hasFragmentHeader)
} else {
r.Stats().IP.PacketsDelivered.Increment()
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
- // in response to unrecognized next header values.
- e.dispatcher.DeliverTransportPacket(r, p, pkt)
+ switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
+ case stack.TransportPacketHandled:
+ case stack.TransportPacketDestinationPortUnreachable:
+ // As per RFC 4443 section 3.1:
+ // A destination node SHOULD originate a Destination Unreachable
+ // message with Code 4 in response to a packet for which the
+ // transport protocol (e.g., UDP) has no listener, if that transport
+ // protocol has no alternative means to inform the sender.
+ _ = returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ case stack.TransportPacketProtocolUnreachable:
+ // As per RFC 8200 section 4. (page 7):
+ // Extension headers are numbered from IANA IP Protocol Numbers
+ // [IANA-PN], the same values used for IPv4 and IPv6. When
+ // processing a sequence of Next Header values in a packet, the
+ // first one that is not an extension header [IANA-EH] indicates
+ // that the next item in the packet is the corresponding upper-layer
+ // header.
+ // With more related information on page 8:
+ // If, as a result of processing a header, the destination node is
+ // required to proceed to the next header but the Next Header value
+ // in the current header is unrecognized by the node, it should
+ // discard the packet and send an ICMP Parameter Problem message to
+ // the source of the packet, with an ICMP Code value of 1
+ // ("unrecognized Next Header type encountered") and the ICMP
+ // Pointer field containing the offset of the unrecognized value
+ // within the original packet.
+ //
+ // Which when taken together indicate that an unknown protocol should
+ // be treated as an unrecognized next header value.
+ _ = returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
+ default:
+ panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
+ }
}
default:
- // If we receive a packet for an extension header we do not yet handle,
- // drop the packet for now.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
- // in response to unrecognized next header values.
+ _ = returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
r.Stats().UnknownProtocolRcvdPackets.Increment()
return
}
@@ -419,19 +854,340 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
}
// Close cleans up resources associated with the endpoint.
-func (*endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ e.disableLocked()
+ e.mu.ndp.removeSLAACAddresses(false /* keepLinkLocal */)
+ e.stopDADForPermanentAddressesLocked()
+ e.mu.addressableEndpointState.Cleanup()
+ e.mu.Unlock()
+
+ e.protocol.forgetEndpoint(e)
+}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
+// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ // TODO(b/169350103): add checks here after making sure we no longer receive
+ // an empty address.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
+}
+
+// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but
+// with locking requirements.
+//
+// addAndAcquirePermanentAddressLocked also joins the passed address's
+// solicited-node multicast group and start duplicate address detection.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ if err != nil {
+ return nil, err
+ }
+
+ if !header.IsV6UnicastAddress(addr.Address) {
+ return addressEndpoint, nil
+ }
+
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil {
+ return nil, err
+ }
+
+ addressEndpoint.SetKind(stack.PermanentTentative)
+
+ if e.Enabled() {
+ if err := e.mu.ndp.startDuplicateAddressDetection(addr.Address, addressEndpoint); err != nil {
+ return nil, err
+ }
+ }
+
+ return addressEndpoint, nil
+}
+
+// RemovePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ return e.removePermanentEndpointLocked(addressEndpoint, true)
+}
+
+// removePermanentEndpointLocked is like removePermanentAddressLocked except
+// it works with a stack.AddressEndpoint.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error {
+ addr := addressEndpoint.AddressWithPrefix()
+ unicast := header.IsV6UnicastAddress(addr.Address)
+ if unicast {
+ e.mu.ndp.stopDuplicateAddressDetection(addr.Address)
+
+ // If we are removing an address generated via SLAAC, cleanup
+ // its SLAAC resources and notify the integrator.
+ switch addressEndpoint.ConfigType() {
+ case stack.AddressConfigSlaac:
+ e.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ case stack.AddressConfigSlaacTemp:
+ e.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ }
+ }
+
+ if err := e.mu.addressableEndpointState.RemovePermanentEndpoint(addressEndpoint); err != nil {
+ return err
+ }
+
+ if !unicast {
+ return nil
+ }
+
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+
+ return nil
+}
+
+// hasPermanentAddressLocked returns true if the endpoint has a permanent
+// address equal to the passed address.
+//
+// Precondition: e.mu must be read or write locked.
+func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool {
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil {
+ return false
+ }
+ return addressEndpoint.GetKind().IsPermanent()
+}
+
+// getAddressRLocked returns the endpoint for the passed address.
+//
+// Precondition: e.mu must be read or write locked.
+func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint {
+ return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr)
+}
+
+// MainAddress implements stack.AddressableEndpoint.
+func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.MainAddress()
+}
+
+// AcquireAssignedAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB)
+}
+
+// acquireAddressOrCreateTempLocked is like AcquireAssignedAddress but with
+// locking requirements.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) acquireAddressOrCreateTempLocked(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ return e.mu.addressableEndpointState.AcquireAssignedAddress(localAddr, allowTemp, tempPEB)
+}
+
+// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
+}
+
+// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
+// but with locking requirements.
+//
+// Precondition: e.mu must be read locked.
+func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ // addrCandidate is a candidate for Source Address Selection, as per
+ // RFC 6724 section 5.
+ type addrCandidate struct {
+ addressEndpoint stack.AddressEndpoint
+ scope header.IPv6AddressScope
+ }
+
+ if len(remoteAddr) == 0 {
+ return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
+ }
+
+ // Create a candidate set of available addresses we can potentially use as a
+ // source address.
+ var cs []addrCandidate
+ e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) {
+ // If r is not valid for outgoing connections, it is not a valid endpoint.
+ if !addressEndpoint.IsAssigned(allowExpired) {
+ return
+ }
+
+ addr := addressEndpoint.AddressWithPrefix().Address
+ scope, err := header.ScopeForIPv6Address(addr)
+ if err != nil {
+ // Should never happen as we got r from the primary IPv6 endpoint list and
+ // ScopeForIPv6Address only returns an error if addr is not an IPv6
+ // address.
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err))
+ }
+
+ cs = append(cs, addrCandidate{
+ addressEndpoint: addressEndpoint,
+ scope: scope,
+ })
+ })
+
+ remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
+ if err != nil {
+ // primaryIPv6Endpoint should never be called with an invalid IPv6 address.
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
+ }
+
+ // Sort the addresses as per RFC 6724 section 5 rules 1-3.
+ //
+ // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
+ sort.Slice(cs, func(i, j int) bool {
+ sa := cs[i]
+ sb := cs[j]
+
+ // Prefer same address as per RFC 6724 section 5 rule 1.
+ if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ return true
+ }
+ if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ return false
+ }
+
+ // Prefer appropriate scope as per RFC 6724 section 5 rule 2.
+ if sa.scope < sb.scope {
+ return sa.scope >= remoteScope
+ } else if sb.scope < sa.scope {
+ return sb.scope < remoteScope
+ }
+
+ // Avoid deprecated addresses as per RFC 6724 section 5 rule 3.
+ if saDep, sbDep := sa.addressEndpoint.Deprecated(), sb.addressEndpoint.Deprecated(); saDep != sbDep {
+ // If sa is not deprecated, it is preferred over sb.
+ return sbDep
+ }
+
+ // Prefer temporary addresses as per RFC 6724 section 5 rule 7.
+ if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp {
+ return saTemp
+ }
+
+ // sa and sb are equal, return the endpoint that is closest to the front of
+ // the primary endpoint list.
+ return i < j
+ })
+
+ // Return the most preferred address that can have its reference count
+ // incremented.
+ for _, c := range cs {
+ if c.addressEndpoint.IncRef() {
+ return c.addressEndpoint
+ }
+ }
+
+ return nil
+}
+
+// PrimaryAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PrimaryAddresses()
+}
+
+// PermanentAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PermanentAddresses()
+}
+
+// JoinGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ if !header.IsV6MulticastAddress(addr) {
+ return false, tcpip.ErrBadAddress
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.JoinGroup(addr)
+}
+
+// LeaveGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.LeaveGroup(addr)
+}
+
+// IsInGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.IsInGroup(addr)
+}
+
+var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
+var _ stack.NetworkProtocol = (*protocol)(nil)
+
type protocol struct {
+ stack *stack.Stack
+
+ mu struct {
+ sync.RWMutex
+
+ eps map[*endpoint]struct{}
+ }
+
// defaultTTL is the current default TTL for the protocol. Only the
- // uint8 portion of it is meaningful and it must be accessed
- // atomically.
- defaultTTL uint32
+ // uint8 portion of it is meaningful.
+ //
+ // Must be accessed using atomic operations.
+ defaultTTL uint32
+
+ // forwarding is set to 1 when the protocol has forwarding enabled and 0
+ // when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
fragmentation *fragmentation.Fragmentation
+
+ // ndpDisp is the NDP event dispatcher that is used to send the netstack
+ // integrator NDP related events.
+ ndpDisp NDPDispatcher
+
+ // ndpConfigs is the default NDP configurations used by an IPv6 endpoint.
+ ndpConfigs NDPConfigurations
+
+ // opaqueIIDOpts hold the options for generating opaque interface identifiers
+ // (IIDs) as outlined by RFC 7217.
+ opaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // tempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ tempIIDSeed []byte
+
+ // autoGenIPv6LinkLocal determines whether or not the stack attempts to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
+ autoGenIPv6LinkLocal bool
}
// Number returns the ipv6 protocol number.
@@ -456,16 +1212,36 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
- return &endpoint{
- nicID: nicID,
- linkEP: linkEP,
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &endpoint{
+ nic: nic,
+ linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
nud: nud,
dispatcher: dispatcher,
protocol: p,
- stack: st,
}
+ e.mu.addressableEndpointState.Init(e)
+ e.mu.ndp = ndpState{
+ ep: e,
+ configs: p.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
+ }
+ e.mu.ndp.initializeTempAddrState()
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.mu.eps[e] = struct{}{}
+ return e
+}
+
+func (p *protocol) forgetEndpoint(e *endpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ delete(p.mu.eps, e)
}
// SetOption implements NetworkProtocol.SetOption.
@@ -506,75 +1282,43 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// Parse implements stack.TransportProtocol.Parse.
+// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
- hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt)
if !ok {
return 0, false, false
}
- ipHdr := header.IPv6(hdr)
- // dataClone consists of:
- // - Any IPv6 header bytes after the first 40 (i.e. extensions).
- // - The transport header, if present.
- // - Any other payload data.
- views := [8]buffer.View{}
- dataClone := pkt.Data.Clone(views[:])
- dataClone.TrimFront(header.IPv6MinimumSize)
- it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)
+ return proto, !fragMore && fragOffset == 0, true
+}
- // Iterate over the IPv6 extensions to find their length.
- //
- // Parsing occurs again in HandlePacket because we don't track the
- // extensions in PacketBuffer. Unfortunately, that means HandlePacket
- // has to do the parsing work again.
- var nextHdr tcpip.TransportProtocolNumber
- foundNext := true
- extensionsSize := 0
-traverseExtensions:
- for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() {
- if err != nil {
- break
- }
- // If we exhaust the extension list, the entire packet is the IPv6 header
- // and (possibly) extensions.
- if done {
- extensionsSize = dataClone.Size()
- foundNext = false
- break
- }
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) Forwarding() bool {
+ return uint8(atomic.LoadUint32(&p.forwarding)) == 1
+}
- switch extHdr := extHdr.(type) {
- case header.IPv6FragmentExtHdr:
- // If this is an atomic fragment, we don't have to treat it specially.
- if !extHdr.More() && extHdr.FragmentOffset() == 0 {
- continue
- }
- // This is a non-atomic fragment and has to be re-assembled before we can
- // examine the payload for a transport header.
- foundNext = false
+// setForwarding sets the forwarding status for the protocol.
+//
+// Returns true if the forwarding status was updated.
+func (p *protocol) setForwarding(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&p.forwarding, 1) == 0
+ }
+ return atomic.SwapUint32(&p.forwarding, 0) == 1
+}
- case header.IPv6RawPayloadHeader:
- // We've found the payload after any extensions.
- extensionsSize = dataClone.Size() - extHdr.Buf.Size()
- nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
- break traverseExtensions
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) SetForwarding(v bool) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
- default:
- // Any other extension is a no-op, keep looping until we find the payload.
- }
+ if !p.setForwarding(v) {
+ return
}
- // Put the IPv6 header with extensions in pkt.NetworkHeader().
- hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize)
- if !ok {
- panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
+ for ep := range p.mu.eps {
+ ep.transitionForwarding(v)
}
- ipHdr = header.IPv6(hdr)
- pkt.Data.CapLength(int(ipHdr.PayloadLength()))
- pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
-
- return nextHdr, foundNext, true
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
@@ -587,10 +1331,69 @@ func calculateMTU(mtu uint32) uint32 {
return maxPayloadSize
}
-// NewProtocol returns an IPv6 network protocol.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{
- defaultTTL: DefaultTTL,
- fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+// Options holds options to configure a new protocol.
+type Options struct {
+ // NDPConfigs is the default NDP configurations used by interfaces.
+ NDPConfigs NDPConfigurations
+
+ // AutoGenIPv6LinkLocal determines whether or not the stack attempts to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs.
+ //
+ // Note, setting this to true does not mean that a link-local address is
+ // assigned right away, or at all. If Duplicate Address Detection is enabled,
+ // an address is only assigned if it successfully resolves. If it fails, no
+ // further attempts are made to auto-generate an IPv6 link-local adddress.
+ //
+ // The generated link-local address follows RFC 4291 Appendix A guidelines.
+ AutoGenIPv6LinkLocal bool
+
+ // NDPDisp is the NDP event dispatcher that an integrator can provide to
+ // receive NDP related events.
+ NDPDisp NDPDispatcher
+
+ // OpaqueIIDOpts hold the options for generating opaque interface
+ // identifiers (IIDs) as outlined by RFC 7217.
+ OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // TempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ //
+ // Temporary SLAAC adresses are short-lived addresses which are unpredictable
+ // and random from the perspective of other nodes on the network. It is
+ // recommended that the seed be a random byte buffer of at least
+ // header.IIDSize bytes to make sure that temporary SLAAC addresses are
+ // sufficiently random. It should follow minimum randomness requirements for
+ // security as outlined by RFC 4086.
+ //
+ // Note: using a nil value, the same seed across netstack program runs, or a
+ // seed that is too small would reduce randomness and increase predictability,
+ // defeating the purpose of temporary SLAAC addresses.
+ TempIIDSeed []byte
+}
+
+// NewProtocolWithOptions returns an IPv6 network protocol.
+func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
+ opts.NDPConfigs.validate()
+
+ return func(s *stack.Stack) stack.NetworkProtocol {
+ p := &protocol{
+ stack: s,
+ fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()),
+
+ ndpDisp: opts.NDPDisp,
+ ndpConfigs: opts.NDPConfigs,
+ opaqueIIDOpts: opts.OpaqueIIDOpts,
+ tempIIDSeed: opts.TempIIDSeed,
+ autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ }
+ p.mu.eps = make(map[*endpoint]struct{})
+ p.SetDefaultTTL(DefaultTTL)
+ return p
}
}
+
+// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options.
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+ return NewProtocolWithOptions(Options{})(s)
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 54787198f..d7f82973b 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -15,13 +15,16 @@
package ipv6
import (
+ "math"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -139,18 +142,18 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
tests := []struct {
name string
- protocolFactory stack.TransportProtocol
+ protocolFactory stack.TransportProtocolFactory
rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
}{
- {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
- {"UDP", udp.NewProtocol(), testReceiveUDP},
+ {"ICMP", icmp.NewProtocol6, testReceiveICMP},
+ {"UDP", udp.NewProtocol, testReceiveUDP},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
})
e := channel.New(10, 1280, linkAddr1)
if err := s.CreateNIC(1, e); err != nil {
@@ -172,11 +175,11 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
tests := []struct {
name string
- protocolFactory stack.TransportProtocol
+ protocolFactory stack.TransportProtocolFactory
rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
}{
- {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
- {"UDP", udp.NewProtocol(), testReceiveUDP},
+ {"ICMP", icmp.NewProtocol6, testReceiveICMP},
+ {"UDP", udp.NewProtocol, testReceiveUDP},
}
snmc := header.SolicitedNodeAddr(addr2)
@@ -184,8 +187,8 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
})
e := channel.New(1, 1280, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
@@ -271,7 +274,7 @@ func TestAddIpv6Address(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -299,11 +302,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
name string
extHdr func(nextHdr uint8) ([]byte, uint8)
shouldAccept bool
+ // Should we expect an ICMP response and if so, with what contents?
+ expectICMP bool
+ ICMPType header.ICMPv6Type
+ ICMPCode header.ICMPv6Code
+ pointer uint32
+ multicast bool
}{
{
name: "None",
extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
shouldAccept: true,
+ expectICMP: false,
},
{
name: "hopbyhop with unknown option skippable action",
@@ -334,9 +344,10 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "hopbyhop with unknown option discard and send icmp action",
+ name: "hopbyhop with unknown option discard and send icmp action (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -346,12 +357,58 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP if option is unknown.
191, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "hopbyhop with unknown option discard and send icmp action unless multicast dest",
+ name: "hopbyhop with unknown option discard and send icmp action (multicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
+ }, hopByHopExtHdrID
+ },
+ multicast: true,
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
+ }, hopByHopExtHdrID
+ },
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -362,39 +419,77 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP unless packet is for multicast destination if
// option is unknown.
255, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
}, hopByHopExtHdrID
},
+ multicast: true,
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "routing with zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID },
+ name: "routing with zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 0, 2, 3, 4, 5,
+ }, routingExtHdrID
+ },
shouldAccept: true,
},
{
- name: "routing with non-zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID },
+ name: "routing with non-zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 1, 2, 3, 4, 5,
+ }, routingExtHdrID
+ },
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6ErroneousHeader,
+ pointer: header.IPv6FixedHeaderSize + 2,
},
{
- name: "atomic fragment with zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID },
+ name: "atomic fragment with zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 0, 0, 0, 0, 0, 0,
+ }, fragmentExtHdrID
+ },
shouldAccept: true,
},
{
- name: "atomic fragment with non-zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ name: "atomic fragment with non-zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 0, 0, 1, 2, 3, 4,
+ }, fragmentExtHdrID
+ },
shouldAccept: true,
+ expectICMP: false,
},
{
- name: "fragment",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ name: "fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 0, 1, 2, 3, 4,
+ }, fragmentExtHdrID
+ },
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "No next header",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{},
+ noNextHdrID
+ },
shouldAccept: false,
+ expectICMP: false,
},
{
name: "destination with unknown option skippable action",
@@ -410,6 +505,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, destinationExtHdrID
},
shouldAccept: true,
+ expectICMP: false,
},
{
name: "destination with unknown option discard action",
@@ -425,9 +521,30 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action (unicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ //^ 191 is an unknown option.
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "destination with unknown option discard and send icmp action",
+ name: "destination with unknown option discard and send icmp action (muilticast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -437,12 +554,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP if option is unknown.
191, 6, 1, 2, 3, 4, 5, 6,
+ //^ 191 is an unknown option.
}, destinationExtHdrID
},
+ multicast: true,
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "destination with unknown option discard and send icmp action unless multicast dest",
+ name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -453,22 +576,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP unless packet is for multicast destination if
// option is unknown.
255, 6, 1, 2, 3, 4, 5, 6,
+ //^ 255 is unknown.
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "routing - atomic fragment",
+ name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
- // Routing extension header.
- fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ nextHdr, 1,
- // Fragment extension header.
- nextHdr, 0, 0, 0, 1, 2, 3, 4,
- }, routingExtHdrID
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ //^ 255 is unknown.
+ }, destinationExtHdrID
},
- shouldAccept: true,
+ shouldAccept: false,
+ expectICMP: false,
+ multicast: true,
},
{
name: "atomic fragment - routing",
@@ -502,12 +636,42 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
return []byte{
// Routing extension header.
hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ // ^^^ The HopByHop extension header may not appear after the first
+ // extension header.
// Hop By Hop extension header with skippable unknown option.
nextHdr, 0, 62, 4, 1, 2, 3, 4,
}, routingExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
+ },
+ {
+ name: "routing - hop by hop (with send icmp unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ // ^^^ The HopByHop extension header may not appear after the first
+ // extension header.
+
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, routingExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
},
{
name: "No next header",
@@ -551,6 +715,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
{
name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)",
@@ -571,16 +736,17 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(0, 1280, linkAddr1)
+ e := channel.New(1, 1280, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -588,6 +754,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
}
+ // Add a default route so that a return packet knows where to go.
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
@@ -629,12 +803,16 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Serialize IPv6 fixed header.
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ dstAddr := tcpip.Address(addr2)
+ if test.multicast {
+ dstAddr = header.IPv6AllNodesMulticastAddress
+ }
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
NextHeader: ipv6NextHdr,
HopLimit: 255,
SrcAddr: addr1,
- DstAddr: addr2,
+ DstAddr: dstAddr,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -648,6 +826,44 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = 0", got)
}
+ if !test.expectICMP {
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpected packet received: %#v", p)
+ }
+ return
+ }
+
+ // ICMP required.
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected packet wasn't written out")
+ }
+
+ // Pack the output packet into a single buffer.View as the checkers
+ // assume that.
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
+ if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want {
+ t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want)
+ }
+
+ ipHdr := header.IPv6(pkt)
+ checker.IPv6(t, ipHdr, checker.ICMPv6(
+ checker.ICMPv6Type(test.ICMPType),
+ checker.ICMPv6Code(test.ICMPCode)))
+
+ // We know we are looking at no extension headers in the error ICMP
+ // packets.
+ icmpPkt := header.ICMPv6(ipHdr.Payload())
+ // We know we sent small packets that won't be truncated when reflected
+ // back to us.
+ originalPacket := icmpPkt.Payload()
+ if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want {
+ t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want)
+ }
+ if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" {
+ t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff)
+ }
return
}
@@ -687,6 +903,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Used to test cases where the fragment blocks are not a multiple of
// the fragment block size of 8 (RFC 8200 section 4.5).
udpPayload3Length = 127
+ udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
fragmentExtHdrLen = 8
// Note, not all routing extension headers will be 8 bytes but this test
// uses 8 byte routing extension headers for most sub tests.
@@ -731,6 +948,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:]
ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2)
+ var udpPayload4Addr1ToAddr2Buf [udpPayload4Length]byte
+ udpPayload4Addr1ToAddr2 := udpPayload4Addr1ToAddr2Buf[:]
+ ipv6Payload4Addr1ToAddr2 := udpGen(udpPayload4Addr1ToAddr2, 4, addr1, addr2)
+
tests := []struct {
name string
expectedPayload []byte
@@ -1020,6 +1241,44 @@ func TestReceiveIPv6Fragments(t *testing.T) {
expectedPayloads: nil,
},
{
+ name: "Two fragments reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+65520,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[:65520],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8190, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[65520:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
+ },
+ {
name: "Two fragments with per-fragment routing header with zero segments left",
fragments: []fragmentData{
{
@@ -1504,8 +1763,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
e := channel.New(0, 1280, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
@@ -1572,3 +1831,308 @@ func TestReceiveIPv6Fragments(t *testing.T) {
})
}
}
+
+func TestInvalidIPv6Fragments(t *testing.T) {
+ const (
+ nicID = 1
+ fragmentExtHdrLen = 8
+ )
+
+ payloadGen := func(payloadLen int) []byte {
+ payload := make([]byte, payloadLen)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = 0x30
+ }
+ return payload
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ wantMalformedIPPackets uint64
+ wantMalformedFragments uint64
+ }{
+ {
+ name: "fragments reassembled into a payload exceeding the max IPv6 payload size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16,
+ []buffer.View{
+ // Fragment extension header.
+ // Fragment offset = 8190, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8,
+ ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8,
+ 0, 0, 0, 1}),
+ // Payload length = 16
+ payloadGen(16),
+ },
+ ),
+ },
+ },
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ NewProtocol,
+ },
+ })
+ e := channel.New(0, 1500, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ for _, f := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize)
+
+ // Serialize IPv6 fixed header.
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(f.data.Size()),
+ NextHeader: f.nextHdr,
+ HopLimit: 255,
+ SrcAddr: f.srcAddr,
+ DstAddr: f.dstAddr,
+ })
+
+ vv := hdr.View().ToVectorisedView()
+ vv.Append(f.data)
+
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
+ }
+
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
+ t.Errorf("got Stats.IP.MalformedPacketsReceived = %d, want = %d", got, want)
+ }
+ if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
+ t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want)
+ }
+ })
+ }
+}
+
+func TestWriteStats(t *testing.T) {
+ const nPackets = 3
+ tests := []struct {
+ name string
+ setup func(*testing.T, *stack.Stack)
+ allowPackets int
+ expectSent int
+ expectDropped int
+ expectWritten int
+ }{
+ {
+ name: "Accept all",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets,
+ expectDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Accept all with error",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: nPackets - 1,
+ expectSent: nPackets - 1,
+ expectDropped: 0,
+ expectWritten: nPackets - 1,
+ }, {
+ name: "Drop all",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule that matches only 1
+ // of the 3 packets.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectDropped: 1,
+ expectWritten: nPackets,
+ },
+ }
+
+ writers := []struct {
+ name string
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ }{
+ {
+ name: "WritePacket",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ nWritten := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
+ return nWritten, err
+ }
+ nWritten++
+ }
+ return nWritten, nil
+ },
+ }, {
+ name: "WritePackets",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
+ },
+ },
+ }
+
+ for _, writer := range writers {
+ t.Run(writer.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ rt := buildRoute(t, ep)
+
+ var pkts stack.PacketBufferList
+ for i := 0; i < nPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
+ Data: buffer.NewView(0).ToVectorisedView(),
+ })
+ pkt.TransportHeader().Push(header.UDPMinimumSize)
+ pkts.PushBack(pkt)
+ }
+
+ test.setup(t, rt.Stack())
+
+ nWritten, _ := writer.writePackets(&rt, pkts)
+
+ if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
+ t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ }
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
+ t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ }
+ if nWritten != test.expectWritten {
+ t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ }
+ })
+ }
+ })
+ }
+}
+
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC(1, _) failed: %s", err)
+ }
+ const (
+ src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ )
+ if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
+ t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"))
+ if err != nil {
+ t.Fatalf("NewSubnet(_, _) failed: %v", err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err)
+ }
+ return rt
+}
+
+// limitedMatcher is an iptables matcher that matches after a certain number of
+// packets are checked against it.
+type limitedMatcher struct {
+ limit int
+}
+
+// Name implements Matcher.Name.
+func (*limitedMatcher) Name() string {
+ return "limitedMatcher"
+}
+
+// Match implements Matcher.Match.
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+ if lm.limit == 0 {
+ return true, false
+ }
+ lm.limit--
+ return false, false
+}
+
+func TestClearEndpointFromProtocolOnClose(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint)
+ {
+ proto.mu.Lock()
+ _, hasEP := proto.mu.eps[ep]
+ proto.mu.Unlock()
+ if !hasEP {
+ t.Fatalf("expected protocol to have ep = %p in set of endpoints", ep)
+ }
+ }
+
+ ep.Close()
+
+ {
+ proto.mu.Lock()
+ _, hasEP := proto.mu.eps[ep]
+ proto.mu.Unlock()
+ if hasEP {
+ t.Fatalf("unexpectedly found ep = %p in set of protocol's endpoints", ep)
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index b0873d1af..48a4c65e3 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package stack
+package ipv6
import (
"fmt"
@@ -23,9 +23,27 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
+ // defaultRetransmitTimer is the default amount of time to wait between
+ // sending reachability probes.
+ //
+ // Default taken from RETRANS_TIMER of RFC 4861 section 10.
+ defaultRetransmitTimer = time.Second
+
+ // minimumRetransmitTimer is the minimum amount of time to wait between
+ // sending reachability probes.
+ //
+ // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here
+ // to make sure the messages are not sent all at once. We also come to this
+ // value because in the RetransmitTimer field of a Router Advertisement, a
+ // value of 0 means unspecified, so the smallest valid value is 1. Note, the
+ // unit of the RetransmitTimer field in the Router Advertisement is
+ // milliseconds.
+ minimumRetransmitTimer = time.Millisecond
+
// defaultDupAddrDetectTransmits is the default number of NDP Neighbor
// Solicitation messages to send when doing Duplicate Address Detection
// for a tentative address.
@@ -34,7 +52,7 @@ const (
defaultDupAddrDetectTransmits = 1
// defaultMaxRtrSolicitations is the default number of Router
- // Solicitation messages to send when a NIC becomes enabled.
+ // Solicitation messages to send when an IPv6 endpoint becomes enabled.
//
// Default = 3 (from RFC 4861 section 10).
defaultMaxRtrSolicitations = 3
@@ -131,7 +149,7 @@ const (
minRegenAdvanceDuration = time.Duration(0)
// maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt
- // SLAAC address regenerations in response to a NIC-local conflict.
+ // SLAAC address regenerations in response to an IPv6 endpoint-local conflict.
maxSLAACAddrLocalRegenAttempts = 10
)
@@ -163,7 +181,7 @@ var (
// This is exported as a variable (instead of a constant) so tests
// can update it to a smaller value.
//
- // This value guarantees that a temporary address will be preferred for at
+ // This value guarantees that a temporary address is preferred for at
// least 1hr if the SLAAC prefix is valid for at least that time.
MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour
@@ -173,11 +191,17 @@ var (
// This is exported as a variable (instead of a constant) so tests
// can update it to a smaller value.
//
- // This value guarantees that a temporary address will be valid for at least
+ // This value guarantees that a temporary address is valid for at least
// 2hrs if the SLAAC prefix is valid for at least that time.
MinMaxTempAddrValidLifetime = 2 * time.Hour
)
+// NDPEndpoint is an endpoint that supports NDP.
+type NDPEndpoint interface {
+ // SetNDPConfigurations sets the NDP configurations.
+ SetNDPConfigurations(NDPConfigurations)
+}
+
// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an
// NDP Router Advertisement informed the Stack about.
type DHCPv6ConfigurationFromNDPRA int
@@ -192,7 +216,7 @@ const (
// DHCPv6ManagedAddress indicates that addresses are available via DHCPv6.
//
// DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6
- // will return all available configuration information.
+ // returns all available configuration information when serving addresses.
DHCPv6ManagedAddress
// DHCPv6OtherConfigurations indicates that other configuration information is
@@ -207,19 +231,18 @@ const (
// NDPDispatcher is the interface integrators of netstack must implement to
// receive and handle NDP related events.
type NDPDispatcher interface {
- // OnDuplicateAddressDetectionStatus will be called when the DAD process
- // for an address (addr) on a NIC (with ID nicID) completes. resolved
- // will be set to true if DAD completed successfully (no duplicate addr
- // detected); false otherwise (addr was detected to be a duplicate on
- // the link the NIC is a part of, or it was stopped for some other
- // reason, such as the address being removed). If an error occured
- // during DAD, err will be set and resolved must be ignored.
+ // OnDuplicateAddressDetectionStatus is called when the DAD process for an
+ // address (addr) on a NIC (with ID nicID) completes. resolved is set to true
+ // if DAD completed successfully (no duplicate addr detected); false otherwise
+ // (addr was detected to be a duplicate on the link the NIC is a part of, or
+ // it was stopped for some other reason, such as the address being removed).
+ // If an error occured during DAD, err is set and resolved must be ignored.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
- // OnDefaultRouterDiscovered will be called when a new default router is
+ // OnDefaultRouterDiscovered is called when a new default router is
// discovered. Implementations must return true if the newly discovered
// router should be remembered.
//
@@ -227,56 +250,55 @@ type NDPDispatcher interface {
// is also not permitted to call into the stack.
OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool
- // OnDefaultRouterInvalidated will be called when a discovered default
- // router that was remembered is invalidated.
+ // OnDefaultRouterInvalidated is called when a discovered default router that
+ // was remembered is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address)
- // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is
- // discovered. Implementations must return true if the newly discovered
- // on-link prefix should be remembered.
+ // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered.
+ // Implementations must return true if the newly discovered on-link prefix
+ // should be remembered.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool
- // OnOnLinkPrefixInvalidated will be called when a discovered on-link
- // prefix that was remembered is invalidated.
+ // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that
+ // was remembered is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet)
- // OnAutoGenAddress will be called when a new prefix with its
- // autonomous address-configuration flag set has been received and SLAAC
- // has been performed. Implementations may prevent the stack from
- // assigning the address to the NIC by returning false.
+ // OnAutoGenAddress is called when a new prefix with its autonomous address-
+ // configuration flag set is received and SLAAC was performed. Implementations
+ // may prevent the stack from assigning the address to the NIC by returning
+ // false.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
- // OnAutoGenAddressDeprecated will be called when an auto-generated
- // address (as part of SLAAC) has been deprecated, but is still
- // considered valid. Note, if an address is invalidated at the same
- // time it is deprecated, the deprecation event MAY be omitted.
+ // OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC)
+ // is deprecated, but is still considered valid. Note, if an address is
+ // invalidated at the same ime it is deprecated, the deprecation event may not
+ // be received.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix)
- // OnAutoGenAddressInvalidated will be called when an auto-generated
- // address (as part of SLAAC) has been invalidated.
+ // OnAutoGenAddressInvalidated is called when an auto-generated address
+ // (SLAAC) is invalidated.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix)
- // OnRecursiveDNSServerOption will be called when an NDP option with
- // recursive DNS servers has been received. Note, addrs may contain
- // link-local addresses.
+ // OnRecursiveDNSServerOption is called when the stack learns of DNS servers
+ // through NDP. Note, the addresses may contain link-local addresses.
//
// It is up to the caller to use the DNS Servers only for their valid
// lifetime. OnRecursiveDNSServerOption may be called for new or
@@ -288,8 +310,8 @@ type NDPDispatcher interface {
// call functions on the stack itself.
OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
- // OnDNSSearchListOption will be called when an NDP option with a DNS
- // search list has been received.
+ // OnDNSSearchListOption is called when the stack learns of DNS search lists
+ // through NDP.
//
// It is up to the caller to use the domain names in the search list
// for only their valid lifetime. OnDNSSearchListOption may be called
@@ -298,8 +320,8 @@ type NDPDispatcher interface {
// be increased, decreased or completely invalidated when lifetime = 0.
OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration)
- // OnDHCPv6Configuration will be called with an updated configuration that is
- // available via DHCPv6 for a specified NIC.
+ // OnDHCPv6Configuration is called with an updated configuration that is
+ // available via DHCPv6 for the passed NIC.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
@@ -320,7 +342,7 @@ type NDPConfigurations struct {
// Must be greater than or equal to 1ms.
RetransmitTimer time.Duration
- // The number of Router Solicitation messages to send when the NIC
+ // The number of Router Solicitation messages to send when the IPv6 endpoint
// becomes enabled.
MaxRtrSolicitations uint8
@@ -335,24 +357,22 @@ type NDPConfigurations struct {
// Must be greater than or equal to 0s.
MaxRtrSolicitationDelay time.Duration
- // HandleRAs determines whether or not Router Advertisements will be
- // processed.
+ // HandleRAs determines whether or not Router Advertisements are processed.
HandleRAs bool
- // DiscoverDefaultRouters determines whether or not default routers will
- // be discovered from Router Advertisements. This configuration is
- // ignored if HandleRAs is false.
+ // DiscoverDefaultRouters determines whether or not default routers are
+ // discovered from Router Advertisements, as per RFC 4861 section 6. This
+ // configuration is ignored if HandleRAs is false.
DiscoverDefaultRouters bool
- // DiscoverOnLinkPrefixes determines whether or not on-link prefixes
- // will be discovered from Router Advertisements' Prefix Information
- // option. This configuration is ignored if HandleRAs is false.
+ // DiscoverOnLinkPrefixes determines whether or not on-link prefixes are
+ // discovered from Router Advertisements' Prefix Information option, as per
+ // RFC 4861 section 6. This configuration is ignored if HandleRAs is false.
DiscoverOnLinkPrefixes bool
- // AutoGenGlobalAddresses determines whether or not global IPv6
- // addresses will be generated for a NIC in response to receiving a new
- // Prefix Information option with its Autonomous Address
- // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC).
+ // AutoGenGlobalAddresses determines whether or not an IPv6 endpoint performs
+ // SLAAC to auto-generate global SLAAC addresses in response to Prefix
+ // Information options, as per RFC 4862.
//
// Note, if an address was already generated for some unique prefix, as
// part of SLAAC, this option does not affect whether or not the
@@ -366,12 +386,12 @@ type NDPConfigurations struct {
//
// If the method used to generate the address does not support creating
// alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's
- // MAC address), then no attempt will be made to resolve the conflict.
+ // MAC address), then no attempt is made to resolve the conflict.
AutoGenAddressConflictRetries uint8
// AutoGenTempGlobalAddresses determines whether or not temporary SLAAC
- // addresses will be generated for a NIC as part of SLAAC privacy extensions,
- // RFC 4941.
+ // addresses are generated for an IPv6 endpoint as part of SLAAC privacy
+ // extensions, as per RFC 4941.
//
// Ignored if AutoGenGlobalAddresses is false.
AutoGenTempGlobalAddresses bool
@@ -410,7 +430,7 @@ func DefaultNDPConfigurations() NDPConfigurations {
}
// validate modifies an NDPConfigurations with valid values. If invalid values
-// are present in c, the corresponding default values will be used instead.
+// are present in c, the corresponding default values are used instead.
func (c *NDPConfigurations) validate() {
if c.RetransmitTimer < minimumRetransmitTimer {
c.RetransmitTimer = defaultRetransmitTimer
@@ -439,8 +459,8 @@ func (c *NDPConfigurations) validate() {
// ndpState is the per-interface NDP state.
type ndpState struct {
- // The NIC this ndpState is for.
- nic *NIC
+ // The IPv6 endpoint this ndpState is for.
+ ep *endpoint
// configs is the per-interface NDP configurations.
configs NDPConfigurations
@@ -458,8 +478,8 @@ type ndpState struct {
// Used to let the Router Solicitation timer know that it has been stopped.
//
// Must only be read from or written to while protected by the lock of
- // the NIC this ndpState is associated with. MUST be set when the timer is
- // set.
+ // the IPv6 endpoint this ndpState is associated with. MUST be set when the
+ // timer is set.
done *bool
}
@@ -492,7 +512,7 @@ type dadState struct {
// Used to let the DAD timer know that it has been stopped.
//
// Must only be read from or written to while protected by the lock of
- // the NIC this dadState is associated with.
+ // the IPv6 endpoint this dadState is associated with.
done *bool
}
@@ -537,7 +557,7 @@ type tempSLAACAddrState struct {
// The address's endpoint.
//
// Must not be nil.
- ref *referencedNetworkEndpoint
+ addressEndpoint stack.AddressEndpoint
// Has a new temporary SLAAC address already been regenerated?
regenerated bool
@@ -567,10 +587,10 @@ type slaacPrefixState struct {
//
// May only be nil when the address is being (re-)generated. Otherwise,
// must not be nil as all SLAAC prefixes must have a stable address.
- ref *referencedNetworkEndpoint
+ addressEndpoint stack.AddressEndpoint
- // The number of times an address has been generated locally where the NIC
- // already had the generated address.
+ // The number of times an address has been generated locally where the IPv6
+ // endpoint already had the generated address.
localGenerationFailures uint8
}
@@ -578,11 +598,12 @@ type slaacPrefixState struct {
tempAddrs map[tcpip.Address]tempSLAACAddrState
// The next two fields are used by both stable and temporary addresses
- // generated for a SLAAC prefix. This is safe as only 1 address will be
- // in the generation and DAD process at any time. That is, no two addresses
- // will be generated at the same time for a given SLAAC prefix.
+ // generated for a SLAAC prefix. This is safe as only 1 address is in the
+ // generation and DAD process at any time. That is, no two addresses are
+ // generated at the same time for a given SLAAC prefix.
- // The number of times an address has been generated and added to the NIC.
+ // The number of times an address has been generated and added to the IPv6
+ // endpoint.
//
// Addresses may be regenerated in reseponse to a DAD conflicts.
generationAttempts uint8
@@ -597,16 +618,16 @@ type slaacPrefixState struct {
// This function must only be called by IPv6 addresses that are currently
// tentative.
//
-// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
// addr must be a valid unicast IPv6 address.
if !header.IsV6UnicastAddress(addr) {
return tcpip.ErrAddressFamilyNotSupported
}
- if ref.getKind() != permanentTentative {
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
// The endpoint should be marked as tentative since we are starting DAD.
- panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
// Should not attempt to perform DAD on an address that is currently in the
@@ -617,18 +638,18 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// existed, we would get an error since we attempted to add a duplicate
// address, or its reference count would have been increased without doing
// the work that would have been done for an address that was brand new.
- // See NIC.addAddressLocked.
- panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()))
+ // See endpoint.addAddressLocked.
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID()))
}
remaining := ndp.configs.DupAddrDetectTransmits
if remaining == 0 {
- ref.setKind(permanent)
+ addressEndpoint.SetKind(stack.Permanent)
// Consider DAD to have resolved even if no DAD messages were actually
// transmitted.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, true, nil)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
}
return nil
@@ -637,25 +658,25 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
var done bool
var timer tcpip.Timer
// We initially start a timer to fire immediately because some of the DAD work
- // cannot be done while holding the NIC's lock. This is effectively the same
- // as starting a goroutine but we use a timer that fires immediately so we can
- // reset it for the next DAD iteration.
- timer = ndp.nic.stack.Clock().AfterFunc(0, func() {
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
+ // cannot be done while holding the IPv6 endpoint's lock. This is effectively
+ // the same as starting a goroutine but we use a timer that fires immediately
+ // so we can reset it for the next DAD iteration.
+ timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() {
+ ndp.ep.mu.Lock()
+ defer ndp.ep.mu.Unlock()
if done {
// If we reach this point, it means that the DAD timer fired after
- // another goroutine already obtained the NIC lock and stopped DAD
- // before this function obtained the NIC lock. Simply return here and do
- // nothing further.
+ // another goroutine already obtained the IPv6 endpoint lock and stopped
+ // DAD before this function obtained the NIC lock. Simply return here and
+ // do nothing further.
return
}
- if ref.getKind() != permanentTentative {
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
// The endpoint should still be marked as tentative since we are still
// performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID()))
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
dadDone := remaining == 0
@@ -663,33 +684,34 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
var err *tcpip.Error
if !dadDone {
// Use the unspecified address as the source address when performing DAD.
- ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
+ addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
// Do not hold the lock when sending packets which may be a long running
// task or may block link address resolution. We know this is safe
// because immediately after obtaining the lock again, we check if DAD
- // has been stopped before doing any work with the NIC. Note, DAD would be
- // stopped if the NIC was disabled or removed, or if the address was
- // removed.
- ndp.nic.mu.Unlock()
- err = ndp.sendDADPacket(addr, ref)
- ndp.nic.mu.Lock()
+ // has been stopped before doing any work with the IPv6 endpoint. Note,
+ // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if
+ // the address was removed.
+ ndp.ep.mu.Unlock()
+ err = ndp.sendDADPacket(addr, addressEndpoint)
+ ndp.ep.mu.Lock()
+ addressEndpoint.DecRef()
}
if done {
// If we reach this point, it means that DAD was stopped after we released
- // the NIC's read lock and before we obtained the write lock.
+ // the IPv6 endpoint's read lock and before we obtained the write lock.
return
}
if dadDone {
// DAD has resolved.
- ref.setKind(permanent)
+ addressEndpoint.SetKind(stack.Permanent)
} else if err == nil {
// DAD is not done and we had no errors when sending the last NDP NS,
// schedule the next DAD timer.
remaining--
- timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer)
+ timer.Reset(ndp.configs.RetransmitTimer)
return
}
@@ -698,16 +720,16 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// integrator know DAD has completed.
delete(ndp.dad, addr)
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
}
// If DAD resolved for a stable SLAAC address, attempt generation of a
// temporary SLAAC address.
- if dadDone && ref.configType == slaac {
+ if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
// Reset the generation attempts counter as we are starting the generation
// of a new address for the SLAAC prefix.
- ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */)
+ ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
}
})
@@ -722,28 +744,31 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns
// addr.
//
-// addr must be a tentative IPv6 address on ndp's NIC.
+// addr must be a tentative IPv6 address on ndp's IPv6 endpoint.
//
-// The NIC ndp belongs to MUST NOT be locked.
-func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
+// The IPv6 endpoint that ndp belongs to MUST NOT be locked.
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- r := makeRoute(header.IPv6ProtocolNumber, ref.address(), snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
defer r.Release()
// Route should resolve immediately since snmc is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
// Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the NIC is not locked,
- // the NIC could have been disabled or removed by another goroutine.
+ // Since this method is required to be called when the IPv6 endpoint is not
+ // locked, the NIC could have been disabled or removed by another goroutine.
if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
return err
}
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err))
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err))
} else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()))
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID()))
}
icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
@@ -752,17 +777,16 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd
ns.SetTargetAddress(addr)
icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- pkt := NewPacketBuffer(PacketBufferOptions{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: buffer.View(icmpData).ToVectorisedView(),
})
sent := r.Stats().ICMP.V6PacketsSent
if err := r.WritePacket(nil,
- NetworkHeaderParams{
+ stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
- TOS: DefaultTOS,
}, pkt,
); err != nil {
sent.Dropped.Increment()
@@ -778,11 +802,9 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd
// such a state forever, unless some other external event resolves the DAD
// process (receiving an NA from the true owner of addr, or an NS for addr
// (implying another node is attempting to use addr)). It is up to the caller
-// of this function to handle such a scenario. Normally, addr will be removed
-// from n right after this function returns or the address successfully
-// resolved.
+// of this function to handle such a scenario.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
dad, ok := ndp.dad[addr]
if !ok {
@@ -801,30 +823,30 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
delete(ndp.dad, addr)
// Let the integrator know DAD did not resolve.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil)
}
}
// handleRA handles a Router Advertisement message that arrived on the NIC
// this ndp is for. Does nothing if the NIC is configured to not handle RAs.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
- // Is the NIC configured to handle RAs at all?
+ // Is the IPv6 endpoint configured to handle RAs at all?
//
// Currently, the stack does not determine router interface status on a
- // per-interface basis; it is a stack-wide configuration, so we check
- // stack's forwarding flag to determine if the NIC is a routing
- // interface.
- if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding {
+ // per-interface basis; it is a protocol-wide configuration, so we check the
+ // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding
+ // packets.
+ if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() {
return
}
// Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we
// only inform the dispatcher on configuration changes. We do nothing else
// with the information.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
var configuration DHCPv6ConfigurationFromNDPRA
switch {
case ra.ManagedAddrConfFlag():
@@ -839,11 +861,11 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
if ndp.dhcpv6Configuration != configuration {
ndp.dhcpv6Configuration = configuration
- ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration)
+ ndpDisp.OnDHCPv6Configuration(ndp.ep.nic.ID(), configuration)
}
}
- // Is the NIC configured to discover default routers?
+ // Is the IPv6 endpoint configured to discover default routers?
if ndp.configs.DiscoverDefaultRouters {
rtr, ok := ndp.defaultRouters[ip]
rl := ra.RouterLifetime()
@@ -881,20 +903,20 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() {
switch opt := opt.(type) {
case header.NDPRecursiveDNSServer:
- if ndp.nic.stack.ndpDisp == nil {
+ if ndp.ep.protocol.ndpDisp == nil {
continue
}
addrs, _ := opt.Addresses()
- ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime())
+ ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
case header.NDPDNSSearchList:
- if ndp.nic.stack.ndpDisp == nil {
+ if ndp.ep.protocol.ndpDisp == nil {
continue
}
domainNames, _ := opt.DomainNames()
- ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime())
+ ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
case header.NDPPrefixInformation:
prefix := opt.Subnet()
@@ -928,7 +950,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// invalidateDefaultRouter invalidates a discovered default router.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
rtr, ok := ndp.defaultRouters[ip]
@@ -942,32 +964,32 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
}
}
// rememberDefaultRouter remembers a newly discovered default router with IPv6
// link-local address ip with lifetime rl.
//
-// The router identified by ip MUST NOT already be known by the NIC.
+// The router identified by ip MUST NOT already be known by the IPv6 endpoint.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
- ndpDisp := ndp.nic.stack.ndpDisp
+ ndpDisp := ndp.ep.protocol.ndpDisp
if ndpDisp == nil {
return
}
// Inform the integrator when we discovered a default router.
- if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) {
+ if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) {
// Informed by the integrator to not remember the router, do
// nothing further.
return
}
state := defaultRouterState{
- invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
}
@@ -982,22 +1004,22 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
//
// The prefix identified by prefix MUST NOT already be known.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) {
- ndpDisp := ndp.nic.stack.ndpDisp
+ ndpDisp := ndp.ep.protocol.ndpDisp
if ndpDisp == nil {
return
}
// Inform the integrator when we discovered an on-link prefix.
- if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) {
+ if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) {
// Informed by the integrator to not remember the prefix, do
// nothing further.
return
}
state := onLinkPrefixState{
- invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
ndp.invalidateOnLinkPrefix(prefix)
}),
}
@@ -1011,7 +1033,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
// invalidateOnLinkPrefix invalidates a discovered on-link prefix.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
s, ok := ndp.onLinkPrefixes[prefix]
@@ -1025,8 +1047,8 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix)
}
}
@@ -1036,7 +1058,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
// handleOnLinkPrefixInformation assumes that the prefix this pi is for is
// not the link-local prefix and the on-link flag is set.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) {
prefix := pi.Subnet()
prefixState, ok := ndp.onLinkPrefixes[prefix]
@@ -1089,7 +1111,7 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// handleAutonomousPrefixInformation assumes that the prefix this pi is for is
// not the link-local prefix and the autonomous flag is set.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) {
vl := pi.ValidLifetime()
pl := pi.PreferredLifetime()
@@ -1125,7 +1147,7 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
//
// pl is the new preferred lifetime. vl is the new valid lifetime.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// If we do not already have an address for this prefix and the valid
// lifetime is 0, no need to do anything further, as per RFC 4862
@@ -1142,15 +1164,15 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
state := slaacPrefixState{
- deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
}
- ndp.deprecateSLAACAddress(state.stableAddr.ref)
+ ndp.deprecateSLAACAddress(state.stableAddr.addressEndpoint)
}),
- invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
@@ -1189,7 +1211,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
// If the address is assigned (DAD resolved), generate a temporary address.
- if state.stableAddr.ref.getKind() == permanent {
+ if state.stableAddr.addressEndpoint.GetKind() == stack.Permanent {
// Reset the generation attempts counter as we are starting the generation
// of a new address for the SLAAC prefix.
ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */)
@@ -1198,32 +1220,27 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
ndp.slaacPrefixes[prefix] = state
}
-// addSLAACAddr adds a SLAAC address to the NIC.
+// addAndAcquireSLAACAddr adds a SLAAC address to the IPv6 endpoint.
//
-// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint {
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint {
// Inform the integrator that we have a new SLAAC address.
- ndpDisp := ndp.nic.stack.ndpDisp
+ ndpDisp := ndp.ep.protocol.ndpDisp
if ndpDisp == nil {
return nil
}
- if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addr) {
+ if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) {
// Informed by the integrator not to add the address.
return nil
}
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addr,
- }
-
- ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated)
+ addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated)
if err != nil {
- panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err))
+ panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err))
}
- return ref
+ return addressEndpoint
}
// generateSLAACAddr generates a SLAAC address for prefix.
@@ -1232,10 +1249,10 @@ func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType netwo
//
// Panics if the prefix is not a SLAAC prefix or it already has an address.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool {
- if r := state.stableAddr.ref; r != nil {
- panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix()))
+ if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil {
+ panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, addressEndpoint.AddressWithPrefix()))
}
// If we have already reached the maximum address generation attempts for the
@@ -1255,11 +1272,11 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
}
dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures
- if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil {
addrBytes = header.AppendOpaqueInterfaceIdentifier(
addrBytes[:header.IIDOffsetInIPv6Address],
prefix,
- oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name),
+ oIID.NICNameFromID(ndp.ep.nic.ID(), ndp.ep.nic.Name()),
dadCounter,
oIID.SecretKey,
)
@@ -1272,7 +1289,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
//
// TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
// LinkEndpoint.LinkAddress) before reaching this point.
- linkAddr := ndp.nic.linkEP.LinkAddress()
+ linkAddr := ndp.ep.linkEP.LinkAddress()
if !header.IsValidUnicastEthernetAddress(linkAddr) {
return false
}
@@ -1291,15 +1308,15 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
PrefixLen: validPrefixLenForAutoGen,
}
- if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) {
+ if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) {
break
}
state.stableAddr.localGenerationFailures++
}
- if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil {
- state.stableAddr.ref = ref
+ if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil {
+ state.stableAddr.addressEndpoint = addressEndpoint
state.generationAttempts++
return true
}
@@ -1309,10 +1326,9 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
// regenerateSLAACAddr regenerates an address for a SLAAC prefix.
//
-// If generating a new address for the prefix fails, the prefix will be
-// invalidated.
+// If generating a new address for the prefix fails, the prefix is invalidated.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
@@ -1332,7 +1348,7 @@ func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) {
// generateTempSLAACAddr generates a new temporary SLAAC address.
//
-// If resetGenAttempts is true, the prefix's generation counter will be reset.
+// If resetGenAttempts is true, the prefix's generation counter is reset.
//
// Returns true if a new address was generated.
func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool {
@@ -1353,7 +1369,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
return false
}
- stableAddr := prefixState.stableAddr.ref.address()
+ stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address
now := time.Now()
// As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
@@ -1392,7 +1408,8 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
return false
}
- // Attempt to generate a new address that is not already assigned to the NIC.
+ // Attempt to generate a new address that is not already assigned to the IPv6
+ // endpoint.
var generatedAddr tcpip.AddressWithPrefix
for i := 0; ; i++ {
// If we were unable to generate an address after the maximum SLAAC address
@@ -1402,7 +1419,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr)
- if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) {
+ if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) {
break
}
}
@@ -1410,13 +1427,13 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
// As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary
// address with a zero preferred lifetime. The checks above ensure this
// so we know the address is not deprecated.
- ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */)
- if ref == nil {
+ addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaacTemp, false /* deprecated */)
+ if addressEndpoint == nil {
return false
}
state := tempSLAACAddrState{
- deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
@@ -1427,9 +1444,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr))
}
- ndp.deprecateSLAACAddress(tempAddrState.ref)
+ ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
}),
- invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
@@ -1442,7 +1459,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
}),
- regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ regenJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
@@ -1465,8 +1482,8 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
prefixState.tempAddrs[generatedAddr.Address] = tempAddrState
ndp.slaacPrefixes[prefix] = prefixState
}),
- createdAt: now,
- ref: ref,
+ createdAt: now,
+ addressEndpoint: addressEndpoint,
}
state.deprecationJob.Schedule(pl)
@@ -1481,7 +1498,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
// regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
@@ -1496,14 +1513,14 @@ func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttemp
//
// pl is the new preferred lifetime. vl is the new valid lifetime.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) {
// If the preferred lifetime is zero, then the prefix should be deprecated.
deprecated := pl == 0
if deprecated {
- ndp.deprecateSLAACAddress(prefixState.stableAddr.ref)
+ ndp.deprecateSLAACAddress(prefixState.stableAddr.addressEndpoint)
} else {
- prefixState.stableAddr.ref.deprecated = false
+ prefixState.stableAddr.addressEndpoint.SetDeprecated(false)
}
// If prefix was preferred for some finite lifetime before, cancel the
@@ -1565,7 +1582,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// If DAD is not yet complete on the stable address, there is no need to do
// work with temporary addresses.
- if prefixState.stableAddr.ref.getKind() != permanent {
+ if prefixState.stableAddr.addressEndpoint.GetKind() != stack.Permanent {
return
}
@@ -1608,9 +1625,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
newPreferredLifetime := preferredUntil.Sub(now)
tempAddrState.deprecationJob.Cancel()
if newPreferredLifetime <= 0 {
- ndp.deprecateSLAACAddress(tempAddrState.ref)
+ ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
} else {
- tempAddrState.ref.deprecated = false
+ tempAddrState.addressEndpoint.SetDeprecated(false)
tempAddrState.deprecationJob.Schedule(newPreferredLifetime)
}
@@ -1635,8 +1652,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// due to an update in preferred lifetime.
//
// If each temporay address has already been regenerated, no new temporary
- // address will be generated. To ensure continuation of temporary SLAAC
- // addresses, we manually try to regenerate an address here.
+ // address is generated. To ensure continuation of temporary SLAAC addresses,
+ // we manually try to regenerate an address here.
if len(regenForAddr) != 0 || allAddressesRegenerated {
// Reset the generation attempts counter as we are starting the generation
// of a new address for the SLAAC prefix.
@@ -1647,57 +1664,58 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
}
-// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP
-// dispatcher that ref has been deprecated.
+// deprecateSLAACAddress marks the address as deprecated and notifies the NDP
+// dispatcher that address has been deprecated.
//
-// deprecateSLAACAddress does nothing if ref is already deprecated.
+// deprecateSLAACAddress does nothing if the address is already deprecated.
//
-// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) {
- if ref.deprecated {
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint) {
+ if addressEndpoint.Deprecated() {
return
}
- ref.deprecated = true
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix())
+ addressEndpoint.SetDeprecated(true)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix())
}
}
// invalidateSLAACPrefix invalidates a SLAAC prefix.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) {
- if r := state.stableAddr.ref; r != nil {
+ ndp.cleanupSLAACPrefixResources(prefix, state)
+
+ if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil {
// Since we are already invalidating the prefix, do not invalidate the
// prefix when removing the address.
- if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil {
- panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", r.addrWithPrefix(), err))
+ if err := ndp.ep.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err))
}
}
-
- ndp.cleanupSLAACPrefixResources(prefix, state)
}
// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's
// resources.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
prefix := addr.Subnet()
state, ok := ndp.slaacPrefixes[prefix]
- if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.address() {
+ if !ok || state.stableAddr.addressEndpoint == nil || addr.Address != state.stableAddr.addressEndpoint.AddressWithPrefix().Address {
return
}
if !invalidatePrefix {
// If the prefix is not being invalidated, disassociate the address from the
// prefix and do nothing further.
- state.stableAddr.ref = nil
+ state.stableAddr.addressEndpoint.DecRef()
+ state.stableAddr.addressEndpoint = nil
ndp.slaacPrefixes[prefix] = state
return
}
@@ -1709,14 +1727,17 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr
//
// Panics if the SLAAC prefix is not known.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) {
// Invalidate all temporary addresses.
for tempAddr, tempAddrState := range state.tempAddrs {
ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState)
}
- state.stableAddr.ref = nil
+ if state.stableAddr.addressEndpoint != nil {
+ state.stableAddr.addressEndpoint.DecRef()
+ state.stableAddr.addressEndpoint = nil
+ }
state.deprecationJob.Cancel()
state.invalidationJob.Cancel()
delete(ndp.slaacPrefixes, prefix)
@@ -1724,12 +1745,12 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa
// invalidateTempSLAACAddr invalidates a temporary SLAAC address.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
// Since we are already invalidating the address, do not invalidate the
// address when removing the address.
- if err := ndp.nic.removePermanentIPv6EndpointLocked(tempAddrState.ref, false /* allowSLAACInvalidation */); err != nil {
- panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.ref.addrWithPrefix(), err))
+ if err := ndp.ep.removePermanentEndpointLocked(tempAddrState.addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err))
}
ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState)
@@ -1738,10 +1759,10 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA
// cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary
// SLAAC address's resources from ndp.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
if !invalidateAddr {
@@ -1765,35 +1786,29 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi
// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
// jobs and entry.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
+ tempAddrState.addressEndpoint.DecRef()
+ tempAddrState.addressEndpoint = nil
tempAddrState.deprecationJob.Cancel()
tempAddrState.invalidationJob.Cancel()
tempAddrState.regenJob.Cancel()
delete(tempAddrs, tempAddr)
}
-// cleanupState cleans up ndp's state.
-//
-// If hostOnly is true, then only host-specific state will be cleaned up.
+// removeSLAACAddresses removes all SLAAC addresses.
//
-// cleanupState MUST be called with hostOnly set to true when ndp's NIC is
-// transitioning from a host to a router. This function will invalidate all
-// discovered on-link prefixes, discovered routers, and auto-generated
-// addresses.
-//
-// If hostOnly is true, then the link-local auto-generated address will not be
-// invalidated as routers are also expected to generate a link-local address.
+// If keepLinkLocal is false, the SLAAC generated link-local address is removed.
//
-// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupState(hostOnly bool) {
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) {
linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
- linkLocalPrefixes := 0
+ var linkLocalPrefixes int
for prefix, state := range ndp.slaacPrefixes {
// RFC 4862 section 5 states that routers are also expected to generate a
// link-local address so we do not invalidate them if we are cleaning up
// host-only state.
- if hostOnly && prefix == linkLocalSubnet {
+ if keepLinkLocal && prefix == linkLocalSubnet {
linkLocalPrefixes++
continue
}
@@ -1804,6 +1819,21 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes))
}
+}
+
+// cleanupState cleans up ndp's state.
+//
+// If hostOnly is true, then only host-specific state is cleaned up.
+//
+// This function invalidates all discovered on-link prefixes, discovered
+// routers, and auto-generated addresses.
+//
+// If hostOnly is true, then the link-local auto-generated address aren't
+// invalidated as routers are also expected to generate a link-local address.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupState(hostOnly bool) {
+ ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */)
for prefix := range ndp.onLinkPrefixes {
ndp.invalidateOnLinkPrefix(prefix)
@@ -1827,7 +1857,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
// startSolicitingRouters starts soliciting routers, as per RFC 4861 section
// 6.3.7. If routers are already being solicited, this function does nothing.
//
-// The NIC ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
if ndp.rtrSolicit.timer != nil {
// We are already soliciting routers.
@@ -1848,27 +1878,37 @@ func (ndp *ndpState) startSolicitingRouters() {
var done bool
ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() {
- ndp.nic.mu.Lock()
+ ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() {
+ ndp.ep.mu.Lock()
if done {
// If we reach this point, it means that the RS timer fired after another
- // goroutine already obtained the NIC lock and stopped solicitations.
- // Simply return here and do nothing further.
- ndp.nic.mu.Unlock()
+ // goroutine already obtained the IPv6 endpoint lock and stopped
+ // solicitations. Simply return here and do nothing further.
+ ndp.ep.mu.Unlock()
return
}
// As per RFC 4861 section 4.1, the source of the RS is an address assigned
// to the sending interface, or the unspecified address if no address is
// assigned to the sending interface.
- ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress)
- if ref == nil {
- ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
+ addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false)
+ if addressEndpoint == nil {
+ // Incase this ends up creating a new temporary address, we need to hold
+ // onto the endpoint until a route is obtained. If we decrement the
+ // reference count before obtaing a route, the address's resources would
+ // be released and attempting to obtain a route after would fail. Once a
+ // route is obtainted, it is safe to decrement the reference count since
+ // obtaining a route increments the address's reference count.
+ addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
}
- ndp.nic.mu.Unlock()
+ ndp.ep.mu.Unlock()
- localAddr := ref.address()
- r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ localAddr := addressEndpoint.AddressWithPrefix().Address
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
+ addressEndpoint.DecRef()
+ if err != nil {
+ return
+ }
defer r.Release()
// Route should resolve immediately since
@@ -1876,15 +1916,16 @@ func (ndp *ndpState) startSolicitingRouters() {
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
// Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the NIC is not locked,
- // the NIC could have been disabled or removed by another goroutine.
+ // Since this method is required to be called when the IPv6 endpoint is
+ // not locked, the IPv6 endpoint could have been disabled or removed by
+ // another goroutine.
if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
return
}
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err))
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err))
} else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()))
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID()))
}
// As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
@@ -1907,21 +1948,20 @@ func (ndp *ndpState) startSolicitingRouters() {
rs.Options().Serialize(optsSerializer)
icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- pkt := NewPacketBuffer(PacketBufferOptions{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: buffer.View(icmpData).ToVectorisedView(),
})
sent := r.Stats().ICMP.V6PacketsSent
if err := r.WritePacket(nil,
- NetworkHeaderParams{
+ stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
- TOS: DefaultTOS,
}, pkt,
); err != nil {
sent.Dropped.Increment()
- log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
+ log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
// Don't send any more messages if we had an error.
remaining = 0
} else {
@@ -1929,19 +1969,19 @@ func (ndp *ndpState) startSolicitingRouters() {
remaining--
}
- ndp.nic.mu.Lock()
+ ndp.ep.mu.Lock()
if done || remaining == 0 {
ndp.rtrSolicit.timer = nil
ndp.rtrSolicit.done = nil
} else if ndp.rtrSolicit.timer != nil {
// Note, we need to explicitly check to make sure that
// the timer field is not nil because if it was nil but
- // we still reached this point, then we know the NIC
+ // we still reached this point, then we know the IPv6 endpoint
// was requested to stop soliciting routers so we don't
// need to send the next Router Solicitation message.
ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
}
- ndp.nic.mu.Unlock()
+ ndp.ep.mu.Unlock()
})
}
@@ -1949,7 +1989,7 @@ func (ndp *ndpState) startSolicitingRouters() {
// stopSolicitingRouters stops soliciting routers. If routers are not currently
// being solicited, this function does nothing.
//
-// The NIC ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) stopSolicitingRouters() {
if ndp.rtrSolicit.timer == nil {
// Nothing to do.
@@ -1965,7 +2005,7 @@ func (ndp *ndpState) stopSolicitingRouters() {
// initializeTempAddrState initializes state related to temporary SLAAC
// addresses.
func (ndp *ndpState) initializeTempAddrState() {
- header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.nic.stack.tempIIDSeed, ndp.nic.ID())
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID())
if MaxDesyncFactor != 0 {
ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 480c495fa..25464a03a 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -17,6 +17,7 @@ package ipv6
import (
"strings"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -35,8 +36,8 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
t.Helper()
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
UseNeighborCache: useNeighborCache,
})
@@ -65,10 +66,94 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+ t.Cleanup(ep.Close)
+
return s, ep
}
+var _ NDPDispatcher = (*testNDPDispatcher)(nil)
+
+// testNDPDispatcher is an NDPDispatcher only allows default router discovery.
+type testNDPDispatcher struct {
+ addr tcpip.Address
+}
+
+func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) {
+}
+
+func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
+ t.addr = addr
+ return true
+}
+
+func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) {
+ t.addr = addr
+}
+
+func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
+ return false
+}
+
+func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {
+}
+
+func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
+ return false
+}
+
+func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {
+}
+
+func (*testNDPDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {
+}
+
+func (*testNDPDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {
+}
+
+func (*testNDPDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {
+}
+
+func (*testNDPDispatcher) OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) {
+}
+
+func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
+ var ndpDisp testNDPDispatcher
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
+ NDPDisp: &ndpDisp,
+ })},
+ })
+
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err)
+ }
+
+ ipv6EP := ep.(*endpoint)
+ ipv6EP.mu.Lock()
+ ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour)
+ ipv6EP.mu.Unlock()
+
+ if ndpDisp.addr != lladdr1 {
+ t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1)
+ }
+
+ ndpDisp.addr = ""
+ ndpEP := ep.(stack.NDPEndpoint)
+ ndpEP.InvalidateDefaultRouter(lladdr1)
+ if ndpDisp.addr != lladdr1 {
+ t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1)
+ }
+}
+
// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
// valid NDP NS message with the Source Link Layer Address option results in a
// new entry in the link address cache for the sender of the message.
@@ -98,7 +183,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
e := channel.New(0, 1280, linkAddr0)
if err := s.CreateNIC(nicID, e); err != nil {
@@ -202,7 +287,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
UseNeighborCache: true,
})
e := channel.New(0, 1280, linkAddr0)
@@ -475,7 +560,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
UseNeighborCache: stackTyp.useNeighborCache,
})
e := channel.New(1, 1280, nicLinkAddr)
@@ -596,7 +681,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
e := channel.New(0, 1280, linkAddr0)
if err := s.CreateNIC(nicID, e); err != nil {
@@ -707,7 +792,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
UseNeighborCache: true,
})
e := channel.New(0, 1280, linkAddr0)
@@ -958,7 +1043,7 @@ func TestNDPValidation(t *testing.T) {
if isRouter {
// Enabling forwarding makes the stack act as a router.
- s.SetForwarding(true)
+ s.SetForwarding(ProtocolNumber, true)
}
stats := s.Stats().ICMP.V6PacketsReceived
@@ -1172,7 +1257,7 @@ func TestRouterAdvertValidation(t *testing.T) {
e := channel.New(10, 1280, linkAddr1)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
UseNeighborCache: stackTyp.useNeighborCache,
})
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
new file mode 100644
index 000000000..c9e57dc0d
--- /dev/null
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "testutil",
+ srcs = [
+ "testutil.go",
+ ],
+ visibility = [
+ "//pkg/tcpip/network/ipv4:__pkg__",
+ "//pkg/tcpip/network/ipv6:__pkg__",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
new file mode 100644
index 000000000..7cc52985e
--- /dev/null
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -0,0 +1,144 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil defines types and functions used to test Network Layer
+// functionality such as IP fragmentation.
+package testutil
+
+import (
+ "fmt"
+ "math/rand"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// MockLinkEndpoint is an endpoint used for testing, it stores packets written
+// to it and can mock errors.
+type MockLinkEndpoint struct {
+ // WrittenPackets is where packets written to the endpoint are stored.
+ WrittenPackets []*stack.PacketBuffer
+
+ mtu uint32
+ err *tcpip.Error
+ allowPackets int
+}
+
+// NewMockLinkEndpoint creates a new MockLinkEndpoint.
+//
+// err is the error that will be returned once allowPackets packets are written
+// to the endpoint.
+func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint {
+ return &MockLinkEndpoint{
+ mtu: mtu,
+ err: err,
+ allowPackets: allowPackets,
+ }
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu }
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 }
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 }
+
+// LinkAddress implements LinkEndpoint.LinkAddress.
+func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (ep *MockLinkEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if ep.allowPackets == 0 {
+ return ep.err
+ }
+ ep.allowPackets--
+ ep.WrittenPackets = append(ep.WrittenPackets, pkt)
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ var n int
+
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := ep.WritePacket(r, gso, protocol, pkt); err != nil {
+ return n, err
+ }
+ n++
+ }
+
+ return n, nil
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ if ep.allowPackets == 0 {
+ return ep.err
+ }
+ ep.allowPackets--
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+ ep.WrittenPackets = append(ep.WrittenPackets, pkt)
+
+ return nil
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (*MockLinkEndpoint) IsAttached() bool { return false }
+
+// Wait implements LinkEndpoint.Wait.
+func (*MockLinkEndpoint) Wait() {}
+
+// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
+func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
+
+// AddHeader implements LinkEndpoint.AddHeader.
+func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
+
+// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
+// how many random bytes will be copied in the Transport Header.
+// extraHeaderReserveLength indicates how much extra space will be reserved for
+// the other headers. The payload is made from Views of the sizes listed in
+// viewSizes.
+func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer {
+ var views buffer.VectorisedView
+
+ for _, s := range viewSizes {
+ newView := buffer.NewView(s)
+ if _, err := rand.Read(newView); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ views.AppendView(newView)
+ }
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength,
+ Data: views,
+ })
+ pkt.NetworkProtocolNumber = proto
+ if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ return pkt
+}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 91fc26722..51d428049 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -127,8 +127,8 @@ func main() {
// Create the stack with ipv4 and tcp protocols, then add a tun-based
// NIC and ipv4 address.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
mtu, err := rawfile.GetMTU(tunName)
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 9e37cab18..8e0ee1cd7 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -112,8 +112,8 @@ func main() {
// Create the stack with ip and tcp protocols, then add a tun-based
// NIC and address.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
mtu, err := rawfile.GetMTU(tunName)
@@ -188,7 +188,7 @@ func main() {
defer wq.EventUnregister(&waitEntry)
for {
- n, wq, err := ep.Accept()
+ n, wq, err := ep.Accept(nil)
if err != nil {
if err == tcpip.ErrWouldBlock {
<-notifyCh
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 900938dd1..2eaeab779 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -54,8 +54,8 @@ go_template_instance(
go_library(
name = "stack",
srcs = [
+ "addressable_endpoint_state.go",
"conntrack.go",
- "dhcpv6configurationfromndpra_string.go",
"forwarder.go",
"headertype_string.go",
"icmp_rate_limit.go",
@@ -65,7 +65,6 @@ go_library(
"iptables_types.go",
"linkaddrcache.go",
"linkaddrentry_list.go",
- "ndp.go",
"neighbor_cache.go",
"neighbor_entry.go",
"neighbor_entry_list.go",
@@ -106,6 +105,7 @@ go_test(
name = "stack_x_test",
size = "medium",
srcs = [
+ "addressable_endpoint_state_test.go",
"ndp_test.go",
"nud_test.go",
"stack_test.go",
@@ -116,6 +116,7 @@ go_test(
deps = [
":stack",
"//pkg/rand",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
@@ -138,7 +139,6 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
- "fake_time_test.go",
"forwarder_test.go",
"linkaddrcache_test.go",
"neighbor_cache_test.go",
@@ -152,8 +152,8 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
- "@com_github_dpjacques_clockwork//:go_default_library",
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
new file mode 100644
index 000000000..db8ac1c2b
--- /dev/null
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -0,0 +1,758 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil)
+var _ AddressableEndpoint = (*AddressableEndpointState)(nil)
+
+// AddressableEndpointState is an implementation of an AddressableEndpoint.
+type AddressableEndpointState struct {
+ networkEndpoint NetworkEndpoint
+
+ // Lock ordering (from outer to inner lock ordering):
+ //
+ // AddressableEndpointState.mu
+ // addressState.mu
+ mu struct {
+ sync.RWMutex
+
+ endpoints map[tcpip.Address]*addressState
+ primary []*addressState
+
+ // groups holds the mapping between group addresses and the number of times
+ // they have been joined.
+ groups map[tcpip.Address]uint32
+ }
+}
+
+// Init initializes the AddressableEndpointState with networkEndpoint.
+//
+// Must be called before calling any other function on m.
+func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) {
+ a.networkEndpoint = networkEndpoint
+
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.mu.endpoints = make(map[tcpip.Address]*addressState)
+ a.mu.groups = make(map[tcpip.Address]uint32)
+}
+
+// ReadOnlyAddressableEndpointState provides read-only access to an
+// AddressableEndpointState.
+type ReadOnlyAddressableEndpointState struct {
+ inner *AddressableEndpointState
+}
+
+// AddrOrMatching returns an endpoint for the passed address that is consisdered
+// bound to the wrapped AddressableEndpointState.
+//
+// If addr is an exact match with an existing address, that address is returned.
+// Otherwise, f is called with each address and the address that f returns true
+// for is returned.
+//
+// Returns nil of no address matches.
+func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint {
+ m.inner.mu.RLock()
+ defer m.inner.mu.RUnlock()
+
+ if ep, ok := m.inner.mu.endpoints[addr]; ok {
+ if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() {
+ return ep
+ }
+ }
+
+ for _, ep := range m.inner.mu.endpoints {
+ if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() {
+ return ep
+ }
+ }
+
+ return nil
+}
+
+// Lookup returns the AddressEndpoint for the passed address.
+//
+// Returns nil if the passed address is not associated with the
+// AddressableEndpointState.
+func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint {
+ m.inner.mu.RLock()
+ defer m.inner.mu.RUnlock()
+
+ ep, ok := m.inner.mu.endpoints[addr]
+ if !ok {
+ return nil
+ }
+ return ep
+}
+
+// ForEach calls f for each address pair.
+//
+// If f returns false, f is no longer be called.
+func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) {
+ m.inner.mu.RLock()
+ defer m.inner.mu.RUnlock()
+
+ for _, ep := range m.inner.mu.endpoints {
+ if !f(ep) {
+ return
+ }
+ }
+}
+
+// ForEachPrimaryEndpoint calls f for each primary address.
+//
+// If f returns false, f is no longer be called.
+func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) {
+ m.inner.mu.RLock()
+ defer m.inner.mu.RUnlock()
+ for _, ep := range m.inner.mu.primary {
+ f(ep)
+ }
+}
+
+// ReadOnly returns a readonly reference to a.
+func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState {
+ return ReadOnlyAddressableEndpointState{inner: a}
+}
+
+func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.releaseAddressStateLocked(addrState)
+}
+
+// releaseAddressState removes addrState from s's address state (primary and endpoints list).
+//
+// Preconditions: a.mu must be write locked.
+func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressState) {
+ oldPrimary := a.mu.primary
+ for i, s := range a.mu.primary {
+ if s == addrState {
+ a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
+ oldPrimary[len(oldPrimary)-1] = nil
+ break
+ }
+ }
+ delete(a.mu.endpoints, addrState.addr.Address)
+}
+
+// AddAndAcquirePermanentAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */)
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since addAndAcquireAddressLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
+ return nil, err
+ }
+ return ep, err
+}
+
+// AddAndAcquireTemporaryAddress adds a temporary address.
+//
+// Returns tcpip.ErrDuplicateAddress if the address exists.
+//
+// The temporary address's endpoint is acquired and returned.
+func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, *tcpip.Error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since addAndAcquireAddressLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
+ return nil, err
+ }
+ return ep, err
+}
+
+// addAndAcquireAddressLocked adds, acquires and returns a permanent or
+// temporary address.
+//
+// If the addressable endpoint already has the address in a non-permanent state,
+// and addAndAcquireAddressLocked is adding a permanent address, that address is
+// promoted in place and its properties set to the properties provided. If the
+// address already exists in any other state, then tcpip.ErrDuplicateAddress is
+// returned, regardless the kind of address that is being added.
+//
+// Precondition: a.mu must be write locked.
+func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, *tcpip.Error) {
+ // attemptAddToPrimary is false when the address is already in the primary
+ // address list.
+ attemptAddToPrimary := true
+ addrState, ok := a.mu.endpoints[addr.Address]
+ if ok {
+ if !permanent {
+ // We are adding a non-permanent address but the address exists. No need
+ // to go any further since we can only promote existing temporary/expired
+ // addresses to permanent.
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
+ addrState.mu.Lock()
+ if addrState.mu.kind.IsPermanent() {
+ addrState.mu.Unlock()
+ // We are adding a permanent address but a permanent address already
+ // exists.
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
+ if addrState.mu.refs == 0 {
+ panic(fmt.Sprintf("found an address that should have been released (ref count == 0); address = %s", addrState.addr))
+ }
+
+ // We now promote the address.
+ for i, s := range a.mu.primary {
+ if s == addrState {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ // The address is already in the primary address list.
+ attemptAddToPrimary = false
+ case FirstPrimaryEndpoint:
+ if i == 0 {
+ // The address is already first in the primary address list.
+ attemptAddToPrimary = false
+ } else {
+ a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
+ }
+ case NeverPrimaryEndpoint:
+ a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
+ default:
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ }
+ break
+ }
+ }
+ }
+
+ if addrState == nil {
+ addrState = &addressState{
+ addressableEndpointState: a,
+ addr: addr,
+ }
+ a.mu.endpoints[addr.Address] = addrState
+ addrState.mu.Lock()
+ // We never promote an address to temporary - it can only be added as such.
+ // If we are actaully adding a permanent address, it is promoted below.
+ addrState.mu.kind = Temporary
+ }
+
+ // At this point we have an address we are either promoting from an expired or
+ // temporary address to permanent, promoting an expired address to temporary,
+ // or we are adding a new temporary or permanent address.
+ //
+ // The address MUST be write locked at this point.
+ defer addrState.mu.Unlock()
+
+ if permanent {
+ if addrState.mu.kind.IsPermanent() {
+ panic(fmt.Sprintf("only non-permanent addresses should be promoted to permanent; address = %s", addrState.addr))
+ }
+
+ // Primary addresses are biased by 1.
+ addrState.mu.refs++
+ addrState.mu.kind = Permanent
+ }
+ // Acquire the address before returning it.
+ addrState.mu.refs++
+ addrState.mu.deprecated = deprecated
+ addrState.mu.configType = configType
+
+ if attemptAddToPrimary {
+ switch peb {
+ case NeverPrimaryEndpoint:
+ case CanBePrimaryEndpoint:
+ a.mu.primary = append(a.mu.primary, addrState)
+ case FirstPrimaryEndpoint:
+ if cap(a.mu.primary) == len(a.mu.primary) {
+ a.mu.primary = append([]*addressState{addrState}, a.mu.primary...)
+ } else {
+ // Shift all the endpoints by 1 to make room for the new address at the
+ // front. We could have just created a new slice but this saves
+ // allocations when the slice has capacity for the new address.
+ primaryCount := len(a.mu.primary)
+ a.mu.primary = append(a.mu.primary, nil)
+ if n := copy(a.mu.primary[1:], a.mu.primary); n != primaryCount {
+ panic(fmt.Sprintf("copied %d elements; expected = %d elements", n, primaryCount))
+ }
+ a.mu.primary[0] = addrState
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ }
+ }
+
+ return addrState, nil
+}
+
+// RemovePermanentAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if _, ok := a.mu.groups[addr]; ok {
+ panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr))
+ }
+
+ return a.removePermanentAddressLocked(addr)
+}
+
+// removePermanentAddressLocked is like RemovePermanentAddress but with locking
+// requirements.
+//
+// Precondition: a.mu must be write locked.
+func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
+ addrState, ok := a.mu.endpoints[addr]
+ if !ok {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ return a.removePermanentEndpointLocked(addrState)
+}
+
+// RemovePermanentEndpoint removes the passed endpoint if it is associated with
+// a and permanent.
+func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *tcpip.Error {
+ addrState, ok := ep.(*addressState)
+ if !ok || addrState.addressableEndpointState != a {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ return a.removePermanentEndpointLocked(addrState)
+}
+
+// removePermanentAddressLocked is like RemovePermanentAddress but with locking
+// requirements.
+//
+// Precondition: a.mu must be write locked.
+func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) *tcpip.Error {
+ if !addrState.GetKind().IsPermanent() {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ addrState.SetKind(PermanentExpired)
+ a.decAddressRefLocked(addrState)
+ return nil
+}
+
+// decAddressRef decrements the address's reference count and releases it once
+// the reference count hits 0.
+func (a *AddressableEndpointState) decAddressRef(addrState *addressState) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.decAddressRefLocked(addrState)
+}
+
+// decAddressRefLocked is like decAddressRef but with locking requirements.
+//
+// Precondition: a.mu must be write locked.
+func (a *AddressableEndpointState) decAddressRefLocked(addrState *addressState) {
+ addrState.mu.Lock()
+ defer addrState.mu.Unlock()
+
+ if addrState.mu.refs == 0 {
+ panic(fmt.Sprintf("attempted to decrease ref count for AddressEndpoint w/ addr = %s when it is already released", addrState.addr))
+ }
+
+ addrState.mu.refs--
+
+ if addrState.mu.refs != 0 {
+ return
+ }
+
+ // A non-expired permanent address must not have its reference count dropped
+ // to 0.
+ if addrState.mu.kind.IsPermanent() {
+ panic(fmt.Sprintf("permanent addresses should be removed through the AddressableEndpoint: addr = %s, kind = %d", addrState.addr, addrState.mu.kind))
+ }
+
+ a.releaseAddressStateLocked(addrState)
+}
+
+// MainAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool {
+ return ep.GetKind() == Permanent
+ })
+ if ep == nil {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ addr := ep.AddressWithPrefix()
+ a.decAddressRefLocked(ep)
+ return addr
+}
+
+// acquirePrimaryAddressRLocked returns an acquired primary address that is
+// valid according to isValid.
+//
+// Precondition: e.mu must be read locked
+func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*addressState) bool) *addressState {
+ var deprecatedEndpoint *addressState
+ for _, ep := range a.mu.primary {
+ if !isValid(ep) {
+ continue
+ }
+
+ if !ep.Deprecated() {
+ if ep.IncRef() {
+ // ep is not deprecated, so return it immediately.
+ //
+ // If we kept track of a deprecated endpoint, decrement its reference
+ // count since it was incremented when we decided to keep track of it.
+ if deprecatedEndpoint != nil {
+ a.decAddressRefLocked(deprecatedEndpoint)
+ deprecatedEndpoint = nil
+ }
+
+ return ep
+ }
+ } else if deprecatedEndpoint == nil && ep.IncRef() {
+ // We prefer an endpoint that is not deprecated, but we keep track of
+ // ep in case a doesn't have any non-deprecated endpoints.
+ //
+ // If we end up finding a more preferred endpoint, ep's reference count
+ // will be decremented.
+ deprecatedEndpoint = ep
+ }
+ }
+
+ return deprecatedEndpoint
+}
+
+// AcquireAssignedAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if addrState, ok := a.mu.endpoints[localAddr]; ok {
+ if !addrState.IsAssigned(allowTemp) {
+ return nil
+ }
+
+ if !addrState.IncRef() {
+ panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr))
+ }
+
+ return addrState
+ }
+
+ if !allowTemp {
+ return nil
+ }
+
+ addr := localAddr.WithPrefix()
+ ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ if err != nil {
+ // addAndAcquireAddressLocked only returns an error if the address is
+ // already assigned but we just checked above if the address exists so we
+ // expect no error.
+ panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err))
+ }
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since addAndAcquireAddressLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
+ return nil
+ }
+ return ep
+}
+
+// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool {
+ return ep.IsAssigned(allowExpired)
+ })
+
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since acquirePrimaryAddressRLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
+ return nil
+ }
+
+ return ep
+}
+
+// PrimaryAddresses implements AddressableEndpoint.
+func (a *AddressableEndpointState) PrimaryAddresses() []tcpip.AddressWithPrefix {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ var addrs []tcpip.AddressWithPrefix
+ for _, ep := range a.mu.primary {
+ // Don't include tentative, expired or temporary endpoints
+ // to avoid confusion and prevent the caller from using
+ // those.
+ switch ep.GetKind() {
+ case PermanentTentative, PermanentExpired, Temporary:
+ continue
+ }
+
+ addrs = append(addrs, ep.AddressWithPrefix())
+ }
+
+ return addrs
+}
+
+// PermanentAddresses implements AddressableEndpoint.
+func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefix {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ var addrs []tcpip.AddressWithPrefix
+ for _, ep := range a.mu.endpoints {
+ if !ep.GetKind().IsPermanent() {
+ continue
+ }
+
+ addrs = append(addrs, ep.AddressWithPrefix())
+ }
+
+ return addrs
+}
+
+// JoinGroup implements GroupAddressableEndpoint.
+func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ joins, ok := a.mu.groups[group]
+ if !ok {
+ ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */)
+ if err != nil {
+ return false, err
+ }
+ // We have no need for the address endpoint.
+ a.decAddressRefLocked(ep)
+ }
+
+ a.mu.groups[group] = joins + 1
+ return !ok, nil
+}
+
+// LeaveGroup implements GroupAddressableEndpoint.
+func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ joins, ok := a.mu.groups[group]
+ if !ok {
+ return false, tcpip.ErrBadLocalAddress
+ }
+
+ if joins == 1 {
+ a.removeGroupAddressLocked(group)
+ delete(a.mu.groups, group)
+ return true, nil
+ }
+
+ a.mu.groups[group] = joins - 1
+ return false, nil
+}
+
+// IsInGroup implements GroupAddressableEndpoint.
+func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ _, ok := a.mu.groups[group]
+ return ok
+}
+
+func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) {
+ if err := a.removePermanentAddressLocked(group); err != nil {
+ // removePermanentEndpointLocked would only return an error if group is
+ // not bound to the addressable endpoint, but we know it MUST be assigned
+ // since we have group in our map of groups.
+ panic(fmt.Sprintf("error removing group address = %s: %s", group, err))
+ }
+}
+
+// Cleanup forcefully leaves all groups and removes all permanent addresses.
+func (a *AddressableEndpointState) Cleanup() {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ for group := range a.mu.groups {
+ a.removeGroupAddressLocked(group)
+ }
+ a.mu.groups = make(map[tcpip.Address]uint32)
+
+ for _, ep := range a.mu.endpoints {
+ // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is
+ // not a permanent address.
+ if err := a.removePermanentEndpointLocked(ep); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err))
+ }
+ }
+}
+
+var _ AddressEndpoint = (*addressState)(nil)
+
+// addressState holds state for an address.
+type addressState struct {
+ addressableEndpointState *AddressableEndpointState
+ addr tcpip.AddressWithPrefix
+
+ // Lock ordering (from outer to inner lock ordering):
+ //
+ // AddressableEndpointState.mu
+ // addressState.mu
+ mu struct {
+ sync.RWMutex
+
+ refs uint32
+ kind AddressKind
+ configType AddressConfigType
+ deprecated bool
+ }
+}
+
+// NetworkEndpoint implements AddressEndpoint.
+func (a *addressState) NetworkEndpoint() NetworkEndpoint {
+ return a.addressableEndpointState.networkEndpoint
+}
+
+// AddressWithPrefix implements AddressEndpoint.
+func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix {
+ return a.addr
+}
+
+// GetKind implements AddressEndpoint.
+func (a *addressState) GetKind() AddressKind {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ return a.mu.kind
+}
+
+// SetKind implements AddressEndpoint.
+func (a *addressState) SetKind(kind AddressKind) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.mu.kind = kind
+}
+
+// IsAssigned implements AddressEndpoint.
+func (a *addressState) IsAssigned(allowExpired bool) bool {
+ if !a.addressableEndpointState.networkEndpoint.Enabled() {
+ return false
+ }
+
+ switch a.GetKind() {
+ case PermanentTentative:
+ return false
+ case PermanentExpired:
+ return allowExpired
+ default:
+ return true
+ }
+}
+
+// IncRef implements AddressEndpoint.
+func (a *addressState) IncRef() bool {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ if a.mu.refs == 0 {
+ return false
+ }
+
+ a.mu.refs++
+ return true
+}
+
+// DecRef implements AddressEndpoint.
+func (a *addressState) DecRef() {
+ a.addressableEndpointState.decAddressRef(a)
+}
+
+// ConfigType implements AddressEndpoint.
+func (a *addressState) ConfigType() AddressConfigType {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ return a.mu.configType
+}
+
+// SetDeprecated implements AddressEndpoint.
+func (a *addressState) SetDeprecated(d bool) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.mu.deprecated = d
+}
+
+// Deprecated implements AddressEndpoint.
+func (a *addressState) Deprecated() bool {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ return a.mu.deprecated
+}
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
new file mode 100644
index 000000000..26787d0a3
--- /dev/null
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -0,0 +1,77 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// TestAddressableEndpointStateCleanup tests that cleaning up an addressable
+// endpoint state removes permanent addresses and leaves groups.
+func TestAddressableEndpointStateCleanup(t *testing.T) {
+ var ep fakeNetworkEndpoint
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
+ var s stack.AddressableEndpointState
+ s.Init(&ep)
+
+ addr := tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: 8,
+ }
+
+ {
+ ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ if err != nil {
+ t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err)
+ }
+ // We don't need the address endpoint.
+ ep.DecRef()
+ }
+ {
+ ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
+ if ep == nil {
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address)
+ }
+ ep.DecRef()
+ }
+
+ group := tcpip.Address("\x02")
+ if added, err := s.JoinGroup(group); err != nil {
+ t.Fatalf("s.JoinGroup(%s): %s", group, err)
+ } else if !added {
+ t.Fatalf("got s.JoinGroup(%s) = false, want = true", group)
+ }
+ if !s.IsInGroup(group) {
+ t.Fatalf("got s.IsInGroup(%s) = false, want = true", group)
+ }
+
+ s.Cleanup()
+ {
+ ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
+ if ep != nil {
+ ep.DecRef()
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
+ }
+ }
+ if s.IsInGroup(group) {
+ t.Fatalf("got s.IsInGroup(%s) = true, want = false", group)
+ }
+}
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 836682ea0..0cd1da11f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -196,13 +196,14 @@ type bucket struct {
// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
// TCP header.
+//
+// Preconditions: pkt.NetworkHeader() is valid.
func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
- // TODO(gvisor.dev/issue/170): Need to support for other
- // protocols as well.
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ netHeader := pkt.Network()
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
return tupleID{}, tcpip.ErrUnknownProtocol
}
+
tcpHeader := header.TCP(pkt.TransportHeader().View())
if len(tcpHeader) < header.TCPMinimumSize {
return tupleID{}, tcpip.ErrUnknownProtocol
@@ -214,7 +215,7 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
dstAddr: netHeader.DestinationAddress(),
dstPort: tcpHeader.DestinationPort(),
transProto: netHeader.TransportProtocol(),
- netProto: header.IPv4ProtocolNumber,
+ netProto: pkt.NetworkProtocolNumber,
}, nil
}
@@ -268,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
return nil, dirOriginal
}
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -281,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt Redirec
// rule. This tuple will be used to manipulate the packet in
// handlePacket.
replyTID := tid.reply()
- replyTID.srcAddr = rt.MinIP
- replyTID.srcPort = rt.MinPort
+ replyTID.srcAddr = rt.Addr
+ replyTID.srcPort = rt.Port
var manip manipType
switch hook {
case Prerouting:
@@ -344,7 +345,7 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
return
}
- netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
// For prerouting redirection, packets going in the original direction
@@ -366,8 +367,12 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
// support cases when they are validated, e.g. when we can't offload
// receive checksumming.
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacketOutput manipulates ports for packets in Output hook.
@@ -377,7 +382,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
return
}
- netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
// For output redirection, packets going in the original direction
@@ -396,7 +401,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
@@ -405,8 +410,11 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacket will manipulate the port and address of the packet if the
@@ -422,7 +430,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
}
// TODO(gvisor.dev/issue/170): Support other transport protocols.
- if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return false
}
@@ -473,7 +481,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
}
// We only track TCP connections.
- if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return
}
@@ -609,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -618,7 +626,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint1
dstAddr: epID.RemoteAddress,
dstPort: epID.RemotePort,
transProto: header.TCPProtocolNumber,
- netProto: header.IPv4ProtocolNumber,
+ netProto: netProto,
}
conn, _ := ct.connForTID(tid)
if conn == nil {
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 54759091a..4e4b00a92 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -20,6 +20,7 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -45,6 +46,8 @@ const (
// use the first three: destination address, source address, and transport
// protocol. They're all one byte fields to simplify parsing.
type fwdTestNetworkEndpoint struct {
+ AddressableEndpointState
+
nicID tcpip.NICID
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
@@ -53,12 +56,18 @@ type fwdTestNetworkEndpoint struct {
var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
-func (f *fwdTestNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
+func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error {
+ return nil
+}
+
+func (*fwdTestNetworkEndpoint) Enabled() bool {
+ return true
}
-func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID {
- return f.nicID
+func (*fwdTestNetworkEndpoint) Disable() {}
+
+func (f *fwdTestNetworkEndpoint) MTU() uint32 {
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
@@ -78,10 +87,6 @@ func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportPr
return 0
}
-func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities {
- return f.ep.Capabilities()
-}
-
func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return f.proto.Number()
}
@@ -106,7 +111,9 @@ func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBu
return tcpip.ErrNotSupported
}
-func (*fwdTestNetworkEndpoint) Close() {}
+func (f *fwdTestNetworkEndpoint) Close() {
+ f.AddressableEndpointState.Cleanup()
+}
// fwdTestNetworkProtocol is a network-layer protocol that implements Address
// resolution.
@@ -116,6 +123,11 @@ type fwdTestNetworkProtocol struct {
addrResolveDelay time.Duration
onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
+
+ mu struct {
+ sync.RWMutex
+ forwarding bool
+ }
}
var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)
@@ -145,13 +157,15 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol
return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
}
-func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint {
- return &fwdTestNetworkEndpoint{
- nicID: nicID,
+func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint {
+ e := &fwdTestNetworkEndpoint{
+ nicID: nic.ID(),
proto: f,
dispatcher: dispatcher,
- ep: ep,
+ ep: nic.LinkEndpoint(),
}
+ e.AddressableEndpointState.Init(e)
+ return e
}
func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
@@ -186,6 +200,21 @@ func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber
return fwdTestNetNumber
}
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (f *fwdTestNetworkProtocol) Forwarding() bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ return f.mu.forwarding
+
+}
+
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (f *fwdTestNetworkProtocol) SetForwarding(v bool) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.forwarding = v
+}
+
// fwdTestPacketInfo holds all the information about an outbound packet.
type fwdTestPacketInfo struct {
RemoteLinkAddress tcpip.LinkAddress
@@ -307,7 +336,7 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
- NetworkProtocols: []NetworkProtocol{proto},
+ NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }},
UseNeighborCache: useNeighborCache,
})
@@ -316,7 +345,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC
}
// Enable forwarding.
- s.SetForwarding(true)
+ s.SetForwarding(proto.Number(), true)
// NIC 1 has the link address "a", and added the network address 1.
ep1 = &fwdTestLinkEndpoint{
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 0e33cbe92..8d6d9a7f1 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -57,14 +57,14 @@ const reaperDelay = 5 * time.Second
// all packets.
func DefaultTables() *IPTables {
return &IPTables{
- tables: [numTables]Table{
+ v4Tables: [numTables]Table{
natID: Table{
Rules: []Rule{
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: ErrorTarget{}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
@@ -83,9 +83,9 @@ func DefaultTables() *IPTables {
},
mangleID: Table{
Rules: []Rule{
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: ErrorTarget{}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
@@ -101,10 +101,75 @@ func DefaultTables() *IPTables {
},
filterID: Table{
Rules: []Rule{
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: ErrorTarget{}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
+ },
+ },
+ },
+ v6Tables: [numTables]Table{
+ natID: Table{
+ Rules: []Rule{
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: 0,
+ Input: 1,
+ Forward: HookUnset,
+ Output: 2,
+ Postrouting: 3,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: 1,
+ Forward: HookUnset,
+ Output: 2,
+ Postrouting: 3,
+ },
+ },
+ mangleID: Table{
+ Rules: []Rule{
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: 0,
+ Output: 1,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: 1,
+ Postrouting: HookUnset,
+ },
+ },
+ filterID: Table{
+ Rules: []Rule{
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: HookUnset,
@@ -166,25 +231,20 @@ func EmptyNATTable() Table {
// GetTable returns a table by name.
func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) {
- // TODO(gvisor.dev/issue/3549): Enable IPv6.
- if ipv6 {
- return Table{}, false
- }
id, ok := nameToID[name]
if !ok {
return Table{}, false
}
it.mu.RLock()
defer it.mu.RUnlock()
- return it.tables[id], true
+ if ipv6 {
+ return it.v6Tables[id], true
+ }
+ return it.v4Tables[id], true
}
// ReplaceTable replaces or inserts table by name.
func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error {
- // TODO(gvisor.dev/issue/3549): Enable IPv6.
- if ipv6 {
- return tcpip.ErrInvalidOptionValue
- }
id, ok := nameToID[name]
if !ok {
return tcpip.ErrInvalidOptionValue
@@ -198,7 +258,11 @@ func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Err
it.startReaper(reaperDelay)
}
it.modified = true
- it.tables[id] = table
+ if ipv6 {
+ it.v6Tables[id] = table
+ } else {
+ it.v4Tables[id] = table
+ }
return nil
}
@@ -221,8 +285,15 @@ const (
// should continue traversing the network stack and false when it should be
// dropped.
//
+// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from
+// which address and nicName can be gathered. Currently, address is only
+// needed for prerouting and nicName is only needed for output.
+//
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool {
+func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool {
+ if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
+ return true
+ }
// Many users never configure iptables. Spare them the cost of rule
// traversal if rules have never been set.
it.mu.RLock()
@@ -243,9 +314,14 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
if tableID == natID && pkt.NatDone {
continue
}
- table := it.tables[tableID]
+ var table Table
+ if pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber {
+ table = it.v6Tables[tableID]
+ } else {
+ table = it.v4Tables[tableID]
+ }
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -256,7 +332,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v {
+ switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v {
case RuleAccept:
continue
case RuleDrop:
@@ -351,11 +427,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
case RuleAccept:
return chainAccept
@@ -372,7 +448,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -398,11 +474,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
- if !rule.Filter.match(header.IPv4(pkt.NetworkHeader().View()), hook, nicName) {
+ if !rule.Filter.match(pkt, hook, nicName) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
@@ -421,16 +497,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt, &it.connections, hook, gso, r, address)
+ return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr)
}
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return "", 0, tcpip.ErrNotConnected
}
- return it.connections.originalDst(epID)
+ return it.connections.originalDst(epID, netProto)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 5f1b2af64..538c4625d 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -21,78 +21,139 @@ import (
)
// AcceptTarget accepts packets.
-type AcceptTarget struct{}
+type AcceptTarget struct {
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (at *AcceptTarget) ID() TargetID {
+ return TargetID{
+ NetworkProtocol: at.NetworkProtocol,
+ }
+}
// Action implements Target.Action.
-func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
}
// DropTarget drops packets.
-type DropTarget struct{}
+type DropTarget struct {
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (dt *DropTarget) ID() TargetID {
+ return TargetID{
+ NetworkProtocol: dt.NetworkProtocol,
+ }
+}
// Action implements Target.Action.
-func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
+// ErrorTargetName is used to mark targets as error targets. Error targets
+// shouldn't be reached - an error has occurred if we fall through to one.
+const ErrorTargetName = "ERROR"
+
// ErrorTarget logs an error and drops the packet. It represents a target that
// should be unreachable.
-type ErrorTarget struct{}
+type ErrorTarget struct {
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (et *ErrorTarget) ID() TargetID {
+ return TargetID{
+ Name: ErrorTargetName,
+ NetworkProtocol: et.NetworkProtocol,
+ }
+}
// Action implements Target.Action.
-func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
// UserChainTarget marks a rule as the beginning of a user chain.
type UserChainTarget struct {
+ // Name is the chain name.
Name string
+
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (uc *UserChainTarget) ID() TargetID {
+ return TargetID{
+ Name: ErrorTargetName,
+ NetworkProtocol: uc.NetworkProtocol,
+ }
}
// Action implements Target.Action.
-func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
// ReturnTarget returns from the current chain. If the chain is a built-in, the
// hook's underflow should be called.
-type ReturnTarget struct{}
+type ReturnTarget struct {
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (rt *ReturnTarget) ID() TargetID {
+ return TargetID{
+ NetworkProtocol: rt.NetworkProtocol,
+ }
+}
// Action implements Target.Action.
-func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
+// RedirectTargetName is used to mark targets as redirect targets. Redirect
+// targets should be reached for only NAT and Mangle tables. These targets will
+// change the destination port/destination IP for packets.
+const RedirectTargetName = "REDIRECT"
+
// RedirectTarget redirects the packet by modifying the destination port/IP.
-// Min and Max values for IP and Ports in the struct indicate the range of
-// values which can be used to redirect.
+// TODO(gvisor.dev/issue/170): Other flags need to be added after we support
+// them.
type RedirectTarget struct {
- // TODO(gvisor.dev/issue/170): Other flags need to be added after
- // we support them.
- // RangeProtoSpecified flag indicates single port is specified to
- // redirect.
- RangeProtoSpecified bool
+ // Addr indicates address used to redirect.
+ Addr tcpip.Address
- // MinIP indicates address used to redirect.
- MinIP tcpip.Address
+ // Port indicates port used to redirect.
+ Port uint16
- // MaxIP indicates address used to redirect.
- MaxIP tcpip.Address
-
- // MinPort indicates port used to redirect.
- MinPort uint16
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
- // MaxPort indicates port used to redirect.
- MaxPort uint16
+// ID implements Target.ID.
+func (rt *RedirectTarget) ID() TargetID {
+ return TargetID{
+ Name: RedirectTargetName,
+ NetworkProtocol: rt.NetworkProtocol,
+ }
}
// Action implements Target.Action.
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
// implementation only works for PREROUTING and calls pkt.Clone(), neither
// of which should be the case.
-func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -103,34 +164,35 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
return RuleDrop, 0
}
- // Change the address to localhost (127.0.0.1) in Output and
- // to primary address of the incoming interface in Prerouting.
+ // Change the address to localhost (127.0.0.1 or ::1) in Output and to
+ // the primary address of the incoming interface in Prerouting.
switch hook {
case Output:
- rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1})
- rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1})
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ rt.Addr = tcpip.Address([]byte{127, 0, 0, 1})
+ } else {
+ rt.Addr = header.IPv6Loopback
+ }
case Prerouting:
- rt.MinIP = address
- rt.MaxIP = address
+ rt.Addr = address
default:
panic("redirect target is supported only on output and prerouting hooks")
}
// TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
// we need to change dest address (for OUTPUT chain) or ports.
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- switch protocol := netHeader.TransportProtocol(); protocol {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
udpHeader := header.UDP(pkt.TransportHeader().View())
- udpHeader.SetDestinationPort(rt.MinPort)
+ udpHeader.SetDestinationPort(rt.Port)
// Calculate UDP checksum and set it.
if hook == Output {
udpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
// Only calculate the checksum if offloading isn't supported.
if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := r.PseudoHeaderChecksum(protocol, length)
for _, v := range pkt.Data.Views() {
xsum = header.Checksum(v, xsum)
@@ -139,10 +201,15 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
}
- // Change destination address.
- netHeader.SetDestinationAddress(rt.MinIP)
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+
+ pkt.Network().SetDestinationAddress(rt.Addr)
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index fbbd2f50f..7b3f3e88b 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -15,6 +15,7 @@
package stack
import (
+ "fmt"
"strings"
"sync"
@@ -81,31 +82,42 @@ const (
//
// +stateify savable
type IPTables struct {
- // mu protects tables, priorities, and modified.
+ // mu protects v4Tables, v6Tables, and modified.
mu sync.RWMutex
-
- // tables maps tableIDs to tables. Holds builtin tables only, not user
- // tables. mu must be locked for accessing.
- tables [numTables]Table
-
- // priorities maps each hook to a list of table names. The order of the
- // list is the order in which each table should be visited for that
- // hook. mu needs to be locked for accessing.
- priorities [NumHooks][]tableID
-
+ // v4Tables and v6tables map tableIDs to tables. They hold builtin
+ // tables only, not user tables. mu must be locked for accessing.
+ v4Tables [numTables]Table
+ v6Tables [numTables]Table
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
// don't utilize iptables.
modified bool
+ // priorities maps each hook to a list of table names. The order of the
+ // list is the order in which each table should be visited for that
+ // hook. It is immutable.
+ priorities [NumHooks][]tableID
+
connections ConnTrack
- // reaperDone can be signalled to stop the reaper goroutine.
+ // reaperDone can be signaled to stop the reaper goroutine.
reaperDone chan struct{}
}
-// A Table defines a set of chains and hooks into the network stack. It is
-// really just a list of rules.
+// A Table defines a set of chains and hooks into the network stack.
+//
+// It is a list of Rules, entry points (BuiltinChains), and error handlers
+// (Underflows). As packets traverse netstack, they hit hooks. When a packet
+// hits a hook, iptables compares it to Rules starting from that hook's entry
+// point. So if a packet hits the Input hook, we look up the corresponding
+// entry point in BuiltinChains and jump to that point.
+//
+// If the Rule doesn't match the packet, iptables continues to the next Rule.
+// If a Rule does match, it can issue a verdict on the packet (e.g. RuleAccept
+// or RuleDrop) that causes the packet to stop traversing iptables. It can also
+// jump to other rules or perform custom actions based on Rule.Target.
+//
+// Underflow Rules are invoked when a chain returns without reaching a verdict.
//
// +stateify savable
type Table struct {
@@ -148,7 +160,7 @@ type Rule struct {
Target Target
}
-// IPHeaderFilter holds basic IP filtering data common to every rule.
+// IPHeaderFilter performs basic IP header matching common to every rule.
//
// +stateify savable
type IPHeaderFilter struct {
@@ -196,16 +208,43 @@ type IPHeaderFilter struct {
OutputInterfaceInvert bool
}
-// match returns whether hdr matches the filter.
-func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool {
- // TODO(gvisor.dev/issue/170): Support other fields of the filter.
+// match returns whether pkt matches the filter.
+//
+// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4
+// or IPv6 header length.
+func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool {
+ // Extract header fields.
+ var (
+ // TODO(gvisor.dev/issue/170): Support other filter fields.
+ transProto tcpip.TransportProtocolNumber
+ dstAddr tcpip.Address
+ srcAddr tcpip.Address
+ )
+ switch proto := pkt.NetworkProtocolNumber; proto {
+ case header.IPv4ProtocolNumber:
+ hdr := header.IPv4(pkt.NetworkHeader().View())
+ transProto = hdr.TransportProtocol()
+ dstAddr = hdr.DestinationAddress()
+ srcAddr = hdr.SourceAddress()
+
+ case header.IPv6ProtocolNumber:
+ hdr := header.IPv6(pkt.NetworkHeader().View())
+ transProto = hdr.TransportProtocol()
+ dstAddr = hdr.DestinationAddress()
+ srcAddr = hdr.SourceAddress()
+
+ default:
+ panic(fmt.Sprintf("unknown network protocol with EtherType: %d", proto))
+ }
+
// Check the transport protocol.
- if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() {
+ if fl.CheckProtocol && fl.Protocol != transProto {
return false
}
- // Check the source and destination IPs.
- if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) {
+ // Check the addresses.
+ if !filterAddress(dstAddr, fl.DstMask, fl.Dst, fl.DstInvert) ||
+ !filterAddress(srcAddr, fl.SrcMask, fl.Src, fl.SrcInvert) {
return false
}
@@ -233,6 +272,18 @@ func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool
return true
}
+// NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header
+// applies.
+func (fl IPHeaderFilter) NetworkProtocol() tcpip.NetworkProtocolNumber {
+ switch len(fl.Src) {
+ case header.IPv4AddressSize:
+ return header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ return header.IPv6ProtocolNumber
+ }
+ panic(fmt.Sprintf("invalid address in IPHeaderFilter: %s", fl.Src))
+}
+
// filterAddress returns whether addr matches the filter.
func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool {
matches := true
@@ -258,8 +309,23 @@ type Matcher interface {
Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
+// A TargetID uniquely identifies a target.
+type TargetID struct {
+ // Name is the target name as stored in the xt_entry_target struct.
+ Name string
+
+ // NetworkProtocol is the protocol to which the target applies.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+
+ // Revision is the version of the target.
+ Revision uint8
+}
+
// A Target is the interface for taking an action for a packet.
type Target interface {
+ // ID uniquely identifies the Target.
+ ID() TargetID
+
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 14fb4239b..33806340e 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -16,6 +16,7 @@ package stack
import (
"fmt"
+ "math"
"sync/atomic"
"testing"
"time"
@@ -191,7 +192,13 @@ func TestCacheReplace(t *testing.T) {
}
func TestCacheResolution(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
+ // There is a race condition causing this test to fail when the executor
+ // takes longer than the resolution timeout to call linkAddrCache.get. This
+ // is especially common when this test is run with gotsan.
+ //
+ // Using a large resolution timeout decreases the probability of experiencing
+ // this race condition and does not affect how long this test takes to run.
+ c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1)
linkRes := &testLinkAddressResolver{cache: c}
for i, ta := range testAddrs {
got, err := getBlocking(c, ta.addr, linkRes)
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 67dc5364f..73a01c2dd 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -150,10 +150,10 @@ type ndpDNSSLEvent struct {
type ndpDHCPv6Event struct {
nicID tcpip.NICID
- configuration stack.DHCPv6ConfigurationFromNDPRA
+ configuration ipv6.DHCPv6ConfigurationFromNDPRA
}
-var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
+var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
// related events happen for test purposes.
@@ -170,7 +170,7 @@ type ndpDispatcher struct {
dhcpv6ConfigurationC chan ndpDHCPv6Event
}
-// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
+// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus.
func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) {
if n.dadC != nil {
n.dadC <- ndpDADEvent{
@@ -182,7 +182,7 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, add
}
}
-// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered.
+// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered.
func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool {
if c := n.routerC; c != nil {
c <- ndpRouterEvent{
@@ -195,7 +195,7 @@ func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.
return n.rememberRouter
}
-// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated.
+// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated.
func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) {
if c := n.routerC; c != nil {
c <- ndpRouterEvent{
@@ -206,7 +206,7 @@ func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip
}
}
-// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered.
+// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered.
func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool {
if c := n.prefixC; c != nil {
c <- ndpPrefixEvent{
@@ -219,7 +219,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip
return n.rememberPrefix
}
-// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated.
+// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated.
func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) {
if c := n.prefixC; c != nil {
c <- ndpPrefixEvent{
@@ -261,7 +261,7 @@ func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpi
}
}
-// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption.
+// Implements ipv6.NDPDispatcher.OnRecursiveDNSServerOption.
func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) {
if c := n.rdnssC; c != nil {
c <- ndpRDNSSEvent{
@@ -274,7 +274,7 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc
}
}
-// Implements stack.NDPDispatcher.OnDNSSearchListOption.
+// Implements ipv6.NDPDispatcher.OnDNSSearchListOption.
func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) {
if n.dnsslC != nil {
n.dnsslC <- ndpDNSSLEvent{
@@ -285,8 +285,8 @@ func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []s
}
}
-// Implements stack.NDPDispatcher.OnDHCPv6Configuration.
-func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) {
+// Implements ipv6.NDPDispatcher.OnDHCPv6Configuration.
+func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration ipv6.DHCPv6ConfigurationFromNDPRA) {
if c := n.dhcpv6ConfigurationC; c != nil {
c <- ndpDHCPv6Event{
nicID,
@@ -319,13 +319,12 @@ func TestDADDisabled(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
}
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- }
-
e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ })},
+ })
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -413,19 +412,21 @@ func TestDADResolve(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent),
}
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- }
- opts.NDPConfigs.RetransmitTimer = test.retransTimer
- opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
e := channelLinkWithHeaderLength{
Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1),
headerLength: test.linkHeaderLen,
}
e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- s := stack.New(opts)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ipv6.NDPConfigurations{
+ RetransmitTimer: test.retransTimer,
+ DupAddrDetectTransmits: test.dupAddrDetectTransmits,
+ },
+ })},
+ })
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -558,6 +559,26 @@ func TestDADResolve(t *testing.T) {
}
}
+func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(tgt)
+ snmc := header.SolicitedNodeAddr(tgt)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: 255,
+ SrcAddr: header.IPv6Any,
+ DstAddr: snmc,
+ })
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
+}
+
// TestDADFail tests to make sure that the DAD process fails if another node is
// detected to be performing DAD on the same address (receive an NS message from
// a node doing DAD for the same address), or if another node is detected to own
@@ -567,39 +588,19 @@ func TestDADFail(t *testing.T) {
tests := []struct {
name string
- makeBuf func(tgt tcpip.Address) buffer.Prependable
+ rxPkt func(e *channel.Endpoint, tgt tcpip.Address)
getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
}{
{
- "RxSolicit",
- func(tgt tcpip.Address) buffer.Prependable {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
- ns.SetTargetAddress(tgt)
- snmc := header.SolicitedNodeAddr(tgt)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: header.IPv6Any,
- DstAddr: snmc,
- })
-
- return hdr
-
- },
- func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RxSolicit",
+ rxPkt: rxNDPSolicit,
+ getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return s.NeighborSolicit
},
},
{
- "RxAdvert",
- func(tgt tcpip.Address) buffer.Prependable {
+ name: "RxAdvert",
+ rxPkt: func(e *channel.Endpoint, tgt tcpip.Address) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
@@ -621,11 +622,9 @@ func TestDADFail(t *testing.T) {
SrcAddr: tgt,
DstAddr: header.IPv6AllNodesMulticastAddress,
})
-
- return hdr
-
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
},
- func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return s.NeighborAdvert
},
},
@@ -636,16 +635,16 @@ func TestDADFail(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
}
- ndpConfigs := stack.DefaultNDPConfigurations()
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
- }
- opts.NDPConfigs.RetransmitTimer = time.Second * 2
+ ndpConfigs := ipv6.DefaultNDPConfigurations()
+ ndpConfigs.RetransmitTimer = time.Second * 2
e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ })},
+ })
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -664,13 +663,8 @@ func TestDADFail(t *testing.T) {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
- // Receive a packet to simulate multiple nodes owning or
- // attempting to own the same address.
- hdr := test.makeBuf(addr1)
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- })
- e.InjectInbound(header.IPv6ProtocolNumber, pkt)
+ // Receive a packet to simulate an address conflict.
+ test.rxPkt(e, addr1)
stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
if got := stat.Value(); got != 1 {
@@ -754,18 +748,19 @@ func TestDADStop(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
}
- ndpConfigs := stack.NDPConfigurations{
+
+ ndpConfigs := ipv6.NDPConfigurations{
RetransmitTimer: time.Second,
DupAddrDetectTransmits: 2,
}
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- NDPConfigs: ndpConfigs,
- }
e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ })},
+ })
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
@@ -815,19 +810,6 @@ func TestDADStop(t *testing.T) {
}
}
-// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if
-// we attempt to update NDP configurations using an invalid NICID.
-func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- })
-
- // No NIC with ID 1 yet.
- if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID {
- t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID)
- }
-}
-
// TestSetNDPConfigurations tests that we can update and use per-interface NDP
// configurations without affecting the default NDP configurations or other
// interfaces' configurations.
@@ -863,8 +845,9 @@ func TestSetNDPConfigurations(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ })},
})
expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) {
@@ -892,12 +875,15 @@ func TestSetNDPConfigurations(t *testing.T) {
}
// Update the NDP configurations on NIC(1) to use DAD.
- configs := stack.NDPConfigurations{
+ configs := ipv6.NDPConfigurations{
DupAddrDetectTransmits: test.dupAddrDetectTransmits,
RetransmitTimer: test.retransmitTimer,
}
- if err := s.SetNDPConfigurations(nicID1, configs); err != nil {
- t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err)
+ if ipv6Ep, err := s.GetNetworkEndpoint(nicID1, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, header.IPv6ProtocolNumber, err)
+ } else {
+ ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
+ ndpEP.SetNDPConfigurations(configs)
}
// Created after updating NIC(1)'s NDP configurations
@@ -1113,14 +1099,15 @@ func TestNoRouterDiscovery(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: handle,
- DiscoverDefaultRouters: discover,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handle,
+ DiscoverDefaultRouters: discover,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
- s.SetForwarding(forwarding)
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -1151,12 +1138,13 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1192,12 +1180,13 @@ func TestRouterDiscovery(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
expectRouterEvent := func(addr tcpip.Address, discovered bool) {
@@ -1285,7 +1274,7 @@ func TestRouterDiscovery(t *testing.T) {
}
// TestRouterDiscoveryMaxRouters tests that only
-// stack.MaxDiscoveredDefaultRouters discovered routers are remembered.
+// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered.
func TestRouterDiscoveryMaxRouters(t *testing.T) {
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 1),
@@ -1293,12 +1282,13 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1306,14 +1296,14 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
}
// Receive an RA from 2 more than the max number of discovered routers.
- for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ {
+ for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ {
linkAddr := []byte{2, 2, 3, 4, 5, 0}
linkAddr[5] = byte(i)
llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr))
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5))
- if i <= stack.MaxDiscoveredDefaultRouters {
+ if i <= ipv6.MaxDiscoveredDefaultRouters {
select {
case e := <-ndpDisp.routerC:
if diff := checkRouterEvent(e, llAddr, true); diff != "" {
@@ -1358,14 +1348,15 @@ func TestNoPrefixDiscovery(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: handle,
- DiscoverOnLinkPrefixes: discover,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handle,
+ DiscoverOnLinkPrefixes: discover,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
- s.SetForwarding(forwarding)
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -1399,13 +1390,14 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: false,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: false,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1445,12 +1437,13 @@ func TestPrefixDiscovery(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1545,12 +1538,13 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1621,33 +1615,34 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
}
// TestPrefixDiscoveryMaxRouters tests that only
-// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
+// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3),
+ prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3),
rememberPrefix: true,
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: false,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: false,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
}
- optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2)
- prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{}
+ optSer := make(header.NDPOptionsSerializer, ipv6.MaxDiscoveredOnLinkPrefixes+2)
+ prefixes := [ipv6.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{}
// Receive an RA with 2 more than the max number of discovered on-link
// prefixes.
- for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
+ for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ {
prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0}
prefixAddr[7] = byte(i)
prefix := tcpip.AddressWithPrefix{
@@ -1665,8 +1660,8 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
}
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer))
- for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
- if i < stack.MaxDiscoveredOnLinkPrefixes {
+ for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ {
+ if i < ipv6.MaxDiscoveredOnLinkPrefixes {
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" {
@@ -1716,14 +1711,15 @@ func TestNoAutoGenAddr(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: handle,
- AutoGenGlobalAddresses: autogen,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handle,
+ AutoGenGlobalAddresses: autogen,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
- s.SetForwarding(forwarding)
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -1749,14 +1745,14 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix,
// TestAutoGenAddr tests that an address is properly generated and invalidated
// when configured to do so.
-func TestAutoGenAddr(t *testing.T) {
+func TestAutoGenAddr2(t *testing.T) {
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
- saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
}()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
@@ -1766,12 +1762,13 @@ func TestAutoGenAddr(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1876,14 +1873,14 @@ func TestAutoGenTempAddr(t *testing.T) {
newMinVLDuration = newMinVL * time.Second
)
- savedMinPrefixInformationValidLifetimeForUpdate := stack.MinPrefixInformationValidLifetimeForUpdate
- savedMaxDesync := stack.MaxDesyncFactor
+ savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate
+ savedMaxDesync := ipv6.MaxDesyncFactor
defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate
- stack.MaxDesyncFactor = savedMaxDesync
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate
+ ipv6.MaxDesyncFactor = savedMaxDesync
}()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
- stack.MaxDesyncFactor = time.Nanosecond
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+ ipv6.MaxDesyncFactor = time.Nanosecond
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
@@ -1931,16 +1928,17 @@ func TestAutoGenTempAddr(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: test.dupAddrTransmits,
- RetransmitTimer: test.retransmitTimer,
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- TempIIDSeed: seed,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: test.dupAddrTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ TempIIDSeed: seed,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2119,11 +2117,11 @@ func TestAutoGenTempAddr(t *testing.T) {
func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
const nicID = 1
- savedMaxDesyncFactor := stack.MaxDesyncFactor
+ savedMaxDesyncFactor := ipv6.MaxDesyncFactor
defer func() {
- stack.MaxDesyncFactor = savedMaxDesyncFactor
+ ipv6.MaxDesyncFactor = savedMaxDesyncFactor
}()
- stack.MaxDesyncFactor = time.Nanosecond
+ ipv6.MaxDesyncFactor = time.Nanosecond
tests := []struct {
name string
@@ -2160,12 +2158,13 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- AutoGenIPv6LinkLocal: true,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ AutoGenIPv6LinkLocal: true,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2211,11 +2210,11 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
retransmitTimer = 2 * time.Second
)
- savedMaxDesyncFactor := stack.MaxDesyncFactor
+ savedMaxDesyncFactor := ipv6.MaxDesyncFactor
defer func() {
- stack.MaxDesyncFactor = savedMaxDesyncFactor
+ ipv6.MaxDesyncFactor = savedMaxDesyncFactor
}()
- stack.MaxDesyncFactor = 0
+ ipv6.MaxDesyncFactor = 0
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
@@ -2228,15 +2227,16 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2294,17 +2294,17 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
newMinVLDuration = newMinVL * time.Second
)
- savedMaxDesyncFactor := stack.MaxDesyncFactor
- savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime
- savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime
+ savedMaxDesyncFactor := ipv6.MaxDesyncFactor
+ savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime
+ savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime
defer func() {
- stack.MaxDesyncFactor = savedMaxDesyncFactor
- stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
- stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
+ ipv6.MaxDesyncFactor = savedMaxDesyncFactor
+ ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
+ ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
}()
- stack.MaxDesyncFactor = 0
- stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration
- stack.MinMaxTempAddrValidLifetime = newMinVLDuration
+ ipv6.MaxDesyncFactor = 0
+ ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration
+ ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
@@ -2317,16 +2317,17 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
- ndpConfigs := stack.NDPConfigurations{
+ ndpConfigs := ipv6.NDPConfigurations{
HandleRAs: true,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
RegenAdvanceDuration: newMinVLDuration - regenAfter,
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2382,8 +2383,11 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
// Stop generating temporary addresses
ndpConfigs.AutoGenTempGlobalAddresses = false
- if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
- t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else {
+ ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
+ ndpEP.SetNDPConfigurations(ndpConfigs)
}
// Wait for all the temporary addresses to get invalidated.
@@ -2439,17 +2443,17 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
newMinVLDuration = newMinVL * time.Second
)
- savedMaxDesyncFactor := stack.MaxDesyncFactor
- savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime
- savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime
+ savedMaxDesyncFactor := ipv6.MaxDesyncFactor
+ savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime
+ savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime
defer func() {
- stack.MaxDesyncFactor = savedMaxDesyncFactor
- stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
- stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
+ ipv6.MaxDesyncFactor = savedMaxDesyncFactor
+ ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
+ ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
}()
- stack.MaxDesyncFactor = 0
- stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration
- stack.MinMaxTempAddrValidLifetime = newMinVLDuration
+ ipv6.MaxDesyncFactor = 0
+ ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration
+ ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
@@ -2462,16 +2466,17 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
- ndpConfigs := stack.NDPConfigurations{
+ ndpConfigs := ipv6.NDPConfigurations{
HandleRAs: true,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
RegenAdvanceDuration: newMinVLDuration - regenAfter,
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2545,9 +2550,12 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
// as paased.
ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second
ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second
- if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
- t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
}
+ ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
+ ndpEP.SetNDPConfigurations(ndpConfigs)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
select {
case e := <-ndpDisp.autoGenAddrC:
@@ -2565,9 +2573,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes
- if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
- t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
- }
+ ndpEP.SetNDPConfigurations(ndpConfigs)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
}
@@ -2655,20 +2661,21 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
- ndpConfigs := stack.NDPConfigurations{
+ ndpConfigs := ipv6.NDPConfigurations{
HandleRAs: true,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: test.tempAddrs,
AutoGenAddressConflictRetries: 1,
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: test.nicNameFromID,
- },
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: test.nicNameFromID,
+ },
+ })},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
s.SetRouteTable([]tcpip.Route{{
@@ -2739,8 +2746,11 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
ndpDisp.dadC = make(chan ndpDADEvent, 2)
ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits
ndpConfigs.RetransmitTimer = retransmitTimer
- if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
- t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else {
+ ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
+ ndpEP.SetNDPConfigurations(ndpConfigs)
}
// Do SLAAC for prefix.
@@ -2754,9 +2764,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
// DAD failure to restart the local generation process.
addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1]
expectAutoGenAddrAsyncEvent(addr, newAddr)
- if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
- t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
- }
+ rxNDPSolicit(e, addr.Address)
select {
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
@@ -2794,14 +2802,15 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: ndpDisp,
- UseNeighborCache: useNeighborCache,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: ndpDisp,
+ })},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ UseNeighborCache: useNeighborCache,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -3036,11 +3045,11 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
for _, stackTyp := range stacks {
t.Run(stackTyp.name, func(t *testing.T) {
- saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
}()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
@@ -3258,12 +3267,12 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
const infiniteVLSeconds = 2
const minVLSeconds = 1
savedIL := header.NDPInfiniteLifetime
- savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate
+ savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate
defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
header.NDPInfiniteLifetime = savedIL
}()
- stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
@@ -3307,12 +3316,13 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3357,11 +3367,11 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
const infiniteVL = 4294967295
const newMinVL = 4
- saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
}()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
@@ -3449,12 +3459,13 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
}
e := channel.New(10, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3515,12 +3526,13 @@ func TestAutoGenAddrRemoval(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3700,12 +3712,13 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3781,18 +3794,19 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
},
- SecretKey: secretKey,
- },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })},
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -3856,11 +3870,11 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
const lifetimeSeconds = 10
// Needed for the temporary address sub test.
- savedMaxDesync := stack.MaxDesyncFactor
+ savedMaxDesync := ipv6.MaxDesyncFactor
defer func() {
- stack.MaxDesyncFactor = savedMaxDesync
+ ipv6.MaxDesyncFactor = savedMaxDesync
}()
- stack.MaxDesyncFactor = time.Nanosecond
+ ipv6.MaxDesyncFactor = time.Nanosecond
var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
secretKey := secretKeyBuf[:]
@@ -3938,14 +3952,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
addrTypes := []struct {
name string
- ndpConfigs stack.NDPConfigurations
+ ndpConfigs ipv6.NDPConfigurations
autoGenLinkLocal bool
prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix
addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix
}{
{
name: "Global address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
HandleRAs: true,
@@ -3963,7 +3977,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
},
{
name: "LinkLocal address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
},
@@ -3977,7 +3991,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
},
{
name: "Temporary address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
HandleRAs: true,
@@ -4029,16 +4043,17 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
ndpConfigs := addrType.ndpConfigs
ndpConfigs.AutoGenAddressConflictRetries = maxRetries
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
},
- SecretKey: secretKey,
- },
+ })},
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -4059,9 +4074,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
// Simulate a DAD conflict.
- if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
- t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
- }
+ rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr)
expectDADEvent(t, &ndpDisp, addr.Address, false)
@@ -4119,14 +4132,14 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
addrTypes := []struct {
name string
- ndpConfigs stack.NDPConfigurations
+ ndpConfigs ipv6.NDPConfigurations
autoGenLinkLocal bool
subnet tcpip.Subnet
triggerSLAACFn func(e *channel.Endpoint)
}{
{
name: "Global address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
HandleRAs: true,
@@ -4142,7 +4155,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
},
{
name: "LinkLocal address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
AutoGenAddressConflictRetries: maxRetries,
@@ -4165,10 +4178,11 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: addrType.ndpConfigs,
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -4198,9 +4212,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
expectAutoGenAddrEvent(addr, newAddr)
// Simulate a DAD conflict.
- if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
- t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
- }
+ rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(addr, invalidatedAddr)
select {
case e := <-ndpDisp.dadC:
@@ -4250,21 +4262,22 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- AutoGenAddressConflictRetries: maxRetries,
- },
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenAddressConflictRetries: maxRetries,
},
- SecretKey: secretKey,
- },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })},
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -4296,9 +4309,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
// Simulate a DAD conflict after some time has passed.
time.Sleep(failureTimer)
- if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
- t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
- }
+ rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(addr, invalidatedAddr)
select {
case e := <-ndpDisp.dadC:
@@ -4459,11 +4470,12 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -4509,11 +4521,12 @@ func TestNDPDNSSearchListDispatch(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -4640,7 +4653,7 @@ func TestCleanupNDPState(t *testing.T) {
name: "Enable forwarding",
cleanupFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
- s.SetForwarding(true)
+ s.SetForwarding(ipv6.ProtocolNumber, true)
},
keepAutoGenLinkLocal: true,
maxAutoGenAddrEvents: 4,
@@ -4694,15 +4707,16 @@ func TestCleanupNDPState(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: true,
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- DiscoverOnLinkPrefixes: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: true,
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ DiscoverOnLinkPrefixes: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
expectRouterEvent := func() (bool, ndpRouterEvent) {
@@ -4967,18 +4981,19 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) {
+ expectDHCPv6Event := func(configuration ipv6.DHCPv6ConfigurationFromNDPRA) {
t.Helper()
select {
case e := <-ndpDisp.dhcpv6ConfigurationC:
@@ -5002,7 +5017,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
// Even if the first RA reports no DHCPv6 configurations are available, the
// dispatcher should get an event.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
- expectDHCPv6Event(stack.DHCPv6NoConfiguration)
+ expectDHCPv6Event(ipv6.DHCPv6NoConfiguration)
// Receiving the same update again should not result in an event to the
// dispatcher.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
@@ -5011,19 +5026,19 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
// Receive an RA that updates the DHCPv6 configuration to Other
// Configurations.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
expectNoDHCPv6Event()
// Receive an RA that updates the DHCPv6 configuration to Managed Address.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
- expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ expectDHCPv6Event(ipv6.DHCPv6ManagedAddress)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
expectNoDHCPv6Event()
// Receive an RA that updates the DHCPv6 configuration to none.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
- expectDHCPv6Event(stack.DHCPv6NoConfiguration)
+ expectDHCPv6Event(ipv6.DHCPv6NoConfiguration)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
expectNoDHCPv6Event()
@@ -5031,7 +5046,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
//
// Note, when the M flag is set, the O flag is redundant.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
- expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ expectDHCPv6Event(ipv6.DHCPv6ManagedAddress)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
expectNoDHCPv6Event()
// Even though the DHCPv6 flags are different, the effective configuration is
@@ -5044,7 +5059,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
// Receive an RA that updates the DHCPv6 configuration to Other
// Configurations.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
expectNoDHCPv6Event()
@@ -5059,7 +5074,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
// Receive an RA that updates the DHCPv6 configuration to Other
// Configurations.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
expectNoDHCPv6Event()
}
@@ -5217,12 +5232,13 @@ func TestRouterSolicitation(t *testing.T) {
}
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })},
})
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -5286,11 +5302,11 @@ func TestStopStartSolicitingRouters(t *testing.T) {
name: "Enable and disable forwarding",
startFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
- s.SetForwarding(false)
+ s.SetForwarding(ipv6.ProtocolNumber, false)
},
stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
t.Helper()
- s.SetForwarding(true)
+ s.SetForwarding(ipv6.ProtocolNumber, true)
},
},
@@ -5357,12 +5373,13 @@ func TestStopStartSolicitingRouters(t *testing.T) {
checker.NDPRS())
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- MaxRtrSolicitations: maxRtrSolicitations,
- RtrSolicitationInterval: interval,
- MaxRtrSolicitationDelay: delay,
- },
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: interval,
+ MaxRtrSolicitationDelay: delay,
+ },
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index b4fa69e3e..a0b7da5cd 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -30,6 +30,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
)
const (
@@ -239,7 +240,7 @@ type entryEvent struct {
func TestNeighborCacheGetConfig(t *testing.T) {
nudDisp := testNUDDispatcher{}
c := DefaultNUDConfigurations()
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, c, clock)
if got, want := neigh.config(), c; got != want {
@@ -257,7 +258,7 @@ func TestNeighborCacheGetConfig(t *testing.T) {
func TestNeighborCacheSetConfig(t *testing.T) {
nudDisp := testNUDDispatcher{}
c := DefaultNUDConfigurations()
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, c, clock)
c.MinRandomFactor = 1
@@ -279,7 +280,7 @@ func TestNeighborCacheSetConfig(t *testing.T) {
func TestNeighborCacheEntry(t *testing.T) {
c := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, c, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -298,7 +299,7 @@ func TestNeighborCacheEntry(t *testing.T) {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
@@ -339,7 +340,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -358,7 +359,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
@@ -409,7 +410,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
}
type testContext struct {
- clock *fakeClock
+ clock *faketime.ManualClock
neigh *neighborCache
store *testEntryStore
linkRes *testNeighborResolver
@@ -418,7 +419,7 @@ type testContext struct {
func newTestContext(c NUDConfigurations) testContext {
nudDisp := &testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(nudDisp, c, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -454,7 +455,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock {
return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- c.clock.advance(c.neigh.config().RetransmitTimer)
+ c.clock.Advance(c.neigh.config().RetransmitTimer)
var wantEvents []testEntryEventInfo
@@ -567,7 +568,7 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- c.clock.advance(c.neigh.config().RetransmitTimer)
+ c.clock.Advance(c.neigh.config().RetransmitTimer)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -803,7 +804,7 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- c.clock.advance(typicalLatency)
+ c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -876,7 +877,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -902,7 +903,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
if doneCh == nil {
t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
@@ -944,7 +945,7 @@ func TestNeighborCacheRemoveWaker(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -974,7 +975,7 @@ func TestNeighborCacheRemoveWaker(t *testing.T) {
// Remove the waker before the neighbor cache has the opportunity to send a
// notification.
neigh.removeWaker(entry.Addr, &w)
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
@@ -1073,7 +1074,7 @@ func TestNeighborCacheClear(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1092,7 +1093,7 @@ func TestNeighborCacheClear(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
@@ -1188,7 +1189,7 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- c.clock.advance(typicalLatency)
+ c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1249,7 +1250,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
config.MaxRandomFactor = 1
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1277,7 +1278,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
@@ -1325,7 +1326,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
@@ -1412,7 +1413,7 @@ func TestNeighborCacheConcurrent(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1440,7 +1441,7 @@ func TestNeighborCacheConcurrent(t *testing.T) {
wg.Wait()
// Process all the requests for a single entry concurrently
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
}
// All goroutines add in the same order and add more values than can fit in
@@ -1472,7 +1473,7 @@ func TestNeighborCacheReplace(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1491,7 +1492,7 @@ func TestNeighborCacheReplace(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
@@ -1541,7 +1542,7 @@ func TestNeighborCacheReplace(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(config.DelayFirstProbeTime + typicalLatency)
+ clock.Advance(config.DelayFirstProbeTime + typicalLatency)
select {
case <-doneCh:
default:
@@ -1552,7 +1553,7 @@ func TestNeighborCacheReplace(t *testing.T) {
// Verify the entry's new link address
{
e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
if err != nil {
t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
}
@@ -1572,7 +1573,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
@@ -1595,7 +1596,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
if err != nil {
t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
@@ -1618,7 +1619,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
- clock.advance(waitFor)
+ clock.Advance(waitFor)
if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
}
@@ -1636,7 +1637,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
config := DefaultNUDConfigurations()
config.RetransmitTimer = time.Millisecond // small enough to cause timeout
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(nil, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1654,7 +1655,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
- clock.advance(waitFor)
+ clock.Advance(waitFor)
if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
}
@@ -1664,7 +1665,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
// resolved immediately and don't send resolution requests.
func TestNeighborCacheStaticResolution(t *testing.T) {
config := DefaultNUDConfigurations()
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(nil, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 0068cacb8..9a72bec79 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
// NeighborEntry describes a neighboring device in the local network.
@@ -73,8 +74,7 @@ const (
type neighborEntry struct {
neighborEntryEntry
- nic *NIC
- protocol tcpip.NetworkProtocolNumber
+ nic *NIC
// linkRes provides the functionality to send reachability probes, used in
// Neighbor Unreachability Detection.
@@ -440,7 +440,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
e.notifyWakersLocked()
}
- if e.isRouter && !flags.IsRouter {
+ if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) {
// "In those cases where the IsRouter flag changes from TRUE to FALSE as
// a result of this update, the node MUST remove that router from the
// Default Router List and update the Destination Cache entries for all
@@ -448,9 +448,17 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
// 7.3.3. This is needed to detect when a node that is used as a router
// stops forwarding packets due to being configured as a host."
// - RFC 4861 section 7.2.5
- e.nic.mu.Lock()
- e.nic.mu.ndp.invalidateDefaultRouter(e.neigh.Addr)
- e.nic.mu.Unlock()
+ //
+ // TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6
+ // here.
+ ep, ok := e.nic.networkEndpoints[header.IPv6ProtocolNumber]
+ if !ok {
+ panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint"))
+ }
+
+ if ndpEP, ok := ep.(NDPEndpoint); ok {
+ ndpEP.InvalidateDefaultRouter(e.neigh.Addr)
+ }
}
e.isRouter = flags.IsRouter
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index b769fb2fa..a265fff0a 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -27,6 +27,8 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
@@ -221,8 +223,8 @@ func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe
return entryTestNetNumber
}
-func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *fakeClock) {
- clock := newFakeClock()
+func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) {
+ clock := faketime.NewManualClock()
disp := testNUDDispatcher{}
nic := NIC{
id: entryTestNICID,
@@ -232,18 +234,15 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
nudDisp: &disp,
},
}
+ nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{
+ header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil),
+ }
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
nudState := NewNUDState(c, rng)
linkRes := entryTestLinkResolver{}
entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes)
- // Stub out ndpState to verify modification of default routers.
- nic.mu.ndp = ndpState{
- nic: &nic,
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- }
-
// Stub out the neighbor cache to verify deletion from the cache.
nic.neigh = &neighborCache{
nic: &nic,
@@ -267,7 +266,7 @@ func TestEntryInitiallyUnknown(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.RetransmitTimer)
+ clock.Advance(c.RetransmitTimer)
// No probes should have been sent.
linkRes.mu.Lock()
@@ -300,7 +299,7 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(time.Hour)
+ clock.Advance(time.Hour)
// No probes should have been sent.
linkRes.mu.Lock()
@@ -410,7 +409,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
updatedAt := e.neigh.UpdatedAt
e.mu.Unlock()
- clock.advance(c.RetransmitTimer)
+ clock.Advance(c.RetransmitTimer)
// UpdatedAt should remain the same during address resolution.
wantProbes := []entryTestProbeInfo{
@@ -439,7 +438,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.RetransmitTimer)
+ clock.Advance(c.RetransmitTimer)
// UpdatedAt should change after failing address resolution. Timing out after
// sending the last probe transitions the entry to Failed.
@@ -459,7 +458,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
}
}
- clock.advance(c.RetransmitTimer)
+ clock.Advance(c.RetransmitTimer)
wantEvents := []testEntryEventInfo{
{
@@ -748,7 +747,7 @@ func TestEntryIncompleteToFailed(t *testing.T) {
e.mu.Unlock()
waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes)
- clock.advance(waitFor)
+ clock.Advance(waitFor)
wantProbes := []entryTestProbeInfo{
// The Incomplete-to-Incomplete state transition is tested here by
@@ -816,6 +815,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, _ := entryTestSetup(c)
+ ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint)
+
e.mu.Lock()
e.handlePacketQueuedLocked()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
@@ -829,9 +830,7 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
if got, want := e.isRouter, true; got != want {
t.Errorf("got e.isRouter = %t, want = %t", got, want)
}
- e.nic.mu.ndp.defaultRouters[entryTestAddr1] = defaultRouterState{
- invalidationJob: e.nic.stack.newJob(&testLocker{}, func() {}),
- }
+
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -840,8 +839,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
if got, want := e.isRouter, false; got != want {
t.Errorf("got e.isRouter = %t, want = %t", got, want)
}
- if _, ok := e.nic.mu.ndp.defaultRouters[entryTestAddr1]; ok {
- t.Errorf("unexpected defaultRouter for %s", entryTestAddr1)
+ if ipv6EP.invalidatedRtr != e.neigh.Addr {
+ t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.neigh.Addr)
}
e.mu.Unlock()
@@ -983,7 +982,7 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -1612,7 +1611,7 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -1706,7 +1705,7 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -1989,7 +1988,7 @@ func TestEntryDelayToProbe(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2069,7 +2068,7 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2166,7 +2165,7 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2267,7 +2266,7 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2364,7 +2363,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// Probe caused by the Delay-to-Probe transition
@@ -2398,7 +2397,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -2463,7 +2462,7 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2503,7 +2502,7 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -2575,7 +2574,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2612,7 +2611,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
}
e.mu.Unlock()
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -2682,7 +2681,7 @@ func TestEntryProbeToFailed(t *testing.T) {
e.mu.Unlock()
waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
- clock.advance(waitFor)
+ clock.Advance(waitFor)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2787,7 +2786,7 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
e.mu.Unlock()
waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
- clock.advance(waitFor)
+ clock.Advance(waitFor)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 863ef6bee..06824843a 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -18,7 +18,6 @@ import (
"fmt"
"math/rand"
"reflect"
- "sort"
"sync/atomic"
"gvisor.dev/gvisor/pkg/sleep"
@@ -28,13 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-var ipv4BroadcastAddr = tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: header.IPv4Broadcast,
- PrefixLen: 8 * header.IPv4AddressSize,
- },
-}
+var _ NetworkInterface = (*NIC)(nil)
// NIC represents a "network interface card" to which the networking stack is
// attached.
@@ -49,18 +42,18 @@ type NIC struct {
neigh *neighborCache
networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint
+ // enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
mu struct {
sync.RWMutex
- enabled bool
spoofing bool
promiscuous bool
- primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
- endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
- mcastJoins map[NetworkEndpointID]uint32
// packetEPs is protected by mu, but the contained PacketEndpoint
// values are not.
packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
- ndp ndpState
}
}
@@ -84,25 +77,6 @@ type DirectionStats struct {
Bytes *tcpip.StatCounter
}
-// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior.
-type PrimaryEndpointBehavior int
-
-const (
- // CanBePrimaryEndpoint indicates the endpoint can be used as a primary
- // endpoint for new connections with no local address. This is the
- // default when calling NIC.AddAddress.
- CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
-
- // FirstPrimaryEndpoint indicates the endpoint should be the first
- // primary endpoint considered. If there are multiple endpoints with
- // this behavior, the most recently-added one will be first.
- FirstPrimaryEndpoint
-
- // NeverPrimaryEndpoint indicates the endpoint should never be a
- // primary endpoint.
- NeverPrimaryEndpoint
-)
-
// newNIC returns a new NIC using the default NDP configurations from stack.
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
@@ -122,19 +96,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
}
- nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
- nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
- nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32)
nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
- nic.mu.ndp = ndpState{
- nic: nic,
- configs: stack.ndpConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
- slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
- }
- nic.mu.ndp.initializeTempAddrState()
// Check for Neighbor Unreachability Detection support.
var nud NUDHandler
@@ -162,7 +124,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
nic.mu.packetEPs[netNum] = nil
- nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nud, nic, ep, stack)
+ nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
nic.linkEP.Attach(nic)
@@ -170,29 +132,28 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
return nic
}
-// enabled returns true if n is enabled.
-func (n *NIC) enabled() bool {
- n.mu.RLock()
- enabled := n.mu.enabled
- n.mu.RUnlock()
- return enabled
+// Enabled implements NetworkInterface.
+func (n *NIC) Enabled() bool {
+ return atomic.LoadUint32(&n.enabled) == 1
}
-// disable disables n.
+// setEnabled sets the enabled status for the NIC.
//
-// It undoes the work done by enable.
-func (n *NIC) disable() *tcpip.Error {
- n.mu.RLock()
- enabled := n.mu.enabled
- n.mu.RUnlock()
- if !enabled {
- return nil
+// Returns true if the enabled status was updated.
+func (n *NIC) setEnabled(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&n.enabled, 1) == 0
}
+ return atomic.SwapUint32(&n.enabled, 0) == 1
+}
+// disable disables n.
+//
+// It undoes the work done by enable.
+func (n *NIC) disable() {
n.mu.Lock()
- err := n.disableLocked()
+ n.disableLocked()
n.mu.Unlock()
- return err
}
// disableLocked disables n.
@@ -200,9 +161,9 @@ func (n *NIC) disable() *tcpip.Error {
// It undoes the work done by enable.
//
// n MUST be locked.
-func (n *NIC) disableLocked() *tcpip.Error {
- if !n.mu.enabled {
- return nil
+func (n *NIC) disableLocked() {
+ if !n.setEnabled(false) {
+ return
}
// TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be
@@ -210,38 +171,9 @@ func (n *NIC) disableLocked() *tcpip.Error {
// again, and applications may not know that the underlying NIC was ever
// disabled.
- if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
- n.mu.ndp.stopSolicitingRouters()
- n.mu.ndp.cleanupState(false /* hostOnly */)
-
- // Stop DAD for all the unicast IPv6 endpoints that are in the
- // permanentTentative state.
- for _, r := range n.mu.endpoints {
- if addr := r.address(); r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) {
- n.mu.ndp.stopDuplicateAddressDetection(addr)
- }
- }
-
- // The NIC may have already left the multicast group.
- if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
- return err
- }
- }
-
- if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
- // The NIC may have already left the multicast group.
- if err := n.leaveGroupLocked(header.IPv4AllSystems, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
- return err
- }
-
- // The address may have already been removed.
- if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress {
- return err
- }
+ for _, ep := range n.networkEndpoints {
+ ep.Disable()
}
-
- n.mu.enabled = false
- return nil
}
// enable enables n.
@@ -251,162 +183,39 @@ func (n *NIC) disableLocked() *tcpip.Error {
// routers if the stack is not operating as a router. If the stack is also
// configured to auto-generate a link-local address, one will be generated.
func (n *NIC) enable() *tcpip.Error {
- n.mu.RLock()
- enabled := n.mu.enabled
- n.mu.RUnlock()
- if enabled {
- return nil
- }
-
n.mu.Lock()
defer n.mu.Unlock()
- if n.mu.enabled {
- return nil
- }
-
- n.mu.enabled = true
-
- // Create an endpoint to receive broadcast packets on this interface.
- if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
- if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
- return err
- }
-
- // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
- // multicast group. Note, the IANA calls the all-hosts multicast group the
- // all-systems multicast group.
- if err := n.joinGroupLocked(header.IPv4ProtocolNumber, header.IPv4AllSystems); err != nil {
- return err
- }
- }
-
- // Join the IPv6 All-Nodes Multicast group if the stack is configured to
- // use IPv6. This is required to ensure that this node properly receives
- // and responds to the various NDP messages that are destined to the
- // all-nodes multicast address. An example is the Neighbor Advertisement
- // when we perform Duplicate Address Detection, or Router Advertisement
- // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
- // section 4.2 for more information.
- //
- // Also auto-generate an IPv6 link-local address based on the NIC's
- // link address if it is configured to do so. Note, each interface is
- // required to have IPv6 link-local unicast address, as per RFC 4291
- // section 2.1.
- _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]
- if !ok {
+ if !n.setEnabled(true) {
return nil
}
- // Join the All-Nodes multicast group before starting DAD as responses to DAD
- // (NDP NS) messages may be sent to the All-Nodes multicast group if the
- // source address of the NDP NS is the unspecified address, as per RFC 4861
- // section 7.2.4.
- if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
- return err
- }
-
- // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
- // state.
- //
- // Addresses may have aleady completed DAD but in the time since the NIC was
- // last enabled, other devices may have acquired the same addresses.
- for _, r := range n.mu.endpoints {
- addr := r.address()
- if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) {
- continue
- }
-
- r.setKind(permanentTentative)
- if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil {
+ for _, ep := range n.networkEndpoints {
+ if err := ep.Enable(); err != nil {
return err
}
}
- // Do not auto-generate an IPv6 link-local address for loopback devices.
- if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() {
- // The valid and preferred lifetime is infinite for the auto-generated
- // link-local address.
- n.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
- }
-
- // If we are operating as a router, then do not solicit routers since we
- // won't process the RAs anyways.
- //
- // Routers do not process Router Advertisements (RA) the same way a host
- // does. That is, routers do not learn from RAs (e.g. on-link prefixes
- // and default routers). Therefore, soliciting RAs from other routers on
- // a link is unnecessary for routers.
- if !n.stack.forwarding {
- n.mu.ndp.startSolicitingRouters()
- }
-
return nil
}
-// remove detaches NIC from the link endpoint, and marks existing referenced
-// network endpoints expired. This guarantees no packets between this NIC and
-// the network stack.
+// remove detaches NIC from the link endpoint and releases network endpoint
+// resources. This guarantees no packets between this NIC and the network
+// stack.
func (n *NIC) remove() *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
n.disableLocked()
- // TODO(b/151378115): come up with a better way to pick an error than the
- // first one.
- var err *tcpip.Error
-
- // Forcefully leave multicast groups.
- for nid := range n.mu.mcastJoins {
- if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil {
- err = tempErr
- }
- }
-
- // Remove permanent and permanentTentative addresses, so no packet goes out.
- for nid, ref := range n.mu.endpoints {
- switch ref.getKind() {
- case permanentTentative, permanent:
- if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil {
- err = tempErr
- }
- }
- }
-
- // Release any resources the network endpoint may hold.
for _, ep := range n.networkEndpoints {
ep.Close()
}
+ n.networkEndpoints = nil
// Detach from link endpoint, so no packet comes in.
n.linkEP.Attach(nil)
-
- return err
-}
-
-// becomeIPv6Router transitions n into an IPv6 router.
-//
-// When transitioning into an IPv6 router, host-only state (NDP discovered
-// routers, discovered on-link prefixes, and auto-generated addresses) will
-// be cleaned up/invalidated and NDP router solicitations will be stopped.
-func (n *NIC) becomeIPv6Router() {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- n.mu.ndp.cleanupState(true /* hostOnly */)
- n.mu.ndp.stopSolicitingRouters()
-}
-
-// becomeIPv6Host transitions n into an IPv6 host.
-//
-// When transitioning into an IPv6 host, NDP router solicitations will be
-// started.
-func (n *NIC) becomeIPv6Host() {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- n.mu.ndp.startSolicitingRouters()
+ return nil
}
// setPromiscuousMode enables or disables promiscuous mode.
@@ -423,7 +232,8 @@ func (n *NIC) isPromiscuousMode() bool {
return rv
}
-func (n *NIC) isLoopback() bool {
+// IsLoopback implements NetworkInterface.
+func (n *NIC) IsLoopback() bool {
return n.linkEP.Capabilities()&CapabilityLoopback != 0
}
@@ -434,206 +244,44 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
-// primaryEndpoint will return the first non-deprecated endpoint if such an
-// endpoint exists for the given protocol and remoteAddr. If no non-deprecated
-// endpoint exists, the first deprecated endpoint will be returned.
-//
-// If an IPv6 primary endpoint is requested, Source Address Selection (as
-// defined by RFC 6724 section 5) will be performed.
-func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint {
- if protocol == header.IPv6ProtocolNumber && len(remoteAddr) != 0 {
- return n.primaryIPv6Endpoint(remoteAddr)
- }
-
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- var deprecatedEndpoint *referencedNetworkEndpoint
- for _, r := range n.mu.primary[protocol] {
- if !r.isValidForOutgoingRLocked() {
- continue
- }
-
- if !r.deprecated {
- if r.tryIncRef() {
- // r is not deprecated, so return it immediately.
- //
- // If we kept track of a deprecated endpoint, decrement its reference
- // count since it was incremented when we decided to keep track of it.
- if deprecatedEndpoint != nil {
- deprecatedEndpoint.decRefLocked()
- deprecatedEndpoint = nil
- }
-
- return r
- }
- } else if deprecatedEndpoint == nil && r.tryIncRef() {
- // We prefer an endpoint that is not deprecated, but we keep track of r in
- // case n doesn't have any non-deprecated endpoints.
- //
- // If we end up finding a more preferred endpoint, r's reference count
- // will be decremented when such an endpoint is found.
- deprecatedEndpoint = r
- }
- }
-
- // n doesn't have any valid non-deprecated endpoints, so return
- // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated
- // endpoints either).
- return deprecatedEndpoint
-}
-
-// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC
-// 6724 section 5).
-type ipv6AddrCandidate struct {
- ref *referencedNetworkEndpoint
- scope header.IPv6AddressScope
-}
-
-// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address
-// Selection (RFC 6724 section 5).
-//
-// Note, only rules 1-3 and 7 are followed.
-//
-// remoteAddr must be a valid IPv6 address.
-func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
+// primaryAddress returns an address that can be used to communicate with
+// remoteAddr.
+func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint {
n.mu.RLock()
- ref := n.primaryIPv6EndpointRLocked(remoteAddr)
+ spoofing := n.mu.spoofing
n.mu.RUnlock()
- return ref
-}
-
-// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address
-// Selection (RFC 6724 section 5).
-//
-// Note, only rules 1-3 and 7 are followed.
-//
-// remoteAddr must be a valid IPv6 address.
-//
-// n.mu MUST be read locked.
-func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
- primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber]
-
- if len(primaryAddrs) == 0 {
- return nil
- }
-
- // Create a candidate set of available addresses we can potentially use as a
- // source address.
- cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs))
- for _, r := range primaryAddrs {
- // If r is not valid for outgoing connections, it is not a valid endpoint.
- if !r.isValidForOutgoingRLocked() {
- continue
- }
-
- addr := r.address()
- scope, err := header.ScopeForIPv6Address(addr)
- if err != nil {
- // Should never happen as we got r from the primary IPv6 endpoint list and
- // ScopeForIPv6Address only returns an error if addr is not an IPv6
- // address.
- panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err))
- }
-
- cs = append(cs, ipv6AddrCandidate{
- ref: r,
- scope: scope,
- })
- }
-
- remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
- if err != nil {
- // primaryIPv6Endpoint should never be called with an invalid IPv6 address.
- panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
- }
-
- // Sort the addresses as per RFC 6724 section 5 rules 1-3.
- //
- // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
- sort.Slice(cs, func(i, j int) bool {
- sa := cs[i]
- sb := cs[j]
-
- // Prefer same address as per RFC 6724 section 5 rule 1.
- if sa.ref.address() == remoteAddr {
- return true
- }
- if sb.ref.address() == remoteAddr {
- return false
- }
-
- // Prefer appropriate scope as per RFC 6724 section 5 rule 2.
- if sa.scope < sb.scope {
- return sa.scope >= remoteScope
- } else if sb.scope < sa.scope {
- return sb.scope < remoteScope
- }
-
- // Avoid deprecated addresses as per RFC 6724 section 5 rule 3.
- if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep {
- // If sa is not deprecated, it is preferred over sb.
- return sbDep
- }
-
- // Prefer temporary addresses as per RFC 6724 section 5 rule 7.
- if saTemp, sbTemp := sa.ref.configType == slaacTemp, sb.ref.configType == slaacTemp; saTemp != sbTemp {
- return saTemp
- }
-
- // sa and sb are equal, return the endpoint that is closest to the front of
- // the primary endpoint list.
- return i < j
- })
-
- // Return the most preferred address that can have its reference count
- // incremented.
- for _, c := range cs {
- if r := c.ref; r.tryIncRef() {
- return r
- }
- }
-
- return nil
-}
-
-// hasPermanentAddrLocked returns true if n has a permanent (including currently
-// tentative) address, addr.
-func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool {
- ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
+ ep, ok := n.networkEndpoints[protocol]
if !ok {
- return false
+ return nil
}
- kind := ref.getKind()
-
- return kind == permanent || kind == permanentTentative
+ return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing)
}
-type getRefBehaviour int
+type getAddressBehaviour int
const (
// spoofing indicates that the NIC's spoofing flag should be observed when
- // getting a NIC's referenced network endpoint.
- spoofing getRefBehaviour = iota
+ // getting a NIC's address endpoint.
+ spoofing getAddressBehaviour = iota
// promiscuous indicates that the NIC's promiscuous flag should be observed
- // when getting a NIC's referenced network endpoint.
+ // when getting a NIC's address endpoint.
promiscuous
)
-func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
- return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
+func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) AssignableAddressEndpoint {
+ return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
}
// findEndpoint finds the endpoint, if any, with the given address.
-func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
- return n.getRefOrCreateTemp(protocol, address, peb, spoofing)
+func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
+ return n.getAddressOrCreateTemp(protocol, address, peb, spoofing)
}
-// getRefEpOrCreateTemp returns the referenced network endpoint for the given
-// protocol and address.
+// getAddressEpOrCreateTemp returns the address endpoint for the given protocol
+// and address.
//
// If none exists a temporary one may be created if we are in promiscuous mode
// or spoofing. Promiscuous mode will only be checked if promiscuous is true.
@@ -641,9 +289,8 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
//
// If the address is the IPv4 broadcast address for an endpoint's network, that
// endpoint will be returned.
-func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint {
+func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getAddressBehaviour) AssignableAddressEndpoint {
n.mu.RLock()
-
var spoofingOrPromiscuous bool
switch tempRef {
case spoofing:
@@ -651,294 +298,54 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
case promiscuous:
spoofingOrPromiscuous = n.mu.promiscuous
}
-
- if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
- // An endpoint with this id exists, check if it can be used and return it.
- if !ref.isAssignedRLocked(spoofingOrPromiscuous) {
- n.mu.RUnlock()
- return nil
- }
-
- if ref.tryIncRef() {
- n.mu.RUnlock()
- return ref
- }
- }
-
- // Check if address is a broadcast address for the endpoint's network.
- //
- // Only IPv4 has a notion of broadcast addresses.
- if protocol == header.IPv4ProtocolNumber {
- if ref := n.getRefForBroadcastRLocked(address); ref != nil {
- n.mu.RUnlock()
- return ref
- }
- }
-
- // A usable reference was not found, create a temporary one if requested by
- // the caller or if the IPv4 address is found in the NIC's subnets and the NIC
- // is a loopback interface.
- createTempEP := spoofingOrPromiscuous
- if !createTempEP && n.isLoopback() && protocol == header.IPv4ProtocolNumber {
- for _, r := range n.mu.endpoints {
- addr := r.addrWithPrefix()
- subnet := addr.Subnet()
- if subnet.Contains(address) {
- createTempEP = true
- break
- }
- }
- }
n.mu.RUnlock()
-
- if !createTempEP {
- return nil
- }
-
- // Try again with the lock in exclusive mode. If we still can't get the
- // endpoint, create a new "temporary" endpoint. It will only exist while
- // there's a route through it.
- n.mu.Lock()
- ref := n.getRefOrCreateTempLocked(protocol, address, peb)
- n.mu.Unlock()
- return ref
+ return n.getAddressOrCreateTempInner(protocol, address, spoofingOrPromiscuous, peb)
}
-// getRefForBroadcastLocked returns an endpoint where address is the IPv4
-// broadcast address for the endpoint's network.
-//
-// n.mu MUST be read locked.
-func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint {
- for _, ref := range n.mu.endpoints {
- // Only IPv4 has a notion of broadcast addresses.
- if ref.protocol != header.IPv4ProtocolNumber {
- continue
- }
-
- addr := ref.addrWithPrefix()
- subnet := addr.Subnet()
- if subnet.IsBroadcast(address) && ref.tryIncRef() {
- return ref
- }
+// getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean
+// is passed to indicate whether or not we should generate temporary endpoints.
+func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
+ if ep, ok := n.networkEndpoints[protocol]; ok {
+ return ep.AcquireAssignedAddress(address, createTemp, peb)
}
return nil
}
-/// getRefOrCreateTempLocked returns an existing endpoint for address or creates
-/// and returns a temporary endpoint.
-//
-// If the address is the IPv4 broadcast address for an endpoint's network, that
-// endpoint will be returned.
-//
-// n.mu must be write locked.
-func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
- if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
- // No need to check the type as we are ok with expired endpoints at this
- // point.
- if ref.tryIncRef() {
- return ref
- }
- // tryIncRef failing means the endpoint is scheduled to be removed once the
- // lock is released. Remove it here so we can create a new (temporary) one.
- // The removal logic waiting for the lock handles this case.
- n.removeEndpointLocked(ref)
- }
-
- // Check if address is a broadcast address for an endpoint's network.
- //
- // Only IPv4 has a notion of broadcast addresses.
- if protocol == header.IPv4ProtocolNumber {
- if ref := n.getRefForBroadcastRLocked(address); ref != nil {
- return ref
- }
- }
-
- // Add a new temporary endpoint.
- netProto, ok := n.stack.networkProtocols[protocol]
- if !ok {
- return nil
- }
- ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, peb, temporary, static, false)
- return ref
-}
-
-// addAddressLocked adds a new protocolAddress to n.
-//
-// If n already has the address in a non-permanent state, and the kind given is
-// permanent, that address will be promoted in place and its properties set to
-// the properties provided. Otherwise, it returns tcpip.ErrDuplicateAddress.
-func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) {
- // TODO(b/141022673): Validate IP addresses before adding them.
-
- // Sanity check.
- id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address}
- if ref, ok := n.mu.endpoints[id]; ok {
- // Endpoint already exists.
- if kind != permanent {
- return nil, tcpip.ErrDuplicateAddress
- }
- switch ref.getKind() {
- case permanentTentative, permanent:
- // The NIC already have a permanent endpoint with that address.
- return nil, tcpip.ErrDuplicateAddress
- case permanentExpired, temporary:
- // Promote the endpoint to become permanent and respect the new peb,
- // configType and deprecated status.
- if ref.tryIncRef() {
- // TODO(b/147748385): Perform Duplicate Address Detection when promoting
- // an IPv6 endpoint to permanent.
- ref.setKind(permanent)
- ref.deprecated = deprecated
- ref.configType = configType
-
- refs := n.mu.primary[ref.protocol]
- for i, r := range refs {
- if r == ref {
- switch peb {
- case CanBePrimaryEndpoint:
- return ref, nil
- case FirstPrimaryEndpoint:
- if i == 0 {
- return ref, nil
- }
- n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
- case NeverPrimaryEndpoint:
- n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
- return ref, nil
- }
- }
- }
-
- n.insertPrimaryEndpointLocked(ref, peb)
-
- return ref, nil
- }
- // tryIncRef failing means the endpoint is scheduled to be removed once
- // the lock is released. Remove it here so we can create a new
- // (permanent) one. The removal logic waiting for the lock handles this
- // case.
- n.removeEndpointLocked(ref)
- }
- }
-
+// addAddress adds a new address to n, so that it starts accepting packets
+// targeted at the given address (and network protocol).
+func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
- return nil, tcpip.ErrUnknownProtocol
- }
-
- isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address)
-
- // If the address is an IPv6 address and it is a permanent address,
- // mark it as tentative so it goes through the DAD process if the NIC is
- // enabled. If the NIC is not enabled, DAD will be started when the NIC is
- // enabled.
- if isIPv6Unicast && kind == permanent {
- kind = permanentTentative
- }
-
- ref := &referencedNetworkEndpoint{
- refs: 1,
- addr: protocolAddress.AddressWithPrefix,
- ep: ep,
- nic: n,
- protocol: protocolAddress.Protocol,
- kind: kind,
- configType: configType,
- deprecated: deprecated,
- }
-
- // Set up resolver if link address resolution exists for this protocol.
- if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
- if linkRes, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok {
- ref.linkCache = n.stack
- ref.linkRes = linkRes
- }
- }
-
- // If we are adding an IPv6 unicast address, join the solicited-node
- // multicast address.
- if isIPv6Unicast {
- snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address)
- if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil {
- return nil, err
- }
+ return tcpip.ErrUnknownProtocol
}
- n.mu.endpoints[id] = ref
-
- n.insertPrimaryEndpointLocked(ref, peb)
-
- // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled.
- if isIPv6Unicast && kind == permanentTentative && n.mu.enabled {
- if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
- return nil, err
- }
+ addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
+ if err == nil {
+ // We have no need for the address endpoint.
+ addressEndpoint.DecRef()
}
-
- return ref, nil
-}
-
-// AddAddress adds a new address to n, so that it starts accepting packets
-// targeted at the given address (and network protocol).
-func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
- // Add the endpoint.
- n.mu.Lock()
- _, err := n.addAddressLocked(protocolAddress, peb, permanent, static, false /* deprecated */)
- n.mu.Unlock()
-
return err
}
-// AllAddresses returns all addresses (primary and non-primary) associated with
+// allPermanentAddresses returns all permanent addresses associated with
// this NIC.
-func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints))
- for _, ref := range n.mu.endpoints {
- // Don't include tentative, expired or temporary endpoints to
- // avoid confusion and prevent the caller from using those.
- switch ref.getKind() {
- case permanentExpired, temporary:
- continue
+func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress {
+ var addrs []tcpip.ProtocolAddress
+ for p, ep := range n.networkEndpoints {
+ for _, a := range ep.PermanentAddresses() {
+ addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a})
}
-
- addrs = append(addrs, tcpip.ProtocolAddress{
- Protocol: ref.protocol,
- AddressWithPrefix: ref.addrWithPrefix(),
- })
}
return addrs
}
-// PrimaryAddresses returns the primary addresses associated with this NIC.
-func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
+// primaryAddresses returns the primary addresses associated with this NIC.
+func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
- for proto, list := range n.mu.primary {
- for _, ref := range list {
- // Don't include tentative, expired or tempory endpoints
- // to avoid confusion and prevent the caller from using
- // those.
- switch ref.getKind() {
- case permanentTentative, permanentExpired, temporary:
- continue
- }
-
- addrs = append(addrs, tcpip.ProtocolAddress{
- Protocol: proto,
- AddressWithPrefix: ref.addrWithPrefix(),
- })
+ for p, ep := range n.networkEndpoints {
+ for _, a := range ep.PrimaryAddresses() {
+ addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a})
}
}
return addrs
@@ -950,147 +357,25 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
// address exists. If no non-deprecated address exists, the first deprecated
// address will be returned.
func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- list, ok := n.mu.primary[proto]
+ ep, ok := n.networkEndpoints[proto]
if !ok {
return tcpip.AddressWithPrefix{}
}
- var deprecatedEndpoint *referencedNetworkEndpoint
- for _, ref := range list {
- // Don't include tentative, expired or tempory endpoints to avoid confusion
- // and prevent the caller from using those.
- switch ref.getKind() {
- case permanentTentative, permanentExpired, temporary:
- continue
- }
-
- if !ref.deprecated {
- return ref.addrWithPrefix()
- }
-
- if deprecatedEndpoint == nil {
- deprecatedEndpoint = ref
- }
- }
-
- if deprecatedEndpoint != nil {
- return deprecatedEndpoint.addrWithPrefix()
- }
-
- return tcpip.AddressWithPrefix{}
-}
-
-// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required
-// by peb.
-//
-// n MUST be locked.
-func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) {
- switch peb {
- case CanBePrimaryEndpoint:
- n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r)
- case FirstPrimaryEndpoint:
- n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...)
- }
-}
-
-func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
- id := NetworkEndpointID{LocalAddress: r.address()}
-
- // Nothing to do if the reference has already been replaced with a different
- // one. This happens in the case where 1) this endpoint's ref count hit zero
- // and was waiting (on the lock) to be removed and 2) the same address was
- // re-added in the meantime by removing this endpoint from the list and
- // adding a new one.
- if n.mu.endpoints[id] != r {
- return
- }
-
- if r.getKind() == permanent {
- panic("Reference count dropped to zero before being removed")
- }
-
- delete(n.mu.endpoints, id)
- refs := n.mu.primary[r.protocol]
- for i, ref := range refs {
- if ref == r {
- n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
- refs[len(refs)-1] = nil
- break
- }
- }
-}
-
-func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
- n.mu.Lock()
- n.removeEndpointLocked(r)
- n.mu.Unlock()
-}
-
-func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
- r, ok := n.mu.endpoints[NetworkEndpointID{addr}]
- if !ok {
- return tcpip.ErrBadLocalAddress
- }
-
- kind := r.getKind()
- if kind != permanent && kind != permanentTentative {
- return tcpip.ErrBadLocalAddress
- }
-
- switch r.protocol {
- case header.IPv6ProtocolNumber:
- return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */)
- default:
- r.expireLocked()
- return nil
- }
+ return ep.MainAddress()
}
-func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error {
- addr := r.addrWithPrefix()
-
- isIPv6Unicast := header.IsV6UnicastAddress(addr.Address)
-
- if isIPv6Unicast {
- n.mu.ndp.stopDuplicateAddressDetection(addr.Address)
-
- // If we are removing an address generated via SLAAC, cleanup
- // its SLAAC resources and notify the integrator.
- switch r.configType {
- case slaac:
- n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
- case slaacTemp:
- n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
- }
- }
-
- r.expireLocked()
-
- // At this point the endpoint is deleted.
-
- // If we are removing an IPv6 unicast address, leave the solicited-node
- // multicast address.
- //
- // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node
- // multicast group may be left by user action.
- if isIPv6Unicast {
- snmc := header.SolicitedNodeAddr(addr.Address)
- if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
+// removeAddress removes an address from n.
+func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error {
+ for _, ep := range n.networkEndpoints {
+ if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress {
+ continue
+ } else {
return err
}
}
- return nil
-}
-
-// RemoveAddress removes an address from n.
-func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
- return n.removePermanentAddressLocked(addr)
+ return tcpip.ErrBadLocalAddress
}
func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) {
@@ -1141,91 +426,66 @@ func (n *NIC) clearNeighbors() *tcpip.Error {
// joinGroup adds a new endpoint for the given multicast address, if none
// exists yet. Otherwise it just increments its count.
func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- return n.joinGroupLocked(protocol, addr)
-}
-
-// joinGroupLocked adds a new endpoint for the given multicast address, if none
-// exists yet. Otherwise it just increments its count. n MUST be locked before
-// joinGroupLocked is called.
-func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
// TODO(b/143102137): When implementing MLD, make sure MLD packets are
// not sent unless a valid link-local address is available for use on n
// as an MLD packet's source address must be a link-local address as
// outlined in RFC 3810 section 5.
- id := NetworkEndpointID{addr}
- joins := n.mu.mcastJoins[id]
- if joins == 0 {
- netProto, ok := n.stack.networkProtocols[protocol]
- if !ok {
- return tcpip.ErrUnknownProtocol
- }
- if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
- return err
- }
+ ep, ok := n.networkEndpoints[protocol]
+ if !ok {
+ return tcpip.ErrNotSupported
}
- n.mu.mcastJoins[id] = joins + 1
- return nil
+
+ gep, ok := ep.(GroupAddressableEndpoint)
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
+
+ _, err := gep.JoinGroup(addr)
+ return err
}
// leaveGroup decrements the count for the given multicast address, and when it
// reaches zero removes the endpoint for this address.
-func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- return n.leaveGroupLocked(addr, false /* force */)
-}
+func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ ep, ok := n.networkEndpoints[protocol]
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
-// leaveGroupLocked decrements the count for the given multicast address, and
-// when it reaches zero removes the endpoint for this address. n MUST be locked
-// before leaveGroupLocked is called.
-//
-// If force is true, then the count for the multicast addres is ignored and the
-// endpoint will be removed immediately.
-func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error {
- id := NetworkEndpointID{addr}
- joins, ok := n.mu.mcastJoins[id]
+ gep, ok := ep.(GroupAddressableEndpoint)
if !ok {
- // There are no joins with this address on this NIC.
- return tcpip.ErrBadLocalAddress
+ return tcpip.ErrNotSupported
}
- joins--
- if force || joins == 0 {
- // There are no outstanding joins or we are forced to leave, clean up.
- delete(n.mu.mcastJoins, id)
- return n.removePermanentAddressLocked(addr)
+ if _, err := gep.LeaveGroup(addr); err != nil {
+ return err
}
- n.mu.mcastJoins[id] = joins
return nil
}
// isInGroup returns true if n has joined the multicast group addr.
func (n *NIC) isInGroup(addr tcpip.Address) bool {
- n.mu.RLock()
- joins := n.mu.mcastJoins[NetworkEndpointID{addr}]
- n.mu.RUnlock()
+ for _, ep := range n.networkEndpoints {
+ gep, ok := ep.(GroupAddressableEndpoint)
+ if !ok {
+ continue
+ }
- return joins != 0
+ if gep.IsInGroup(addr) {
+ return true
+ }
+ }
+
+ return false
}
-func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) {
- r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
+func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
+ r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
-
- ref.ep.HandlePacket(&r, pkt)
- ref.decRef()
+ addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt)
+ addressEndpoint.DecRef()
}
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
@@ -1236,7 +496,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// the ownership of the items is not retained by the caller.
func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
n.mu.RLock()
- enabled := n.mu.enabled
+ enabled := n.Enabled()
// If the NIC is not yet enabled, don't receive any packets.
if !enabled {
n.mu.RUnlock()
@@ -1262,9 +522,9 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
local = n.linkEP.LinkAddress()
}
- // Are any packet sockets listening for this network protocol?
+ // Are any packet type sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
- // Add any other packet sockets that maybe listening for all protocols.
+ // Add any other packet type sockets that may be listening for all protocols.
packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
n.mu.RUnlock()
for _, ep := range packetEPs {
@@ -1285,6 +545,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
return
}
if hasTransportHdr {
+ pkt.TransportProtocolNumber = transProtoNum
// Parse the transport header if present.
if state, ok := n.stack.transportProtocols[transProtoNum]; ok {
state.proto.Parse(pkt)
@@ -1293,29 +554,33 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View())
- if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
- // The source address is one of our own, so we never should have gotten a
- // packet like this unless handleLocal is false. Loopback also calls this
- // function even though the packets didn't come from the physical interface
- // so don't drop those.
- n.stack.stats.IP.InvalidSourceAddressesReceived.Increment()
- return
+ if n.stack.handleLocal && !n.IsLoopback() {
+ if r := n.getAddress(protocol, src); r != nil {
+ r.DecRef()
+
+ // The source address is one of our own, so we never should have gotten a
+ // packet like this unless handleLocal is false. Loopback also calls this
+ // function even though the packets didn't come from the physical interface
+ // so don't drop those.
+ n.stack.stats.IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
}
- // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet.
// Loopback traffic skips the prerouting chain.
- if protocol == header.IPv4ProtocolNumber && !n.isLoopback() {
+ if !n.IsLoopback() {
// iptables filtering.
ipt := n.stack.IPTables()
address := n.primaryAddress(protocol)
if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok {
// iptables is telling us to drop the packet.
+ n.stack.stats.IP.IPTablesPreroutingDropped.Increment()
return
}
}
- if ref := n.getRef(protocol, dst); ref != nil {
- handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt)
+ if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil {
+ n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt)
return
}
@@ -1323,7 +588,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// packet and forward it to the NIC.
//
// TODO: Should we be forwarding the packet even if promiscuous?
- if n.stack.Forwarding() {
+ if n.stack.Forwarding(protocol) {
r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
if err != nil {
n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
@@ -1331,25 +596,26 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
// Found a NIC.
- n := r.ref.nic
- n.mu.RLock()
- ref, ok := n.mu.endpoints[NetworkEndpointID{dst}]
- ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef()
- n.mu.RUnlock()
- if ok {
- r.LocalLinkAddress = n.linkEP.LinkAddress()
- r.RemoteLinkAddress = remote
- r.RemoteAddress = src
- // TODO(b/123449044): Update the source NIC as well.
- ref.ep.HandlePacket(&r, pkt)
- ref.decRef()
- r.Release()
- return
+ n := r.nic
+ if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
+ if n.isValidForOutgoing(addressEndpoint) {
+ r.LocalLinkAddress = n.linkEP.LinkAddress()
+ r.RemoteLinkAddress = remote
+ r.RemoteAddress = src
+ // TODO(b/123449044): Update the source NIC as well.
+ addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt)
+ addressEndpoint.DecRef()
+ r.Release()
+ return
+ }
+
+ addressEndpoint.DecRef()
}
// n doesn't have a destination endpoint.
// Send the packet out of n.
// TODO(b/128629022): move this logic to route.WritePacket.
+ // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
@@ -1416,11 +682,11 @@ func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
- return
+ return TransportPacketProtocolUnreachable
}
transProto := state.proto
@@ -1441,41 +707,47 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// we parse it using the minimum size.
if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok {
n.stack.stats.MalformedRcvdPackets.Increment()
- return
+ // We consider a malformed transport packet handled because there is
+ // nothing the caller can do.
+ return TransportPacketHandled
}
- } else {
- // This is either a bad packet or was re-assembled from fragments.
- transProto.Parse(pkt)
+ } else if !transProto.Parse(pkt) {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return TransportPacketHandled
}
}
- if pkt.TransportHeader().View().Size() < transProto.MinimumPacketSize() {
- n.stack.stats.MalformedRcvdPackets.Increment()
- return
- }
-
srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View())
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
- return
+ return TransportPacketHandled
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
if n.stack.demux.deliverPacket(r, protocol, pkt, id) {
- return
+ return TransportPacketHandled
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
if state.defaultHandler(r, id, pkt) {
- return
+ return TransportPacketHandled
}
}
- // We could not find an appropriate destination for this packet, so
- // deliver it to the global handler.
- if !transProto.HandleUnknownDestinationPacket(r, id, pkt) {
+ // We could not find an appropriate destination for this packet so
+ // give the protocol specific error handler a chance to handle it.
+ // If it doesn't handle it then we should do so.
+ switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res {
+ case UnknownDestinationPacketMalformed:
n.stack.stats.MalformedRcvdPackets.Increment()
+ return TransportPacketHandled
+ case UnknownDestinationPacketUnhandled:
+ return TransportPacketDestinationPortUnreachable
+ case UnknownDestinationPacketHandled:
+ return TransportPacketHandled
+ default:
+ panic(fmt.Sprintf("unrecognized result from HandleUnknownDestinationPacket = %d", res))
}
}
@@ -1508,96 +780,23 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
}
}
-// ID returns the identifier of n.
+// ID implements NetworkInterface.
func (n *NIC) ID() tcpip.NICID {
return n.id
}
-// Name returns the name of n.
+// Name implements NetworkInterface.
func (n *NIC) Name() string {
return n.name
}
-// Stack returns the instance of the Stack that owns this NIC.
-func (n *NIC) Stack() *Stack {
- return n.stack
-}
-
-// LinkEndpoint returns the link endpoint of n.
+// LinkEndpoint implements NetworkInterface.
func (n *NIC) LinkEndpoint() LinkEndpoint {
return n.linkEP
}
-// isAddrTentative returns true if addr is tentative on n.
-//
-// Note that if addr is not associated with n, then this function will return
-// false. It will only return true if the address is associated with the NIC
-// AND it is tentative.
-func (n *NIC) isAddrTentative(addr tcpip.Address) bool {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
- if !ok {
- return false
- }
-
- return ref.getKind() == permanentTentative
-}
-
-// dupTentativeAddrDetected attempts to inform n that a tentative addr is a
-// duplicate on a link.
-//
-// dupTentativeAddrDetected will remove the tentative address if it exists. If
-// the address was generated via SLAAC, an attempt will be made to generate a
-// new address.
-func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
- if !ok {
- return tcpip.ErrBadAddress
- }
-
- if ref.getKind() != permanentTentative {
- return tcpip.ErrInvalidEndpointState
- }
-
- // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a
- // new address will be generated for it.
- if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACInvalidation */); err != nil {
- return err
- }
-
- prefix := ref.addrWithPrefix().Subnet()
-
- switch ref.configType {
- case slaac:
- n.mu.ndp.regenerateSLAACAddr(prefix)
- case slaacTemp:
- // Do not reset the generation attempts counter for the prefix as the
- // temporary address is being regenerated in response to a DAD conflict.
- n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
- }
-
- return nil
-}
-
-// setNDPConfigs sets the NDP configurations for n.
-//
-// Note, if c contains invalid NDP configuration values, it will be fixed to
-// use default values for the erroneous values.
-func (n *NIC) setNDPConfigs(c NDPConfigurations) {
- c.validate()
-
- n.mu.Lock()
- n.mu.ndp.configs = c
- n.mu.Unlock()
-}
-
-// NUDConfigs gets the NUD configurations for n.
-func (n *NIC) NUDConfigs() (NUDConfigurations, *tcpip.Error) {
+// nudConfigs gets the NUD configurations for n.
+func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) {
if n.neigh == nil {
return NUDConfigurations{}, tcpip.ErrNotSupported
}
@@ -1617,49 +816,6 @@ func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error {
return nil
}
-// handleNDPRA handles an NDP Router Advertisement message that arrived on n.
-func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- n.mu.ndp.handleRA(ip, ra)
-}
-
-type networkEndpointKind int32
-
-const (
- // A permanentTentative endpoint is a permanent address that is not yet
- // considered to be fully bound to an interface in the traditional
- // sense. That is, the address is associated with a NIC, but packets
- // destined to the address MUST NOT be accepted and MUST be silently
- // dropped, and the address MUST NOT be used as a source address for
- // outgoing packets. For IPv6, addresses will be of this kind until
- // NDP's Duplicate Address Detection has resolved, or be deleted if
- // the process results in detecting a duplicate address.
- permanentTentative networkEndpointKind = iota
-
- // A permanent endpoint is created by adding a permanent address (vs. a
- // temporary one) to the NIC. Its reference count is biased by 1 to avoid
- // removal when no route holds a reference to it. It is removed by explicitly
- // removing the permanent address from the NIC.
- permanent
-
- // An expired permanent endpoint is a permanent endpoint that had its address
- // removed from the NIC, and it is waiting to be removed once no more routes
- // hold a reference to it. This is achieved by decreasing its reference count
- // by 1. If its address is re-added before the endpoint is removed, its type
- // changes back to permanent and its reference count increases by 1 again.
- permanentExpired
-
- // A temporary endpoint is created for spoofing outgoing packets, or when in
- // promiscuous mode and accepting incoming packets that don't match any
- // permanent endpoint. Its reference count is not biased by 1 and the
- // endpoint is removed immediately when no more route holds a reference to
- // it. A temporary endpoint can be promoted to permanent if its address
- // is added permanently.
- temporary
-)
-
func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
@@ -1690,153 +846,12 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
}
}
-type networkEndpointConfigType int32
-
-const (
- // A statically configured endpoint is an address that was added by
- // some user-specified action (adding an explicit address, joining a
- // multicast group).
- static networkEndpointConfigType = iota
-
- // A SLAAC configured endpoint is an IPv6 endpoint that was added by
- // SLAAC as per RFC 4862 section 5.5.3.
- slaac
-
- // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by
- // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are
- // not expected to be valid (or preferred) forever; hence the term temporary.
- slaacTemp
-)
-
-type referencedNetworkEndpoint struct {
- ep NetworkEndpoint
- addr tcpip.AddressWithPrefix
- nic *NIC
- protocol tcpip.NetworkProtocolNumber
-
- // linkCache is set if link address resolution is enabled for this
- // protocol. Set to nil otherwise.
- linkCache LinkAddressCache
-
- // linkRes is set if link address resolution is enabled for this protocol.
- // Set to nil otherwise.
- linkRes LinkAddressResolver
-
- // refs is counting references held for this endpoint. When refs hits zero it
- // triggers the automatic removal of the endpoint from the NIC.
- refs int32
-
- // networkEndpointKind must only be accessed using {get,set}Kind().
- kind networkEndpointKind
-
- // configType is the method that was used to configure this endpoint.
- // This must never change except during endpoint creation and promotion to
- // permanent.
- configType networkEndpointConfigType
-
- // deprecated indicates whether or not the endpoint should be considered
- // deprecated. That is, when deprecated is true, other endpoints that are not
- // deprecated should be preferred.
- deprecated bool
-}
-
-func (r *referencedNetworkEndpoint) address() tcpip.Address {
- return r.addr.Address
-}
-
-func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix {
- return r.addr
-}
-
-func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
- return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind)))
-}
-
-func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) {
- atomic.StoreInt32((*int32)(&r.kind), int32(kind))
-}
-
// isValidForOutgoing returns true if the endpoint can be used to send out a
// packet. It requires the endpoint to not be marked expired (i.e., its address)
// has been removed) unless the NIC is in spoofing mode, or temporary.
-func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
- r.nic.mu.RLock()
- defer r.nic.mu.RUnlock()
-
- return r.isValidForOutgoingRLocked()
-}
-
-// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires
-// r.nic.mu to be read locked.
-func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
- if !r.nic.mu.enabled {
- return false
- }
-
- return r.isAssignedRLocked(r.nic.mu.spoofing)
-}
-
-// isAssignedRLocked returns true if r is considered to be assigned to the NIC.
-//
-// r.nic.mu must be read locked.
-func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool {
- switch r.getKind() {
- case permanentTentative:
- return false
- case permanentExpired:
- return spoofingOrPromiscuous
- default:
- return true
- }
-}
-
-// expireLocked decrements the reference count and marks the permanent endpoint
-// as expired.
-func (r *referencedNetworkEndpoint) expireLocked() {
- r.setKind(permanentExpired)
- r.decRefLocked()
-}
-
-// decRef decrements the ref count and cleans up the endpoint once it reaches
-// zero.
-func (r *referencedNetworkEndpoint) decRef() {
- if atomic.AddInt32(&r.refs, -1) == 0 {
- r.nic.removeEndpoint(r)
- }
-}
-
-// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
-// locked.
-func (r *referencedNetworkEndpoint) decRefLocked() {
- if atomic.AddInt32(&r.refs, -1) == 0 {
- r.nic.removeEndpointLocked(r)
- }
-}
-
-// incRef increments the ref count. It must only be called when the caller is
-// known to be holding a reference to the endpoint, otherwise tryIncRef should
-// be used.
-func (r *referencedNetworkEndpoint) incRef() {
- atomic.AddInt32(&r.refs, 1)
-}
-
-// tryIncRef attempts to increment the ref count from n to n+1, but only if n is
-// not zero. That is, it will increment the count if the endpoint is still
-// alive, and do nothing if it has already been clean up.
-func (r *referencedNetworkEndpoint) tryIncRef() bool {
- for {
- v := atomic.LoadInt32(&r.refs)
- if v == 0 {
- return false
- }
-
- if atomic.CompareAndSwapInt32(&r.refs, v, v+1) {
- return true
- }
- }
-}
-
-// stack returns the Stack instance that owns the underlying endpoint.
-func (r *referencedNetworkEndpoint) stack() *Stack {
- return r.nic.stack
+func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool {
+ n.mu.RLock()
+ spoofing := n.mu.spoofing
+ n.mu.RUnlock()
+ return n.Enabled() && ep.IsAssigned(spoofing)
}
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index dd6474297..fdd49b77f 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -15,96 +15,40 @@
package stack
import (
- "math"
"testing"
- "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-var _ LinkEndpoint = (*testLinkEndpoint)(nil)
+var _ AddressableEndpoint = (*testIPv6Endpoint)(nil)
+var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
+var _ NDPEndpoint = (*testIPv6Endpoint)(nil)
-// A LinkEndpoint that throws away outgoing packets.
+// An IPv6 NetworkEndpoint that throws away outgoing packets.
//
-// We use this instead of the channel endpoint as the channel package depends on
+// We use this instead of ipv6.endpoint because the ipv6 package depends on
// the stack package which this test lives in, causing a cyclic dependency.
-type testLinkEndpoint struct {
- dispatcher NetworkDispatcher
-}
-
-// Attach implements LinkEndpoint.Attach.
-func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
- e.dispatcher = dispatcher
-}
-
-// IsAttached implements LinkEndpoint.IsAttached.
-func (e *testLinkEndpoint) IsAttached() bool {
- return e.dispatcher != nil
-}
-
-// MTU implements LinkEndpoint.MTU.
-func (*testLinkEndpoint) MTU() uint32 {
- return math.MaxUint16
-}
-
-// Capabilities implements LinkEndpoint.Capabilities.
-func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities {
- return CapabilityResolutionRequired
-}
+type testIPv6Endpoint struct {
+ AddressableEndpointState
-// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
-func (*testLinkEndpoint) MaxHeaderLength() uint16 {
- return 0
-}
+ nicID tcpip.NICID
+ linkEP LinkEndpoint
+ protocol *testIPv6Protocol
-// LinkAddress returns the link address of this endpoint.
-func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress {
- return ""
+ invalidatedRtr tcpip.Address
}
-// Wait implements LinkEndpoint.Wait.
-func (*testLinkEndpoint) Wait() {}
-
-// WritePacket implements LinkEndpoint.WritePacket.
-func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error {
+func (*testIPv6Endpoint) Enable() *tcpip.Error {
return nil
}
-// WritePackets implements LinkEndpoint.WritePackets.
-func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- // Our tests don't use this so we don't support it.
- return 0, tcpip.ErrNotSupported
-}
-
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- // Our tests don't use this so we don't support it.
- return tcpip.ErrNotSupported
-}
-
-// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
-func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
- panic("not implemented")
-}
-
-// AddHeader implements stack.LinkEndpoint.AddHeader.
-func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- panic("not implemented")
+func (*testIPv6Endpoint) Enabled() bool {
+ return true
}
-var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
-
-// An IPv6 NetworkEndpoint that throws away outgoing packets.
-//
-// We use this instead of ipv6.endpoint because the ipv6 package depends on
-// the stack package which this test lives in, causing a cyclic dependency.
-type testIPv6Endpoint struct {
- nicID tcpip.NICID
- linkEP LinkEndpoint
- protocol *testIPv6Protocol
-}
+func (*testIPv6Endpoint) Disable() {}
// DefaultTTL implements NetworkEndpoint.DefaultTTL.
func (*testIPv6Endpoint) DefaultTTL() uint8 {
@@ -116,11 +60,6 @@ func (e *testIPv6Endpoint) MTU() uint32 {
return e.linkEP.MTU() - header.IPv6MinimumSize
}
-// Capabilities implements NetworkEndpoint.Capabilities.
-func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
@@ -144,23 +83,24 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip
return tcpip.ErrNotSupported
}
-// NICID implements NetworkEndpoint.NICID.
-func (e *testIPv6Endpoint) NICID() tcpip.NICID {
- return e.nicID
-}
-
// HandlePacket implements NetworkEndpoint.HandlePacket.
func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
}
// Close implements NetworkEndpoint.Close.
-func (*testIPv6Endpoint) Close() {}
+func (e *testIPv6Endpoint) Close() {
+ e.AddressableEndpointState.Cleanup()
+}
// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber.
func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return header.IPv6ProtocolNumber
}
+func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
+ e.invalidatedRtr = rtr
+}
+
var _ NetworkProtocol = (*testIPv6Protocol)(nil)
// An IPv6 NetworkProtocol that supports the bare minimum to make a stack
@@ -192,12 +132,14 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address)
}
// NewEndpoint implements NetworkProtocol.NewEndpoint.
-func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint {
- return &testIPv6Endpoint{
- nicID: nicID,
- linkEP: linkEP,
+func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint {
+ e := &testIPv6Endpoint{
+ nicID: nic.ID(),
+ linkEP: nic.LinkEndpoint(),
protocol: p,
}
+ e.AddressableEndpointState.Init(e)
+ return e
}
// SetOption implements NetworkProtocol.SetOption.
@@ -241,38 +183,6 @@ func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAdd
return "", false
}
-// Test the race condition where a NIC is removed and an RS timer fires at the
-// same time.
-func TestRemoveNICWhileHandlingRSTimer(t *testing.T) {
- const (
- nicID = 1
-
- maxRtrSolicitations = 5
- )
-
- e := testLinkEndpoint{}
- s := New(Options{
- NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}},
- NDPConfigs: NDPConfigurations{
- MaxRtrSolicitations: maxRtrSolicitations,
- RtrSolicitationInterval: minimumRtrSolicitationInterval,
- },
- })
-
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err)
- }
-
- s.mu.Lock()
- // Wait for the router solicitation timer to fire and block trying to obtain
- // the stack lock when doing link address resolution.
- time.Sleep(minimumRtrSolicitationInterval * 2)
- if err := s.removeNICLocked(nicID); err != nil {
- t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err)
- }
- s.mu.Unlock()
-}
-
func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
// When the NIC is disabled, the only field that matters is the stats field.
// This test is limited to stats counter checks.
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
index 2b97e5972..8cffb9fc6 100644
--- a/pkg/tcpip/stack/nud_test.go
+++ b/pkg/tcpip/stack/nud_test.go
@@ -60,7 +60,7 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) {
// A neighbor cache is required to store NUDConfigurations. The networking
// stack will only allocate neighbor caches if a protocol providing link
// address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
UseNeighborCache: true,
})
@@ -137,7 +137,7 @@ func TestDefaultNUDConfigurations(t *testing.T) {
// A neighbor cache is required to store NUDConfigurations. The networking
// stack will only allocate neighbor caches if a protocol providing link
// address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: stack.DefaultNUDConfigurations(),
UseNeighborCache: true,
})
@@ -192,7 +192,7 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) {
// A neighbor cache is required to store NUDConfigurations. The
// networking stack will only allocate neighbor caches if a protocol
// providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
UseNeighborCache: true,
})
@@ -249,7 +249,7 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) {
// A neighbor cache is required to store NUDConfigurations. The
// networking stack will only allocate neighbor caches if a protocol
// providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
UseNeighborCache: true,
})
@@ -329,7 +329,7 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) {
// A neighbor cache is required to store NUDConfigurations. The
// networking stack will only allocate neighbor caches if a protocol
// providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
UseNeighborCache: true,
})
@@ -391,7 +391,7 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) {
// A neighbor cache is required to store NUDConfigurations. The
// networking stack will only allocate neighbor caches if a protocol
// providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
UseNeighborCache: true,
})
@@ -443,7 +443,7 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) {
// A neighbor cache is required to store NUDConfigurations. The
// networking stack will only allocate neighbor caches if a protocol
// providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
UseNeighborCache: true,
})
@@ -495,7 +495,7 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) {
// A neighbor cache is required to store NUDConfigurations. The
// networking stack will only allocate neighbor caches if a protocol
// providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
UseNeighborCache: true,
})
@@ -547,7 +547,7 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) {
// A neighbor cache is required to store NUDConfigurations. The
// networking stack will only allocate neighbor caches if a protocol
// providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
UseNeighborCache: true,
})
@@ -599,7 +599,7 @@ func TestNUDConfigurationsUnreachableTime(t *testing.T) {
// A neighbor cache is required to store NUDConfigurations. The
// networking stack will only allocate neighbor caches if a protocol
// providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
UseNeighborCache: true,
})
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 17b8beebb..105583c49 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
type headerType int
@@ -80,11 +81,17 @@ type PacketBuffer struct {
// data are held in the same underlying buffer storage.
header buffer.Prependable
- // NetworkProtocol is only valid when NetworkHeader is set.
+ // NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty()
+ // returns false.
// TODO(gvisor.dev/issue/3574): Remove the separately passed protocol
// numbers in registration APIs that take a PacketBuffer.
NetworkProtocolNumber tcpip.NetworkProtocolNumber
+ // TransportProtocol is only valid if it is non zero.
+ // TODO(gvisor.dev/issue/3810): This and the network protocol number should
+ // be moved into the headerinfo. This should resolve the validity issue.
+ TransportProtocolNumber tcpip.TransportProtocolNumber
+
// Hash is the transport layer hash of this packet. A value of zero
// indicates no valid hash has been set.
Hash uint32
@@ -234,20 +241,35 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
// underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
newPk := &PacketBuffer{
- PacketBufferEntry: pk.PacketBufferEntry,
- Data: pk.Data.Clone(nil),
- headers: pk.headers,
- header: pk.header,
- Hash: pk.Hash,
- Owner: pk.Owner,
- EgressRoute: pk.EgressRoute,
- GSOOptions: pk.GSOOptions,
- NetworkProtocolNumber: pk.NetworkProtocolNumber,
- NatDone: pk.NatDone,
+ PacketBufferEntry: pk.PacketBufferEntry,
+ Data: pk.Data.Clone(nil),
+ headers: pk.headers,
+ header: pk.header,
+ Hash: pk.Hash,
+ Owner: pk.Owner,
+ EgressRoute: pk.EgressRoute,
+ GSOOptions: pk.GSOOptions,
+ NetworkProtocolNumber: pk.NetworkProtocolNumber,
+ NatDone: pk.NatDone,
+ TransportProtocolNumber: pk.TransportProtocolNumber,
}
return newPk
}
+// Network returns the network header as a header.Network.
+//
+// Network should only be called when NetworkHeader has been set.
+func (pk *PacketBuffer) Network() header.Network {
+ switch netProto := pk.NetworkProtocolNumber; netProto {
+ case header.IPv4ProtocolNumber:
+ return header.IPv4(pk.NetworkHeader().View())
+ case header.IPv6ProtocolNumber:
+ return header.IPv6(pk.NetworkHeader().View())
+ default:
+ panic(fmt.Sprintf("unknown network protocol number %d", netProto))
+ }
+}
+
// headerInfo stores metadata about a header in a packet.
type headerInfo struct {
// buf is the memorized slice for both prepended and consumed header.
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 2d88fa1f7..16f854e1f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -15,6 +15,8 @@
package stack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -125,6 +127,26 @@ type PacketEndpoint interface {
HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
+// UnknownDestinationPacketDisposition enumerates the possible return vaues from
+// HandleUnknownDestinationPacket().
+type UnknownDestinationPacketDisposition int
+
+const (
+ // UnknownDestinationPacketMalformed denotes that the packet was malformed
+ // and no further processing should be attempted other than updating
+ // statistics.
+ UnknownDestinationPacketMalformed UnknownDestinationPacketDisposition = iota
+
+ // UnknownDestinationPacketUnhandled tells the caller that the packet was
+ // well formed but that the issue was not handled and the stack should take
+ // the default action.
+ UnknownDestinationPacketUnhandled
+
+ // UnknownDestinationPacketHandled tells the caller that it should do
+ // no further processing.
+ UnknownDestinationPacketHandled
+)
+
// TransportProtocol is the interface that needs to be implemented by transport
// protocols (e.g., tcp, udp) that want to be part of the networking stack.
type TransportProtocol interface {
@@ -132,10 +154,10 @@ type TransportProtocol interface {
Number() tcpip.TransportProtocolNumber
// NewEndpoint creates a new endpoint of the transport protocol.
- NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
// NewRawEndpoint creates a new raw endpoint of the transport protocol.
- NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
// MinimumPacketSize returns the minimum valid packet size of this
// transport protocol. The stack automatically drops any packets smaller
@@ -147,24 +169,22 @@ type TransportProtocol interface {
ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this
- // protocol but that don't match any existing endpoint. For example,
- // it is targeted at a port that have no listeners.
- //
- // The return value indicates whether the packet was well-formed (for
- // stats purposes only).
+ // protocol that don't match any existing endpoint. For example,
+ // it is targeted at a port that has no listeners.
//
- // HandleUnknownDestinationPacket takes ownership of pkt.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
+ // HandleUnknownDestinationPacket takes ownership of pkt if it handles
+ // the issue.
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
- SetOption(option interface{}) *tcpip.Error
+ SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error
// Option allows retrieving protocol specific option values.
// Option returns an error if the option is not supported or the
// provided option value is invalid.
- Option(option interface{}) *tcpip.Error
+ Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error
// Close requests that any worker goroutines owned by the protocol
// stop.
@@ -179,6 +199,25 @@ type TransportProtocol interface {
Parse(pkt *PacketBuffer) (ok bool)
}
+// TransportPacketDisposition is the result from attempting to deliver a packet
+// to the transport layer.
+type TransportPacketDisposition int
+
+const (
+ // TransportPacketHandled indicates that a transport packet was handled by the
+ // transport layer and callers need not take any further action.
+ TransportPacketHandled TransportPacketDisposition = iota
+
+ // TransportPacketProtocolUnreachable indicates that the transport
+ // protocol requested in the packet is not supported.
+ TransportPacketProtocolUnreachable
+
+ // TransportPacketDestinationPortUnreachable indicates that there weren't any
+ // listeners interested in the packet and the transport protocol has no means
+ // to notify the sender.
+ TransportPacketDestinationPortUnreachable
+)
+
// TransportDispatcher contains the methods used by the network stack to deliver
// packets to the appropriate transport endpoint after it has been handled by
// the network layer.
@@ -189,7 +228,7 @@ type TransportDispatcher interface {
// pkt.NetworkHeader must be set before calling DeliverTransportPacket.
//
// DeliverTransportPacket takes ownership of pkt.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer)
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
@@ -226,9 +265,257 @@ type NetworkHeaderParams struct {
TOS uint8
}
+// GroupAddressableEndpoint is an endpoint that supports group addressing.
+//
+// An endpoint is considered to support group addressing when one or more
+// endpoints may associate themselves with the same identifier (group address).
+type GroupAddressableEndpoint interface {
+ // JoinGroup joins the spcified group.
+ //
+ // Returns true if the group was newly joined.
+ JoinGroup(group tcpip.Address) (bool, *tcpip.Error)
+
+ // LeaveGroup attempts to leave the specified group.
+ //
+ // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group.
+ LeaveGroup(group tcpip.Address) (bool, *tcpip.Error)
+
+ // IsInGroup returns true if the endpoint is a member of the specified group.
+ IsInGroup(group tcpip.Address) bool
+}
+
+// PrimaryEndpointBehavior is an enumeration of an AddressEndpoint's primary
+// behavior.
+type PrimaryEndpointBehavior int
+
+const (
+ // CanBePrimaryEndpoint indicates the endpoint can be used as a primary
+ // endpoint for new connections with no local address. This is the
+ // default when calling NIC.AddAddress.
+ CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
+
+ // FirstPrimaryEndpoint indicates the endpoint should be the first
+ // primary endpoint considered. If there are multiple endpoints with
+ // this behavior, they are ordered by recency.
+ FirstPrimaryEndpoint
+
+ // NeverPrimaryEndpoint indicates the endpoint should never be a
+ // primary endpoint.
+ NeverPrimaryEndpoint
+)
+
+// AddressConfigType is the method used to add an address.
+type AddressConfigType int
+
+const (
+ // AddressConfigStatic is a statically configured address endpoint that was
+ // added by some user-specified action (adding an explicit address, joining a
+ // multicast group).
+ AddressConfigStatic AddressConfigType = iota
+
+ // AddressConfigSlaac is an address endpoint added by SLAAC, as per RFC 4862
+ // section 5.5.3.
+ AddressConfigSlaac
+
+ // AddressConfigSlaacTemp is a temporary address endpoint added by SLAAC as
+ // per RFC 4941. Temporary SLAAC addresses are short-lived and are not
+ // to be valid (or preferred) forever; hence the term temporary.
+ AddressConfigSlaacTemp
+)
+
+// AssignableAddressEndpoint is a reference counted address endpoint that may be
+// assigned to a NetworkEndpoint.
+type AssignableAddressEndpoint interface {
+ // NetworkEndpoint returns the NetworkEndpoint the receiver is associated
+ // with.
+ NetworkEndpoint() NetworkEndpoint
+
+ // AddressWithPrefix returns the endpoint's address.
+ AddressWithPrefix() tcpip.AddressWithPrefix
+
+ // IsAssigned returns whether or not the endpoint is considered bound
+ // to its NetworkEndpoint.
+ IsAssigned(allowExpired bool) bool
+
+ // IncRef increments this endpoint's reference count.
+ //
+ // Returns true if it was successfully incremented. If it returns false, then
+ // the endpoint is considered expired and should no longer be used.
+ IncRef() bool
+
+ // DecRef decrements this endpoint's reference count.
+ DecRef()
+}
+
+// AddressEndpoint is an endpoint representing an address assigned to an
+// AddressableEndpoint.
+type AddressEndpoint interface {
+ AssignableAddressEndpoint
+
+ // GetKind returns the address kind for this endpoint.
+ GetKind() AddressKind
+
+ // SetKind sets the address kind for this endpoint.
+ SetKind(AddressKind)
+
+ // ConfigType returns the method used to add the address.
+ ConfigType() AddressConfigType
+
+ // Deprecated returns whether or not this endpoint is deprecated.
+ Deprecated() bool
+
+ // SetDeprecated sets this endpoint's deprecated status.
+ SetDeprecated(bool)
+}
+
+// AddressKind is the kind of of an address.
+//
+// See the values of AddressKind for more details.
+type AddressKind int
+
+const (
+ // PermanentTentative is a permanent address endpoint that is not yet
+ // considered to be fully bound to an interface in the traditional
+ // sense. That is, the address is associated with a NIC, but packets
+ // destined to the address MUST NOT be accepted and MUST be silently
+ // dropped, and the address MUST NOT be used as a source address for
+ // outgoing packets. For IPv6, addresses are of this kind until NDP's
+ // Duplicate Address Detection (DAD) resolves. If DAD fails, the address
+ // is removed.
+ PermanentTentative AddressKind = iota
+
+ // Permanent is a permanent endpoint (vs. a temporary one) assigned to the
+ // NIC. Its reference count is biased by 1 to avoid removal when no route
+ // holds a reference to it. It is removed by explicitly removing the address
+ // from the NIC.
+ Permanent
+
+ // PermanentExpired is a permanent endpoint that had its address removed from
+ // the NIC, and it is waiting to be removed once no references to it are held.
+ //
+ // If the address is re-added before the endpoint is removed, its type
+ // changes back to Permanent.
+ PermanentExpired
+
+ // Temporary is an endpoint, created on a one-off basis to temporarily
+ // consider the NIC bound an an address that it is not explictiy bound to
+ // (such as a permanent address). Its reference count must not be biased by 1
+ // so that the address is removed immediately when references to it are no
+ // longer held.
+ //
+ // A temporary endpoint may be promoted to permanent if the address is added
+ // permanently.
+ Temporary
+)
+
+// IsPermanent returns true if the AddressKind represents a permanent address.
+func (k AddressKind) IsPermanent() bool {
+ switch k {
+ case Permanent, PermanentTentative:
+ return true
+ case Temporary, PermanentExpired:
+ return false
+ default:
+ panic(fmt.Sprintf("unrecognized address kind = %d", k))
+ }
+}
+
+// AddressableEndpoint is an endpoint that supports addressing.
+//
+// An endpoint is considered to support addressing when the endpoint may
+// associate itself with an identifier (address).
+type AddressableEndpoint interface {
+ // AddAndAcquirePermanentAddress adds the passed permanent address.
+ //
+ // Returns tcpip.ErrDuplicateAddress if the address exists.
+ //
+ // Acquires and returns the AddressEndpoint for the added address.
+ AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error)
+
+ // RemovePermanentAddress removes the passed address if it is a permanent
+ // address.
+ //
+ // Returns tcpip.ErrBadLocalAddress if the endpoint does not have the passed
+ // permanent address.
+ RemovePermanentAddress(addr tcpip.Address) *tcpip.Error
+
+ // MainAddress returns the endpoint's primary permanent address.
+ MainAddress() tcpip.AddressWithPrefix
+
+ // AcquireAssignedAddress returns an address endpoint for the passed address
+ // that is considered bound to the endpoint, optionally creating a temporary
+ // endpoint if requested and no existing address exists.
+ //
+ // The returned endpoint's reference count is incremented.
+ //
+ // Returns nil if the specified address is not local to this endpoint.
+ AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint
+
+ // AcquireOutgoingPrimaryAddress returns a primary address that may be used as
+ // a source address when sending packets to the passed remote address.
+ //
+ // If allowExpired is true, expired addresses may be returned.
+ //
+ // The returned endpoint's reference count is incremented.
+ //
+ // Returns nil if a primary address is not available.
+ AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint
+
+ // PrimaryAddresses returns the primary addresses.
+ PrimaryAddresses() []tcpip.AddressWithPrefix
+
+ // PermanentAddresses returns all the permanent addresses.
+ PermanentAddresses() []tcpip.AddressWithPrefix
+}
+
+// NDPEndpoint is a network endpoint that supports NDP.
+type NDPEndpoint interface {
+ NetworkEndpoint
+
+ // InvalidateDefaultRouter invalidates a default router discovered through
+ // NDP.
+ InvalidateDefaultRouter(tcpip.Address)
+}
+
+// NetworkInterface is a network interface.
+type NetworkInterface interface {
+ // ID returns the interface's ID.
+ ID() tcpip.NICID
+
+ // IsLoopback returns true if the interface is a loopback interface.
+ IsLoopback() bool
+
+ // Name returns the name of the interface.
+ //
+ // May return an empty string if the interface is not configured with a name.
+ Name() string
+
+ // Enabled returns true if the interface is enabled.
+ Enabled() bool
+
+ // LinkEndpoint returns the link endpoint backing the interface.
+ LinkEndpoint() LinkEndpoint
+}
+
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
+ AddressableEndpoint
+
+ // Enable enables the endpoint.
+ //
+ // Must only be called when the stack is in a state that allows the endpoint
+ // to send and receive packets.
+ //
+ // Returns tcpip.ErrNotPermitted if the endpoint cannot be enabled.
+ Enable() *tcpip.Error
+
+ // Enabled returns true if the endpoint is enabled.
+ Enabled() bool
+
+ // Disable disables the endpoint.
+ Disable()
+
// DefaultTTL is the default time-to-live value (or hop limit, in ipv6)
// for this endpoint.
DefaultTTL() uint8
@@ -238,10 +525,6 @@ type NetworkEndpoint interface {
// minus the network endpoint max header length.
MTU() uint32
- // Capabilities returns the set of capabilities supported by the
- // underlying link-layer endpoint.
- Capabilities() LinkEndpointCapabilities
-
// MaxHeaderLength returns the maximum size the network (and lower
// level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
@@ -262,9 +545,6 @@ type NetworkEndpoint interface {
// header to the given destination address. It takes ownership of pkt.
WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error
- // NICID returns the id of the NIC this endpoint belongs to.
- NICID() tcpip.NICID
-
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint. It sets pkt.NetworkHeader.
//
@@ -279,6 +559,17 @@ type NetworkEndpoint interface {
NetworkProtocolNumber() tcpip.NetworkProtocolNumber
}
+// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets.
+type ForwardingNetworkProtocol interface {
+ NetworkProtocol
+
+ // Forwarding returns the forwarding configuration.
+ Forwarding() bool
+
+ // SetForwarding sets the forwarding configuration.
+ SetForwarding(bool)
+}
+
// NetworkProtocol is the interface that needs to be implemented by network
// protocols (e.g., ipv4, ipv6) that want to be part of the networking stack.
type NetworkProtocol interface {
@@ -298,7 +589,7 @@ type NetworkProtocol interface {
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
- NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint
+ NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -427,8 +718,8 @@ type LinkEndpoint interface {
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
//
- // Attach will be called with a nil dispatcher if the receiver's associated
- // NIC is being removed.
+ // Attach is called with a nil dispatcher when the endpoint's NIC is being
+ // removed.
Attach(dispatcher NetworkDispatcher)
// IsAttached returns whether a NetworkDispatcher is attached to the
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index c2eabde9e..effe30155 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -42,21 +42,27 @@ type Route struct {
// NetProto is the network-layer protocol.
NetProto tcpip.NetworkProtocolNumber
- // ref a reference to the network endpoint through which the route
- // starts.
- ref *referencedNetworkEndpoint
-
// Loop controls where WritePacket should send packets.
Loop PacketLooping
- // directedBroadcast indicates whether this route is sending a directed
- // broadcast packet.
- directedBroadcast bool
+ // nic is the NIC the route goes through.
+ nic *NIC
+
+ // addressEndpoint is the local address this route is associated with.
+ addressEndpoint AssignableAddressEndpoint
+
+ // linkCache is set if link address resolution is enabled for this protocol on
+ // the route's NIC.
+ linkCache LinkAddressCache
+
+ // linkRes is set if link address resolution is enabled for this protocol on
+ // the route's NIC.
+ linkRes LinkAddressResolver
}
// makeRoute initializes a new route. It takes ownership of the provided
-// reference to a network endpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, handleLocal, multicastLoop bool) Route {
+// AssignableAddressEndpoint.
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
loop := PacketOut
if handleLocal && localAddr != "" && remoteAddr == localAddr {
loop = PacketLoop
@@ -66,29 +72,40 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
loop |= PacketLoop
}
- return Route{
+ linkEP := nic.LinkEndpoint()
+ r := Route{
NetProto: netProto,
LocalAddress: localAddr,
- LocalLinkAddress: localLinkAddr,
+ LocalLinkAddress: linkEP.LinkAddress(),
RemoteAddress: remoteAddr,
- ref: ref,
+ addressEndpoint: addressEndpoint,
+ nic: nic,
Loop: loop,
}
+
+ if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok {
+ r.linkRes = linkRes
+ r.linkCache = nic.stack
+ }
+ }
+
+ return r
}
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
- return r.ref.ep.NICID()
+ return r.nic.ID()
}
// MaxHeaderLength forwards the call to the network endpoint's implementation.
func (r *Route) MaxHeaderLength() uint16 {
- return r.ref.ep.MaxHeaderLength()
+ return r.addressEndpoint.NetworkEndpoint().MaxHeaderLength()
}
// Stats returns a mutable copy of current stats.
func (r *Route) Stats() tcpip.Stats {
- return r.ref.nic.stack.Stats()
+ return r.nic.stack.Stats()
}
// PseudoHeaderChecksum forwards the call to the network endpoint's
@@ -99,12 +116,12 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
// Capabilities returns the link-layer capabilities of the route.
func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.ref.ep.Capabilities()
+ return r.nic.LinkEndpoint().Capabilities()
}
// GSOMaxSize returns the maximum GSO packet size.
func (r *Route) GSOMaxSize() uint32 {
- if gso, ok := r.ref.ep.(GSOEndpoint); ok {
+ if gso, ok := r.addressEndpoint.NetworkEndpoint().(GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
@@ -142,8 +159,8 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
nextAddr = r.RemoteAddress
}
- if r.ref.nic.neigh != nil {
- entry, ch, err := r.ref.nic.neigh.entry(nextAddr, r.LocalAddress, r.ref.linkRes, waker)
+ if neigh := r.nic.neigh; neigh != nil {
+ entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker)
if err != nil {
return ch, err
}
@@ -151,7 +168,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
return nil, nil
}
- linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+ linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
if err != nil {
return ch, err
}
@@ -166,12 +183,12 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
nextAddr = r.RemoteAddress
}
- if r.ref.nic.neigh != nil {
- r.ref.nic.neigh.removeWaker(nextAddr, waker)
+ if neigh := r.nic.neigh; neigh != nil {
+ neigh.removeWaker(nextAddr, waker)
return
}
- r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker)
+ r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker)
}
// IsResolutionRequired returns true if Resolve() must be called to resolve
@@ -179,104 +196,94 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
//
// The NIC r uses must not be locked.
func (r *Route) IsResolutionRequired() bool {
- if r.ref.nic.neigh != nil {
- return r.ref.isValidForOutgoing() && r.ref.linkRes != nil && r.RemoteLinkAddress == ""
+ if r.nic.neigh != nil {
+ return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == ""
}
- return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
+ return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
- if !r.ref.isValidForOutgoing() {
+ if !r.nic.isValidForOutgoing(r.addressEndpoint) {
return tcpip.ErrInvalidEndpointState
}
// WritePacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Size()
- err := r.ref.ep.WritePacket(r, gso, params, pkt)
- if err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- } else {
- r.ref.nic.stats.Tx.Packets.Increment()
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ if err := r.addressEndpoint.NetworkEndpoint().WritePacket(r, gso, params, pkt); err != nil {
+ return err
}
- return err
+
+ r.nic.stats.Tx.Packets.Increment()
+ r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ return nil
}
// WritePackets writes a list of n packets through the given route and returns
// the number of packets written.
func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
- if !r.ref.isValidForOutgoing() {
+ if !r.nic.isValidForOutgoing(r.addressEndpoint) {
return 0, tcpip.ErrInvalidEndpointState
}
- // WritePackets takes ownership of pkt, calculate length first.
- numPkts := pkts.Len()
-
- n, err := r.ref.ep.WritePackets(r, gso, pkts, params)
- if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n))
- }
- r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n))
-
+ n, err := r.addressEndpoint.NetworkEndpoint().WritePackets(r, gso, pkts, params)
+ r.nic.stats.Tx.Packets.IncrementBy(uint64(n))
writtenBytes := 0
for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() {
writtenBytes += pb.Size()
}
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
+ r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
return n, err
}
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
- if !r.ref.isValidForOutgoing() {
+ if !r.nic.isValidForOutgoing(r.addressEndpoint) {
return tcpip.ErrInvalidEndpointState
}
// WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Data.Size()
- if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
+ if err := r.addressEndpoint.NetworkEndpoint().WriteHeaderIncludedPacket(r, pkt); err != nil {
return err
}
- r.ref.nic.stats.Tx.Packets.Increment()
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ r.nic.stats.Tx.Packets.Increment()
+ r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
return nil
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
- return r.ref.ep.DefaultTTL()
+ return r.addressEndpoint.NetworkEndpoint().DefaultTTL()
}
// MTU returns the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
- return r.ref.ep.MTU()
+ return r.addressEndpoint.NetworkEndpoint().MTU()
}
// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying
// network endpoint.
func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
- return r.ref.ep.NetworkProtocolNumber()
+ return r.addressEndpoint.NetworkEndpoint().NetworkProtocolNumber()
}
// Release frees all resources associated with the route.
func (r *Route) Release() {
- if r.ref != nil {
- r.ref.decRef()
- r.ref = nil
+ if r.addressEndpoint != nil {
+ r.addressEndpoint.DecRef()
+ r.addressEndpoint = nil
}
}
-// Clone Clone a route such that the original one can be released and the new
-// one will remain valid.
+// Clone clones the route.
func (r *Route) Clone() Route {
- if r.ref != nil {
- r.ref.incRef()
+ if r.addressEndpoint != nil {
+ _ = r.addressEndpoint.IncRef()
}
return *r
}
@@ -300,27 +307,30 @@ func (r *Route) MakeLoopedRoute() Route {
// Stack returns the instance of the Stack that owns this route.
func (r *Route) Stack() *Stack {
- return r.ref.stack()
+ return r.nic.stack
+}
+
+func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
+ if addr == header.IPv4Broadcast {
+ return true
+ }
+
+ subnet := r.addressEndpoint.AddressWithPrefix().Subnet()
+ return subnet.IsBroadcast(addr)
}
// IsOutboundBroadcast returns true if the route is for an outbound broadcast
// packet.
func (r *Route) IsOutboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
- return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast
+ return r.isV4Broadcast(r.RemoteAddress)
}
// IsInboundBroadcast returns true if the route is for an inbound broadcast
// packet.
func (r *Route) IsInboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
- if r.LocalAddress == header.IPv4Broadcast {
- return true
- }
-
- addr := r.ref.addrWithPrefix()
- subnet := addr.Subnet()
- return subnet.IsBroadcast(r.LocalAddress)
+ return r.isV4Broadcast(r.LocalAddress)
}
// ReverseRoute returns new route with given source and destination address.
@@ -331,7 +341,10 @@ func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
LocalLinkAddress: r.RemoteLinkAddress,
RemoteAddress: src,
RemoteLinkAddress: r.LocalLinkAddress,
- ref: r.ref,
Loop: r.Loop,
+ addressEndpoint: r.addressEndpoint,
+ nic: r.nic,
+ linkCache: r.linkCache,
+ linkRes: r.linkRes,
}
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index c86ee1c13..57d8e79e0 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -144,10 +144,7 @@ type TCPReceiverState struct {
// PendingBufUsed is the number of bytes pending in the receive
// queue.
- PendingBufUsed seqnum.Size
-
- // PendingBufSize is the size of the socket receive buffer.
- PendingBufSize seqnum.Size
+ PendingBufUsed int
}
// TCPSenderState holds a copy of the internal state of the sender for
@@ -366,38 +363,6 @@ func (u *uniqueIDGenerator) UniqueID() uint64 {
return atomic.AddUint64((*uint64)(u), 1)
}
-// NICNameFromID is a function that returns a stable name for the specified NIC,
-// even if different NIC IDs are used to refer to the same NIC in different
-// program runs. It is used when generating opaque interface identifiers (IIDs).
-// If the NIC was created with a name, it will be passed to NICNameFromID.
-//
-// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
-// generated for the same prefix on differnt NICs.
-type NICNameFromID func(tcpip.NICID, string) string
-
-// OpaqueInterfaceIdentifierOptions holds the options related to the generation
-// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
-type OpaqueInterfaceIdentifierOptions struct {
- // NICNameFromID is a function that returns a stable name for a specified NIC,
- // even if the NIC ID changes over time.
- //
- // Must be specified to generate the opaque IID.
- NICNameFromID NICNameFromID
-
- // SecretKey is a pseudo-random number used as the secret key when generating
- // opaque IIDs as defined by RFC 7217. The key SHOULD be at least
- // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness
- // requirements for security as outlined by RFC 4086. SecretKey MUST NOT
- // change between program runs, unless explicitly changed.
- //
- // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey
- // MUST NOT be modified after Stack is created.
- //
- // May be nil, but a nil value is highly discouraged to maintain
- // some level of randomness between nodes.
- SecretKey []byte
-}
-
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -415,10 +380,12 @@ type Stack struct {
linkAddrCache *linkAddrCache
- mu sync.RWMutex
- nics map[tcpip.NICID]*NIC
- forwarding bool
- cleanupEndpoints map[TransportEndpoint]struct{}
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
+
+ // cleanupEndpointsMu protects cleanupEndpoints.
+ cleanupEndpointsMu sync.Mutex
+ cleanupEndpoints map[TransportEndpoint]struct{}
// route is the route table passed in by the user via SetRouteTable(),
// it is used by FindRoute() to build a route for a specific
@@ -429,7 +396,7 @@ type Stack struct {
// If not nil, then any new endpoints will have this probe function
// invoked everytime they receive a TCP segment.
- tcpProbeFunc TCPProbeFunc
+ tcpProbeFunc atomic.Value // TCPProbeFunc
// clock is used to generate user-visible times.
clock tcpip.Clock
@@ -455,9 +422,6 @@ type Stack struct {
// TODO(gvisor.dev/issue/940): S/R this field.
seed uint32
- // ndpConfigs is the default NDP configurations used by interfaces.
- ndpConfigs NDPConfigurations
-
// nudConfigs is the default NUD configurations used by interfaces.
nudConfigs NUDConfigurations
@@ -465,15 +429,6 @@ type Stack struct {
// by the NIC's neighborCache instead of linkAddrCache.
useNeighborCache bool
- // autoGenIPv6LinkLocal determines whether or not the stack will attempt
- // to auto-generate an IPv6 link-local address for newly enabled non-loopback
- // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
- autoGenIPv6LinkLocal bool
-
- // ndpDisp is the NDP event dispatcher that is used to send the netstack
- // integrator NDP related events.
- ndpDisp NDPDispatcher
-
// nudDisp is the NUD event dispatcher that is used to send the netstack
// integrator NUD related events.
nudDisp NUDDispatcher
@@ -481,14 +436,6 @@ type Stack struct {
// uniqueIDGenerator is a generator of unique identifiers.
uniqueIDGenerator UniqueID
- // opaqueIIDOpts hold the options for generating opaque interface identifiers
- // (IIDs) as outlined by RFC 7217.
- opaqueIIDOpts OpaqueInterfaceIdentifierOptions
-
- // tempIIDSeed is used to seed the initial temporary interface identifier
- // history value used to generate IIDs for temporary SLAAC addresses.
- tempIIDSeed []byte
-
// forwarder holds the packets that wait for their link-address resolutions
// to complete, and forwards them when each resolution is done.
forwarder *forwardQueue
@@ -511,13 +458,25 @@ type UniqueID interface {
UniqueID() uint64
}
+// NetworkProtocolFactory instantiates a network protocol.
+//
+// NetworkProtocolFactory must not attempt to modify the stack, it may only
+// query the stack.
+type NetworkProtocolFactory func(*Stack) NetworkProtocol
+
+// TransportProtocolFactory instantiates a transport protocol.
+//
+// TransportProtocolFactory must not attempt to modify the stack, it may only
+// query the stack.
+type TransportProtocolFactory func(*Stack) TransportProtocol
+
// Options contains optional Stack configuration.
type Options struct {
// NetworkProtocols lists the network protocols to enable.
- NetworkProtocols []NetworkProtocol
+ NetworkProtocols []NetworkProtocolFactory
// TransportProtocols lists the transport protocols to enable.
- TransportProtocols []TransportProtocol
+ TransportProtocols []TransportProtocolFactory
// Clock is an optional clock source used for timestampping packets.
//
@@ -535,13 +494,6 @@ type Options struct {
// UniqueID is an optional generator of unique identifiers.
UniqueID UniqueID
- // NDPConfigs is the default NDP configurations used by interfaces.
- //
- // By default, NDPConfigs will have a zero value for its
- // DupAddrDetectTransmits field, implying that DAD will not be performed
- // before assigning an address to a NIC.
- NDPConfigs NDPConfigurations
-
// NUDConfigs is the default NUD configurations used by interfaces.
NUDConfigs NUDConfigurations
@@ -552,24 +504,6 @@ type Options struct {
// and ClearNeighbors.
UseNeighborCache bool
- // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to
- // auto-generate an IPv6 link-local address for newly enabled non-loopback
- // NICs.
- //
- // Note, setting this to true does not mean that a link-local address
- // will be assigned right away, or at all. If Duplicate Address Detection
- // is enabled, an address will only be assigned if it successfully resolves.
- // If it fails, no further attempt will be made to auto-generate an IPv6
- // link-local address.
- //
- // The generated link-local address will follow RFC 4291 Appendix A
- // guidelines.
- AutoGenIPv6LinkLocal bool
-
- // NDPDisp is the NDP event dispatcher that an integrator can provide to
- // receive NDP related events.
- NDPDisp NDPDispatcher
-
// NUDDisp is the NUD event dispatcher that an integrator can provide to
// receive NUD related events.
NUDDisp NUDDispatcher
@@ -578,31 +512,12 @@ type Options struct {
// this is non-nil.
RawFactory RawFactory
- // OpaqueIIDOpts hold the options for generating opaque interface
- // identifiers (IIDs) as outlined by RFC 7217.
- OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
-
// RandSource is an optional source to use to generate random
// numbers. If omitted it defaults to a Source seeded by the data
// returned by rand.Read().
//
// RandSource must be thread-safe.
RandSource mathrand.Source
-
- // TempIIDSeed is used to seed the initial temporary interface identifier
- // history value used to generate IIDs for temporary SLAAC addresses.
- //
- // Temporary SLAAC adresses are short-lived addresses which are unpredictable
- // and random from the perspective of other nodes on the network. It is
- // recommended that the seed be a random byte buffer of at least
- // header.IIDSize bytes to make sure that temporary SLAAC addresses are
- // sufficiently random. It should follow minimum randomness requirements for
- // security as outlined by RFC 4086.
- //
- // Note: using a nil value, the same seed across netstack program runs, or a
- // seed that is too small would reduce randomness and increase predictability,
- // defeating the purpose of temporary SLAAC addresses.
- TempIIDSeed []byte
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -705,36 +620,28 @@ func New(opts Options) *Stack {
randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
}
- // Make sure opts.NDPConfigs contains valid values only.
- opts.NDPConfigs.validate()
-
opts.NUDConfigs.resetInvalidFields()
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
- nics: make(map[tcpip.NICID]*NIC),
- cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- tables: DefaultTables(),
- icmpRateLimiter: NewICMPRateLimiter(),
- seed: generateRandUint32(),
- ndpConfigs: opts.NDPConfigs,
- nudConfigs: opts.NUDConfigs,
- useNeighborCache: opts.UseNeighborCache,
- autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
- uniqueIDGenerator: opts.UniqueID,
- ndpDisp: opts.NDPDisp,
- nudDisp: opts.NUDDisp,
- opaqueIIDOpts: opts.OpaqueIIDOpts,
- tempIIDSeed: opts.TempIIDSeed,
- forwarder: newForwardQueue(),
- randomGenerator: mathrand.New(randSrc),
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
+ nics: make(map[tcpip.NICID]*NIC),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ tables: DefaultTables(),
+ icmpRateLimiter: NewICMPRateLimiter(),
+ seed: generateRandUint32(),
+ nudConfigs: opts.NUDConfigs,
+ useNeighborCache: opts.UseNeighborCache,
+ uniqueIDGenerator: opts.UniqueID,
+ nudDisp: opts.NUDDisp,
+ forwarder: newForwardQueue(),
+ randomGenerator: mathrand.New(randSrc),
sendBufferSize: SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -748,7 +655,8 @@ func New(opts Options) *Stack {
}
// Add specified network protocols.
- for _, netProto := range opts.NetworkProtocols {
+ for _, netProtoFactory := range opts.NetworkProtocols {
+ netProto := netProtoFactory(s)
s.networkProtocols[netProto.Number()] = netProto
if r, ok := netProto.(LinkAddressResolver); ok {
s.linkAddrResolvers[r.LinkAddressProtocol()] = r
@@ -756,7 +664,8 @@ func New(opts Options) *Stack {
}
// Add specified transport protocols.
- for _, transProto := range opts.TransportProtocols {
+ for _, transProtoFactory := range opts.TransportProtocols {
+ transProto := transProtoFactory(s)
s.transportProtocols[transProto.Number()] = &transportProtocolState{
proto: transProto,
}
@@ -814,7 +723,7 @@ func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, optio
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
// is incorrect.
-func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) *tcpip.Error {
transProtoState, ok := s.transportProtocols[transport]
if !ok {
return tcpip.ErrUnknownProtocol
@@ -829,7 +738,7 @@ func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumb
// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
// ...
// }
-func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) *tcpip.Error {
transProtoState, ok := s.transportProtocols[transport]
if !ok {
return tcpip.ErrUnknownProtocol
@@ -863,46 +772,37 @@ func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
-// SetForwarding enables or disables the packet forwarding between NICs.
-//
-// When forwarding becomes enabled, any host-only state on all NICs will be
-// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started.
-// When forwarding becomes disabled and if IPv6 is enabled, NDP Router
-// Solicitations will be stopped.
-func (s *Stack) SetForwarding(enable bool) {
- // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
- s.mu.Lock()
- defer s.mu.Unlock()
+// SetForwarding enables or disables packet forwarding between NICs for the
+// passed protocol.
+func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) *tcpip.Error {
+ protocol, ok := s.networkProtocols[protocolNum]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
- // If forwarding status didn't change, do nothing further.
- if s.forwarding == enable {
- return
+ forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
+ if !ok {
+ return tcpip.ErrNotSupported
}
- s.forwarding = enable
+ forwardingProtocol.SetForwarding(enable)
+ return nil
+}
- // If this stack does not support IPv6, do nothing further.
- if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok {
- return
+// Forwarding returns true if packet forwarding between NICs is enabled for the
+// passed protocol.
+func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool {
+ protocol, ok := s.networkProtocols[protocolNum]
+ if !ok {
+ return false
}
- if enable {
- for _, nic := range s.nics {
- nic.becomeIPv6Router()
- }
- } else {
- for _, nic := range s.nics {
- nic.becomeIPv6Host()
- }
+ forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
+ if !ok {
+ return false
}
-}
-// Forwarding returns if the packet forwarding between NICs is enabled.
-func (s *Stack) Forwarding() bool {
- // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
- s.mu.RLock()
- defer s.mu.RUnlock()
- return s.forwarding
+ return forwardingProtocol.Forwarding()
}
// SetRouteTable assigns the route table to be used by this stack. It
@@ -937,7 +837,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
return nil, tcpip.ErrUnknownProtocol
}
- return t.proto.NewEndpoint(s, network, waiterQueue)
+ return t.proto.NewEndpoint(network, waiterQueue)
}
// NewRawEndpoint creates a new raw transport layer endpoint of the given
@@ -957,7 +857,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
return nil, tcpip.ErrUnknownProtocol
}
- return t.proto.NewRawEndpoint(s, network, waiterQueue)
+ return t.proto.NewRawEndpoint(network, waiterQueue)
}
// NewPacketEndpoint creates a new packet endpoint listening for the given
@@ -1064,7 +964,8 @@ func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error {
return tcpip.ErrUnknownNICID
}
- return nic.disable()
+ nic.disable()
+ return nil
}
// CheckNIC checks if a NIC is usable.
@@ -1077,7 +978,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool {
return false
}
- return nic.enabled()
+ return nic.Enabled()
}
// RemoveNIC removes NIC and all related routes from the network stack.
@@ -1155,14 +1056,14 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
for id, nic := range s.nics {
flags := NICStateFlags{
Up: true, // Netstack interfaces are always up.
- Running: nic.enabled(),
+ Running: nic.Enabled(),
Promiscuous: nic.isPromiscuousMode(),
- Loopback: nic.isLoopback(),
+ Loopback: nic.IsLoopback(),
}
nics[id] = NICInfo{
Name: nic.name,
LinkAddress: nic.linkEP.LinkAddress(),
- ProtocolAddresses: nic.PrimaryAddresses(),
+ ProtocolAddresses: nic.primaryAddresses(),
Flags: flags,
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
@@ -1226,7 +1127,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
return tcpip.ErrUnknownNICID
}
- return nic.AddAddress(protocolAddress, peb)
+ return nic.addAddress(protocolAddress, peb)
}
// RemoveAddress removes an existing network-layer address from the specified
@@ -1236,7 +1137,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
- return nic.RemoveAddress(addr)
+ return nic.removeAddress(addr)
}
return tcpip.ErrUnknownNICID
@@ -1250,7 +1151,7 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress)
for id, nic := range s.nics {
- nics[id] = nic.AllAddresses()
+ nics[id] = nic.allPermanentAddresses()
}
return nics
}
@@ -1272,7 +1173,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol
return nic.primaryAddress(protocol), nil
}
-func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
+func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint {
if len(localAddr) == 0 {
return nic.primaryEndpoint(netProto, remoteAddr)
}
@@ -1289,9 +1190,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
if id != 0 && !needRoute {
- if nic, ok := s.nics[id]; ok && nic.enabled() {
- if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
- return makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
+ if nic, ok := s.nics[id]; ok && nic.Enabled() {
+ if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil
}
}
} else {
@@ -1299,22 +1200,20 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
continue
}
- if nic, ok := s.nics[route.NIC]; ok && nic.enabled() {
- if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
+ if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() {
+ if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
if len(remoteAddr) == 0 {
// If no remote address was provided, then the route
// provided will refer to the link local address.
- remoteAddr = ref.address()
+ remoteAddr = addressEndpoint.AddressWithPrefix().Address
}
- r := makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback())
- r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr)
-
+ r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback())
if len(route.Gateway) > 0 {
if needRoute {
r.NextHop = route.Gateway
}
- } else if r.directedBroadcast {
+ } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
r.RemoteLinkAddress = header.EthernetBroadcastAddress
}
@@ -1352,21 +1251,20 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto
return 0
}
- ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
- if ref == nil {
+ addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
return 0
}
- ref.decRef()
+ addressEndpoint.DecRef()
return nic.id
}
// Go through all the NICs.
for _, nic := range s.nics {
- ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
- if ref != nil {
- ref.decRef()
+ if addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
return nic.id
}
}
@@ -1528,10 +1426,9 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip
// StartTransportEndpointCleanup removes the endpoint with the given id from
// the stack transport dispatcher. It also transitions it to the cleanup stage.
func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
+ s.cleanupEndpointsMu.Lock()
s.cleanupEndpoints[ep] = struct{}{}
+ s.cleanupEndpointsMu.Unlock()
s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
@@ -1539,9 +1436,9 @@ func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcp
// CompleteTransportEndpointCleanup removes the endpoint from the cleanup
// stage.
func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
- s.mu.Lock()
+ s.cleanupEndpointsMu.Lock()
delete(s.cleanupEndpoints, ep)
- s.mu.Unlock()
+ s.cleanupEndpointsMu.Unlock()
}
// FindTransportEndpoint finds an endpoint that most closely matches the provided
@@ -1584,23 +1481,23 @@ func (s *Stack) RegisteredEndpoints() []TransportEndpoint {
// CleanupEndpoints returns endpoints currently in the cleanup state.
func (s *Stack) CleanupEndpoints() []TransportEndpoint {
- s.mu.Lock()
+ s.cleanupEndpointsMu.Lock()
es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints))
for e := range s.cleanupEndpoints {
es = append(es, e)
}
- s.mu.Unlock()
+ s.cleanupEndpointsMu.Unlock()
return es
}
// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful
// for restoring a stack after a save.
func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
- s.mu.Lock()
+ s.cleanupEndpointsMu.Lock()
for _, e := range es {
s.cleanupEndpoints[e] = struct{}{}
}
- s.mu.Unlock()
+ s.cleanupEndpointsMu.Unlock()
}
// Close closes all currently registered transport endpoints.
@@ -1795,18 +1692,17 @@ func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) Tra
// guarantee provided on which probe will be invoked. Ideally this should only
// be called once per stack.
func (s *Stack) AddTCPProbe(probe TCPProbeFunc) {
- s.mu.Lock()
- s.tcpProbeFunc = probe
- s.mu.Unlock()
+ s.tcpProbeFunc.Store(probe)
}
// GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil
// otherwise.
func (s *Stack) GetTCPProbe() TCPProbeFunc {
- s.mu.Lock()
- p := s.tcpProbeFunc
- s.mu.Unlock()
- return p
+ p := s.tcpProbeFunc.Load()
+ if p == nil {
+ return nil
+ }
+ return p.(TCPProbeFunc)
}
// RemoveTCPProbe removes an installed TCP probe.
@@ -1815,9 +1711,8 @@ func (s *Stack) GetTCPProbe() TCPProbeFunc {
// have a probe attached. Endpoints already created will continue to invoke
// TCP probe.
func (s *Stack) RemoveTCPProbe() {
- s.mu.Lock()
- s.tcpProbeFunc = nil
- s.mu.Unlock()
+ // This must be TCPProbeFunc(nil) because atomic.Value.Store(nil) panics.
+ s.tcpProbeFunc.Store(TCPProbeFunc(nil))
}
// JoinGroup joins the given multicast group on the given NIC.
@@ -1838,7 +1733,7 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC
defer s.mu.RUnlock()
if nic, ok := s.nics[nicID]; ok {
- return nic.leaveGroup(multicastAddr)
+ return nic.leaveGroup(protocol, multicastAddr)
}
return tcpip.ErrUnknownNICID
}
@@ -1890,53 +1785,18 @@ func (s *Stack) AllowICMPMessage() bool {
return s.icmpRateLimiter.Allow()
}
-// IsAddrTentative returns true if addr is tentative on the NIC with ID id.
-//
-// Note that if addr is not associated with a NIC with id ID, then this
-// function will return false. It will only return true if the address is
-// associated with the NIC AND it is tentative.
-func (s *Stack) IsAddrTentative(id tcpip.NICID, addr tcpip.Address) (bool, *tcpip.Error) {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic, ok := s.nics[id]
- if !ok {
- return false, tcpip.ErrUnknownNICID
- }
-
- return nic.isAddrTentative(addr), nil
-}
-
-// DupTentativeAddrDetected attempts to inform the NIC with ID id that a
-// tentative addr on it is a duplicate on a link.
-func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+// GetNetworkEndpoint returns the NetworkEndpoint with the specified protocol
+// number installed on the specified NIC.
+func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
- nic, ok := s.nics[id]
- if !ok {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.dupTentativeAddrDetected(addr)
-}
-
-// SetNDPConfigurations sets the per-interface NDP configurations on the NIC
-// with ID id to c.
-//
-// Note, if c contains invalid NDP configuration values, it will be fixed to
-// use default values for the erroneous values.
-func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- nic, ok := s.nics[id]
+ nic, ok := s.nics[nicID]
if !ok {
- return tcpip.ErrUnknownNICID
+ return nil, tcpip.ErrUnknownNICID
}
- nic.setNDPConfigs(c)
- return nil
+ return nic.networkEndpoints[proto], nil
}
// NUDConfigurations gets the per-interface NUD configurations.
@@ -1949,7 +1809,7 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Err
return NUDConfigurations{}, tcpip.ErrUnknownNICID
}
- return nic.NUDConfigs()
+ return nic.nudConfigs()
}
// SetNUDConfigurations sets the per-interface NUD configurations.
@@ -1968,22 +1828,6 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip
return nic.setNUDConfigs(c)
}
-// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement
-// message that it needs to handle.
-func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- nic, ok := s.nics[id]
- if !ok {
- return tcpip.ErrUnknownNICID
- }
-
- nic.handleNDPRA(ip, ra)
-
- return nil
-}
-
// Seed returns a 32 bit value that can be used as a seed value for port
// picking, ISN generation etc.
//
@@ -2025,16 +1869,14 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres
defer s.mu.RUnlock()
for _, nic := range s.nics {
- id := NetworkEndpointID{address}
-
- if ref, ok := nic.mu.endpoints[id]; ok {
- nic.mu.RLock()
- defer nic.mu.RUnlock()
-
- // An endpoint with this id exists, check if it can be
- // used and return it.
- return ref.ep, nil
+ addressEndpoint := nic.getAddressOrCreateTempInner(netProto, address, false /* createTemp */, NeverPrimaryEndpoint)
+ if addressEndpoint == nil {
+ continue
}
+
+ ep := addressEndpoint.NetworkEndpoint()
+ addressEndpoint.DecRef()
+ return ep, nil
}
return nil, tcpip.ErrBadAddress
}
@@ -2051,3 +1893,8 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
return nic.Name()
}
+
+// NewJob returns a new tcpip.Job using the stack's clock.
+func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job {
+ return tcpip.NewJob(s.clock, l, f)
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 60b54c244..aa20f750b 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -29,6 +29,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -68,18 +69,41 @@ const (
// use the first three: destination address, source address, and transport
// protocol. They're all one byte fields to simplify parsing.
type fakeNetworkEndpoint struct {
+ stack.AddressableEndpointState
+
+ mu struct {
+ sync.RWMutex
+
+ enabled bool
+ }
+
nicID tcpip.NICID
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
ep stack.LinkEndpoint
}
-func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
+func (f *fakeNetworkEndpoint) Enable() *tcpip.Error {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.enabled = true
+ return nil
}
-func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
- return f.nicID
+func (f *fakeNetworkEndpoint) Enabled() bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ return f.mu.enabled
+}
+
+func (f *fakeNetworkEndpoint) Disable() {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.enabled = false
+}
+
+func (f *fakeNetworkEndpoint) MTU() uint32 {
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
@@ -118,10 +142,6 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto
return 0
}
-func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return f.ep.Capabilities()
-}
-
func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return f.proto.Number()
}
@@ -156,7 +176,9 @@ func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack
return tcpip.ErrNotSupported
}
-func (*fakeNetworkEndpoint) Close() {}
+func (f *fakeNetworkEndpoint) Close() {
+ f.AddressableEndpointState.Cleanup()
+}
// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the
// number of packets sent and received via endpoints of this protocol. The index
@@ -165,6 +187,11 @@ type fakeNetworkProtocol struct {
packetCount [10]int
sendPacketCount [10]int
defaultTTL uint8
+
+ mu struct {
+ sync.RWMutex
+ forwarding bool
+ }
}
func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
@@ -187,13 +214,15 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint {
- return &fakeNetworkEndpoint{
- nicID: nicID,
+func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &fakeNetworkEndpoint{
+ nicID: nic.ID(),
proto: f,
dispatcher: dispatcher,
- ep: ep,
+ ep: nic.LinkEndpoint(),
}
+ e.AddressableEndpointState.Init(e)
+ return e
}
func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
@@ -216,13 +245,13 @@ func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption)
}
}
-// Close implements TransportProtocol.Close.
+// Close implements NetworkProtocol.Close.
func (*fakeNetworkProtocol) Close() {}
-// Wait implements TransportProtocol.Wait.
+// Wait implements NetworkProtocol.Wait.
func (*fakeNetworkProtocol) Wait() {}
-// Parse implements TransportProtocol.Parse.
+// Parse implements NetworkProtocol.Parse.
func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen)
if !ok {
@@ -231,7 +260,21 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto
return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
}
-func fakeNetFactory() stack.NetworkProtocol {
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (f *fakeNetworkProtocol) Forwarding() bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ return f.mu.forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (f *fakeNetworkProtocol) SetForwarding(v bool) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.forwarding = v
+}
+
+func fakeNetFactory(*stack.Stack) stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
@@ -268,7 +311,7 @@ func TestNetworkReceive(t *testing.T) {
// addresses attached to it: 1 & 2.
ep := channel.New(10, defaultMTU, "")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -428,7 +471,7 @@ func TestNetworkSend(t *testing.T) {
// existing nic.
ep := channel.New(10, defaultMTU, "")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("NewNIC failed:", err)
@@ -455,7 +498,7 @@ func TestNetworkSendMultiRoute(t *testing.T) {
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep1 := channel.New(10, defaultMTU, "")
@@ -555,7 +598,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
e := linkEPWithMockedAttach{
@@ -574,7 +617,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) {
func TestDisableUnknownNIC(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID {
@@ -586,7 +629,7 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
e := loopback.New()
@@ -633,7 +676,7 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
func TestRemoveUnknownNIC(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID {
@@ -645,7 +688,7 @@ func TestRemoveNIC(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
e := linkEPWithMockedAttach{
@@ -706,7 +749,7 @@ func TestRouteWithDownNIC(t *testing.T) {
setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep1 := channel.New(1, defaultMTU, "")
@@ -872,7 +915,7 @@ func TestRoutes(t *testing.T) {
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep1 := channel.New(10, defaultMTU, "")
@@ -952,7 +995,7 @@ func TestAddressRemoval(t *testing.T) {
remoteAddr := tcpip.Address("\x02")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -999,7 +1042,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
remoteAddr := tcpip.Address("\x02")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1090,7 +1133,7 @@ func TestEndpointExpiration(t *testing.T) {
for _, spoofing := range []bool{true, false} {
t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1248,7 +1291,7 @@ func TestEndpointExpiration(t *testing.T) {
func TestPromiscuousMode(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1300,7 +1343,7 @@ func TestSpoofingWithAddress(t *testing.T) {
dstAddr := tcpip.Address("\x03")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1366,7 +1409,7 @@ func TestSpoofingNoAddress(t *testing.T) {
dstAddr := tcpip.Address("\x02")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1429,7 +1472,7 @@ func verifyRoute(gotRoute, wantRoute stack.Route) error {
func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1472,7 +1515,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
// Create a new stack with two NICs.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -1573,7 +1616,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1630,8 +1673,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
func TestNetworkOption(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ TransportProtocols: []stack.TransportProtocolFactory{},
})
opt := tcpip.DefaultTTLOption(5)
@@ -1657,7 +1700,7 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
for never := 0; never < 3; never++ {
t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -1724,7 +1767,7 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
func TestGetMainNICAddressAddRemove(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -1809,7 +1852,7 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
func TestAddAddress(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID, ep); err != nil {
@@ -1836,7 +1879,7 @@ func TestAddAddress(t *testing.T) {
func TestAddProtocolAddress(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID, ep); err != nil {
@@ -1870,7 +1913,7 @@ func TestAddProtocolAddress(t *testing.T) {
func TestAddAddressWithOptions(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID, ep); err != nil {
@@ -1901,7 +1944,7 @@ func TestAddAddressWithOptions(t *testing.T) {
func TestAddProtocolAddressWithOptions(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID, ep); err != nil {
@@ -2022,7 +2065,7 @@ func TestCreateNICWithOptions(t *testing.T) {
func TestNICStats(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep1); err != nil {
@@ -2089,9 +2132,9 @@ func TestNICForwarding(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
- s.SetForwarding(true)
+ s.SetForwarding(fakeNetNumber, true)
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID1, ep1); err != nil {
@@ -2213,7 +2256,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName string
autoGen bool
linkAddr tcpip.LinkAddress
- iidOpts stack.OpaqueInterfaceIdentifierOptions
+ iidOpts ipv6.OpaqueInterfaceIdentifierOptions
shouldGen bool
expectedAddr tcpip.Address
}{
@@ -2229,7 +2272,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "nic1",
autoGen: false,
linkAddr: linkAddr1,
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:],
},
@@ -2274,7 +2317,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "nic1",
autoGen: true,
linkAddr: linkAddr1,
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:],
},
@@ -2286,7 +2329,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
{
name: "OIID Empty MAC and empty nicName",
autoGen: true,
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:1],
},
@@ -2298,7 +2341,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "test",
autoGen: true,
linkAddr: "\x01\x02\x03",
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:2],
},
@@ -2310,7 +2353,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "test2",
autoGen: true,
linkAddr: "\x01\x02\x03\x04\x05\x06",
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:3],
},
@@ -2322,7 +2365,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "test3",
autoGen: true,
linkAddr: "\x00\x00\x00\x00\x00\x00",
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
},
shouldGen: true,
@@ -2336,10 +2379,11 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: test.autoGen,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: test.iidOpts,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: test.autoGen,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: test.iidOpts,
+ })},
}
e := channel.New(0, 1280, test.linkAddr)
@@ -2411,15 +2455,15 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
tests := []struct {
name string
- opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions
+ opaqueIIDOpts ipv6.OpaqueInterfaceIdentifierOptions
}{
{
name: "IID From MAC",
- opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{},
+ opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{},
},
{
name: "Opaque IID",
- opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: func(_ tcpip.NICID, nicName string) string {
return nicName
},
@@ -2430,9 +2474,10 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: true,
- OpaqueIIDOpts: test.opaqueIIDOpts,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: true,
+ OpaqueIIDOpts: test.opaqueIIDOpts,
+ })},
}
e := loopback.New()
@@ -2461,12 +2506,13 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent),
}
- ndpConfigs := stack.DefaultNDPConfigurations()
+ ndpConfigs := ipv6.DefaultNDPConfigurations()
opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: ndpConfigs,
- AutoGenIPv6LinkLocal: true,
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ })},
}
e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1)
@@ -2522,7 +2568,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
for _, ps := range pebs {
t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep1); err != nil {
@@ -2813,14 +2859,15 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDispatcher{},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDispatcher{},
+ })},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -2869,7 +2916,7 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
e := loopback.New()
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
})
nicOpts := stack.NICOptions{Disabled: true}
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
@@ -2921,7 +2968,7 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
})
e := channel.New(10, 1280, linkAddr1)
if err := s.CreateNIC(1, e); err != nil {
@@ -2982,7 +3029,7 @@ func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
e := loopback.New()
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
})
nicOpts := stack.NICOptions{Disabled: true}
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
@@ -3059,12 +3106,13 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
dadC: make(chan ndpDADEvent),
}
opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
+ NDPDisp: &ndpDisp,
+ })},
}
e := channel.New(dadTransmits, 1280, linkAddr1)
@@ -3423,7 +3471,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
})
ep := channel.New(0, defaultMTU, "")
if err := s.CreateNIC(nicID1, ep); err != nil {
@@ -3461,7 +3509,7 @@ func TestResolveWith(t *testing.T) {
)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
})
ep := channel.New(0, defaultMTU, "")
ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
@@ -3499,3 +3547,131 @@ func TestResolveWith(t *testing.T) {
t.Fatal("got r.IsResolutionRequired() = true, want = false")
}
}
+
+// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its
+// associated address is removed should not cause a panic.
+func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
+ const (
+ nicID = 1
+ localAddr = tcpip.Address("\x01")
+ remoteAddr = tcpip.Address("\x02")
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ })
+
+ ep := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ r, err := s.FindRoute(nicID, localAddr, remoteAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, localAddr, remoteAddr, fakeNetNumber, err)
+ }
+ // Should not panic.
+ defer r.Release()
+
+ // Check that removing the same address fails.
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, localAddr, err)
+ }
+}
+
+func TestGetNetworkEndpoint(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ protoFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ }{
+ {
+ name: "IPv4",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ },
+ {
+ name: "IPv6",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ },
+ }
+
+ factories := make([]stack.NetworkProtocolFactory, 0, len(tests))
+ for _, test := range tests {
+ factories = append(factories, test.protoFactory)
+ }
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: factories,
+ })
+
+ if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep, err := s.GetNetworkEndpoint(nicID, test.protoNum)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, test.protoNum, err)
+ }
+
+ if got := ep.NetworkProtocolNumber(); got != test.protoNum {
+ t.Fatalf("got ep.NetworkProtocolNumber() = %d, want = %d", got, test.protoNum)
+ }
+ })
+ }
+}
+
+func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ })
+
+ if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err)
+ }
+
+ // Check that we get the right initial address and prefix length.
+ if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
+ } else if gotAddr != protocolAddress.AddressWithPrefix {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ }
+
+ // Should still get the address when the NIC is diabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
+ } else if gotAddr != protocolAddress.AddressWithPrefix {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index b902c6ca9..35e5b1a2e 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -155,7 +155,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
- mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[r.nic.ID()]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if isMulticastOrBroadcast(id.LocalAddress) {
+ if isInboundMulticastOrBroadcast(r) {
mpep.handlePacketAll(r, id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
@@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
@@ -544,9 +544,11 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
return true
}
- // If the packet is a TCP packet with a non-unicast source or destination
- // address, then do nothing further and instruct the caller to do the same.
- if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ // If the packet is a TCP packet with a unspecified source or non-unicast
+ // destination address, then do nothing further and instruct the caller to do
+ // the same. The network layer handles address validation for specified source
+ // addresses.
+ if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) {
// TCP can only be used to communicate between a single source and a
// single destination; the addresses must be unicast.
r.Stats().TCP.InvalidSegmentsReceived.Increment()
@@ -626,7 +628,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
epsByNIC.mu.RLock()
eps.mu.RUnlock()
- mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[r.nic.ID()]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -677,10 +679,10 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
eps.mu.Unlock()
}
-func isMulticastOrBroadcast(addr tcpip.Address) bool {
- return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
+func isInboundMulticastOrBroadcast(r *Route) bool {
+ return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
}
-func isUnicast(addr tcpip.Address) bool {
- return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
+func isSpecified(addr tcpip.Address) bool {
+ return addr != header.IPv4Any && addr != header.IPv6Any
}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 4d6d62eec..698c8609e 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -51,8 +51,8 @@ type testContext struct {
// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
linkEps := make(map[tcpip.NICID]*channel.Endpoint)
for _, linkEpID := range linkEpIDs {
@@ -182,8 +182,8 @@ func TestTransportDemuxerRegister(t *testing.T) {
} {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index a1458c899..62ab6d92f 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -39,7 +39,7 @@ const (
// use it.
type fakeTransportEndpoint struct {
stack.TransportEndpointInfo
- stack *stack.Stack
+
proto *fakeTransportProtocol
peerAddr tcpip.Address
route stack.Route
@@ -59,8 +59,8 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats {
func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
-func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
+ return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
func (f *fakeTransportEndpoint) Abort() {
@@ -143,7 +143,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
f.peerAddr = addr.Addr
// Find the route.
- r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
+ r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
return tcpip.ErrNoRoute
}
@@ -151,7 +151,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
+ err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -180,7 +180,7 @@ func (*fakeTransportEndpoint) Listen(int) *tcpip.Error {
return nil
}
-func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
if len(f.acceptQueue) == 0 {
return nil, nil, nil
}
@@ -190,7 +190,7 @@ func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.
}
func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
- if err := f.stack.RegisterTransportEndpoint(
+ if err := f.proto.stack.RegisterTransportEndpoint(
a.NIC,
[]tcpip.NetworkProtocolNumber{fakeNetNumber},
fakeTransNumber,
@@ -218,7 +218,6 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE
f.proto.packetCount++
if f.acceptQueue != nil {
f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
- stack: f.stack,
TransportEndpointInfo: stack.TransportEndpointInfo{
ID: f.ID,
NetProto: f.NetProto,
@@ -262,6 +261,8 @@ type fakeTransportProtocolOptions struct {
// fakeTransportProtocol is a transport-layer protocol descriptor. It
// aggregates the number of packets received via endpoints of this protocol.
type fakeTransportProtocol struct {
+ stack *stack.Stack
+
packetCount int
controlCount int
opts fakeTransportProtocolOptions
@@ -271,11 +272,11 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
return fakeTransNumber
}
-func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
+func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil
}
-func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return nil, tcpip.ErrUnknownProtocol
}
@@ -287,26 +288,24 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
- return true
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+ return stack.UnknownDestinationPacketHandled
}
-func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error {
+func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case fakeTransportGoodOption:
- f.opts.good = bool(v)
+ case *tcpip.TCPModerateReceiveBufferOption:
+ f.opts.good = bool(*v)
return nil
- case fakeTransportInvalidValueOption:
- return tcpip.ErrInvalidOptionValue
default:
return tcpip.ErrUnknownProtocolOption
}
}
-func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
+func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case *fakeTransportGoodOption:
- *v = fakeTransportGoodOption(f.opts.good)
+ case *tcpip.TCPModerateReceiveBufferOption:
+ *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good)
return nil
default:
return tcpip.ErrUnknownProtocolOption
@@ -328,15 +327,15 @@ func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool {
return ok
}
-func fakeTransFactory() stack.TransportProtocol {
- return &fakeTransportProtocol{}
+func fakeTransFactory(s *stack.Stack) stack.TransportProtocol {
+ return &fakeTransportProtocol{stack: s}
}
func TestTransportReceive(t *testing.T) {
linkEP := channel.New(10, defaultMTU, "")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
})
if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
@@ -406,8 +405,8 @@ func TestTransportReceive(t *testing.T) {
func TestTransportControlReceive(t *testing.T) {
linkEP := channel.New(10, defaultMTU, "")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
})
if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
@@ -483,8 +482,8 @@ func TestTransportControlReceive(t *testing.T) {
func TestTransportSend(t *testing.T) {
linkEP := channel.New(10, defaultMTU, "")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
})
if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
@@ -529,54 +528,29 @@ func TestTransportSend(t *testing.T) {
func TestTransportOptions(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
})
- // Try an unsupported transport protocol.
- if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
- t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
- }
-
- testCases := []struct {
- option interface{}
- wantErr *tcpip.Error
- verifier func(t *testing.T, p stack.TransportProtocol)
- }{
- {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) {
- t.Helper()
- fakeTrans := p.(*fakeTransportProtocol)
- if fakeTrans.opts.good != true {
- t.Fatalf("fakeTrans.opts.good = false, want = true")
- }
- var v fakeTransportGoodOption
- if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil {
- t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err)
- }
- if v != true {
- t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v)
- }
-
- }},
- {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
- {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
- }
- for _, tc := range testCases {
- if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr {
- t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr)
- }
- if tc.verifier != nil {
- tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber))
- }
+ v := tcpip.TCPModerateReceiveBufferOption(true)
+ if err := s.SetTransportProtocolOption(fakeTransNumber, &v); err != nil {
+ t.Errorf("s.SetTransportProtocolOption(fakeTrans, &%T(%t)): %s", v, v, err)
+ }
+ v = false
+ if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil {
+ t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &%T): %s", v, err)
+ }
+ if !v {
+ t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true")
}
}
func TestTransportForwarding(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
})
- s.SetForwarding(true)
+ s.SetForwarding(fakeNetNumber, true)
// TODO(b/123449044): Change this to a channel NIC.
ep1 := loopback.New()
@@ -631,7 +605,7 @@ func TestTransportForwarding(t *testing.T) {
Data: req.ToVectorisedView(),
}))
- aep, _, err := ep.Accept()
+ aep, _, err := ep.Accept(nil)
if err != nil || aep == nil {
t.Fatalf("Accept failed: %v, %v", aep, err)
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 47a8d7c86..c42bb0991 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -237,6 +237,14 @@ type Timer interface {
// network node. Or, in the case of unix endpoints, it may represent a path.
type Address string
+// WithPrefix returns the address with a prefix that represents a point subnet.
+func (a Address) WithPrefix() AddressWithPrefix {
+ return AddressWithPrefix{
+ Address: a,
+ PrefixLen: len(a) * 8,
+ }
+}
+
// AddressMask is a bitmask for an address.
type AddressMask string
@@ -561,7 +569,10 @@ type Endpoint interface {
// block if no new connections are available.
//
// The returned Queue is the wait queue for the newly created endpoint.
- Accept() (Endpoint, *waiter.Queue, *Error)
+ //
+ // If peerAddr is not nil then it is populated with the peer address of the
+ // returned endpoint.
+ Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, *Error)
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
@@ -861,12 +872,93 @@ func (*DefaultTTLOption) isGettableNetworkProtocolOption() {}
func (*DefaultTTLOption) isSettableNetworkProtocolOption() {}
-// AvailableCongestionControlOption is used to query the supported congestion
-// control algorithms.
-type AvailableCongestionControlOption string
+// GettableTransportProtocolOption is a marker interface for transport protocol
+// options that may be queried.
+type GettableTransportProtocolOption interface {
+ isGettableTransportProtocolOption()
+}
+
+// SettableTransportProtocolOption is a marker interface for transport protocol
+// options that may be set.
+type SettableTransportProtocolOption interface {
+ isSettableTransportProtocolOption()
+}
+
+// TCPSACKEnabled the SACK option for TCP.
+//
+// See: https://tools.ietf.org/html/rfc2018.
+type TCPSACKEnabled bool
+
+func (*TCPSACKEnabled) isGettableTransportProtocolOption() {}
+
+func (*TCPSACKEnabled) isSettableTransportProtocolOption() {}
+
+// TCPRecovery is the loss deteoction algorithm used by TCP.
+type TCPRecovery int32
+
+func (*TCPRecovery) isGettableTransportProtocolOption() {}
+
+func (*TCPRecovery) isSettableTransportProtocolOption() {}
+
+const (
+ // TCPRACKLossDetection indicates RACK is used for loss detection and
+ // recovery.
+ TCPRACKLossDetection TCPRecovery = 1 << iota
+
+ // TCPRACKStaticReoWnd indicates the reordering window should not be
+ // adjusted when DSACK is received.
+ TCPRACKStaticReoWnd
+
+ // TCPRACKNoDupTh indicates RACK should not consider the classic three
+ // duplicate acknowledgements rule to mark the segments as lost. This
+ // is used when reordering is not detected.
+ TCPRACKNoDupTh
+)
+
+// TCPDelayEnabled enables/disables Nagle's algorithm in TCP.
+type TCPDelayEnabled bool
+
+func (*TCPDelayEnabled) isGettableTransportProtocolOption() {}
+
+func (*TCPDelayEnabled) isSettableTransportProtocolOption() {}
+
+// TCPSendBufferSizeRangeOption is the send buffer size range for TCP.
+type TCPSendBufferSizeRangeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+func (*TCPSendBufferSizeRangeOption) isGettableTransportProtocolOption() {}
+
+func (*TCPSendBufferSizeRangeOption) isSettableTransportProtocolOption() {}
+
+// TCPReceiveBufferSizeRangeOption is the receive buffer size range for TCP.
+type TCPReceiveBufferSizeRangeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+func (*TCPReceiveBufferSizeRangeOption) isGettableTransportProtocolOption() {}
+
+func (*TCPReceiveBufferSizeRangeOption) isSettableTransportProtocolOption() {}
+
+// TCPAvailableCongestionControlOption is the supported congestion control
+// algorithms for TCP
+type TCPAvailableCongestionControlOption string
+
+func (*TCPAvailableCongestionControlOption) isGettableTransportProtocolOption() {}
+
+func (*TCPAvailableCongestionControlOption) isSettableTransportProtocolOption() {}
+
+// TCPModerateReceiveBufferOption enables/disables receive buffer moderation
+// for TCP.
+type TCPModerateReceiveBufferOption bool
+
+func (*TCPModerateReceiveBufferOption) isGettableTransportProtocolOption() {}
-// ModerateReceiveBufferOption is used by buffer moderation.
-type ModerateReceiveBufferOption bool
+func (*TCPModerateReceiveBufferOption) isSettableTransportProtocolOption() {}
// GettableSocketOption is a marker interface for socket options that may be
// queried.
@@ -932,6 +1024,10 @@ func (*CongestionControlOption) isGettableSocketOption() {}
func (*CongestionControlOption) isSettableSocketOption() {}
+func (*CongestionControlOption) isGettableTransportProtocolOption() {}
+
+func (*CongestionControlOption) isSettableTransportProtocolOption() {}
+
// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state
// before being marked closed.
@@ -941,6 +1037,10 @@ func (*TCPLingerTimeoutOption) isGettableSocketOption() {}
func (*TCPLingerTimeoutOption) isSettableSocketOption() {}
+func (*TCPLingerTimeoutOption) isGettableTransportProtocolOption() {}
+
+func (*TCPLingerTimeoutOption) isSettableTransportProtocolOption() {}
+
// TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
// maximum duration for which a socket lingers in the TIME_WAIT state
// before being marked closed.
@@ -950,6 +1050,10 @@ func (*TCPTimeWaitTimeoutOption) isGettableSocketOption() {}
func (*TCPTimeWaitTimeoutOption) isSettableSocketOption() {}
+func (*TCPTimeWaitTimeoutOption) isGettableTransportProtocolOption() {}
+
+func (*TCPTimeWaitTimeoutOption) isSettableTransportProtocolOption() {}
+
// TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a
// accept to return a completed connection only when there is data to be
// read. This usually means the listening socket will drop the final ACK
@@ -968,6 +1072,10 @@ func (*TCPMinRTOOption) isGettableSocketOption() {}
func (*TCPMinRTOOption) isSettableSocketOption() {}
+func (*TCPMinRTOOption) isGettableTransportProtocolOption() {}
+
+func (*TCPMinRTOOption) isSettableTransportProtocolOption() {}
+
// TCPMaxRTOOption is use by SetSockOpt/GetSockOpt to allow overriding
// default MaxRTO used by the Stack.
type TCPMaxRTOOption time.Duration
@@ -976,6 +1084,10 @@ func (*TCPMaxRTOOption) isGettableSocketOption() {}
func (*TCPMaxRTOOption) isSettableSocketOption() {}
+func (*TCPMaxRTOOption) isGettableTransportProtocolOption() {}
+
+func (*TCPMaxRTOOption) isSettableTransportProtocolOption() {}
+
// TCPMaxRetriesOption is used by SetSockOpt/GetSockOpt to set/get the
// maximum number of retransmits after which we time out the connection.
type TCPMaxRetriesOption uint64
@@ -984,6 +1096,10 @@ func (*TCPMaxRetriesOption) isGettableSocketOption() {}
func (*TCPMaxRetriesOption) isSettableSocketOption() {}
+func (*TCPMaxRetriesOption) isGettableTransportProtocolOption() {}
+
+func (*TCPMaxRetriesOption) isSettableTransportProtocolOption() {}
+
// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify
// the number of endpoints that can be in SYN-RCVD state before the stack
// switches to using SYN cookies.
@@ -993,6 +1109,10 @@ func (*TCPSynRcvdCountThresholdOption) isGettableSocketOption() {}
func (*TCPSynRcvdCountThresholdOption) isSettableSocketOption() {}
+func (*TCPSynRcvdCountThresholdOption) isGettableTransportProtocolOption() {}
+
+func (*TCPSynRcvdCountThresholdOption) isSettableTransportProtocolOption() {}
+
// TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide
// default for number of times SYN is retransmitted before aborting a connect.
type TCPSynRetriesOption uint8
@@ -1001,6 +1121,10 @@ func (*TCPSynRetriesOption) isGettableSocketOption() {}
func (*TCPSynRetriesOption) isSettableSocketOption() {}
+func (*TCPSynRetriesOption) isGettableTransportProtocolOption() {}
+
+func (*TCPSynRetriesOption) isSettableTransportProtocolOption() {}
+
// MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a
// default interface for multicast.
type MulticastInterfaceOption struct {
@@ -1059,6 +1183,10 @@ func (*TCPTimeWaitReuseOption) isGettableSocketOption() {}
func (*TCPTimeWaitReuseOption) isSettableSocketOption() {}
+func (*TCPTimeWaitReuseOption) isGettableTransportProtocolOption() {}
+
+func (*TCPTimeWaitReuseOption) isSettableTransportProtocolOption() {}
+
const (
// TCPTimeWaitReuseDisabled indicates reuse of port bound by endponts in TIME-WAIT cannot
// be reused for new connections.
@@ -1354,6 +1482,18 @@ type IPStats struct {
// MalformedFragmentsReceived is the total number of IP Fragments that were
// dropped due to the fragment failing validation checks.
MalformedFragmentsReceived *StatCounter
+
+ // IPTablesPreroutingDropped is the total number of IP packets dropped
+ // in the Prerouting chain.
+ IPTablesPreroutingDropped *StatCounter
+
+ // IPTablesInputDropped is the total number of IP packets dropped in
+ // the Input chain.
+ IPTablesInputDropped *StatCounter
+
+ // IPTablesOutputDropped is the total number of IP packets dropped in
+ // the Output chain.
+ IPTablesOutputDropped *StatCounter
}
// TCPStats collects TCP-specific stats.
@@ -1482,9 +1622,6 @@ type UDPStats struct {
// ChecksumErrors is the number of datagrams dropped due to bad checksums.
ChecksumErrors *StatCounter
-
- // InvalidSourceAddress is the number of invalid sourced datagrams dropped.
- InvalidSourceAddress *StatCounter
}
// Stats holds statistics about the networking stack.
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 1b18023c5..e8caf09ba 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -16,6 +16,7 @@ package integration_test
import (
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -29,6 +30,69 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
+
+type ndpDispatcher struct{}
+
+func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) {
+}
+
+func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool {
+ return false
+}
+
+func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {}
+
+func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
+ return false
+}
+
+func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {}
+
+func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
+ return true
+}
+
+func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {}
+
+func (*ndpDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {}
+
+func (*ndpDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {}
+
+func (*ndpDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {}
+
+func (*ndpDispatcher) OnDHCPv6Configuration(tcpip.NICID, ipv6.DHCPv6ConfigurationFromNDPRA) {}
+
+// TestInitialLoopbackAddresses tests that the loopback interface does not
+// auto-generate a link-local address when it is brought up.
+func TestInitialLoopbackAddresses(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDispatcher{},
+ AutoGenIPv6LinkLocal: true,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(nicID tcpip.NICID, nicName string) string {
+ t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName)
+ return ""
+ },
+ },
+ })},
+ })
+
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ nicsInfo := s.NICInfo()
+ if nicInfo, ok := nicsInfo[nicID]; !ok {
+ t.Fatalf("did not find NIC with ID = %d in s.NICInfo() = %#v", nicID, nicsInfo)
+ } else if got := len(nicInfo.ProtocolAddresses); got != 0 {
+ t.Fatalf("got len(nicInfo.ProtocolAddresses) = %d, want = 0; nicInfo.ProtocolAddresses = %#v", got, nicInfo.ProtocolAddresses)
+ }
+}
+
// TestLoopbackAcceptAllInSubnet tests that a loopback interface considers
// itself bound to all addresses in the subnet of an assigned address.
func TestLoopbackAcceptAllInSubnet(t *testing.T) {
@@ -120,8 +184,8 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
@@ -187,3 +251,64 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
})
}
}
+
+// TestLoopbackSubnetLifetimeBoundToAddr tests that the lifetime of an address
+// in a loopback interface's associated subnet is bound to the permanently bound
+// address.
+func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
+ const nicID = 1
+
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ }
+ addrBytes := []byte(ipv4Addr.Address)
+ addrBytes[len(addrBytes)-1]++
+ otherAddr := tcpip.Address(addrBytes)
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ r, err := s.FindRoute(nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, err)
+ }
+ defer r.Release()
+
+ params := stack.NetworkHeaderParams{
+ Protocol: 111,
+ TTL: 64,
+ TOS: stack.DefaultTOS,
+ }
+ data := buffer.View([]byte{1, 2, 3, 4})
+ if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: data.ToVectorisedView(),
+ })); err != nil {
+ t.Fatalf("r.WritePacket(nil, %#v, _): %s", params, err)
+ }
+
+ // Removing the address should make the endpoint invalid.
+ if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err)
+ }
+ if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: data.ToVectorisedView(),
+ })); err != tcpip.ErrInvalidEndpointState {
+ t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState)
+ }
+}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 52c27e045..72d86b5ab 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -139,11 +140,9 @@ func TestPingMulticastBroadcast(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ipv4Proto := ipv4.NewProtocol()
- ipv6Proto := ipv6.NewProtocol()
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4Proto, ipv6Proto},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4(), icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
})
// We only expect a single packet in response to our ICMP Echo Request.
e := channel.New(1, defaultMTU, "")
@@ -175,18 +174,18 @@ func TestPingMulticastBroadcast(t *testing.T) {
var rxICMP func(*channel.Endpoint, tcpip.Address)
var expectedSrc tcpip.Address
var expectedDst tcpip.Address
- var proto stack.NetworkProtocol
+ var protoNum tcpip.NetworkProtocolNumber
switch l := len(test.dstAddr); l {
case header.IPv4AddressSize:
rxICMP = rxIPv4ICMP
expectedSrc = ipv4Addr.Address
expectedDst = remoteIPv4Addr
- proto = ipv4Proto
+ protoNum = header.IPv4ProtocolNumber
case header.IPv6AddressSize:
rxICMP = rxIPv6ICMP
expectedSrc = ipv6Addr.Address
expectedDst = remoteIPv6Addr
- proto = ipv6Proto
+ protoNum = header.IPv6ProtocolNumber
default:
t.Fatalf("got unexpected address length = %d bytes", l)
}
@@ -204,7 +203,7 @@ func TestPingMulticastBroadcast(t *testing.T) {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst)
}
- src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader().View())
+ src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(pkt.Pkt.NetworkHeader().View())
if src != expectedSrc {
t.Errorf("got pkt source = %s, want = %s", src, expectedSrc)
}
@@ -379,8 +378,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
e := channel.New(0, defaultMTU, "")
if err := s.CreateNIC(nicID, e); err != nil {
@@ -436,3 +435,122 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
})
}
}
+
+// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all
+// interested endpoints.
+func TestReuseAddrAndBroadcast(t *testing.T) {
+ const (
+ nicID = 1
+ localPort = 9000
+ loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff")
+ )
+
+ data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+
+ tests := []struct {
+ name string
+ broadcastAddr tcpip.Address
+ }{
+ {
+ name: "Subnet directed broadcast",
+ broadcastAddr: loopbackBroadcast,
+ },
+ {
+ name: "IPv4 broadcast",
+ broadcastAddr: header.IPv4Broadcast,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x7f\x00\x00\x01",
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ // We use the empty subnet instead of just the loopback subnet so we
+ // also have a route to the IPv4 Broadcast address.
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // We create endpoints that bind to both the wildcard address and the
+ // broadcast address to make sure both of these types of "broadcast
+ // interested" endpoints receive broadcast packets.
+ wq := waiter.Queue{}
+ var eps []tcpip.Endpoint
+ for _, bindWildcard := range []bool{false, true} {
+ // Create multiple endpoints for each type of "broadcast interested"
+ // endpoint so we can test that all endpoints receive the broadcast
+ // packet.
+ for i := 0; i < 2; i++ {
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err)
+ }
+
+ bindAddr := tcpip.FullAddress{Port: localPort}
+ if bindWildcard {
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ }
+ } else {
+ bindAddr.Addr = test.broadcastAddr
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ }
+ }
+
+ eps = append(eps, ep)
+ }
+ }
+
+ for i, wep := range eps {
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{
+ Addr: test.broadcastAddr,
+ Port: localPort,
+ },
+ }
+ if n, _, err := wep.Write(data, writeOpts); err != nil {
+ t.Fatalf("eps[%d].Write(_, _): %s", i, err)
+ } else if want := int64(len(data)); n != want {
+ t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want)
+ }
+
+ for j, rep := range eps {
+ if gotPayload, _, err := rep.Read(nil); err != nil {
+ t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err)
+ } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
+ t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 346ca4bda..41eb0ca44 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -74,6 +74,8 @@ type endpoint struct {
route stack.Route `state:"manual"`
ttl uint8
stats tcpip.TransportEndpointStats `state:"nosave"`
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -344,9 +346,14 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch opt.(type) {
+ switch v := opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
+
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ e.linger = *v
+ e.mu.Unlock()
}
return nil
}
@@ -415,8 +422,17 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ *o = e.linger
+ e.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error {
@@ -430,6 +446,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
pkt.Owner = owner
icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize))
+ pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
copy(icmpv4, data)
// Set the ident to the user-specified port. Sequence number should
// already be set by the user.
@@ -462,6 +479,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
})
icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize))
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
copy(icmpv6, data)
// Set the ident. Sequence number is provided by the user.
icmpv6.SetIdent(ident)
@@ -597,7 +615,7 @@ func (*endpoint) Listen(int) *tcpip.Error {
}
// Accept is not supported by UDP, it just fails.
-func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 74ef6541e..87d510f96 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -13,12 +13,7 @@
// limitations under the License.
// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport
-// protocols for use in ping. To use it in the networking stack, this package
-// must be added to the project, and activated on the stack by passing
-// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport
-// protocols when calling stack.New(). Then endpoints can be created by passing
-// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number
-// when calling Stack.NewEndpoint().
+// protocols for use in ping.
package icmp
import (
@@ -42,6 +37,8 @@ const (
// protocol implements stack.TransportProtocol.
type protocol struct {
+ stack *stack.Stack
+
number tcpip.TransportProtocolNumber
}
@@ -62,20 +59,20 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
// NewEndpoint creates a new icmp endpoint. It implements
// stack.TransportProtocol.NewEndpoint.
-func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return newEndpoint(stack, netProto, p.number, waiterQueue)
+ return newEndpoint(p.stack, netProto, p.number, waiterQueue)
}
// NewRawEndpoint creates a new raw icmp endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return raw.NewEndpoint(stack, netProto, p.number, waiterQueue)
+ return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue)
}
// MinimumPacketSize returns the minimum valid icmp packet size.
@@ -104,17 +101,17 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
- return true
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+ return stack.UnknownDestinationPacketHandled
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (*protocol) SetOption(option interface{}) *tcpip.Error {
+func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements stack.TransportProtocol.Option.
-func (*protocol) Option(option interface{}) *tcpip.Error {
+func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
@@ -135,11 +132,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
}
// NewProtocol4 returns an ICMPv4 transport protocol.
-func NewProtocol4() stack.TransportProtocol {
- return &protocol{ProtocolNumber4}
+func NewProtocol4(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s, number: ProtocolNumber4}
}
// NewProtocol6 returns an ICMPv6 transport protocol.
-func NewProtocol6() stack.TransportProtocol {
- return &protocol{ProtocolNumber6}
+func NewProtocol6(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s, number: ProtocolNumber6}
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 81093e9ca..072601d2d 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -83,6 +83,8 @@ type endpoint struct {
stats tcpip.TransportEndpointStats `state:"nosave"`
bound bool
boundNIC tcpip.NICID
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
// lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
@@ -192,13 +194,13 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
return ep.ReadPacket(addr, nil)
}
-func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// TODO(gvisor.dev/issue/173): Implement.
return 0, nil, tcpip.ErrInvalidOptionValue
}
// Peek implements tcpip.Endpoint.Peek.
-func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -210,25 +212,25 @@ func (*endpoint) Disconnect() *tcpip.Error {
// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be
// connected, and this function always returnes tcpip.ErrNotSupported.
-func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used
// with Shutdown, and this function always returns tcpip.ErrNotSupported.
-func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with
// Listen, and this function always returns tcpip.ErrNotSupported.
-func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+func (*endpoint) Listen(backlog int) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with
// Accept, and this function always returns tcpip.ErrNotSupported.
-func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
@@ -267,12 +269,12 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotSupported
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
-func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// Even a connected socket doesn't return a remote address.
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -298,10 +300,16 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch opt.(type) {
+ switch v := opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
+ case *tcpip.LingerOption:
+ ep.mu.Lock()
+ ep.linger = *v
+ ep.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -366,12 +374,21 @@ func (ep *endpoint) LastError() *tcpip.Error {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.LingerOption:
+ ep.mu.Lock()
+ *o = ep.linger
+ ep.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrNotSupported
+ }
}
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return false, tcpip.ErrNotSupported
}
@@ -508,7 +525,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
}
// State implements socket.Socket.State.
-func (ep *endpoint) State() uint32 {
+func (*endpoint) State() uint32 {
return 0
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 71feeb748..e37c00523 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -84,6 +84,8 @@ type endpoint struct {
// Connect(), and is valid only when conneted is true.
route stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -446,12 +448,12 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Listen implements tcpip.Endpoint.Listen.
-func (e *endpoint) Listen(backlog int) *tcpip.Error {
+func (*endpoint) Listen(backlog int) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Accept implements tcpip.Endpoint.Accept.
-func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
@@ -482,12 +484,12 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotSupported
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
-func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// Even a connected socket doesn't return a remote address.
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -511,10 +513,16 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch opt.(type) {
+ switch v := opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ e.linger = *v
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -577,8 +585,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ *o = e.linger
+ e.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 234fb95ce..518449602 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -69,6 +69,7 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
@@ -93,6 +94,7 @@ go_test(
shard_count = 10,
deps = [
":tcp",
+ "//pkg/rand",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 72df5c2a1..6891fd245 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -522,7 +522,7 @@ func (h *handshake) execute() *tcpip.Error {
s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
defer s.Done()
- var sackEnabled SACKEnabled
+ var sackEnabled tcpip.TCPSACKEnabled
if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
// If stack returned an error when checking for SACKEnabled
// status then just default to switching off SACK negotiation.
@@ -747,6 +747,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV
func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) {
optLen := len(tf.opts)
tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen))
+ pkt.TransportProtocolNumber = header.TCPProtocolNumber
tcp.Encode(&header.TCPFields{
SrcPort: tf.id.LocalPort,
DstPort: tf.id.RemotePort,
@@ -897,7 +898,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// sendRaw sends a TCP segment to the endpoint's peer.
func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
var sackBlocks []header.SACKBlock
- if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
@@ -1002,9 +1003,8 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// (indicated by a negative send window scale).
e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
- rcvBufSize := seqnum.Size(e.receiveBufferSize())
e.rcvListMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
// Bootstrap the auto tuning algorithm. Starting at zero will
// result in a really large receive window after the first auto
// tuning adjustment.
@@ -1135,12 +1135,11 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
}
cont, err := e.handleSegment(s)
+ s.decRef()
if err != nil {
- s.decRef()
return err
}
if !cont {
- s.decRef()
return nil
}
}
@@ -1220,6 +1219,12 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
return true, nil
}
+ // Increase counter if after processing the segment we would potentially
+ // advertise a zero window.
+ if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above {
+ e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
+ }
+
// Now check if the received segment has caused us to transition
// to a CLOSED state, if yes then terminate processing and do
// not invoke the sender.
@@ -1232,7 +1237,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
// or a notification from the protocolMainLoop (caller goroutine).
// This means that with this return, the segment dequeue below can
// never occur on a closed endpoint.
- s.decRef()
return false, nil
}
@@ -1424,10 +1428,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.rcv.nonZeroWindow()
}
- if n&notifyReceiveWindowChanged != 0 {
- e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize())
- }
-
if n&notifyMTUChanged != 0 {
e.sndBufMu.Lock()
count := e.packetTooBigCount
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 6074cc24e..560b4904c 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -78,8 +78,8 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network
ackCheckers := append(checkers, checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
))
checker.IPv4(t, c.GetPacket(), ackCheckers...)
@@ -185,8 +185,8 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network
ackCheckers := append(checkers, checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
))
checker.IPv6(t, c.GetV6Packet(), ackCheckers...)
@@ -283,7 +283,7 @@ func TestV4RefuseOnV6Only(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
}
@@ -319,7 +319,7 @@ func TestV6RefuseOnBoundToV4Mapped(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
}
@@ -352,7 +352,7 @@ func testV4Accept(t *testing.T, c *context.Context) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
@@ -371,12 +371,12 @@ func testV4Accept(t *testing.T, c *context.Context) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- nep, _, err := c.EP.Accept()
+ nep, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept()
+ nep, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -492,7 +492,7 @@ func TestV6AcceptOnV6(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
@@ -510,13 +510,13 @@ func TestV6AcceptOnV6(t *testing.T) {
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
-
- nep, _, err := c.EP.Accept()
+ var addr tcpip.FullAddress
+ nep, _, err := c.EP.Accept(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept()
+ nep, _, err = c.EP.Accept(&addr)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -526,20 +526,14 @@ func TestV6AcceptOnV6(t *testing.T) {
}
}
+ if addr.Addr != context.TestV6Addr {
+ t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr)
+ }
+
// Make sure we can still query the v6 only status of the new endpoint,
// that is, that it is in fact a v6 socket.
if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
- t.Fatalf("GetSockOpt failed failed: %v", err)
- }
-
- // Check the peer address.
- addr, err := nep.GetRemoteAddress()
- if err != nil {
- t.Fatalf("GetRemoteAddress failed failed: %v", err)
- }
-
- if addr.Addr != context.TestV6Addr {
- t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr)
+ t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err)
}
}
@@ -566,8 +560,9 @@ func TestV4AcceptOnV4(t *testing.T) {
func testV4ListenClose(t *testing.T, c *context.Context) {
// Set the SynRcvd threshold to zero to force a syn cookie based accept
// to happen.
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
const n = uint16(32)
@@ -610,12 +605,12 @@ func testV4ListenClose(t *testing.T, c *context.Context) {
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- nep, _, err := c.EP.Accept()
+ nep, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept()
+ nep, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index c5d9eba5d..ae817091a 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -63,6 +63,17 @@ const (
StateClosing
)
+const (
+ // rcvAdvWndScale is used to split the available socket buffer into
+ // application buffer and the window to be advertised to the peer. This is
+ // currently hard coded to split the available space equally.
+ rcvAdvWndScale = 1
+
+ // SegOverheadFactor is used to multiply the value provided by the
+ // user on a SetSockOpt for setting the socket send/receive buffer sizes.
+ SegOverheadFactor = 2
+)
+
// connected returns true when s is one of the states representing an
// endpoint connected to a peer.
func (s EndpointState) connected() bool {
@@ -149,7 +160,6 @@ func (s EndpointState) String() string {
// Reasons for notifying the protocol goroutine.
const (
notifyNonZeroReceiveWindow = 1 << iota
- notifyReceiveWindowChanged
notifyClose
notifyMTUChanged
notifyDrain
@@ -384,13 +394,26 @@ type endpoint struct {
// to indicate to users that no more data is coming.
//
// rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
- rcvList segmentList `state:"wait"`
- rcvClosed bool
- rcvBufSize int
+ rcvListMu sync.Mutex `state:"nosave"`
+ rcvList segmentList `state:"wait"`
+ rcvClosed bool
+ // rcvBufSize is the total size of the receive buffer.
+ rcvBufSize int
+ // rcvBufUsed is the actual number of payload bytes held in the receive buffer
+ // not counting any overheads of the segments itself. NOTE: This will always
+ // be strictly <= rcvMemUsed below.
rcvBufUsed int
rcvAutoParams rcvBufAutoTuneParams
+ // rcvMemUsed tracks the total amount of memory in use by received segments
+ // held in rcvList, pendingRcvdSegments and the segment queue. This is used to
+ // compute the window and the actual available buffer space. This is distinct
+ // from rcvBufUsed above which is the actual number of payload bytes held in
+ // the buffer not including any segment overheads.
+ //
+ // rcvMemUsed must be accessed atomically.
+ rcvMemUsed int32
+
// mu protects all endpoint fields unless documented otherwise. mu must
// be acquired before interacting with the endpoint fields.
mu sync.Mutex `state:"nosave"`
@@ -852,12 +875,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
maxSynRetries: DefaultSynRetries,
}
- var ss SendBufferSizeOption
+ var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
e.sndBufSize = ss.Default
}
- var rs ReceiveBufferSizeOption
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
e.rcvBufSize = rs.Default
}
@@ -867,12 +890,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.cc = cs
}
- var mrb tcpip.ModerateReceiveBufferOption
+ var mrb tcpip.TCPModerateReceiveBufferOption
if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil {
e.rcvAutoParams.disabled = !bool(mrb)
}
- var de DelayEnabled
+ var de tcpip.TCPDelayEnabled
if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
e.SetSockOptBool(tcpip.DelayOption, true)
}
@@ -891,7 +914,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.probe = p
}
- e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.segmentQueue.ep = e
e.tsOffset = timeStampOffset()
e.acceptCond = sync.NewCond(&e.acceptMu)
@@ -904,7 +927,12 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
result := waiter.EventMask(0)
switch e.EndpointState() {
- case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
+ case StateInitial, StateBound:
+ // This prevents blocking of new sockets which are not
+ // connected when SO_LINGER is set.
+ result |= waiter.EventHUp
+
+ case StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
case StateClose, StateError, StateTimeWait:
@@ -1075,6 +1103,8 @@ func (e *endpoint) closeNoShutdownLocked() {
e.notifyProtocolGoroutine(notifyClose)
} else {
e.transitionToStateCloseLocked()
+ // Notify that the endpoint is closed.
+ e.waiterQueue.Notify(waiter.EventHUp)
}
}
@@ -1129,10 +1159,16 @@ func (e *endpoint) cleanupLocked() {
tcpip.DeleteDanglingEndpoint(e)
}
+// wndFromSpace returns the window that we can advertise based on the available
+// receive buffer space.
+func wndFromSpace(space int) int {
+ return space / (1 << rcvAdvWndScale)
+}
+
// initialReceiveWindow returns the initial receive window to advertise in the
// SYN/SYN-ACK.
func (e *endpoint) initialReceiveWindow() int {
- rcvWnd := e.receiveBufferAvailable()
+ rcvWnd := wndFromSpace(e.receiveBufferAvailable())
if rcvWnd > math.MaxUint16 {
rcvWnd = math.MaxUint16
}
@@ -1209,14 +1245,12 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
// reject valid data that might already be in flight as the
// acceptable window will shrink.
if rcvWnd > e.rcvBufSize {
- availBefore := e.receiveBufferAvailableLocked()
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvBufSize = rcvWnd
- availAfter := e.receiveBufferAvailableLocked()
- mask := uint32(notifyReceiveWindowChanged)
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
- mask |= notifyNonZeroReceiveWindow
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
- e.notifyProtocolGoroutine(mask)
}
// We only update prevCopied when we grow the buffer because in cases
@@ -1293,18 +1327,22 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
v := views[s.viewToDeliver]
s.viewToDeliver++
+ var delta int
if s.viewToDeliver >= len(views) {
e.rcvList.Remove(s)
+ // We only free up receive buffer space when the segment is released as the
+ // segment is still holding on to the views even though some views have been
+ // read out to the user.
+ delta = s.segMemSize()
s.decRef()
}
e.rcvBufUsed -= len(v)
-
// If the window was small before this read and if the read freed up
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
- if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -1317,14 +1355,17 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// indicating the reason why it's not writable.
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
- // The endpoint cannot be written to if it's not connected.
- if !e.EndpointState().connected() {
- switch e.EndpointState() {
- case StateError:
- return 0, e.HardError
- default:
- return 0, tcpip.ErrClosedForSend
- }
+ switch s := e.EndpointState(); {
+ case s == StateError:
+ return 0, e.HardError
+ case !s.connecting() && !s.connected():
+ return 0, tcpip.ErrClosedForSend
+ case s.connecting():
+ // As per RFC793, page 56, a send request arriving when in connecting
+ // state, can be queued to be completed after the state becomes
+ // connected. Return an error code for the caller of endpoint Write to
+ // try again, until the connection handshake is complete.
+ return 0, tcpip.ErrWouldBlock
}
// Check if the connection has already been closed for sends.
@@ -1478,11 +1519,11 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
}
// windowCrossedACKThresholdLocked checks if the receive window to be announced
-// now would be under aMSS or under half receive buffer, whichever smaller. This
-// is useful as a receive side silly window syndrome prevention mechanism. If
-// window grows to reasonable value, we should send ACK to the sender to inform
-// the rx space is now large. We also want ensure a series of small read()'s
-// won't trigger a flood of spurious tiny ACK's.
+// would be under aMSS or under the window derived from half receive buffer,
+// whichever smaller. This is useful as a receive side silly window syndrome
+// prevention mechanism. If window grows to reasonable value, we should send ACK
+// to the sender to inform the rx space is now large. We also want ensure a
+// series of small read()'s won't trigger a flood of spurious tiny ACK's.
//
// For large receive buffers, the threshold is aMSS - once reader reads more
// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of
@@ -1493,17 +1534,18 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
//
// Precondition: e.mu and e.rcvListMu must be held.
func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
- newAvail := e.receiveBufferAvailableLocked()
+ newAvail := wndFromSpace(e.receiveBufferAvailableLocked())
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
oldAvail = 0
}
-
threshold := int(e.amss)
- if threshold > e.rcvBufSize/2 {
- threshold = e.rcvBufSize / 2
+ // rcvBufFraction is the inverse of the fraction of receive buffer size that
+ // is used to decide if the available buffer space is now above it.
+ const rcvBufFraction = 2
+ if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold {
+ threshold = wndThreshold
}
-
switch {
case oldAvail < threshold && newAvail >= threshold:
return true, true
@@ -1632,18 +1674,24 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
- var rs ReceiveBufferSizeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err))
+ }
+
+ if v > rs.Max {
+ v = rs.Max
+ }
+
+ if v < math.MaxInt32/SegOverheadFactor {
+ v *= SegOverheadFactor
if v < rs.Min {
v = rs.Min
}
- if v > rs.Max {
- v = rs.Max
- }
+ } else {
+ v = math.MaxInt32
}
- mask := uint32(notifyReceiveWindowChanged)
-
e.LockUser()
e.rcvListMu.Lock()
@@ -1657,14 +1705,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
v = 1 << scale
}
- // Make sure 2*size doesn't overflow.
- if v > math.MaxInt32/2 {
- v = math.MaxInt32 / 2
- }
-
- availBefore := e.receiveBufferAvailableLocked()
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvBufSize = v
- availAfter := e.receiveBufferAvailableLocked()
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvAutoParams.disabled = true
@@ -1672,24 +1715,31 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// syndrome prevetion, when our available space grows above aMSS
// or half receive buffer, whichever smaller.
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
- mask |= notifyNonZeroReceiveWindow
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
e.rcvListMu.Unlock()
e.UnlockUser()
- e.notifyProtocolGoroutine(mask)
case tcpip.SendBufferSizeOption:
// Make sure the send buffer size is within the min and max
// allowed.
- var ss SendBufferSizeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ var ss tcpip.TCPSendBufferSizeRangeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err))
+ }
+
+ if v > ss.Max {
+ v = ss.Max
+ }
+
+ if v < math.MaxInt32/SegOverheadFactor {
+ v *= SegOverheadFactor
if v < ss.Min {
v = ss.Min
}
- if v > ss.Max {
- v = ss.Max
- }
+ } else {
+ v = math.MaxInt32
}
e.sndBufMu.Lock()
@@ -1722,7 +1772,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return tcpip.ErrInvalidOptionValue
}
}
- var rs ReceiveBufferSizeOption
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
if v < rs.Min/2 {
v = rs.Min / 2
@@ -1771,7 +1821,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
// Query the available cc algorithms in the stack and
// validate that the specified algorithm is actually
// supported in the stack.
- var avail tcpip.AvailableCongestionControlOption
+ var avail tcpip.TCPAvailableCongestionControlOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil {
return err
}
@@ -2047,8 +2097,10 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
e.UnlockUser()
case *tcpip.OriginalDestinationOption:
+ e.LockUser()
ipt := e.stack.IPTables()
- addr, port, err := ipt.OriginalDst(e.ID)
+ addr, port, err := ipt.OriginalDst(e.ID, e.NetProto)
+ e.UnlockUser()
if err != nil {
return err
}
@@ -2486,7 +2538,9 @@ func (e *endpoint) startAcceptedLoop() {
// Accept returns a new endpoint if a peer has established a connection
// to an endpoint previously set to listen mode.
-func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+//
+// addr if not-nil will contain the peer address of the returned endpoint.
+func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -2508,6 +2562,9 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
default:
return nil, nil, tcpip.ErrWouldBlock
}
+ if peerAddr != nil {
+ *peerAddr = n.getRemoteAddress()
+ }
return n, n.waiterQueue, nil
}
@@ -2610,11 +2667,15 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
+ return e.getRemoteAddress(), nil
+}
+
+func (e *endpoint) getRemoteAddress() tcpip.FullAddress {
return tcpip.FullAddress{
Addr: e.ID.RemoteAddress,
Port: e.ID.RemotePort,
NIC: e.boundNICID,
- }, nil
+ }
}
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
@@ -2685,13 +2746,8 @@ func (e *endpoint) updateSndBufferUsage(v int) {
func (e *endpoint) readyToRead(s *segment) {
e.rcvListMu.Lock()
if s != nil {
+ e.rcvBufUsed += s.payloadSize()
s.incRef()
- e.rcvBufUsed += s.data.Size()
- // Increase counter if the receive window falls down below MSS
- // or half receive buffer size, whichever smaller.
- if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
- e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
- }
e.rcvList.PushBack(s)
} else {
e.rcvClosed = true
@@ -2706,15 +2762,17 @@ func (e *endpoint) readyToRead(s *segment) {
func (e *endpoint) receiveBufferAvailableLocked() int {
// We may use more bytes than the buffer size when the receive buffer
// shrinks.
- if e.rcvBufUsed >= e.rcvBufSize {
+ memUsed := e.receiveMemUsed()
+ if memUsed >= e.rcvBufSize {
return 0
}
- return e.rcvBufSize - e.rcvBufUsed
+ return e.rcvBufSize - memUsed
}
// receiveBufferAvailable calculates how many bytes are still available in the
-// receive buffer.
+// receive buffer based on the actual memory used by all segments held in
+// receive buffer/pending and segment queue.
func (e *endpoint) receiveBufferAvailable() int {
e.rcvListMu.Lock()
available := e.receiveBufferAvailableLocked()
@@ -2722,16 +2780,37 @@ func (e *endpoint) receiveBufferAvailable() int {
return available
}
+// receiveBufferUsed returns the amount of in-use receive buffer.
+func (e *endpoint) receiveBufferUsed() int {
+ e.rcvListMu.Lock()
+ used := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+ return used
+}
+
+// receiveBufferSize returns the current size of the receive buffer.
func (e *endpoint) receiveBufferSize() int {
e.rcvListMu.Lock()
size := e.rcvBufSize
e.rcvListMu.Unlock()
-
return size
}
+// receiveMemUsed returns the total memory in use by segments held by this
+// endpoint.
+func (e *endpoint) receiveMemUsed() int {
+ return int(atomic.LoadInt32(&e.rcvMemUsed))
+}
+
+// updateReceiveMemUsed adds the provided delta to e.rcvMemUsed.
+func (e *endpoint) updateReceiveMemUsed(delta int) {
+ atomic.AddInt32(&e.rcvMemUsed, int32(delta))
+}
+
+// maxReceiveBufferSize returns the stack wide maximum receive buffer size for
+// an endpoint.
func (e *endpoint) maxReceiveBufferSize() int {
- var rs ReceiveBufferSizeOption
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
// As a fallback return the hardcoded max buffer size.
return MaxBufferSize
@@ -2811,7 +2890,7 @@ func timeStampOffset() uint32 {
// if the SYN options indicate that the SACK option was negotiated and the TCP
// stack is configured to enable TCP SACK option.
func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
- var v SACKEnabled
+ var v tcpip.TCPSACKEnabled
if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
return
@@ -2880,7 +2959,6 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
RcvAcc: e.rcv.rcvAcc,
RcvWndScale: e.rcv.rcvWndScale,
PendingBufUsed: e.rcv.pendingBufUsed,
- PendingBufSize: e.rcv.pendingBufSize,
}
// Copy sender state.
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 723e47ddc..b25431467 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -44,7 +44,7 @@ func (e *endpoint) drainSegmentLocked() {
// beforeSave is invoked by stateify.
func (e *endpoint) beforeSave() {
// Stop incoming packets.
- e.segmentQueue.setLimit(0)
+ e.segmentQueue.freeze()
e.mu.Lock()
defer e.mu.Unlock()
@@ -178,18 +178,18 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
- e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.segmentQueue.thaw()
epState := e.origEndpointState
switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
- var ss SendBufferSizeOption
+ var ss tcpip.TCPSendBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
}
}
- var rs ReceiveBufferSizeOption
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max {
panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max))
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index c5afa2680..5bce73605 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -12,12 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package tcp contains the implementation of the TCP transport protocol. To use
-// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing tcp.NewProtocol() as one of the
-// transport protocols when calling stack.New(). Then endpoints can be created
-// by passing tcp.ProtocolNumber as the transport protocol number when calling
-// Stack.NewEndpoint().
+// Package tcp contains the implementation of the TCP transport protocol.
package tcp
import (
@@ -29,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
@@ -79,50 +75,6 @@ const (
ccCubic = "cubic"
)
-// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to
-// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018.
-type SACKEnabled bool
-
-// Recovery is used by stack.(*Stack).TransportProtocolOption to
-// set loss detection algorithm in TCP.
-type Recovery int32
-
-const (
- // RACKLossDetection indicates RACK is used for loss detection and
- // recovery.
- RACKLossDetection Recovery = 1 << iota
-
- // RACKStaticReoWnd indicates the reordering window should not be
- // adjusted when DSACK is received.
- RACKStaticReoWnd
-
- // RACKNoDupTh indicates RACK should not consider the classic three
- // duplicate acknowledgements rule to mark the segments as lost. This
- // is used when reordering is not detected.
- RACKNoDupTh
-)
-
-// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to
-// enable/disable Nagle's algorithm in TCP.
-type DelayEnabled bool
-
-// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption
-// to get/set the default, min and max TCP send buffer sizes.
-type SendBufferSizeOption struct {
- Min int
- Default int
- Max int
-}
-
-// ReceiveBufferSizeOption is used by
-// stack.(Stack*).TransportProtocolOption to get/set the default, min and max
-// TCP receive buffer sizes.
-type ReceiveBufferSizeOption struct {
- Min int
- Default int
- Max int
-}
-
// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The
// value is protected by a mutex so that we can increment only when it's
// guaranteed not to go above a threshold.
@@ -181,12 +133,14 @@ func (s *synRcvdCounter) Threshold() uint64 {
}
type protocol struct {
+ stack *stack.Stack
+
mu sync.RWMutex
sackEnabled bool
- recovery Recovery
+ recovery tcpip.TCPRecovery
delayEnabled bool
- sendBufferSize SendBufferSizeOption
- recvBufferSize ReceiveBufferSizeOption
+ sendBufferSize tcpip.TCPSendBufferSizeRangeOption
+ recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption
congestionControl string
availableCongestionControl []string
moderateReceiveBuffer bool
@@ -207,14 +161,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new tcp endpoint.
-func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newEndpoint(stack, netProto, waiterQueue), nil
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue)
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue)
}
// MinimumPacketSize returns the minimum valid tcp packet size.
@@ -244,21 +198,20 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
s := newSegment(r, id, pkt)
defer s.decRef()
if !s.parse() || !s.csumValid {
- return false
+ return stack.UnknownDestinationPacketMalformed
}
- // There's nothing to do if this is already a reset packet.
- if s.flagIsSet(header.TCPFlagRst) {
- return true
+ if !s.flagIsSet(header.TCPFlagRst) {
+ replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
}
- replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
- return true
+ return stack.UnknownDestinationPacketHandled
}
// replyWithReset replies to the given segment with a reset segment.
@@ -296,49 +249,49 @@ func replyWithReset(s *segment, tos, ttl uint8) {
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case SACKEnabled:
+ case *tcpip.TCPSACKEnabled:
p.mu.Lock()
- p.sackEnabled = bool(v)
+ p.sackEnabled = bool(*v)
p.mu.Unlock()
return nil
- case Recovery:
+ case *tcpip.TCPRecovery:
p.mu.Lock()
- p.recovery = Recovery(v)
+ p.recovery = *v
p.mu.Unlock()
return nil
- case DelayEnabled:
+ case *tcpip.TCPDelayEnabled:
p.mu.Lock()
- p.delayEnabled = bool(v)
+ p.delayEnabled = bool(*v)
p.mu.Unlock()
return nil
- case SendBufferSizeOption:
+ case *tcpip.TCPSendBufferSizeRangeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
return tcpip.ErrInvalidOptionValue
}
p.mu.Lock()
- p.sendBufferSize = v
+ p.sendBufferSize = *v
p.mu.Unlock()
return nil
- case ReceiveBufferSizeOption:
+ case *tcpip.TCPReceiveBufferSizeRangeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
return tcpip.ErrInvalidOptionValue
}
p.mu.Lock()
- p.recvBufferSize = v
+ p.recvBufferSize = *v
p.mu.Unlock()
return nil
- case tcpip.CongestionControlOption:
+ case *tcpip.CongestionControlOption:
for _, c := range p.availableCongestionControl {
- if string(v) == c {
+ if string(*v) == c {
p.mu.Lock()
- p.congestionControl = string(v)
+ p.congestionControl = string(*v)
p.mu.Unlock()
return nil
}
@@ -347,75 +300,79 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
// is specified.
return tcpip.ErrNoSuchFile
- case tcpip.ModerateReceiveBufferOption:
+ case *tcpip.TCPModerateReceiveBufferOption:
p.mu.Lock()
- p.moderateReceiveBuffer = bool(v)
+ p.moderateReceiveBuffer = bool(*v)
p.mu.Unlock()
return nil
- case tcpip.TCPLingerTimeoutOption:
- if v < 0 {
- v = 0
- }
+ case *tcpip.TCPLingerTimeoutOption:
p.mu.Lock()
- p.lingerTimeout = time.Duration(v)
+ if *v < 0 {
+ p.lingerTimeout = 0
+ } else {
+ p.lingerTimeout = time.Duration(*v)
+ }
p.mu.Unlock()
return nil
- case tcpip.TCPTimeWaitTimeoutOption:
- if v < 0 {
- v = 0
- }
+ case *tcpip.TCPTimeWaitTimeoutOption:
p.mu.Lock()
- p.timeWaitTimeout = time.Duration(v)
+ if *v < 0 {
+ p.timeWaitTimeout = 0
+ } else {
+ p.timeWaitTimeout = time.Duration(*v)
+ }
p.mu.Unlock()
return nil
- case tcpip.TCPTimeWaitReuseOption:
- if v < tcpip.TCPTimeWaitReuseDisabled || v > tcpip.TCPTimeWaitReuseLoopbackOnly {
+ case *tcpip.TCPTimeWaitReuseOption:
+ if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly {
return tcpip.ErrInvalidOptionValue
}
p.mu.Lock()
- p.timeWaitReuse = v
+ p.timeWaitReuse = *v
p.mu.Unlock()
return nil
- case tcpip.TCPMinRTOOption:
- if v < 0 {
- v = tcpip.TCPMinRTOOption(MinRTO)
- }
+ case *tcpip.TCPMinRTOOption:
p.mu.Lock()
- p.minRTO = time.Duration(v)
+ if *v < 0 {
+ p.minRTO = MinRTO
+ } else {
+ p.minRTO = time.Duration(*v)
+ }
p.mu.Unlock()
return nil
- case tcpip.TCPMaxRTOOption:
- if v < 0 {
- v = tcpip.TCPMaxRTOOption(MaxRTO)
- }
+ case *tcpip.TCPMaxRTOOption:
p.mu.Lock()
- p.maxRTO = time.Duration(v)
+ if *v < 0 {
+ p.maxRTO = MaxRTO
+ } else {
+ p.maxRTO = time.Duration(*v)
+ }
p.mu.Unlock()
return nil
- case tcpip.TCPMaxRetriesOption:
+ case *tcpip.TCPMaxRetriesOption:
p.mu.Lock()
- p.maxRetries = uint32(v)
+ p.maxRetries = uint32(*v)
p.mu.Unlock()
return nil
- case tcpip.TCPSynRcvdCountThresholdOption:
+ case *tcpip.TCPSynRcvdCountThresholdOption:
p.mu.Lock()
- p.synRcvdCount.SetThreshold(uint64(v))
+ p.synRcvdCount.SetThreshold(uint64(*v))
p.mu.Unlock()
return nil
- case tcpip.TCPSynRetriesOption:
- if v < 1 || v > 255 {
+ case *tcpip.TCPSynRetriesOption:
+ if *v < 1 || *v > 255 {
return tcpip.ErrInvalidOptionValue
}
p.mu.Lock()
- p.synRetries = uint8(v)
+ p.synRetries = uint8(*v)
p.mu.Unlock()
return nil
@@ -425,33 +382,33 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
}
// Option implements stack.TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case *SACKEnabled:
+ case *tcpip.TCPSACKEnabled:
p.mu.RLock()
- *v = SACKEnabled(p.sackEnabled)
+ *v = tcpip.TCPSACKEnabled(p.sackEnabled)
p.mu.RUnlock()
return nil
- case *Recovery:
+ case *tcpip.TCPRecovery:
p.mu.RLock()
- *v = Recovery(p.recovery)
+ *v = tcpip.TCPRecovery(p.recovery)
p.mu.RUnlock()
return nil
- case *DelayEnabled:
+ case *tcpip.TCPDelayEnabled:
p.mu.RLock()
- *v = DelayEnabled(p.delayEnabled)
+ *v = tcpip.TCPDelayEnabled(p.delayEnabled)
p.mu.RUnlock()
return nil
- case *SendBufferSizeOption:
+ case *tcpip.TCPSendBufferSizeRangeOption:
p.mu.RLock()
*v = p.sendBufferSize
p.mu.RUnlock()
return nil
- case *ReceiveBufferSizeOption:
+ case *tcpip.TCPReceiveBufferSizeRangeOption:
p.mu.RLock()
*v = p.recvBufferSize
p.mu.RUnlock()
@@ -463,15 +420,15 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.RUnlock()
return nil
- case *tcpip.AvailableCongestionControlOption:
+ case *tcpip.TCPAvailableCongestionControlOption:
p.mu.RLock()
- *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
+ *v = tcpip.TCPAvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
p.mu.RUnlock()
return nil
- case *tcpip.ModerateReceiveBufferOption:
+ case *tcpip.TCPModerateReceiveBufferOption:
p.mu.RLock()
- *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer)
+ *v = tcpip.TCPModerateReceiveBufferOption(p.moderateReceiveBuffer)
p.mu.RUnlock()
return nil
@@ -546,33 +503,19 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter {
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
- // TCP header is variable length, peek at it first.
- hdrLen := header.TCPMinimumSize
- hdr, ok := pkt.Data.PullUp(hdrLen)
- if !ok {
- return false
- }
-
- // If the header has options, pull those up as well.
- if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() {
- // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of
- // packets.
- hdrLen = offset
- }
-
- _, ok = pkt.TransportHeader().Consume(hdrLen)
- return ok
+ return parse.TCP(pkt)
}
// NewProtocol returns a TCP transport protocol.
-func NewProtocol() stack.TransportProtocol {
+func NewProtocol(s *stack.Stack) stack.TransportProtocol {
p := protocol{
- sendBufferSize: SendBufferSizeOption{
+ stack: s,
+ sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{
Min: MinBufferSize,
Default: DefaultSendBufferSize,
Max: MaxBufferSize,
},
- recvBufferSize: ReceiveBufferSizeOption{
+ recvBufferSize: tcpip.TCPReceiveBufferSizeRangeOption{
Min: MinBufferSize,
Default: DefaultReceiveBufferSize,
Max: MaxBufferSize,
@@ -587,7 +530,7 @@ func NewProtocol() stack.TransportProtocol {
minRTO: MinRTO,
maxRTO: MaxRTO,
maxRetries: MaxRetries,
- recovery: RACKLossDetection,
+ recovery: tcpip.TCPRACKLossDetection,
}
p.dispatcher.init(runtime.GOMAXPROCS(0))
return &p
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 5e0bfe585..48bf196d8 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -47,22 +47,24 @@ type receiver struct {
closed bool
+ // pendingRcvdSegments is bounded by the receive buffer size of the
+ // endpoint.
pendingRcvdSegments segmentHeap
- pendingBufUsed seqnum.Size
- pendingBufSize seqnum.Size
+ // pendingBufUsed tracks the total number of bytes (including segment
+ // overhead) currently queued in pendingRcvdSegments.
+ pendingBufUsed int
// Time when the last ack was received.
lastRcvdAckTime time.Time `state:".(unixTime)"`
}
-func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver {
+func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
return &receiver{
ep: ep,
rcvNxt: irs + 1,
rcvAcc: irs.Add(rcvWnd + 1),
rcvWnd: rcvWnd,
rcvWndScale: rcvWndScale,
- pendingBufSize: pendingBufSize,
lastRcvdAckTime: time.Now(),
}
}
@@ -85,15 +87,30 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
// getSendParams returns the parameters needed by the sender when building
// segments to send.
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
- // Calculate the window size based on the available buffer space.
- receiveBufferAvailable := r.ep.receiveBufferAvailable()
- acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable))
- if r.rcvAcc.LessThan(acc) {
- r.rcvAcc = acc
+ avail := wndFromSpace(r.ep.receiveBufferAvailable())
+ if avail == 0 {
+ // We have no space available to accept any data, move to zero window
+ // state.
+ r.rcvWnd = 0
+ return r.rcvNxt, 0
+ }
+
+ acc := r.rcvNxt.Add(seqnum.Size(avail))
+ newWnd := r.rcvNxt.Size(acc)
+ curWnd := r.rcvNxt.Size(r.rcvAcc)
+
+ // Update rcvAcc only if new window is > previously advertised window. We
+ // should never shrink the acceptable sequence space once it has been
+ // advertised the peer. If we shrink the acceptable sequence space then we
+ // would end up dropping bytes that might already be in flight.
+ if newWnd > curWnd {
+ r.rcvAcc = r.rcvNxt.Add(newWnd)
+ } else {
+ newWnd = curWnd
}
// Stash away the non-scaled receive window as we use it for measuring
// receiver's estimated RTT.
- r.rcvWnd = r.rcvNxt.Size(r.rcvAcc)
+ r.rcvWnd = newWnd
return r.rcvNxt, r.rcvWnd >> r.rcvWndScale
}
@@ -195,7 +212,9 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
}
for i := first; i < len(r.pendingRcvdSegments); i++ {
+ r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize()
r.pendingRcvdSegments[i].decRef()
+
// Note that slice truncation does not allow garbage collection of
// truncated items, thus truncated items must be set to nil to avoid
// memory leaks.
@@ -268,14 +287,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// If we are in one of the shutdown states then we need to do
// additional checks before we try and process the segment.
switch state {
- case StateCloseWait:
- // If the ACK acks something not yet sent then we send an ACK.
- if r.ep.snd.sndNxt.LessThan(s.ackNumber) {
- r.ep.snd.sendAck()
- return true, nil
- }
- fallthrough
- case StateClosing, StateLastAck:
+ case StateCloseWait, StateClosing, StateLastAck:
if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
// Just drop the segment as we have
// already received a FIN and this
@@ -284,9 +296,31 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
return true, nil
}
fallthrough
- case StateFinWait1:
- fallthrough
- case StateFinWait2:
+ case StateFinWait1, StateFinWait2:
+ // If the ACK acks something not yet sent then we send an ACK.
+ //
+ // RFC793, page 37: If the connection is in a synchronized state,
+ // (ESTABLISHED, FIN-WAIT-1, FIN-WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK,
+ // TIME-WAIT), any unacceptable segment (out of window sequence number
+ // or unacceptable acknowledgment number) must elicit only an empty
+ // acknowledgment segment containing the current send-sequence number
+ // and an acknowledgment indicating the next sequence number expected
+ // to be received, and the connection remains in the same state.
+ //
+ // Just as on Linux, we do not apply this behavior when state is
+ // ESTABLISHED.
+ // Linux receive processing for all states except ESTABLISHED and
+ // TIME_WAIT is here where if the ACK check fails, we attempt to
+ // reply back with an ACK with correct seq/ack numbers.
+ // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L6186
+ // The ESTABLISHED state processing is here where if the ACK check
+ // fails, we ignore the packet:
+ // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591
+ if r.ep.snd.sndNxt.LessThan(s.ackNumber) {
+ r.ep.snd.sendAck()
+ return true, nil
+ }
+
// If we are closed for reads (either due to an
// incoming FIN or the user calling shutdown(..,
// SHUT_RD) then any data past the rcvNxt should
@@ -369,10 +403,16 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
// Defer segment processing if it can't be consumed now.
if !r.consumeSegment(s, segSeq, segLen) {
if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
- // We only store the segment if it's within our buffer
- // size limit.
- if r.pendingBufUsed < r.pendingBufSize {
- r.pendingBufUsed += seqnum.Size(s.segMemSize())
+ // We only store the segment if it's within our buffer size limit.
+ //
+ // Only use 75% of the receive buffer queue for out-of-order
+ // segments. This ensures that we always leave some space for the inorder
+ // segments to arrive allowing pending segments to be processed and
+ // delivered to the user.
+ if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 {
+ r.ep.rcvListMu.Lock()
+ r.pendingBufUsed += s.segMemSize()
+ r.ep.rcvListMu.Unlock()
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
@@ -406,7 +446,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
}
heap.Pop(&r.pendingRcvdSegments)
- r.pendingBufUsed -= seqnum.Size(s.segMemSize())
+ r.ep.rcvListMu.Lock()
+ r.pendingBufUsed -= s.segMemSize()
+ r.ep.rcvListMu.Unlock()
s.decRef()
}
return false, nil
@@ -421,6 +463,13 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
// Just silently drop any RST packets in TIME_WAIT. We do not support
// TIME_WAIT assasination as a result we confirm w/ fix 1 as described
// in https://tools.ietf.org/html/rfc1337#section-3.
+ //
+ // This behavior overrides RFC793 page 70 where we transition to CLOSED
+ // on receiving RST, which is also default Linux behavior.
+ // On Linux the RST can be ignored by setting sysctl net.ipv4.tcp_rfc1337.
+ //
+ // As we do not yet support PAWS, we are being conservative in ignoring
+ // RSTs by default.
if s.flagIsSet(header.TCPFlagRst) {
return false, false
}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 94307d31a..13acaf753 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "fmt"
"sync/atomic"
"time"
@@ -24,6 +25,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// queueFlags are used to indicate which queue of an endpoint a particular segment
+// belongs to. This is used to track memory accounting correctly.
+type queueFlags uint8
+
+const (
+ recvQ queueFlags = 1 << iota
+ sendQ
+)
+
// segment represents a TCP segment. It holds the payload and parsed TCP segment
// information, and can be added to intrusive lists.
// segment is mostly immutable, the only field allowed to change is viewToDeliver.
@@ -32,6 +42,8 @@ import (
type segment struct {
segmentEntry
refCnt int32
+ ep *endpoint
+ qFlags queueFlags
id stack.TransportEndpointID `state:"manual"`
route stack.Route `state:"manual"`
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
@@ -100,6 +112,8 @@ func (s *segment) clone() *segment {
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
xmitCount: s.xmitCount,
+ ep: s.ep,
+ qFlags: s.qFlags,
}
t.data = s.data.Clone(t.views[:])
return t
@@ -115,8 +129,34 @@ func (s *segment) flagsAreSet(flags uint8) bool {
return s.flags&flags == flags
}
+// setOwner sets the owning endpoint for this segment. Its required
+// to be called to ensure memory accounting for receive/send buffer
+// queues is done properly.
+func (s *segment) setOwner(ep *endpoint, qFlags queueFlags) {
+ switch qFlags {
+ case recvQ:
+ ep.updateReceiveMemUsed(s.segMemSize())
+ case sendQ:
+ // no memory account for sendQ yet.
+ default:
+ panic(fmt.Sprintf("unexpected queue flag %b", qFlags))
+ }
+ s.ep = ep
+ s.qFlags = qFlags
+}
+
func (s *segment) decRef() {
if atomic.AddInt32(&s.refCnt, -1) == 0 {
+ if s.ep != nil {
+ switch s.qFlags {
+ case recvQ:
+ s.ep.updateReceiveMemUsed(-s.segMemSize())
+ case sendQ:
+ // no memory accounting for sendQ yet.
+ default:
+ panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags))
+ }
+ }
s.route.Release()
}
}
@@ -138,6 +178,11 @@ func (s *segment) logicalLen() seqnum.Size {
return l
}
+// payloadSize is the size of s.data.
+func (s *segment) payloadSize() int {
+ return s.data.Size()
+}
+
// segMemSize is the amount of memory used to hold the segment data and
// the associated metadata.
func (s *segment) segMemSize() int {
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index 48a257137..54545a1b1 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -22,16 +22,16 @@ import (
//
// +stateify savable
type segmentQueue struct {
- mu sync.Mutex `state:"nosave"`
- list segmentList `state:"wait"`
- limit int
- used int
+ mu sync.Mutex `state:"nosave"`
+ list segmentList `state:"wait"`
+ ep *endpoint
+ frozen bool
}
// emptyLocked determines if the queue is empty.
// Preconditions: q.mu must be held.
func (q *segmentQueue) emptyLocked() bool {
- return q.used == 0
+ return q.list.Empty()
}
// empty determines if the queue is empty.
@@ -43,14 +43,6 @@ func (q *segmentQueue) empty() bool {
return r
}
-// setLimit updates the limit. No segments are immediately dropped in case the
-// queue becomes full due to the new limit.
-func (q *segmentQueue) setLimit(limit int) {
- q.mu.Lock()
- q.limit = limit
- q.mu.Unlock()
-}
-
// enqueue adds the given segment to the queue.
//
// Returns true when the segment is successfully added to the queue, in which
@@ -58,15 +50,23 @@ func (q *segmentQueue) setLimit(limit int) {
// false if the queue is full, in which case ownership is retained by the
// caller.
func (q *segmentQueue) enqueue(s *segment) bool {
+ // q.ep.receiveBufferParams() must be called without holding q.mu to
+ // avoid lock order inversion.
+ bufSz := q.ep.receiveBufferSize()
+ used := q.ep.receiveMemUsed()
q.mu.Lock()
- r := q.used < q.limit
- if r {
+ // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue
+ // is currently full).
+ allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen
+
+ if allow {
q.list.PushBack(s)
- q.used++
+ // Set the owner now that the endpoint owns the segment.
+ s.setOwner(q.ep, recvQ)
}
q.mu.Unlock()
- return r
+ return allow
}
// dequeue removes and returns the next segment from queue, if one exists.
@@ -77,9 +77,25 @@ func (q *segmentQueue) dequeue() *segment {
s := q.list.Front()
if s != nil {
q.list.Remove(s)
- q.used--
}
q.mu.Unlock()
return s
}
+
+// freeze prevents any more segments from being added to the queue. i.e all
+// future segmentQueue.enqueue will return false and not add the segment to the
+// queue till the queue is unfroze with a corresponding segmentQueue.thaw call.
+func (q *segmentQueue) freeze() {
+ q.mu.Lock()
+ q.frozen = true
+ q.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen queue using segmentQueue.freeze() and
+// allows new segments to be queued again.
+func (q *segmentQueue) thaw() {
+ q.mu.Lock()
+ q.frozen = false
+ q.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 99521f0c1..ef7f5719f 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -46,8 +46,9 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
t.Helper()
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err)
+ opt := tcpip.TCPSACKEnabled(enable)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("c.s.SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -162,8 +163,9 @@ func TestSackPermittedAccept(t *testing.T) {
// Set the SynRcvd threshold to
// zero to force a syn cookie
// based accept to happen.
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
setStackSACKPermitted(t, c, sackEnabled)
@@ -236,8 +238,9 @@ func TestSackDisabledAccept(t *testing.T) {
// Set the SynRcvd threshold to
// zero to force a syn cookie
// based accept to happen.
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index adb32e428..5b504d0d1 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -240,6 +241,38 @@ func TestTCPResetsSentIncrement(t *testing.T) {
}
}
+// TestTCPResetsSentNoICMP confirms that we don't get an ICMP
+// DstUnreachable packet when we try send a packet which is not part
+// of an active session.
+func TestTCPResetsSentNoICMP(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ stats := c.Stack().Stats()
+
+ // Send a SYN request for a closed port. This should elicit an RST
+ // but NOT an ICMPv4 DstUnreachable packet.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive whatever comes back.
+ b := c.GetPacket()
+ ipHdr := header.IPv4(b)
+ if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want {
+ t.Errorf("unexpected protocol, got = %d, want = %d", got, want)
+ }
+
+ // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded.
+ sent := stats.ICMP.V4PacketsSent
+ if got, want := sent.DstUnreachable.Value(), uint64(0); got != want {
+ t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want)
+ }
+}
+
// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates
// a RST if an ACK is received on the listening socket for which there is no
// active handshake in progress and we are not using SYN cookies.
@@ -291,12 +324,12 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -309,16 +342,16 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
// Lower stackwide TIME_WAIT timeout so that the reservations
// are released instantly on Close.
tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil {
- t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err)
}
c.EP.Close()
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
SrcPort: context.TestPort,
@@ -348,8 +381,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst)))
}
@@ -432,8 +465,9 @@ func TestConnectResetAfterClose(t *testing.T) {
// Set TCPLinger to 3 seconds so that sockets are marked closed
// after 3 second in FIN_WAIT2 state.
tcpLingerTimeout := 3 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err)
+ opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
@@ -446,8 +480,8 @@ func TestConnectResetAfterClose(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -488,8 +522,8 @@ func TestConnectResetAfterClose(t *testing.T) {
// RST is always generated with sndNxt which if the FIN
// has been sent will be 1 higher than the sequence number
// of the FIN itself.
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -506,8 +540,9 @@ func TestCurrentConnectedIncrement(t *testing.T) {
// Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
// after 1 second in TIME_WAIT state.
tcpTimeWaitTimeout := 1 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
+ opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
@@ -527,8 +562,8 @@ func TestCurrentConnectedIncrement(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -563,8 +598,8 @@ func TestCurrentConnectedIncrement(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -610,8 +645,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -631,8 +666,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -691,8 +726,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -743,8 +778,8 @@ func TestSimpleReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -933,8 +968,8 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
// Set the SynRcvd threshold to force a syn cookie based accept to happen.
opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, %#v): %s", tcp.ProtocolNumber, opt, err)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
@@ -996,7 +1031,7 @@ func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
@@ -1024,7 +1059,7 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
checker.IPv6(t, c.GetV6Packet(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete,
@@ -1061,8 +1096,8 @@ func TestTCPAckBeforeAcceptV4(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete,
@@ -1099,8 +1134,8 @@ func TestTCPAckBeforeAcceptV6(t *testing.T) {
checker.IPv6(t, c.GetV6Packet(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
func TestSendRstOnListenerRxAckV4(t *testing.T) {
@@ -1128,7 +1163,7 @@ func TestSendRstOnListenerRxAckV4(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
func TestSendRstOnListenerRxAckV6(t *testing.T) {
@@ -1156,7 +1191,7 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) {
checker.IPv6(t, c.GetV6Packet(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
// TestListenShutdown tests for the listening endpoint replying with RST
@@ -1272,8 +1307,8 @@ func TestTOSV4(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790), // Acknum is initial sequence number + 1
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790), // Acknum is initial sequence number + 1
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
checker.TOS(tos, 0),
@@ -1321,8 +1356,8 @@ func TestTrafficClassV6(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
checker.TOS(tos, 0),
@@ -1512,8 +1547,8 @@ func TestOutOfOrderReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1563,8 +1598,8 @@ func TestOutOfOrderReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1574,8 +1609,8 @@ func TestOutOfOrderFlood(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- // Create a new connection with initial window size of 10.
- c.CreateConnected(789, 30000, 10)
+ rcvBufSz := math.MaxUint16
+ c.CreateConnected(789, 30000, rcvBufSz)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
@@ -1596,8 +1631,8 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1617,8 +1652,8 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1637,8 +1672,8 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(793),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(793),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1679,8 +1714,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1694,7 +1729,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
// We shouldn't consume a sequence number on RST.
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
@@ -1748,8 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1762,7 +1797,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
))
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
@@ -1781,7 +1816,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
// RST is always generated with sndNxt which if the FIN
// has been sent will be 1 higher than the sequence
// number of the FIN itself.
- checker.SeqNum(uint32(c.IRS)+2),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
@@ -1827,7 +1862,8 @@ func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, 10)
+ const rcvBufSz = 10
+ c.CreateConnected(789, 30000, rcvBufSz)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1838,8 +1874,13 @@ func TestFullWindowReceive(t *testing.T) {
t.Fatalf("Read failed: %s", err)
}
- // Fill up the window.
- data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
+ // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies
+ // the provided buffer value by tcp.SegOverheadFactor to calculate the actual
+ // receive buffer size.
+ data := make([]byte, tcp.SegOverheadFactor*rcvBufSz)
+ for i := range data {
+ data[i] = byte(i % 255)
+ }
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
@@ -1860,10 +1901,10 @@ func TestFullWindowReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(0),
+ checker.TCPWindow(0),
),
)
@@ -1886,10 +1927,10 @@ func TestFullWindowReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(10),
+ checker.TCPWindow(10),
),
)
}
@@ -1898,12 +1939,15 @@ func TestNoWindowShrinking(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- // Start off with a window size of 10, then shrink it to 5.
- c.CreateConnected(789, 30000, 10)
-
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
- }
+ // Start off with a certain receive buffer then cut it in half and verify that
+ // the right edge of the window does not shrink.
+ // NOTE: Netstack doubles the value specified here.
+ rcvBufSize := 65536
+ iss := seqnum.Value(789)
+ // Enable window scaling with a scale of zero from our end.
+ c.CreateConnectedWithRawOptions(iss, 30000, rcvBufSize, []byte{
+ header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
+ })
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1912,14 +1956,15 @@ func TestNoWindowShrinking(t *testing.T) {
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
-
- // Send 3 bytes, check that the peer acknowledges them.
- data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
- c.SendPacket(data[:3], &context.Headers{
+ // Send a 1 byte payload so that we can record the current receive window.
+ // Send a payload of half the size of rcvBufSize.
+ seqNum := iss.Add(1)
+ payload := []byte{1}
+ c.SendPacket(payload, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: seqNum,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1931,46 +1976,93 @@ func TestNoWindowShrinking(t *testing.T) {
t.Fatalf("Timed out waiting for data to arrive")
}
- // Check that data is acknowledged, and that window doesn't go to zero
- // just yet because it was previously set to 10. It must go to 7 now.
- checker.IPv4(t, c.GetPacket(),
+ // Read the 1 byte payload we just sent.
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+ if got, want := payload, v; !bytes.Equal(got, want) {
+ t.Fatalf("got data: %v, want: %v", got, want)
+ }
+
+ seqNum = seqNum.Add(1)
+ // Verify that the ACK does not shrink the window.
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(793),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(7),
),
)
+ // Stash the initial window.
+ initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
+ initialLastAcceptableSeq := seqNum.Add(seqnum.Size(initialWnd))
+ // Now shrink the receive buffer to half its original size.
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
+ }
- // Send 7 more bytes, check that the window fills up.
- c.SendPacket(data[3:], &context.Headers{
+ data := generateRandomPayload(t, rcvBufSize)
+ // Send a payload of half the size of rcvBufSize.
+ c.SendPacket(data[:rcvBufSize/2], &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 793,
+ SeqNum: seqNum,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
+ seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2))
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
+ // Verify that the ACK does not shrink the window.
+ pkt = c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
+ newLastAcceptableSeq := seqNum.Add(seqnum.Size(newWnd))
+ if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) {
+ t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq)
}
+ // Send another payload of half the size of rcvBufSize. This should fill up the
+ // socket receive buffer and we should see a zero window.
+ c.SendPacket(data[rcvBufSize/2:], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqNum,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2))
+
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(0),
+ checker.TCPWindow(0),
),
)
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
// Receive data and check it.
- read := make([]byte, 0, 10)
+ read := make([]byte, 0, rcvBufSize)
for len(read) < len(data) {
v, _, err := c.EP.Read(nil)
if err != nil {
@@ -1984,15 +2076,15 @@ func TestNoWindowShrinking(t *testing.T) {
t.Fatalf("got data = %v, want = %v", read, data)
}
- // Check that we get an ACK for the newly non-zero window, which is the
- // new size.
+ // Check that we get an ACK for the newly non-zero window, which is the new
+ // receive buffer size we set after the connection was established.
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(5),
+ checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale),
),
)
}
@@ -2017,8 +2109,8 @@ func TestSimpleSend(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2059,8 +2151,8 @@ func TestZeroWindowSend(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2081,8 +2173,8 @@ func TestZeroWindowSend(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2121,16 +2213,16 @@ func TestScaledWindowConnect(t *testing.T) {
t.Fatalf("Write failed: %s", err)
}
- // Check that data is received, and that advertised window is 0xbfff,
+ // Check that data is received, and that advertised window is 0x5fff,
// that is, that it is scaled.
b := c.GetPacket()
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xbfff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0x5fff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2160,9 +2252,9 @@ func TestNonScaledWindowConnect(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xffff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0xffff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2196,19 +2288,20 @@ func TestScaledWindowAccept(t *testing.T) {
}
// Do 3-way handshake.
- c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2
+ c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -2226,16 +2319,16 @@ func TestScaledWindowAccept(t *testing.T) {
t.Fatalf("Write failed: %s", err)
}
- // Check that data is received, and that advertised window is 0xbfff,
+ // Check that data is received, and that advertised window is 0x5fff,
// that is, that it is scaled.
b := c.GetPacket()
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xbfff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0x5fff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2277,12 +2370,12 @@ func TestNonScaledWindowAccept(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -2307,9 +2400,9 @@ func TestNonScaledWindowAccept(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xffff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0xffff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2322,18 +2415,19 @@ func TestZeroScaledWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- // Set the window size such that a window scale of 4 will be used.
- const wnd = 65535 * 10
- const ws = uint32(4)
- c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{
+ // Set the buffer size such that a window scale of 5 will be used.
+ const bufSz = 65535 * 10
+ const ws = uint32(5)
+ c.CreateConnectedWithRawOptions(789, 30000, bufSz, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
// Write chunks of 50000 bytes.
- remain := wnd
+ remain := 0
sent := 0
data := make([]byte, 50000)
- for remain > len(data) {
+ // Keep writing till the window drops below len(data).
+ for {
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
@@ -2343,21 +2437,25 @@ func TestZeroScaledWindowReceive(t *testing.T) {
RcvWnd: 30000,
})
sent += len(data)
- remain -= len(data)
- checker.IPv4(t, c.GetPacket(),
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(remain>>ws)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
+ // Don't reduce window to zero here.
+ if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) {
+ remain = wnd << ws
+ break
+ }
}
// Make the window non-zero, but the scaled window zero.
- if remain >= 16 {
+ for remain >= 16 {
data = data[:remain-15]
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
@@ -2368,22 +2466,35 @@ func TestZeroScaledWindowReceive(t *testing.T) {
RcvWnd: 30000,
})
sent += len(data)
- remain -= len(data)
- checker.IPv4(t, c.GetPacket(),
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(0),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
+ // Since the receive buffer is split between window advertisement and
+ // application data buffer the window does not always reflect the space
+ // available and actual space available can be a bit more than what is
+ // advertised in the window.
+ wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize())
+ if wnd == 0 {
+ break
+ }
+ remain = wnd << ws
}
- // Read at least 1MSS of data. An ack should be sent in response to that.
+ // Read at least 2MSS of data. An ack should be sent in response to that.
+ // Since buffer space is now split in half between window and application
+ // data we need to read more than 1 MSS(65536) of data for a non-zero window
+ // update to be sent. For 1MSS worth of window to be available we need to
+ // read at least 128KB. Since our segments above were 50KB each it means
+ // we need to read at 3 packets.
sz := 0
- for sz < defaultMTU {
+ for sz < defaultMTU*2 {
v, _, err := c.EP.Read(nil)
if err != nil {
t.Fatalf("Read failed: %s", err)
@@ -2395,9 +2506,9 @@ func TestZeroScaledWindowReceive(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(sz>>ws)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -2464,8 +2575,8 @@ func TestSegmentMerging(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize+1),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+uint32(i)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2487,8 +2598,8 @@ func TestSegmentMerging(t *testing.T) {
checker.PayloadLen(len(allData)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+11),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+11),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2535,8 +2646,8 @@ func TestDelay(t *testing.T) {
checker.PayloadLen(len(want)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2582,8 +2693,8 @@ func TestUndelay(t *testing.T) {
checker.PayloadLen(len(allData[0])+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2605,8 +2716,8 @@ func TestUndelay(t *testing.T) {
checker.PayloadLen(len(allData[1])+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2667,8 +2778,8 @@ func TestMSSNotDelayed(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2719,8 +2830,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2840,12 +2951,12 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -2867,8 +2978,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
// Set the SynRcvd threshold to zero to force a syn cookie based accept
// to happen.
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ opt := tcpip.TCPSynRcvdCountThresholdOption(0)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
// Create EP and start listening.
@@ -2895,12 +3007,12 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -2961,7 +3073,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- const wndScale = 2
+ const wndScale = 3
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
}
@@ -2996,7 +3108,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagSyn),
checker.SrcPort(tcpHdr.SourcePort()),
- checker.SeqNum(tcpHdr.SequenceNumber()),
+ checker.TCPSeqNum(tcpHdr.SequenceNumber()),
checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
),
)
@@ -3017,8 +3129,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
),
)
@@ -3146,8 +3258,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
defer c.Cleanup()
const numRetries = 2
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRetriesOption(numRetries)); err != nil {
- t.Fatalf("could not set protocol option MaxRetries.\n")
+ opt := tcpip.TCPMaxRetriesOption(numRetries)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
@@ -3206,8 +3319,9 @@ func TestMaxRTO(t *testing.T) {
defer c.Cleanup()
rto := 1 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRTOOption(rto)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPMaxRTO(%d) failed: %s", rto, err)
+ opt := tcpip.TCPMaxRTOOption(rto)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
@@ -3309,8 +3423,8 @@ func TestFinImmediately(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3330,8 +3444,8 @@ func TestFinImmediately(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3352,8 +3466,8 @@ func TestFinRetransmit(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3363,8 +3477,8 @@ func TestFinRetransmit(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3384,8 +3498,8 @@ func TestFinRetransmit(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3408,8 +3522,8 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3433,8 +3547,8 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3455,8 +3569,8 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3483,8 +3597,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3502,8 +3616,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3522,8 +3636,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3543,8 +3657,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3567,8 +3681,8 @@ func TestFinWithPendingData(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3592,8 +3706,8 @@ func TestFinWithPendingData(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3608,8 +3722,8 @@ func TestFinWithPendingData(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3629,8 +3743,8 @@ func TestFinWithPendingData(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3654,8 +3768,8 @@ func TestFinWithPartialAck(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3675,8 +3789,8 @@ func TestFinWithPartialAck(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3690,8 +3804,8 @@ func TestFinWithPartialAck(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3706,8 +3820,8 @@ func TestFinWithPartialAck(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3798,8 +3912,8 @@ func scaledSendWindow(t *testing.T, scale uint8) {
checker.PayloadLen((1<<scale)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3937,7 +4051,7 @@ func TestReceivedSegmentQueuing(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3964,8 +4078,9 @@ func TestReadAfterClosedState(t *testing.T) {
// Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
// after 1 second in TIME_WAIT state.
tcpTimeWaitTimeout := 1 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
+ opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
@@ -3987,8 +4102,8 @@ func TestReadAfterClosedState(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -4012,8 +4127,8 @@ func TestReadAfterClosedState(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(uint32(791+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(uint32(791+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -4185,8 +4300,8 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
func TestDefaultBufferSizes(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
// Check the default values.
@@ -4204,11 +4319,15 @@ func TestDefaultBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default send buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{
- Min: 1,
- Default: tcp.DefaultSendBufferSize * 2,
- Max: tcp.DefaultSendBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPSendBufferSizeRangeOption{
+ Min: 1,
+ Default: tcp.DefaultSendBufferSize * 2,
+ Max: tcp.DefaultSendBufferSize * 20,
+ }
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
ep.Close()
@@ -4221,11 +4340,15 @@ func TestDefaultBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default receive buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{
- Min: 1,
- Default: tcp.DefaultReceiveBufferSize * 3,
- Max: tcp.DefaultReceiveBufferSize * 30}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ {
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{
+ Min: 1,
+ Default: tcp.DefaultReceiveBufferSize * 3,
+ Max: tcp.DefaultReceiveBufferSize * 30,
+ }
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
ep.Close()
@@ -4240,8 +4363,8 @@ func TestDefaultBufferSizes(t *testing.T) {
func TestMinMaxBufferSizes(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
// Check the default values.
@@ -4252,22 +4375,28 @@ func TestMinMaxBufferSizes(t *testing.T) {
defer ep.Close()
// Change the min/max values for send/receive
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
- // Set values below the min.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
+ // Set values below the min/2.
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil {
t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err)
}
checkRecvBufferSize(t, ep, 200)
- if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil {
t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err)
}
@@ -4278,19 +4407,21 @@ func TestMinMaxBufferSizes(t *testing.T) {
t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
}
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
+ // Values above max are capped at max and then doubled.
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2)
if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err)
}
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
+ // Values above max are capped at max and then doubled.
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2)
}
func TestBindToDeviceOption(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}})
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}})
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
@@ -4339,11 +4470,11 @@ func TestBindToDeviceOption(t *testing.T) {
func makeStack() (*stack.Stack, *tcpip.Error) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{
- ipv4.NewProtocol(),
- ipv6.NewProtocol(),
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocol,
+ ipv6.NewProtocol,
},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
id := loopback.New()
@@ -4626,8 +4757,8 @@ func TestPathMTUDiscovery(t *testing.T) {
checker.PayloadLen(size+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(seqNum),
- checker.AckNum(790),
+ checker.TCPSeqNum(seqNum),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -4718,8 +4849,8 @@ func TestStackSetCongestionControl(t *testing.T) {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err {
- t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err {
+ t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err)
}
var cc tcpip.CongestionControlOption
@@ -4751,12 +4882,12 @@ func TestStackAvailableCongestionControl(t *testing.T) {
s := c.Stack()
// Query permitted congestion control algorithms.
- var aCC tcpip.AvailableCongestionControlOption
+ var aCC tcpip.TCPAvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
}
- if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want)
}
}
@@ -4767,18 +4898,18 @@ func TestStackSetAvailableCongestionControl(t *testing.T) {
s := c.Stack()
// Setting AvailableCongestionControlOption should fail.
- aCC := tcpip.AvailableCongestionControlOption("xyz")
+ aCC := tcpip.TCPAvailableCongestionControlOption("xyz")
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC)
+ t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC)
}
// Verify that we still get the expected list of congestion control options.
- var cc tcpip.AvailableCongestionControlOption
+ var cc tcpip.TCPAvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
+ t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err)
}
- if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want)
}
}
@@ -4842,8 +4973,8 @@ func TestEndpointSetCongestionControl(t *testing.T) {
func enableCUBIC(t *testing.T, c *context.Context) {
t.Helper()
opt := tcpip.CongestionControlOption("cubic")
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -4878,8 +5009,8 @@ func TestKeepalive(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(uint32(790)),
+ checker.TCPSeqNum(uint32(c.IRS)),
+ checker.TCPAckNum(uint32(790)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -4912,8 +5043,8 @@ func TestKeepalive(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -4924,8 +5055,8 @@ func TestKeepalive(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
),
)
@@ -4950,8 +5081,8 @@ func TestKeepalive(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next-1)),
- checker.AckNum(uint32(790)),
+ checker.TCPSeqNum(uint32(next-1)),
+ checker.TCPAckNum(uint32(790)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -4977,8 +5108,8 @@ func TestKeepalive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next)),
- checker.AckNum(uint32(0)),
+ checker.TCPSeqNum(uint32(next)),
+ checker.TCPAckNum(uint32(0)),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -5018,7 +5149,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
checker.SrcPort(context.StackPort),
checker.DstPort(srcPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
if synCookieInUse {
@@ -5062,7 +5193,7 @@ func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCoo
checker.SrcPort(context.StackPort),
checker.DstPort(srcPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
if synCookieInUse {
@@ -5135,12 +5266,12 @@ func TestListenBacklogFull(t *testing.T) {
defer c.WQ.EventUnregister(&we)
for i := 0; i < listenBacklog; i++ {
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5152,7 +5283,7 @@ func TestListenBacklogFull(t *testing.T) {
}
// Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != tcpip.ErrWouldBlock {
select {
case <-ch:
@@ -5164,12 +5295,12 @@ func TestListenBacklogFull(t *testing.T) {
// Now a new handshake must succeed.
executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */)
- newEP, _, err := c.EP.Accept()
+ newEP, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- newEP, _, err = c.EP.Accept()
+ newEP, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5194,6 +5325,8 @@ func TestListenBacklogFull(t *testing.T) {
func TestListenNoAcceptNonUnicastV4(t *testing.T) {
multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
+ subnet := context.StackAddrWithPrefix.Subnet()
+ subnetBroadcastAddr := subnet.Broadcast()
tests := []struct {
name string
@@ -5201,53 +5334,59 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
dstAddr tcpip.Address
}{
{
- "SourceUnspecified",
- header.IPv4Any,
- context.StackAddr,
+ name: "SourceUnspecified",
+ srcAddr: header.IPv4Any,
+ dstAddr: context.StackAddr,
},
{
- "SourceBroadcast",
- header.IPv4Broadcast,
- context.StackAddr,
+ name: "SourceBroadcast",
+ srcAddr: header.IPv4Broadcast,
+ dstAddr: context.StackAddr,
},
{
- "SourceOurMulticast",
- multicastAddr,
- context.StackAddr,
+ name: "SourceOurMulticast",
+ srcAddr: multicastAddr,
+ dstAddr: context.StackAddr,
},
{
- "SourceOtherMulticast",
- otherMulticastAddr,
- context.StackAddr,
+ name: "SourceOtherMulticast",
+ srcAddr: otherMulticastAddr,
+ dstAddr: context.StackAddr,
},
{
- "DestUnspecified",
- context.TestAddr,
- header.IPv4Any,
+ name: "DestUnspecified",
+ srcAddr: context.TestAddr,
+ dstAddr: header.IPv4Any,
},
{
- "DestBroadcast",
- context.TestAddr,
- header.IPv4Broadcast,
+ name: "DestBroadcast",
+ srcAddr: context.TestAddr,
+ dstAddr: header.IPv4Broadcast,
},
{
- "DestOurMulticast",
- context.TestAddr,
- multicastAddr,
+ name: "DestOurMulticast",
+ srcAddr: context.TestAddr,
+ dstAddr: multicastAddr,
},
{
- "DestOtherMulticast",
- context.TestAddr,
- otherMulticastAddr,
+ name: "DestOtherMulticast",
+ srcAddr: context.TestAddr,
+ dstAddr: otherMulticastAddr,
+ },
+ {
+ name: "SrcSubnetBroadcast",
+ srcAddr: subnetBroadcastAddr,
+ dstAddr: context.StackAddr,
+ },
+ {
+ name: "DestSubnetBroadcast",
+ srcAddr: context.TestAddr,
+ dstAddr: subnetBroadcastAddr,
},
}
for _, test := range tests {
- test := test // capture range variable
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -5288,7 +5427,7 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
+ checker.TCPAckNum(uint32(irs)+1)))
})
}
}
@@ -5296,8 +5435,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
// non unicast IPv6 address are not accepted.
func TestListenNoAcceptNonUnicastV6(t *testing.T) {
- multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
- otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
+ multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
+ otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
tests := []struct {
name string
@@ -5347,11 +5486,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) {
}
for _, test := range tests {
- test := test // capture range variable
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -5392,7 +5527,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
+ checker.TCPAckNum(uint32(irs)+1)))
})
}
}
@@ -5440,7 +5575,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
@@ -5476,12 +5611,12 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- newEP, _, err := c.EP.Accept()
+ newEP, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- newEP, _, err = c.EP.Accept()
+ newEP, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5505,8 +5640,9 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err)
+ opt := tcpip.TCPSynRcvdCountThresholdOption(1)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
// Create TCP endpoint.
@@ -5552,12 +5688,12 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5568,7 +5704,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
// Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != tcpip.ErrWouldBlock {
select {
case <-ch:
@@ -5617,7 +5753,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
@@ -5638,8 +5774,8 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.AckNum(uint32(irs) + 1),
- checker.SeqNum(uint32(iss + 1)),
+ checker.TCPAckNum(uint32(irs) + 1),
+ checker.TCPSeqNum(uint32(iss + 1)),
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
@@ -5657,7 +5793,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
RcvWnd: 30000,
})
- newEP, _, err := c.EP.Accept()
+ newEP, _, err := c.EP.Accept(nil)
if err != nil && err != tcpip.ErrWouldBlock {
t.Fatalf("Accept failed: %s", err)
@@ -5672,7 +5808,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
// Wait for connection to be established.
select {
case <-ch:
- newEP, _, err = c.EP.Accept()
+ newEP, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5730,12 +5866,12 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
defer c.WQ.EventUnregister(&we)
// Verify that there is only one acceptable connection at this point.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5800,12 +5936,12 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
defer c.WQ.EventUnregister(&we)
// Now check that there is one acceptable connections.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5853,12 +5989,12 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- aep, _, err := ep.Accept()
+ aep, _, err := ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- aep, _, err = ep.Accept()
+ aep, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5906,13 +6042,19 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// the segment queue holding unprocessed packets is limited to 500.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
// Enable auto-tuning.
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPModerateReceiveBufferOption(true)
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
}
// Change the expected window scale to match the value needed for the
// maximum buffer size defined above.
@@ -5931,16 +6073,14 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
time.Sleep(latency)
rawEP.SendPacketWithTS([]byte{1}, tsVal)
- // Verify that the ACK has the expected window.
- wantRcvWnd := receiveBufferSize
- wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale))
- rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1))
+ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
+ rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize()
time.Sleep(25 * time.Millisecond)
// Allocate a large enough payload for the test.
- b := make([]byte, int(receiveBufferSize)*2)
- offset := 0
- payloadSize := receiveBufferSize - 1
+ payloadSize := receiveBufferSize * 2
+ b := make([]byte, int(payloadSize))
+
worker := (c.EP).(interface {
StopWork()
ResumeWork()
@@ -5949,11 +6089,15 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Stop the worker goroutine.
worker.StopWork()
- start := offset
- end := offset + payloadSize
+ start := 0
+ end := payloadSize / 2
packetsSent := 0
for ; start < end; start += mss {
- rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
+ packetEnd := start + mss
+ if start+mss > end {
+ packetEnd = end
+ }
+ rawEP.SendPacketWithTS(b[start:packetEnd], tsVal)
packetsSent++
}
@@ -5961,29 +6105,20 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// are waiting to be read.
worker.ResumeWork()
- // Since we read no bytes the window should goto zero till the
- // application reads some of the data.
- // Discard all intermediate acks except the last one.
- if packetsSent > 100 {
- for i := 0; i < (packetsSent / 100); i++ {
- _ = c.GetPacket()
- }
+ // Since we sent almost the full receive buffer worth of data (some may have
+ // been dropped due to segment overheads), we should get a zero window back.
+ pkt = c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(pkt).Payload())
+ gotRcvWnd := tcpHdr.WindowSize()
+ wantAckNum := tcpHdr.AckNumber()
+ if got, want := int(gotRcvWnd), 0; got != want {
+ t.Fatalf("got rcvWnd: %d, want: %d", got, want)
}
- rawEP.VerifyACKRcvWnd(0)
time.Sleep(25 * time.Millisecond)
- // Verify that sending more data when window is closed is dropped and
- // not acked.
+ // Verify that sending more data when receiveBuffer is exhausted.
rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
- // Verify that the stack sends us back an ACK with the sequence number
- // of the last packet sent indicating it was dropped.
- p := c.GetPacket()
- checker.IPv4(t, p, checker.TCP(
- checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
- checker.Window(0),
- ))
-
// Now read all the data from the endpoint and verify that advertised
// window increases to the full available buffer size.
for {
@@ -5996,23 +6131,26 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Verify that we receive a non-zero window update ACK. When running
// under thread santizer this test can end up sending more than 1
// ack, 1 for the non-zero window
- p = c.GetPacket()
+ p := c.GetPacket()
checker.IPv4(t, p, checker.TCP(
- checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
+ checker.TCPAckNum(uint32(wantAckNum)),
func(t *testing.T, h header.Transport) {
tcp, ok := h.(header.TCP)
if !ok {
return
}
- if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) {
- t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w)
+ // We use 10% here as the error margin upwards as the initial window we
+ // got was afer 1 segment was already in the receive buffer queue.
+ tolerance := 1.1
+ if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) {
+ t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance))
}
},
))
}
-// This test verifies that the auto tuning does not grow the receive buffer if
-// the application is not reading the data actively.
+// This test verifies that the advertised window is auto-tuned up as the
+// application is reading the data that is being received.
func TestReceiveBufferAutoTuning(t *testing.T) {
const mtu = 1500
const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
@@ -6022,26 +6160,33 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// Enable Auto-tuning.
stk := c.Stack()
- // Set lower limits for auto-tuning tests. This is required because the
- // test stops the worker which can cause packets to be dropped because
- // the segment queue holding unprocessed packets is limited to 300.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
// Enable auto-tuning.
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPModerateReceiveBufferOption(true)
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
}
// Change the expected window scale to match the value needed for the
// maximum buffer size used by stack.
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
-
- wantRcvWnd := receiveBufferSize
+ tsVal := uint32(rawEP.TSVal)
+ rawEP.NextSeqNum--
+ rawEP.SendPacketWithTS(nil, tsVal)
+ rawEP.NextSeqNum++
+ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
+ curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
scaleRcvWnd := func(rcvWnd int) uint16 {
return uint16(rcvWnd >> uint16(c.WindowScale))
}
@@ -6058,14 +6203,8 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
StopWork()
ResumeWork()
})
- tsVal := rawEP.TSVal
- // We are going to do our own computation of what the moderated receive
- // buffer should be based on sent/copied data per RTT and verify that
- // the advertised window by the stack matches our calculations.
- prevCopied := 0
- done := false
latency := 1 * time.Millisecond
- for i := 0; !done; i++ {
+ for i := 0; i < 5; i++ {
tsVal++
// Stop the worker goroutine.
@@ -6087,15 +6226,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// Give 1ms for the worker to process the packets.
time.Sleep(1 * time.Millisecond)
- // Verify that the advertised window on the ACK is reduced by
- // the total bytes sent.
- expectedWnd := wantRcvWnd - totalSent
- if packetsSent > 100 {
- for i := 0; i < (packetsSent / 100); i++ {
- _ = c.GetPacket()
+ lastACK := c.GetPacket()
+ // Discard any intermediate ACKs and only check the last ACK we get in a
+ // short time period of few ms.
+ for {
+ time.Sleep(1 * time.Millisecond)
+ pkt := c.GetPacketNonBlocking()
+ if pkt == nil {
+ break
}
+ lastACK = pkt
+ }
+ if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want {
+ t.Fatalf("advertised window got: %d, want <= %d", got, want)
}
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd))
// Now read all the data from the endpoint and invoke the
// moderation API to allow for receive buffer auto-tuning
@@ -6120,35 +6264,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
rawEP.NextSeqNum--
rawEP.SendPacketWithTS(nil, tsVal)
rawEP.NextSeqNum++
-
if i == 0 {
// In the first iteration the receiver based RTT is not
// yet known as a result the moderation code should not
// increase the advertised window.
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd))
- prevCopied = totalCopied
+ rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd))
} else {
- rttCopied := totalCopied
- if i == 1 {
- // The moderation code accumulates copied bytes till
- // RTT is established. So add in the bytes sent in
- // the first iteration to the total bytes for this
- // RTT.
- rttCopied += prevCopied
- // Now reset it to the initial value used by the
- // auto tuning logic.
- prevCopied = tcp.InitialCwnd * mss * 2
- }
- newWnd := rttCopied<<1 + 16*mss
- grow := (newWnd * (rttCopied - prevCopied)) / prevCopied
- newWnd += (grow << 1)
- if newWnd > maxReceiveBufferSize {
- newWnd = maxReceiveBufferSize
- done = true
+ pkt := c.GetPacket()
+ curRcvWnd = int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
+ // If thew new current window is close maxReceiveBufferSize then terminate
+ // the loop. This can happen before all iterations are done due to timing
+ // differences when running the test.
+ if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 {
+ break
}
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd))
- wantRcvWnd = newWnd
- prevCopied = rttCopied
// Increase the latency after first two iterations to
// establish a low RTT value in the receiver since it
// only tracks the lowest value. This ensures that when
@@ -6161,6 +6290,12 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
offset += payloadSize
payloadSize *= 2
}
+ // Check that at the end of our iterations the receive window grew close to the maximum
+ // permissible size of maxReceiveBufferSize/2
+ if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want {
+ t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want)
+ }
+
}
func TestDelayEnabled(t *testing.T) {
@@ -6169,7 +6304,7 @@ func TestDelayEnabled(t *testing.T) {
checkDelayOption(t, c, false, false) // Delay is disabled by default.
for _, v := range []struct {
- delayEnabled tcp.DelayEnabled
+ delayEnabled tcpip.TCPDelayEnabled
wantDelayOption bool
}{
{delayEnabled: false, wantDelayOption: false},
@@ -6177,17 +6312,17 @@ func TestDelayEnabled(t *testing.T) {
} {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil {
- t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err)
}
checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption)
}
}
-func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) {
+func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) {
t.Helper()
- var gotDelayEnabled tcp.DelayEnabled
+ var gotDelayEnabled tcpip.TCPDelayEnabled
if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err)
}
@@ -6293,12 +6428,12 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6312,8 +6447,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6330,8 +6465,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Now send a RST and this should be ignored and not
@@ -6359,8 +6494,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
}
@@ -6412,12 +6547,12 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6431,8 +6566,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6449,8 +6584,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Out of order ACK should generate an immediate ACK in
@@ -6466,8 +6601,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
}
@@ -6519,12 +6654,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6538,8 +6673,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6556,8 +6691,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Send a SYN request w/ sequence number lower than
@@ -6602,12 +6737,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
c.SendPacket(nil, ackHeaders)
// Try to accept the connection.
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6625,8 +6760,9 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
// Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
// after 5 seconds in TIME_WAIT state.
tcpTimeWaitTimeout := 5 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err)
}
want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1
@@ -6675,12 +6811,12 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6694,8 +6830,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6712,8 +6848,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
time.Sleep(2 * time.Second)
@@ -6727,8 +6863,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Sleep for 4 seconds so at this point we are 1 second past the
@@ -6756,8 +6892,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(ackHeaders.AckNum)),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(ackHeaders.AckNum)),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst)))
if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
@@ -6775,8 +6911,9 @@ func TestTCPCloseWithData(t *testing.T) {
// Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
// after 5 seconds in TIME_WAIT state.
tcpTimeWaitTimeout := 5 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err)
}
wq := &waiter.Queue{}
@@ -6824,12 +6961,12 @@ func TestTCPCloseWithData(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6855,8 +6992,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Now write a few bytes and then close the endpoint.
@@ -6874,8 +7011,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -6889,8 +7026,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))),
- checker.AckNum(uint32(iss+2)),
+ checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))),
+ checker.TCPAckNum(uint32(iss+2)),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
// First send a partial ACK.
@@ -6935,8 +7072,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(ackHeaders.AckNum)),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(ackHeaders.AckNum)),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst)))
}
@@ -6972,8 +7109,8 @@ func TestTCPUserTimeout(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -7007,8 +7144,8 @@ func TestTCPUserTimeout(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next)),
- checker.AckNum(uint32(0)),
+ checker.TCPSeqNum(uint32(next)),
+ checker.TCPAckNum(uint32(0)),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -7069,8 +7206,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(uint32(790)),
+ checker.TCPSeqNum(uint32(c.IRS)),
+ checker.TCPAckNum(uint32(790)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -7095,8 +7232,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(0)),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(0)),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -7112,9 +7249,9 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
}
}
-func TestIncreaseWindowOnReceive(t *testing.T) {
+func TestIncreaseWindowOnRead(t *testing.T) {
// This test ensures that the endpoint sends an ack,
- // after recv() when the window grows to more than 1 MSS.
+ // after read() when the window grows by more than 1 MSS.
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -7123,10 +7260,9 @@ func TestIncreaseWindowOnReceive(t *testing.T) {
// Write chunks of ~30000 bytes. It's important that two
// payloads make it equal or longer than MSS.
- remain := rcvBuf
+ remain := rcvBuf * 2
sent := 0
data := make([]byte, defaultMTU/2)
- lastWnd := uint16(0)
for remain > len(data) {
c.SendPacket(data, &context.Headers{
@@ -7139,46 +7275,43 @@ func TestIncreaseWindowOnReceive(t *testing.T) {
})
sent += len(data)
remain -= len(data)
-
- lastWnd = uint16(remain)
- if remain > 0xffff {
- lastWnd = 0xffff
- }
- checker.IPv4(t, c.GetPacket(),
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(lastWnd),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
+ // Break once the window drops below defaultMTU/2
+ if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 {
+ break
+ }
}
- if lastWnd == 0xffff || lastWnd == 0 {
- t.Fatalf("expected small, non-zero window: %d", lastWnd)
- }
-
- // We now have < 1 MSS in the buffer space. Read the data! An
- // ack should be sent in response to that. The window was not
- // zero, but it grew to larger than MSS.
- if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
- if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %s", err)
+ // We now have < 1 MSS in the buffer space. Read at least > 2 MSS
+ // worth of data as receive buffer space
+ read := 0
+ // defaultMTU is a good enough estimate for the MSS used for this
+ // connection.
+ for read < defaultMTU*2 {
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+ read += len(v)
}
- // After reading two packets, we surely crossed MSS. See the ack:
+ // After reading > MSS worth of data, we surely crossed MSS. See the ack:
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(0xffff)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindow(uint16(0xffff)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -7195,10 +7328,9 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
// Write chunks of ~30000 bytes. It's important that two
// payloads make it equal or longer than MSS.
- remain := rcvBuf
+ remain := rcvBuf * 2
sent := 0
data := make([]byte, defaultMTU/2)
- lastWnd := uint16(0)
for remain > len(data) {
c.SendPacket(data, &context.Headers{
@@ -7212,38 +7344,29 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
sent += len(data)
remain -= len(data)
- lastWnd = uint16(remain)
- if remain > 0xffff {
- lastWnd = 0xffff
- }
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(lastWnd),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindowLessThanEq(0xffff),
checker.TCPFlags(header.TCPFlagAck),
),
)
}
- if lastWnd == 0xffff || lastWnd == 0 {
- t.Fatalf("expected small, non-zero window: %d", lastWnd)
- }
-
// Increasing the buffer from should generate an ACK,
// since window grew from small value to larger equal MSS
c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2)
- // After reading two packets, we surely crossed MSS. See the ack:
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(0xffff)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindow(uint16(0xffff)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -7271,8 +7394,8 @@ func TestTCPDeferAccept(t *testing.T) {
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
- if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
+ if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock)
}
// Send data. This should result in an acceptable endpoint.
@@ -7288,14 +7411,14 @@ func TestTCPDeferAccept(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
// Give a bit of time for the socket to be delivered to the accept queue.
time.Sleep(50 * time.Millisecond)
- aep, _, err := c.EP.Accept()
+ aep, _, err := c.EP.Accept(nil)
if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err)
}
aep.Close()
@@ -7303,8 +7426,8 @@ func TestTCPDeferAccept(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
func TestTCPDeferAcceptTimeout(t *testing.T) {
@@ -7329,8 +7452,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
- if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
+ if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock)
}
// Sleep for a little of the tcpDeferAccept timeout.
@@ -7341,7 +7464,7 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
+ checker.TCPAckNum(uint32(irs)+1)))
// Send data. This should result in an acceptable endpoint.
c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
@@ -7357,14 +7480,14 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
// Give sometime for the endpoint to be delivered to the accept queue.
time.Sleep(50 * time.Millisecond)
- aep, _, err := c.EP.Accept()
+ aep, _, err := c.EP.Accept(nil)
if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err)
}
aep.Close()
@@ -7373,8 +7496,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
func TestResetDuringClose(t *testing.T) {
@@ -7399,8 +7522,8 @@ func TestResetDuringClose(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(irs.Add(1))),
- checker.AckNum(uint32(iss.Add(5)))))
+ checker.TCPSeqNum(uint32(irs.Add(1))),
+ checker.TCPAckNum(uint32(iss.Add(5)))))
// Close in a separate goroutine so that we can trigger
// a race with the RST we send below. This should not
@@ -7462,9 +7585,10 @@ func TestSetStackTimeWaitReuse(t *testing.T) {
}
for _, tc := range testCases {
- err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitReuseOption(tc.v))
+ opt := tcpip.TCPTimeWaitReuseOption(tc.v)
+ err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)
if got, want := err, tc.err; got != want {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.v, err, tc.err)
+ t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err)
}
if tc.err != nil {
continue
@@ -7480,3 +7604,14 @@ func TestSetStackTimeWaitReuse(t *testing.T) {
}
}
}
+
+// generateRandomPayload generates a random byte slice of the specified length
+// causing a fatal test failure if it is unable to do so.
+func generateRandomPayload(t *testing.T, n int) []byte {
+ t.Helper()
+ buf := make([]byte, n)
+ if _, err := rand.Read(buf); err != nil {
+ t.Fatalf("rand.Read(buf) failed: %s", err)
+ }
+ return buf
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 8edbff964..0f9ed06cd 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -131,8 +131,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
defer c.Cleanup()
if cookieEnabled {
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -158,9 +159,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
checker.PayloadLen(len(data)+header.TCPMinimumSize+12),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(wndSize),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(wndSize),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
checker.TCPTimestampChecker(true, 0, tsVal+1),
),
@@ -180,7 +181,8 @@ func TestTimeStampEnabledAccept(t *testing.T) {
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
+ // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that.
+ {false, 5, 0x4000},
}
for _, tc := range testCases {
timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
@@ -192,8 +194,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
defer c.Cleanup()
if cookieEnabled {
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -217,9 +220,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(wndSize),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(wndSize),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
checker.TCPTimestampChecker(false, 0, 0),
),
@@ -235,7 +238,9 @@ func TestTimeStampDisabledAccept(t *testing.T) {
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
+ // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of
+ // that.
+ {false, 5, 0x4000},
}
for _, tc := range testCases {
timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 1f5340cd0..faf51ef95 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -53,11 +53,11 @@ const (
TestPort = 4096
// StackV6Addr is the IPv6 address assigned to the stack.
- StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
// TestV6Addr is the source address for packets sent to the stack via
// the link layer endpoint.
- TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
// StackV4MappedAddr is StackAddr as a mapped v6 address.
StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr
@@ -73,6 +73,18 @@ const (
testInitialSequenceNumber = 789
)
+// StackAddrWithPrefix is StackAddr with its associated prefix length.
+var StackAddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: StackAddr,
+ PrefixLen: 24,
+}
+
+// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length.
+var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: StackV6Addr,
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+}
+
// Headers is used to represent the TCP header fields when building a
// new packet.
type Headers struct {
@@ -133,32 +145,39 @@ type Context struct {
// WindowScale is the expected window scale in SYN packets sent by
// the stack.
WindowScale uint8
+
+ // RcvdWindowScale is the actual window scale sent by the stack in
+ // SYN/SYN-ACK.
+ RcvdWindowScale uint8
}
// New allocates and initializes a test context containing a new
// stack and a link-layer endpoint.
func New(t *testing.T, mtu uint32) *Context {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
const sendBufferSize = 1 << 20 // 1 MiB
const recvBufferSize = 1 << 20 // 1 MiB
// Allow minimum send/receive buffer sizes to be 1 during tests.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ sendBufOpt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sendBufOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, sendBufOpt, err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ rcvBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvBufOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, rcvBufOpt, err)
}
// Increase minimum RTO in tests to avoid test flakes due to early
// retransmit in case the test executors are overloaded and cause timers
// to fire earlier than expected.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil {
- t.Fatalf("failed to set stack-wide minRTO: %s", err)
+ minRTOOpt := tcpip.TCPMinRTOOption(3 * time.Second)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
}
// Some of the congestion control tests send up to 640 packets, we so
@@ -181,12 +200,20 @@ func New(t *testing.T, mtu uint32) *Context {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ v4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: StackAddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
}
- if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ v6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: StackV6AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -238,18 +265,17 @@ func (c *Context) CheckNoPacket(errMsg string) {
c.CheckNoPacketTimeout(errMsg, 1*time.Second)
}
-// GetPacket reads a packet from the link layer endpoint and verifies
+// GetPacketWithTimeout reads a packet from the link layer endpoint and verifies
// that it is an IPv4 packet with the expected source and destination
-// addresses. It will fail with an error if no packet is received for
-// 2 seconds.
-func (c *Context) GetPacket() []byte {
+// addresses. If no packet is received in the specified timeout it will return
+// nil.
+func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte {
c.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
p, ok := c.linkEP.ReadContext(ctx)
if !ok {
- c.t.Fatalf("Packet wasn't written out")
return nil
}
@@ -257,6 +283,14 @@ func (c *Context) GetPacket() []byte {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
+ // Just check that the stack set the transport protocol number for outbound
+ // TCP messages.
+ // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part
+ // of the headerinfo.
+ if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber {
+ c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber)
+ }
+
vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
b := vv.ToView()
@@ -268,6 +302,21 @@ func (c *Context) GetPacket() []byte {
return b
}
+// GetPacket reads a packet from the link layer endpoint and verifies
+// that it is an IPv4 packet with the expected source and destination
+// addresses.
+func (c *Context) GetPacket() []byte {
+ c.t.Helper()
+
+ p := c.GetPacketWithTimeout(5 * time.Second)
+ if p == nil {
+ c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
+
+ return p
+}
+
// GetPacketNonBlocking reads a packet from the link layer endpoint
// and verifies that it is an IPv4 packet with the expected source
// and destination address. If no packet is available it will return
@@ -284,6 +333,14 @@ func (c *Context) GetPacketNonBlocking() []byte {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
+ // Just check that the stack set the transport protocol number for outbound
+ // TCP messages.
+ // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part
+ // of the headerinfo.
+ if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber {
+ c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber)
+ }
+
vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
b := vv.ToView()
@@ -447,8 +504,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op
checker.PayloadLen(size+header.TCPMinimumSize+optlen),
checker.TCP(
checker.DstPort(TestPort),
- checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
+ checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -474,8 +531,8 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int
checker.PayloadLen(size+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(TestPort),
- checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
+ checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -613,6 +670,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
}
tcpHdr := header.TCP(header.IPv4(b).Payload())
+ synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */)
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
c.SendPacket(nil, &Headers{
@@ -630,8 +688,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
checker.TCP(
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
),
)
@@ -648,6 +706,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
}
+ c.RcvdWindowScale = uint8(synOpts.WS)
c.Port = tcpHdr.SourcePort()
}
@@ -719,17 +778,18 @@ func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) {
r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload)))
}
-// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
-// tsVal.
-func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
+// VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches
+// the provided tsVal as well as returns the original packet.
+func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte {
+ r.C.t.Helper()
// Read ACK and verify that tsEcr of ACK packet is [1,2,3,4]
ackPacket := r.C.GetPacket()
checker.IPv4(r.C.t, ackPacket,
checker.TCP(
checker.DstPort(r.SrcPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
+ checker.TCPSeqNum(uint32(r.AckNum)),
+ checker.TCPAckNum(uint32(r.NextSeqNum)),
checker.TCPTimestampChecker(true, 0, tsVal),
),
)
@@ -737,19 +797,28 @@ func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
tcpSeg := header.TCP(header.IPv4(ackPacket).Payload())
opts := tcpSeg.ParsedOptions()
r.RecentTS = opts.TSVal
+ return ackPacket
+}
+
+// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
+// tsVal.
+func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
+ r.C.t.Helper()
+ _ = r.VerifyAndReturnACKWithTS(tsVal)
}
// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK
// matches the provided rcvWnd.
func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) {
+ r.C.t.Helper()
ackPacket := r.C.GetPacket()
checker.IPv4(r.C.t, ackPacket,
checker.TCP(
checker.DstPort(r.SrcPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
- checker.Window(rcvWnd),
+ checker.TCPSeqNum(uint32(r.AckNum)),
+ checker.TCPAckNum(uint32(r.NextSeqNum)),
+ checker.TCPWindow(rcvWnd),
),
)
}
@@ -768,8 +837,8 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
checker.TCP(
checker.DstPort(r.SrcPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
+ checker.TCPSeqNum(uint32(r.AckNum)),
+ checker.TCPAckNum(uint32(r.NextSeqNum)),
checker.TCPSACKBlockChecker(sackBlocks),
),
)
@@ -861,8 +930,8 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
tcpCheckers := []checker.TransportChecker{
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS) + 1),
- checker.AckNum(uint32(iss) + 1),
+ checker.TCPSeqNum(uint32(c.IRS) + 1),
+ checker.TCPAckNum(uint32(iss) + 1),
}
// Verify that tsEcr of ACK packet is wantOptions.TSVal if the
@@ -897,7 +966,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Mark in context that timestamp option is enabled for this endpoint.
c.TimeStampEnabled = true
-
+ c.RcvdWindowScale = uint8(synOptions.WS)
return &RawEndpoint{
C: c,
SrcPort: tcpSeg.DestinationPort(),
@@ -948,12 +1017,12 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
c.t.Fatalf("Accept failed: %v", err)
}
@@ -990,6 +1059,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP
// value of the window scaling option to be sent in the SYN. If synOptions.WS >
// 0 then we send the WindowScale option.
func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+ c.t.Helper()
opts := make([]byte, header.TCPOptionsMaximumSize)
offset := 0
offset += header.EncodeMSSOption(uint32(maxPayload), opts)
@@ -1028,13 +1098,14 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
// are present.
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
+ rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */)
c.IRS = seqnum.Value(tcp.SequenceNumber())
tcpCheckers := []checker.TransportChecker{
checker.SrcPort(StackPort),
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(iss) + 1),
+ checker.TCPAckNum(uint32(iss) + 1),
checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}),
}
@@ -1077,6 +1148,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
// Send ACK.
c.SendPacket(nil, ackHeaders)
+ c.RcvdWindowScale = uint8(rcvdSynOptions.WS)
c.Port = StackPort
return &RawEndpoint{
@@ -1096,7 +1168,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true
// for the Stack in the context.
func (c *Context) SACKEnabled() bool {
- var v tcp.SACKEnabled
+ var v tcpip.TCPSACKEnabled
if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
return false
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index b5d2d0ba6..c78549424 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index c74bc4d94..d57ed5d79 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -139,7 +139,7 @@ type endpoint struct {
// multicastMemberships that need to be remvoed when the endpoint is
// closed. Protected by the mu mutex.
- multicastMemberships []multicastMembership
+ multicastMemberships map[multicastMembership]struct{}
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -154,6 +154,9 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
+
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
}
// +stateify savable
@@ -182,12 +185,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
// TTL=1.
//
// Linux defaults to TTL=1.
- multicastTTL: 1,
- multicastLoop: true,
- rcvBufSizeMax: 32 * 1024,
- sndBufSizeMax: 32 * 1024,
- state: StateInitial,
- uniqueID: s.UniqueID(),
+ multicastTTL: 1,
+ multicastLoop: true,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
+ multicastMemberships: make(map[multicastMembership]struct{}),
+ state: StateInitial,
+ uniqueID: s.UniqueID(),
}
// Override with stack defaults.
@@ -237,10 +241,10 @@ func (e *endpoint) Close() {
e.boundPortFlags = ports.Flags{}
}
- for _, mem := range e.multicastMemberships {
+ for mem := range e.multicastMemberships {
e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
}
- e.multicastMemberships = nil
+ e.multicastMemberships = make(map[multicastMembership]struct{})
// Close the receive list and drain it.
e.rcvMu.Lock()
@@ -752,17 +756,15 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- for _, mem := range e.multicastMemberships {
- if mem == memToInsert {
- return tcpip.ErrPortInUse
- }
+ if _, ok := e.multicastMemberships[memToInsert]; ok {
+ return tcpip.ErrPortInUse
}
if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
return err
}
- e.multicastMemberships = append(e.multicastMemberships, memToInsert)
+ e.multicastMemberships[memToInsert] = struct{}{}
case *tcpip.RemoveMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
@@ -786,18 +788,11 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
}
memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
- memToRemoveIndex := -1
e.mu.Lock()
defer e.mu.Unlock()
- for i, mem := range e.multicastMemberships {
- if mem == memToRemove {
- memToRemoveIndex = i
- break
- }
- }
- if memToRemoveIndex == -1 {
+ if _, ok := e.multicastMemberships[memToRemove]; !ok {
return tcpip.ErrBadLocalAddress
}
@@ -805,8 +800,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
return err
}
- e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
- e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+ delete(e.multicastMemberships, memToRemove)
case *tcpip.BindToDeviceOption:
id := tcpip.NICID(*v)
@@ -819,6 +813,11 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.SocketDetachFilterOption:
return nil
+
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ e.linger = *v
+ e.mu.Unlock()
}
return nil
}
@@ -975,6 +974,11 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
*o = tcpip.BindToDeviceOption(e.bindToDevice)
e.mu.RUnlock()
+ case *tcpip.LingerOption:
+ e.mu.RLock()
+ *o = e.linger
+ e.mu.RUnlock()
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -992,6 +996,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// Initialize the UDP header.
udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ pkt.TransportProtocolNumber = ProtocolNumber
length := uint16(pkt.Size())
udp.Encode(&header.UDPFields{
@@ -1218,7 +1223,7 @@ func (*endpoint) Listen(int) *tcpip.Error {
}
// Accept is not supported by UDP, it just fails.
-func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
@@ -1392,15 +1397,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
return
}
- // Never receive from a multicast address.
- if header.IsV4MulticastAddress(id.RemoteAddress) ||
- header.IsV6MulticastAddress(id.RemoteAddress) {
- e.stack.Stats().UDP.InvalidSourceAddress.Increment()
- e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment()
- e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
- return
- }
-
if !verifyChecksum(r, hdr, pkt) {
// Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 851e6b635..858c99a45 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -92,7 +92,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
- for _, m := range e.multicastMemberships {
+ for m := range e.multicastMemberships {
if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index f65751dd4..da5b1deb2 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -12,18 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package udp contains the implementation of the UDP transport protocol. To use
-// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing udp.NewProtocol() as one of the
-// transport protocols when calling stack.New(). Then endpoints can be created
-// by passing udp.ProtocolNumber as the transport protocol number when calling
-// Stack.NewEndpoint().
+// Package udp contains the implementation of the UDP transport protocol.
package udp
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/waiter"
@@ -49,6 +45,7 @@ const (
)
type protocol struct {
+ stack *stack.Stack
}
// Number returns the udp protocol number.
@@ -57,14 +54,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new udp endpoint.
-func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newEndpoint(stack, netProto, waiterQueue), nil
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw UDP endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue)
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue)
}
// MinimumPacketSize returns the minimum valid udp packet size.
@@ -79,135 +76,30 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
return h.SourcePort(), h.DestinationPort(), nil
}
-// HandleUnknownDestinationPacket handles packets targeted at this protocol but
-// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+// HandleUnknownDestinationPacket handles packets that are targeted at this
+// protocol but don't match any existing endpoint.
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
- // Malformed packet.
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
- return true
+ return stack.UnknownDestinationPacketMalformed
}
if !verifyChecksum(r, hdr, pkt) {
- // Checksum Error.
r.Stack().Stats().UDP.ChecksumErrors.Increment()
- return true
+ return stack.UnknownDestinationPacketMalformed
}
- // Only send ICMP error if the address is not a multicast/broadcast
- // v4/v6 address or the source is not the unspecified address.
- //
- // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4
- if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any {
- return true
- }
-
- // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
- // Unreachable messages with code:
- //
- // 2 (Protocol Unreachable), when the designated transport protocol
- // is not supported; or
- //
- // 3 (Port Unreachable), when the designated transport protocol
- // (e.g., UDP) is unable to demultiplex the datagram but has no
- // protocol mechanism to inform the sender.
- switch len(id.LocalAddress) {
- case header.IPv4AddressSize:
- if !r.Stack().AllowICMPMessage() {
- r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment()
- return true
- }
- // As per RFC 1812 Section 4.3.2.3
- //
- // ICMP datagram SHOULD contain as much of the original
- // datagram as possible without the length of the ICMP
- // datagram exceeding 576 bytes
- //
- // NOTE: The above RFC referenced is different from the original
- // recommendation in RFC 1122 where it mentioned that at least 8
- // bytes of the payload must be included. Today linux and other
- // systems implement the] RFC1812 definition and not the original
- // RFC 1122 requirement.
- mtu := int(r.MTU())
- if mtu > header.IPv4MinimumProcessableDatagramSize {
- mtu = header.IPv4MinimumProcessableDatagramSize
- }
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
- available := int(mtu) - headerLen
- payloadLen := pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size() + pkt.Data.Size()
- if payloadLen > available {
- payloadLen = available
- }
-
- // The buffers used by pkt may be used elsewhere in the system.
- // For example, a raw or packet socket may use what UDP
- // considers an unreachable destination. Thus we deep copy pkt
- // to prevent multiple ownership and SR errors.
- newHeader := append(buffer.View(nil), pkt.NetworkHeader().View()...)
- newHeader = append(newHeader, pkt.TransportHeader().View()...)
- payload := newHeader.ToVectorisedView()
- payload.AppendView(pkt.Data.ToView())
- payload.CapLength(payloadLen)
-
- icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: headerLen,
- Data: payload,
- })
- icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt)
-
- case header.IPv6AddressSize:
- if !r.Stack().AllowICMPMessage() {
- r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment()
- return true
- }
-
- // As per RFC 4443 section 2.4
- //
- // (c) Every ICMPv6 error message (type < 128) MUST include
- // as much of the IPv6 offending (invoking) packet (the
- // packet that caused the error) as possible without making
- // the error message packet exceed the minimum IPv6 MTU
- // [IPv6].
- mtu := int(r.MTU())
- if mtu > header.IPv6MinimumMTU {
- mtu = header.IPv6MinimumMTU
- }
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
- available := int(mtu) - headerLen
- network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
- payloadLen := len(network) + len(transport) + pkt.Data.Size()
- if payloadLen > available {
- payloadLen = available
- }
- payload := buffer.NewVectorisedView(len(network)+len(transport), []buffer.View{network, transport})
- payload.Append(pkt.Data)
- payload.CapLength(payloadLen)
-
- icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: headerLen,
- Data: payload,
- })
- icmpHdr := header.ICMPv6(icmpPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6PortUnreachable)
- icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, icmpPkt.Data))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt)
- }
- return true
+ return stack.UnknownDestinationPacketUnhandled
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements stack.TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
@@ -219,11 +111,10 @@ func (*protocol) Wait() {}
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
- _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize)
- return ok
+ return parse.UDP(pkt)
}
// NewProtocol returns a UDP transport protocol.
-func NewProtocol() stack.TransportProtocol {
- return &protocol{}
+func NewProtocol(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 0cbc045d8..4e14a5fc5 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -294,8 +294,8 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
return newDualTestContextWithOptions(t, mtu, stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
}
@@ -388,6 +388,10 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
}
+ if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want {
+ c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want)
+ }
+
vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
b := vv.ToView()
@@ -528,8 +532,8 @@ func newMinPayload(minSize int) []byte {
func TestBindToDeviceOption(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}})
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
@@ -803,8 +807,8 @@ func TestV4ReadSelfSource(t *testing.T) {
} {
t.Run(tt.name, func(t *testing.T) {
c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
HandleLocal: tt.handleLocal,
})
defer c.cleanup()
@@ -924,42 +928,6 @@ func TestReadFromMulticast(t *testing.T) {
}
}
-// TestReadFromMulticaststats checks that a discarded packet
-// that that was sent with multicast SOURCE address increments
-// the correct counters and that a regular packet does not.
-func TestReadFromMulticastStats(t *testing.T) {
- t.Helper()
- for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- c.injectPacket(flow, payload, false)
-
- var want uint64 = 0
- if flow.isReverseMulticast() {
- want = 1
- }
- if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want {
- t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
- }
- if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want {
- t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
- }
- })
- }
-}
-
// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
// and receive broadcast and unicast data.
func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
@@ -1462,6 +1430,30 @@ func TestNoChecksum(t *testing.T) {
}
}
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct{}
+
+func (*testInterface) ID() tcpip.NICID {
+ return 0
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (*testInterface) Enabled() bool {
+ return true
+}
+
+func (*testInterface) LinkEndpoint() stack.LinkEndpoint {
+ return nil
+}
+
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1479,16 +1471,19 @@ func TestTTL(t *testing.T) {
if flow.isMulticast() {
wantTTL = multicastTTL
} else {
- var p stack.NetworkProtocol
+ var p stack.NetworkProtocolFactory
+ var n tcpip.NetworkProtocolNumber
if flow.isV4() {
- p = ipv4.NewProtocol()
+ p = ipv4.NewProtocol
+ n = ipv4.ProtocolNumber
} else {
- p = ipv6.NewProtocol()
+ p = ipv6.NewProtocol
+ n = ipv6.ProtocolNumber
}
- ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- }))
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{p},
+ })
+ ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil, nil, nil)
wantTTL = ep.DefaultTTL()
ep.Close()
}
@@ -1512,18 +1507,6 @@ func TestSetTTL(t *testing.T) {
c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
}
- var p stack.NetworkProtocol
- if flow.isV4() {
- p = ipv4.NewProtocol()
- } else {
- p = ipv6.NewProtocol()
- }
- ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- }))
- ep.Close()
-
testWrite(c, flow, checker.TTL(wantTTL))
})
}
@@ -2343,17 +2326,18 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- requiresBroadcastOpt: true,
+ remoteAddr: remNetSubnetBcast,
+ // TODO(gvisor.dev/issue/3938): Once we support marking a route as
+ // broadcast, this test should require the broadcast option to be set.
+ requiresBroadcastOpt: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
-
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
e := channel.New(0, defaultMTU, "")
if err := s.CreateNIC(nicID1, e); err != nil {
diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go
index 052b6b99d..64d17f661 100644
--- a/pkg/test/dockerutil/container.go
+++ b/pkg/test/dockerutil/container.go
@@ -22,6 +22,7 @@ import (
"net"
"os"
"path"
+ "path/filepath"
"regexp"
"strconv"
"strings"
@@ -403,10 +404,13 @@ func (c *Container) CopyFiles(opts *RunOpts, target string, sources ...string) {
return
}
for _, name := range sources {
- src, err := testutil.FindFile(name)
- if err != nil {
- c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err)
- return
+ src := name
+ if !filepath.IsAbs(src) {
+ src, err = testutil.FindFile(name)
+ if err != nil {
+ c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %w", name, err)
+ return
+ }
}
dst := path.Join(dir, path.Base(name))
if err := testutil.Copy(src, dst); err != nil {
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
index b7f873392..06fb823f6 100644
--- a/pkg/test/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -332,13 +332,13 @@ func PollContext(ctx context.Context, cb func() error) error {
}
// WaitForHTTP tries GET requests on a port until the call succeeds or timeout.
-func WaitForHTTP(port int, timeout time.Duration) error {
+func WaitForHTTP(ip string, port int, timeout time.Duration) error {
cb := func() error {
c := &http.Client{
// Calculate timeout to be able to do minimum 5 attempts.
Timeout: timeout / 5,
}
- url := fmt.Sprintf("http://localhost:%d/", port)
+ url := fmt.Sprintf("http://%s:%d/", ip, port)
resp, err := c.Get(url)
if err != nil {
log.Printf("Waiting %s: %v", url, err)
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index 27279b409..9b1e7a085 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -21,7 +21,6 @@ import (
"io"
"strconv"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/safemem"
@@ -184,51 +183,6 @@ func (rw *IOReadWriter) Write(src []byte) (int, error) {
return n, err
}
-// CopyObjectOut copies a fixed-size value or slice of fixed-size values from
-// src to the memory mapped at addr in uio. It returns the number of bytes
-// copied.
-//
-// CopyObjectOut must use reflection to encode src; performance-sensitive
-// clients should do encoding manually and use uio.CopyOut directly.
-//
-// Preconditions: Same as IO.CopyOut.
-func CopyObjectOut(ctx context.Context, uio IO, addr Addr, src interface{}, opts IOOpts) (int, error) {
- w := &IOReadWriter{
- Ctx: ctx,
- IO: uio,
- Addr: addr,
- Opts: opts,
- }
- // Allocate a byte slice the size of the object being marshaled. This
- // adds an extra reflection call, but avoids needing to grow the slice
- // during encoding, which can result in many heap-allocated slices.
- b := make([]byte, 0, binary.Size(src))
- return w.Write(binary.Marshal(b, ByteOrder, src))
-}
-
-// CopyObjectIn copies a fixed-size value or slice of fixed-size values from
-// the memory mapped at addr in uio to dst. It returns the number of bytes
-// copied.
-//
-// CopyObjectIn must use reflection to decode dst; performance-sensitive
-// clients should use uio.CopyIn directly and do decoding manually.
-//
-// Preconditions: Same as IO.CopyIn.
-func CopyObjectIn(ctx context.Context, uio IO, addr Addr, dst interface{}, opts IOOpts) (int, error) {
- r := &IOReadWriter{
- Ctx: ctx,
- IO: uio,
- Addr: addr,
- Opts: opts,
- }
- buf := make([]byte, binary.Size(dst))
- if _, err := io.ReadFull(r, buf); err != nil {
- return 0, err
- }
- binary.Unmarshal(buf, ByteOrder, dst)
- return int(r.Addr - addr), nil
-}
-
// CopyStringIn tuning parameters, defined outside that function for tests.
const (
copyStringIncrement = 64
diff --git a/pkg/usermem/usermem_test.go b/pkg/usermem/usermem_test.go
index bf3c5df2b..da60b0cc7 100644
--- a/pkg/usermem/usermem_test.go
+++ b/pkg/usermem/usermem_test.go
@@ -16,7 +16,6 @@ package usermem
import (
"bytes"
- "encoding/binary"
"fmt"
"reflect"
"strings"
@@ -174,23 +173,6 @@ type testStruct struct {
Uint64 uint64
}
-func TestCopyObject(t *testing.T) {
- wantObj := testStruct{1, 2, 3, 4, 5, 6, 7, 8}
- wantN := binary.Size(wantObj)
- b := &BytesIO{make([]byte, wantN)}
- ctx := newContext()
- if n, err := CopyObjectOut(ctx, b, 0, &wantObj, IOOpts{}); n != wantN || err != nil {
- t.Fatalf("CopyObjectOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- var gotObj testStruct
- if n, err := CopyObjectIn(ctx, b, 0, &gotObj, IOOpts{}); n != wantN || err != nil {
- t.Errorf("CopyObjectIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if gotObj != wantObj {
- t.Errorf("CopyObject round trip: got %+v, wanted %+v", gotObj, wantObj)
- }
-}
-
func TestCopyStringInShort(t *testing.T) {
// Tests for string length <= copyStringIncrement.
want := strings.Repeat("A", copyStringIncrement-2)
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 040f6a72d..2d9517f4a 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -26,10 +26,13 @@ go_library(
deps = [
"//pkg/abi",
"//pkg/abi/linux",
+ "//pkg/bpf",
+ "//pkg/cleanup",
"//pkg/context",
"//pkg/control/server",
"//pkg/cpuid",
"//pkg/eventchannel",
+ "//pkg/fd",
"//pkg/fspath",
"//pkg/log",
"//pkg/memutil",
@@ -106,6 +109,7 @@ go_library(
"//runsc/boot/pprof",
"//runsc/config",
"//runsc/specutils",
+ "//runsc/specutils/seccomp",
"@com_github_golang_protobuf//proto:go_default_library",
"@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
@@ -123,6 +127,7 @@ go_test(
library = ":boot",
deps = [
"//pkg/control/server",
+ "//pkg/fd",
"//pkg/fspath",
"//pkg/log",
"//pkg/p9",
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 68a2b45cf..894651519 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -22,6 +22,7 @@ import (
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/control/server"
+ "gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -257,13 +258,20 @@ func (cm *containerManager) Start(args *StartArgs, _ *struct{}) error {
// All validation passed, logs the spec for debugging.
specutils.LogSpec(args.Spec)
- err := cm.l.startContainer(args.Spec, args.Conf, args.CID, args.FilePayload.Files)
+ fds, err := fd.NewFromFiles(args.FilePayload.Files)
if err != nil {
+ return err
+ }
+ defer func() {
+ for _, fd := range fds {
+ _ = fd.Close()
+ }
+ }()
+ if err := cm.l.startContainer(args.Spec, args.Conf, args.CID, fds); err != nil {
log.Debugf("containerManager.Start failed %q: %+v: %v", args.CID, args, err)
return err
}
log.Debugf("Container %q started", args.CID)
-
return nil
}
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index 149eb0b1b..6ac19668f 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -27,41 +27,30 @@ import (
// allowedSyscalls is the set of syscalls executed by the Sentry to the host OS.
var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_CLOCK_GETTIME: {},
- syscall.SYS_CLONE: []seccomp.Rule{
- {
- seccomp.AllowValue(
- syscall.CLONE_VM |
- syscall.CLONE_FS |
- syscall.CLONE_FILES |
- syscall.CLONE_SIGHAND |
- syscall.CLONE_SYSVSEM |
- syscall.CLONE_THREAD),
- },
- },
- syscall.SYS_CLOSE: {},
- syscall.SYS_DUP: {},
+ syscall.SYS_CLOSE: {},
+ syscall.SYS_DUP: {},
syscall.SYS_DUP3: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.O_CLOEXEC),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.O_CLOEXEC),
},
},
syscall.SYS_EPOLL_CREATE1: {},
syscall.SYS_EPOLL_CTL: {},
syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(0),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(0),
},
},
syscall.SYS_EVENTFD2: []seccomp.Rule{
{
- seccomp.AllowValue(0),
- seccomp.AllowValue(0),
+ seccomp.EqualTo(0),
+ seccomp.EqualTo(0),
},
},
syscall.SYS_EXIT: {},
@@ -70,16 +59,16 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_FCHMOD: {},
syscall.SYS_FCNTL: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.F_GETFL),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.F_GETFL),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.F_SETFL),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.F_SETFL),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.F_GETFD),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.F_GETFD),
},
},
syscall.SYS_FSTAT: {},
@@ -87,52 +76,52 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_FTRUNCATE: {},
syscall.SYS_FUTEX: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
- seccomp.AllowAny{},
- seccomp.AllowAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
- seccomp.AllowAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
+ seccomp.MatchAny{},
},
// Non-private variants are included for flipcall support. They are otherwise
// unncessary, as the sentry will use only private futexes internally.
{
- seccomp.AllowAny{},
- seccomp.AllowValue(linux.FUTEX_WAIT),
- seccomp.AllowAny{},
- seccomp.AllowAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(linux.FUTEX_WAIT),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(linux.FUTEX_WAKE),
- seccomp.AllowAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(linux.FUTEX_WAKE),
+ seccomp.MatchAny{},
},
},
syscall.SYS_GETPID: {},
unix.SYS_GETRANDOM: {},
syscall.SYS_GETSOCKOPT: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_DOMAIN),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_DOMAIN),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_TYPE),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_TYPE),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_ERROR),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_ERROR),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_SNDBUF),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_SNDBUF),
},
},
syscall.SYS_GETTID: {},
@@ -141,34 +130,34 @@ var allowedSyscalls = seccomp.SyscallRules{
// setting/getting termios and winsize.
syscall.SYS_IOCTL: []seccomp.Rule{
{
- seccomp.AllowAny{}, /* fd */
- seccomp.AllowValue(linux.TCGETS),
- seccomp.AllowAny{}, /* termios struct */
+ seccomp.MatchAny{}, /* fd */
+ seccomp.EqualTo(linux.TCGETS),
+ seccomp.MatchAny{}, /* termios struct */
},
{
- seccomp.AllowAny{}, /* fd */
- seccomp.AllowValue(linux.TCSETS),
- seccomp.AllowAny{}, /* termios struct */
+ seccomp.MatchAny{}, /* fd */
+ seccomp.EqualTo(linux.TCSETS),
+ seccomp.MatchAny{}, /* termios struct */
},
{
- seccomp.AllowAny{}, /* fd */
- seccomp.AllowValue(linux.TCSETSF),
- seccomp.AllowAny{}, /* termios struct */
+ seccomp.MatchAny{}, /* fd */
+ seccomp.EqualTo(linux.TCSETSF),
+ seccomp.MatchAny{}, /* termios struct */
},
{
- seccomp.AllowAny{}, /* fd */
- seccomp.AllowValue(linux.TCSETSW),
- seccomp.AllowAny{}, /* termios struct */
+ seccomp.MatchAny{}, /* fd */
+ seccomp.EqualTo(linux.TCSETSW),
+ seccomp.MatchAny{}, /* termios struct */
},
{
- seccomp.AllowAny{}, /* fd */
- seccomp.AllowValue(linux.TIOCSWINSZ),
- seccomp.AllowAny{}, /* winsize struct */
+ seccomp.MatchAny{}, /* fd */
+ seccomp.EqualTo(linux.TIOCSWINSZ),
+ seccomp.MatchAny{}, /* winsize struct */
},
{
- seccomp.AllowAny{}, /* fd */
- seccomp.AllowValue(linux.TIOCGWINSZ),
- seccomp.AllowAny{}, /* winsize struct */
+ seccomp.MatchAny{}, /* fd */
+ seccomp.EqualTo(linux.TIOCGWINSZ),
+ seccomp.MatchAny{}, /* winsize struct */
},
},
syscall.SYS_LSEEK: {},
@@ -182,46 +171,46 @@ var allowedSyscalls = seccomp.SyscallRules{
// TODO(b/148688965): Remove once this is gone from Go.
syscall.SYS_MLOCK: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowValue(4096),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4096),
},
},
syscall.SYS_MMAP: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MAP_SHARED),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MAP_SHARED),
},
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MAP_PRIVATE),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MAP_PRIVATE),
},
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS),
},
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_STACK),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_STACK),
},
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_NORESERVE),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_NORESERVE),
},
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.PROT_WRITE | syscall.PROT_READ),
- seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.PROT_WRITE | syscall.PROT_READ),
+ seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED),
},
},
syscall.SYS_MPROTECT: {},
@@ -237,32 +226,32 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_READ: {},
syscall.SYS_RECVMSG: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC),
},
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK),
},
},
syscall.SYS_RECVMMSG: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(fdbased.MaxMsgsPerRecv),
- seccomp.AllowValue(syscall.MSG_DONTWAIT),
- seccomp.AllowValue(0),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(fdbased.MaxMsgsPerRecv),
+ seccomp.EqualTo(syscall.MSG_DONTWAIT),
+ seccomp.EqualTo(0),
},
},
unix.SYS_SENDMMSG: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MSG_DONTWAIT),
- seccomp.AllowValue(0),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MSG_DONTWAIT),
+ seccomp.EqualTo(0),
},
},
syscall.SYS_RESTART_SYSCALL: {},
@@ -272,49 +261,49 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_SCHED_YIELD: {},
syscall.SYS_SENDMSG: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL),
},
},
syscall.SYS_SETITIMER: {},
syscall.SYS_SHUTDOWN: []seccomp.Rule{
// Used by fs/host to shutdown host sockets.
- {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RD)},
- {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_WR)},
+ {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RD)},
+ {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_WR)},
// Used by unet to shutdown connections.
- {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)},
+ {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RDWR)},
},
syscall.SYS_SIGALTSTACK: {},
unix.SYS_STATX: {},
syscall.SYS_SYNC_FILE_RANGE: {},
syscall.SYS_TEE: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(1), /* len */
- seccomp.AllowValue(unix.SPLICE_F_NONBLOCK), /* flags */
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(1), /* len */
+ seccomp.EqualTo(unix.SPLICE_F_NONBLOCK), /* flags */
},
},
syscall.SYS_TGKILL: []seccomp.Rule{
{
- seccomp.AllowValue(uint64(os.Getpid())),
+ seccomp.EqualTo(uint64(os.Getpid())),
},
},
syscall.SYS_UTIMENSAT: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowValue(0), /* null pathname */
- seccomp.AllowAny{},
- seccomp.AllowValue(0), /* flags */
+ seccomp.MatchAny{},
+ seccomp.EqualTo(0), /* null pathname */
+ seccomp.MatchAny{},
+ seccomp.EqualTo(0), /* flags */
},
},
syscall.SYS_WRITE: {},
// For rawfile.NonBlockingWriteIovec.
syscall.SYS_WRITEV: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
seccomp.GreaterThan(0),
},
},
@@ -325,10 +314,10 @@ func hostInetFilters() seccomp.SyscallRules {
return seccomp.SyscallRules{
syscall.SYS_ACCEPT4: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
},
},
syscall.SYS_BIND: {},
@@ -337,84 +326,84 @@ func hostInetFilters() seccomp.SyscallRules {
syscall.SYS_GETSOCKNAME: {},
syscall.SYS_GETSOCKOPT: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_IP),
- seccomp.AllowValue(syscall.IP_TOS),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IP),
+ seccomp.EqualTo(syscall.IP_TOS),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_IP),
- seccomp.AllowValue(syscall.IP_RECVTOS),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IP),
+ seccomp.EqualTo(syscall.IP_RECVTOS),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_IPV6),
- seccomp.AllowValue(syscall.IPV6_TCLASS),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(syscall.IPV6_TCLASS),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_IPV6),
- seccomp.AllowValue(syscall.IPV6_RECVTCLASS),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(syscall.IPV6_RECVTCLASS),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_IPV6),
- seccomp.AllowValue(syscall.IPV6_V6ONLY),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(syscall.IPV6_V6ONLY),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_ERROR),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_ERROR),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_KEEPALIVE),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_KEEPALIVE),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_SNDBUF),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_SNDBUF),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_RCVBUF),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_RCVBUF),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_REUSEADDR),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_REUSEADDR),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_TYPE),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_TYPE),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_LINGER),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_LINGER),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_TCP),
- seccomp.AllowValue(syscall.TCP_NODELAY),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_TCP),
+ seccomp.EqualTo(syscall.TCP_NODELAY),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_TCP),
- seccomp.AllowValue(syscall.TCP_INFO),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_TCP),
+ seccomp.EqualTo(syscall.TCP_INFO),
},
},
syscall.SYS_IOCTL: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.TIOCOUTQ),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.TIOCOUTQ),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.TIOCINQ),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.TIOCINQ),
},
},
syscall.SYS_LISTEN: {},
@@ -425,103 +414,103 @@ func hostInetFilters() seccomp.SyscallRules {
syscall.SYS_SENDTO: {},
syscall.SYS_SETSOCKOPT: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_IPV6),
- seccomp.AllowValue(syscall.IPV6_V6ONLY),
- seccomp.AllowAny{},
- seccomp.AllowValue(4),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(syscall.IPV6_V6ONLY),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_SNDBUF),
- seccomp.AllowAny{},
- seccomp.AllowValue(4),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_SNDBUF),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_RCVBUF),
- seccomp.AllowAny{},
- seccomp.AllowValue(4),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_RCVBUF),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_REUSEADDR),
- seccomp.AllowAny{},
- seccomp.AllowValue(4),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_REUSEADDR),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_TCP),
- seccomp.AllowValue(syscall.TCP_NODELAY),
- seccomp.AllowAny{},
- seccomp.AllowValue(4),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_TCP),
+ seccomp.EqualTo(syscall.TCP_NODELAY),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_IP),
- seccomp.AllowValue(syscall.IP_TOS),
- seccomp.AllowAny{},
- seccomp.AllowValue(4),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IP),
+ seccomp.EqualTo(syscall.IP_TOS),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_IP),
- seccomp.AllowValue(syscall.IP_RECVTOS),
- seccomp.AllowAny{},
- seccomp.AllowValue(4),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IP),
+ seccomp.EqualTo(syscall.IP_RECVTOS),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_IPV6),
- seccomp.AllowValue(syscall.IPV6_TCLASS),
- seccomp.AllowAny{},
- seccomp.AllowValue(4),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(syscall.IPV6_TCLASS),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_IPV6),
- seccomp.AllowValue(syscall.IPV6_RECVTCLASS),
- seccomp.AllowAny{},
- seccomp.AllowValue(4),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(syscall.IPV6_RECVTCLASS),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4),
},
},
syscall.SYS_SHUTDOWN: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SHUT_RD),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SHUT_RD),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SHUT_WR),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SHUT_WR),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SHUT_RDWR),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SHUT_RDWR),
},
},
syscall.SYS_SOCKET: []seccomp.Rule{
{
- seccomp.AllowValue(syscall.AF_INET),
- seccomp.AllowValue(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
- seccomp.AllowValue(0),
+ seccomp.EqualTo(syscall.AF_INET),
+ seccomp.EqualTo(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ seccomp.EqualTo(0),
},
{
- seccomp.AllowValue(syscall.AF_INET),
- seccomp.AllowValue(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
- seccomp.AllowValue(0),
+ seccomp.EqualTo(syscall.AF_INET),
+ seccomp.EqualTo(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ seccomp.EqualTo(0),
},
{
- seccomp.AllowValue(syscall.AF_INET6),
- seccomp.AllowValue(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
- seccomp.AllowValue(0),
+ seccomp.EqualTo(syscall.AF_INET6),
+ seccomp.EqualTo(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ seccomp.EqualTo(0),
},
{
- seccomp.AllowValue(syscall.AF_INET6),
- seccomp.AllowValue(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
- seccomp.AllowValue(0),
+ seccomp.EqualTo(syscall.AF_INET6),
+ seccomp.EqualTo(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ seccomp.EqualTo(0),
},
},
syscall.SYS_WRITEV: {},
@@ -532,20 +521,20 @@ func controlServerFilters(fd int) seccomp.SyscallRules {
return seccomp.SyscallRules{
syscall.SYS_ACCEPT: []seccomp.Rule{
{
- seccomp.AllowValue(fd),
+ seccomp.EqualTo(fd),
},
},
syscall.SYS_LISTEN: []seccomp.Rule{
{
- seccomp.AllowValue(fd),
- seccomp.AllowValue(16 /* unet.backlog */),
+ seccomp.EqualTo(fd),
+ seccomp.EqualTo(16 /* unet.backlog */),
},
},
syscall.SYS_GETSOCKOPT: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_PEERCRED),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.SOL_SOCKET),
+ seccomp.EqualTo(syscall.SO_PEERCRED),
},
},
}
diff --git a/runsc/boot/filter/config_amd64.go b/runsc/boot/filter/config_amd64.go
index 5335ff82c..cea5613b8 100644
--- a/runsc/boot/filter/config_amd64.go
+++ b/runsc/boot/filter/config_amd64.go
@@ -24,8 +24,41 @@ import (
)
func init() {
- allowedSyscalls[syscall.SYS_ARCH_PRCTL] = append(allowedSyscalls[syscall.SYS_ARCH_PRCTL],
- seccomp.Rule{seccomp.AllowValue(linux.ARCH_GET_FS)},
- seccomp.Rule{seccomp.AllowValue(linux.ARCH_SET_FS)},
- )
+ allowedSyscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{
+ // TODO(b/168828518): No longer used in Go 1.16+.
+ {seccomp.EqualTo(linux.ARCH_SET_FS)},
+ }
+
+ allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{
+ // parent_tidptr and child_tidptr are always 0 because neither
+ // CLONE_PARENT_SETTID nor CLONE_CHILD_SETTID are used.
+ {
+ seccomp.EqualTo(
+ syscall.CLONE_VM |
+ syscall.CLONE_FS |
+ syscall.CLONE_FILES |
+ syscall.CLONE_SETTLS |
+ syscall.CLONE_SIGHAND |
+ syscall.CLONE_SYSVSEM |
+ syscall.CLONE_THREAD),
+ seccomp.MatchAny{}, // newsp
+ seccomp.EqualTo(0), // parent_tidptr
+ seccomp.EqualTo(0), // child_tidptr
+ seccomp.MatchAny{}, // tls
+ },
+ {
+ // TODO(b/168828518): No longer used in Go 1.16+ (on amd64).
+ seccomp.EqualTo(
+ syscall.CLONE_VM |
+ syscall.CLONE_FS |
+ syscall.CLONE_FILES |
+ syscall.CLONE_SIGHAND |
+ syscall.CLONE_SYSVSEM |
+ syscall.CLONE_THREAD),
+ seccomp.MatchAny{}, // newsp
+ seccomp.EqualTo(0), // parent_tidptr
+ seccomp.EqualTo(0), // child_tidptr
+ seccomp.MatchAny{}, // tls
+ },
+ }
}
diff --git a/runsc/boot/filter/config_arm64.go b/runsc/boot/filter/config_arm64.go
index 7fa9bbda3..37313f97f 100644
--- a/runsc/boot/filter/config_arm64.go
+++ b/runsc/boot/filter/config_arm64.go
@@ -16,6 +16,29 @@
package filter
-// Reserve for future customization.
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
func init() {
+ allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{
+ {
+ seccomp.EqualTo(
+ syscall.CLONE_VM |
+ syscall.CLONE_FS |
+ syscall.CLONE_FILES |
+ syscall.CLONE_SIGHAND |
+ syscall.CLONE_SYSVSEM |
+ syscall.CLONE_THREAD),
+ seccomp.MatchAny{}, // newsp
+ // These arguments are left uninitialized by the Go
+ // runtime, so they may be anything (and are unused by
+ // the host).
+ seccomp.MatchAny{}, // parent_tidptr
+ seccomp.MatchAny{}, // tls
+ seccomp.MatchAny{}, // child_tidptr
+ },
+ }
}
diff --git a/runsc/boot/filter/config_profile.go b/runsc/boot/filter/config_profile.go
index 194952a7b..7b8669595 100644
--- a/runsc/boot/filter/config_profile.go
+++ b/runsc/boot/filter/config_profile.go
@@ -25,9 +25,9 @@ func profileFilters() seccomp.SyscallRules {
return seccomp.SyscallRules{
syscall.SYS_OPENAT: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC),
},
},
}
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index 163265afe..ddf288456 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -34,6 +34,7 @@ import (
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/gofer"
@@ -253,7 +254,7 @@ func mustFindFilesystem(name string) fs.Filesystem {
// addSubmountOverlay overlays the inode over a ramfs tree containing the given
// paths.
-func addSubmountOverlay(ctx context.Context, inode *fs.Inode, submounts []string) (*fs.Inode, error) {
+func addSubmountOverlay(ctx context.Context, inode *fs.Inode, submounts []string, mf fs.MountSourceFlags) (*fs.Inode, error) {
// Construct a ramfs tree of mount points. The contents never
// change, so this can be fully caching. There's no real
// filesystem backing this tree, so we set the filesystem to
@@ -263,7 +264,7 @@ func addSubmountOverlay(ctx context.Context, inode *fs.Inode, submounts []string
if err != nil {
return nil, fmt.Errorf("creating mount tree: %v", err)
}
- overlayInode, err := fs.NewOverlayRoot(ctx, inode, mountTree, fs.MountSourceFlags{})
+ overlayInode, err := fs.NewOverlayRoot(ctx, inode, mountTree, mf)
if err != nil {
return nil, fmt.Errorf("adding mount overlay: %v", err)
}
@@ -320,14 +321,14 @@ func adjustDirentCache(k *kernel.Kernel) error {
}
type fdDispenser struct {
- fds []int
+ fds []*fd.FD
}
func (f *fdDispenser) remove() int {
if f.empty() {
panic("fdDispenser out of fds")
}
- rv := f.fds[0]
+ rv := f.fds[0].Release()
f.fds = f.fds[1:]
return rv
}
@@ -453,17 +454,17 @@ func (m *mountHint) isSupported() bool {
func (m *mountHint) checkCompatible(mount specs.Mount) error {
// Remove options that don't affect to mount's behavior.
masterOpts := filterUnsupportedOptions(m.mount)
- slaveOpts := filterUnsupportedOptions(mount)
+ replicaOpts := filterUnsupportedOptions(mount)
- if len(masterOpts) != len(slaveOpts) {
- return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, slaveOpts)
+ if len(masterOpts) != len(replicaOpts) {
+ return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, replicaOpts)
}
sort.Strings(masterOpts)
- sort.Strings(slaveOpts)
+ sort.Strings(replicaOpts)
for i, opt := range masterOpts {
- if opt != slaveOpts[i] {
- return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, slaveOpts)
+ if opt != replicaOpts[i] {
+ return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, replicaOpts)
}
}
return nil
@@ -564,7 +565,7 @@ type containerMounter struct {
hints *podMountHints
}
-func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hints *podMountHints) *containerMounter {
+func newContainerMounter(spec *specs.Spec, goferFDs []*fd.FD, k *kernel.Kernel, hints *podMountHints) *containerMounter {
return &containerMounter{
root: spec.Root,
mounts: compileMounts(spec),
@@ -740,7 +741,7 @@ func (c *containerMounter) createRootMount(ctx context.Context, conf *config.Con
// for submount paths. "/dev" "/sys" "/proc" and "/tmp" are always
// mounted even if they are not in the spec.
submounts := append(subtargets("/", c.mounts), "/dev", "/sys", "/proc", "/tmp")
- rootInode, err = addSubmountOverlay(ctx, rootInode, submounts)
+ rootInode, err = addSubmountOverlay(ctx, rootInode, submounts, mf)
if err != nil {
return nil, fmt.Errorf("adding submount overlay: %v", err)
}
@@ -850,7 +851,7 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *config.Confi
submounts := subtargets(m.Destination, c.mounts)
if len(submounts) > 0 {
log.Infof("Adding submount overlay over %q", m.Destination)
- inode, err = addSubmountOverlay(ctx, inode, submounts)
+ inode, err = addSubmountOverlay(ctx, inode, submounts, mf)
if err != nil {
return fmt.Errorf("adding submount overlay: %v", err)
}
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index c3c754046..2e652ddad 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -27,8 +27,10 @@ import (
specs "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bpf"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/memutil"
"gvisor.dev/gvisor/pkg/rand"
@@ -69,6 +71,7 @@ import (
"gvisor.dev/gvisor/runsc/boot/pprof"
"gvisor.dev/gvisor/runsc/config"
"gvisor.dev/gvisor/runsc/specutils"
+ "gvisor.dev/gvisor/runsc/specutils/seccomp"
// Include supported socket providers.
"gvisor.dev/gvisor/pkg/sentry/socket/hostinet"
@@ -89,10 +92,10 @@ type containerInfo struct {
procArgs kernel.CreateProcessArgs
// stdioFDs contains stdin, stdout, and stderr.
- stdioFDs []int
+ stdioFDs []*fd.FD
// goferFDs are the FDs that attach the sandbox to the gofers.
- goferFDs []int
+ goferFDs []*fd.FD
}
// Loader keeps state needed to start the kernel and run the container..
@@ -356,12 +359,17 @@ func New(args Args) (*Loader, error) {
k.SetHostMount(hostMount)
}
+ info := containerInfo{
+ conf: args.Conf,
+ spec: args.Spec,
+ procArgs: procArgs,
+ }
+
// Make host FDs stable between invocations. Host FDs must map to the exact
// same number when the sandbox is restored. Otherwise the wrong FD will be
// used.
- var stdioFDs []int
newfd := startingStdioFD
- for _, fd := range args.StdioFDs {
+ for _, stdioFD := range args.StdioFDs {
// Check that newfd is unused to avoid clobbering over it.
if _, err := unix.FcntlInt(uintptr(newfd), unix.F_GETFD, 0); !errors.Is(err, unix.EBADF) {
if err != nil {
@@ -370,14 +378,17 @@ func New(args Args) (*Loader, error) {
return nil, fmt.Errorf("unable to remap stdios, FD %d is already in use", newfd)
}
- err := unix.Dup3(fd, newfd, unix.O_CLOEXEC)
+ err := unix.Dup3(stdioFD, newfd, unix.O_CLOEXEC)
if err != nil {
- return nil, fmt.Errorf("dup3 of stdioFDs failed: %v", err)
+ return nil, fmt.Errorf("dup3 of stdios failed: %w", err)
}
- stdioFDs = append(stdioFDs, newfd)
- _ = unix.Close(fd)
+ info.stdioFDs = append(info.stdioFDs, fd.New(newfd))
+ _ = unix.Close(stdioFD)
newfd++
}
+ for _, goferFD := range args.GoferFDs {
+ info.goferFDs = append(info.goferFDs, fd.New(goferFD))
+ }
eid := execID{cid: args.ID}
l := &Loader{
@@ -386,13 +397,7 @@ func New(args Args) (*Loader, error) {
sandboxID: args.ID,
processes: map[execID]*execProcess{eid: {}},
mountHints: mountHints,
- root: containerInfo{
- conf: args.Conf,
- stdioFDs: stdioFDs,
- goferFDs: args.GoferFDs,
- spec: args.Spec,
- procArgs: procArgs,
- },
+ root: info,
}
// We don't care about child signals; some platforms can generate a
@@ -466,9 +471,14 @@ func (l *Loader) Destroy() {
}
l.watchdog.Stop()
- for i, fd := range l.root.stdioFDs {
- _ = unix.Close(fd)
- l.root.stdioFDs[i] = -1
+ // In the success case, stdioFDs and goferFDs will only contain
+ // released/closed FDs that ownership has been passed over to host FDs and
+ // gofer sessions. Close them here in case on failure.
+ for _, fd := range l.root.stdioFDs {
+ _ = fd.Close()
+ }
+ for _, fd := range l.root.goferFDs {
+ _ = fd.Close()
}
}
@@ -499,6 +509,7 @@ func createMemoryFile() (*pgalloc.MemoryFile, error) {
return mf, nil
}
+// installSeccompFilters installs sandbox seccomp filters with the host.
func (l *Loader) installSeccompFilters() error {
if l.root.conf.DisableSeccomp {
filter.Report("syscall filter is DISABLED. Running in less secure mode.")
@@ -569,6 +580,7 @@ func (l *Loader) run() error {
if _, err := l.createContainerProcess(true, l.sandboxID, &l.root, ep); err != nil {
return err
}
+
}
ep.tg = l.k.GlobalInit()
@@ -598,17 +610,6 @@ func (l *Loader) run() error {
}
})
- // l.stdioFDs are derived from dup() in boot.New() and they are now dup()ed again
- // either in createFDTable() during initial start or in descriptor.initAfterLoad()
- // during restore, we can release l.stdioFDs now. VFS2 takes ownership of the
- // passed FDs, so only close for VFS1.
- if !kernel.VFS2Enabled {
- for i, fd := range l.root.stdioFDs {
- _ = unix.Close(fd)
- l.root.stdioFDs[i] = -1
- }
- }
-
log.Infof("Process should have started...")
l.watchdog.Start()
return l.k.Start()
@@ -628,9 +629,9 @@ func (l *Loader) createContainer(cid string) error {
}
// startContainer starts a child container. It returns the thread group ID of
-// the newly created process. Caller owns 'files' and may close them after
-// this method returns.
-func (l *Loader) startContainer(spec *specs.Spec, conf *config.Config, cid string, files []*os.File) error {
+// the newly created process. Used FDs are either closed or released. It's safe
+// for the caller to close any remaining files upon return.
+func (l *Loader) startContainer(spec *specs.Spec, conf *config.Config, cid string, files []*fd.FD) error {
// Create capabilities.
caps, err := specutils.Capabilities(conf.EnableRaw, spec.Process.Capabilities)
if err != nil {
@@ -681,28 +682,15 @@ func (l *Loader) startContainer(spec *specs.Spec, conf *config.Config, cid strin
}
info := &containerInfo{
- conf: conf,
- spec: spec,
+ conf: conf,
+ spec: spec,
+ stdioFDs: files[:3],
+ goferFDs: files[3:],
}
info.procArgs, err = createProcessArgs(cid, spec, creds, l.k, pidns)
if err != nil {
return fmt.Errorf("creating new process: %v", err)
}
-
- // setupContainerFS() dups stdioFDs, so we don't need to dup them here.
- for _, f := range files[:3] {
- info.stdioFDs = append(info.stdioFDs, int(f.Fd()))
- }
-
- // Can't take ownership away from os.File. dup them to get a new FDs.
- for _, f := range files[3:] {
- fd, err := unix.Dup(int(f.Fd()))
- if err != nil {
- return fmt.Errorf("failed to dup file: %v", err)
- }
- info.goferFDs = append(info.goferFDs, fd)
- }
-
tg, err := l.createContainerProcess(false, cid, info, ep)
if err != nil {
return err
@@ -780,19 +768,44 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn
}
}
+ // Install seccomp filters with the new task if there are any.
+ if info.conf.OCISeccomp {
+ if info.spec.Linux != nil && info.spec.Linux.Seccomp != nil {
+ program, err := seccomp.BuildProgram(info.spec.Linux.Seccomp)
+ if err != nil {
+ return nil, fmt.Errorf("building seccomp program: %v", err)
+ }
+
+ if log.IsLogging(log.Debug) {
+ out, _ := bpf.DecodeProgram(program)
+ log.Debugf("Installing OCI seccomp filters\nProgram:\n%s", out)
+ }
+
+ task := tg.Leader()
+ // NOTE: It seems Flags are ignored by runc so we ignore them too.
+ if err := task.AppendSyscallFilter(program, true); err != nil {
+ return nil, fmt.Errorf("appending seccomp filters: %v", err)
+ }
+ }
+ } else {
+ if info.spec.Linux != nil && info.spec.Linux.Seccomp != nil {
+ log.Warningf("Seccomp spec is being ignored")
+ }
+ }
+
return tg, nil
}
// startGoferMonitor runs a goroutine to monitor gofer's health. It polls on
// the gofer FDs looking for disconnects, and destroys the container if a
// disconnect occurs in any of the gofer FDs.
-func (l *Loader) startGoferMonitor(cid string, goferFDs []int) {
+func (l *Loader) startGoferMonitor(cid string, goferFDs []*fd.FD) {
go func() {
log.Debugf("Monitoring gofer health for container %q", cid)
var events []unix.PollFd
- for _, fd := range goferFDs {
+ for _, goferFD := range goferFDs {
events = append(events, unix.PollFd{
- Fd: int32(fd),
+ Fd: int32(goferFD.FD()),
Events: unix.POLLHUP | unix.POLLRDHUP,
})
}
@@ -1046,8 +1059,8 @@ func newRootNetworkNamespace(conf *config.Config, clock tcpip.Clock, uniqueID st
}
func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) {
- netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}
- transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()}
+ netProtos := []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol}
+ transProtos := []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4}
s := netstack.Stack{stack.New(stack.Options{
NetworkProtocols: netProtos,
TransportProtocols: transProtos,
@@ -1061,22 +1074,30 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in
})}
// Enable SACK Recovery.
- if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil {
- return nil, fmt.Errorf("failed to enable SACK: %s", err)
+ {
+ opt := tcpip.TCPSACKEnabled(true)
+ if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
}
// Set default TTLs as required by socket/netstack.
- opt := tcpip.DefaultTTLOption(netstack.DefaultTTL)
- if err := s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, &opt); err != nil {
- return nil, fmt.Errorf("SetNetworkProtocolOption(%d, &%T(%d)): %s", ipv4.ProtocolNumber, opt, opt, err)
- }
- if err := s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, &opt); err != nil {
- return nil, fmt.Errorf("SetNetworkProtocolOption(%d, &%T(%d)): %s", ipv6.ProtocolNumber, opt, opt, err)
+ {
+ opt := tcpip.DefaultTTLOption(netstack.DefaultTTL)
+ if err := s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, &opt); err != nil {
+ return nil, fmt.Errorf("SetNetworkProtocolOption(%d, &%T(%d)): %s", ipv4.ProtocolNumber, opt, opt, err)
+ }
+ if err := s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, &opt); err != nil {
+ return nil, fmt.Errorf("SetNetworkProtocolOption(%d, &%T(%d)): %s", ipv6.ProtocolNumber, opt, opt, err)
+ }
}
// Enable Receive Buffer Auto-Tuning.
- if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPModerateReceiveBufferOption(true)
+ if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
}
return &s, nil
@@ -1272,7 +1293,7 @@ func (l *Loader) ttyFromIDLocked(key execID) (*host.TTYFileOperations, *hostvfs2
return ep.tty, ep.ttyVFS2, nil
}
-func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.FDTable, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
+func createFDTable(ctx context.Context, console bool, stdioFDs []*fd.FD) (*kernel.FDTable, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
if len(stdioFDs) != 3 {
return nil, nil, nil, fmt.Errorf("stdioFDs should contain exactly 3 FDs (stdin, stdout, and stderr), but %d FDs received", len(stdioFDs))
}
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index 2343ce76c..bf9ec5d38 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -26,6 +26,7 @@ import (
specs "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/control/server"
+ "gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
@@ -444,7 +445,7 @@ func TestCreateMountNamespace(t *testing.T) {
}
defer cleanup()
- mntr := newContainerMounter(&tc.spec, []int{sandEnd}, nil, &podMountHints{})
+ mntr := newContainerMounter(&tc.spec, []*fd.FD{fd.New(sandEnd)}, nil, &podMountHints{})
mns, err := mntr.createMountNamespace(ctx, conf)
if err != nil {
t.Fatalf("failed to create mount namespace: %v", err)
@@ -490,9 +491,9 @@ func TestCreateMountNamespaceVFS2(t *testing.T) {
}
ctx := l.k.SupervisorContext()
- mns, err := mntr.setupVFS2(ctx, l.root.conf, &l.root.procArgs)
+ mns, err := mntr.mountAll(l.root.conf, &l.root.procArgs)
if err != nil {
- t.Fatalf("failed to setupVFS2: %v", err)
+ t.Fatalf("mountAll: %v", err)
}
root := mns.Root()
@@ -702,7 +703,11 @@ func TestRestoreEnvironment(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conf := testConfig()
- mntr := newContainerMounter(tc.spec, tc.ioFDs, nil, &podMountHints{})
+ var ioFDs []*fd.FD
+ for _, ioFD := range tc.ioFDs {
+ ioFDs = append(ioFDs, fd.New(ioFD))
+ }
+ mntr := newContainerMounter(tc.spec, ioFDs, nil, &podMountHints{})
actualRenv, err := mntr.createRestoreEnvironment(conf)
if !tc.errorExpected && err != nil {
t.Fatalf("could not create restore environment for test:%s", tc.name)
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
index 66b6cf19b..e36664938 100644
--- a/runsc/boot/vfs.go
+++ b/runsc/boot/vfs.go
@@ -21,6 +21,7 @@ import (
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/log"
@@ -134,7 +135,7 @@ func registerFilesystems(k *kernel.Kernel) error {
}
func setupContainerVFS2(ctx context.Context, conf *config.Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error {
- mns, err := mntr.setupVFS2(ctx, conf, procArgs)
+ mns, err := mntr.mountAll(conf, procArgs)
if err != nil {
return fmt.Errorf("failed to setupFS: %w", err)
}
@@ -149,7 +150,7 @@ func setupContainerVFS2(ctx context.Context, conf *config.Config, mntr *containe
return nil
}
-func (c *containerMounter) setupVFS2(ctx context.Context, conf *config.Config, procArgs *kernel.CreateProcessArgs) (*vfs.MountNamespace, error) {
+func (c *containerMounter) mountAll(conf *config.Config, procArgs *kernel.CreateProcessArgs) (*vfs.MountNamespace, error) {
log.Infof("Configuring container's file system with VFS2")
// Create context with root credentials to mount the filesystem (the current
@@ -168,34 +169,108 @@ func (c *containerMounter) setupVFS2(ctx context.Context, conf *config.Config, p
}
rootProcArgs.MountNamespaceVFS2 = mns
+ root := mns.Root()
+ defer root.DecRef(rootCtx)
+ if root.Mount().ReadOnly() {
+ // Switch to ReadWrite while we setup submounts.
+ if err := c.k.VFS().SetMountReadOnly(root.Mount(), false); err != nil {
+ return nil, fmt.Errorf(`failed to set mount at "/" readwrite: %w`, err)
+ }
+ // Restore back to ReadOnly at the end.
+ defer func() {
+ if err := c.k.VFS().SetMountReadOnly(root.Mount(), true); err != nil {
+ panic(fmt.Sprintf(`failed to restore mount at "/" back to readonly: %v`, err))
+ }
+ }()
+ }
+
// Mount submounts.
if err := c.mountSubmountsVFS2(rootCtx, conf, mns, rootCreds); err != nil {
return nil, fmt.Errorf("mounting submounts vfs2: %w", err)
}
+
return mns, nil
}
+// createMountNamespaceVFS2 creates the container's root mount and namespace.
func (c *containerMounter) createMountNamespaceVFS2(ctx context.Context, conf *config.Config, creds *auth.Credentials) (*vfs.MountNamespace, error) {
fd := c.fds.remove()
- opts := p9MountData(fd, conf.FileAccess, true /* vfs2 */)
+ data := p9MountData(fd, conf.FileAccess, true /* vfs2 */)
if conf.OverlayfsStaleRead {
// We can't check for overlayfs here because sandbox is chroot'ed and gofer
// can only send mount options for specs.Mounts (specs.Root is missing
// Options field). So assume root is always on top of overlayfs.
- opts = append(opts, "overlayfs_stale_read")
+ data = append(data, "overlayfs_stale_read")
}
log.Infof("Mounting root over 9P, ioFD: %d", fd)
- mns, err := c.k.VFS().NewMountNamespace(ctx, creds, "", gofer.Name, &vfs.GetFilesystemOptions{
- Data: strings.Join(opts, ","),
- })
+ opts := &vfs.MountOptions{
+ ReadOnly: c.root.Readonly,
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ Data: strings.Join(data, ","),
+ },
+ InternalMount: true,
+ }
+
+ fsName := gofer.Name
+ if conf.Overlay && !c.root.Readonly {
+ log.Infof("Adding overlay on top of root")
+ var err error
+ var cleanup func()
+ opts, cleanup, err = c.configureOverlay(ctx, creds, opts, fsName)
+ if err != nil {
+ return nil, fmt.Errorf("mounting root with overlay: %w", err)
+ }
+ defer cleanup()
+ fsName = overlay.Name
+ }
+
+ mns, err := c.k.VFS().NewMountNamespace(ctx, creds, "", fsName, opts)
if err != nil {
return nil, fmt.Errorf("setting up mount namespace: %w", err)
}
return mns, nil
}
+// configureOverlay mounts the lower layer using "lowerOpts", mounts the upper
+// layer using tmpfs, and return overlay mount options. "cleanup" must be called
+// after the options have been used to mount the overlay, to release refs on
+// lower and upper mounts.
+func (c *containerMounter) configureOverlay(ctx context.Context, creds *auth.Credentials, lowerOpts *vfs.MountOptions, lowerFSName string) (*vfs.MountOptions, func(), error) {
+ // First copy options from lower layer to upper layer and overlay. Clear
+ // filesystem specific options.
+ upperOpts := *lowerOpts
+ upperOpts.GetFilesystemOptions = vfs.GetFilesystemOptions{}
+
+ overlayOpts := *lowerOpts
+ overlayOpts.GetFilesystemOptions = vfs.GetFilesystemOptions{}
+
+ // Next mount upper and lower. Upper is a tmpfs mount to keep all
+ // modifications inside the sandbox.
+ upper, err := c.k.VFS().MountDisconnected(ctx, creds, "" /* source */, tmpfs.Name, &upperOpts)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to create upper layer for overlay, opts: %+v: %v", upperOpts, err)
+ }
+ cu := cleanup.Make(func() { upper.DecRef(ctx) })
+ defer cu.Clean()
+
+ // All writes go to the upper layer, be paranoid and make lower readonly.
+ lowerOpts.ReadOnly = true
+ lower, err := c.k.VFS().MountDisconnected(ctx, creds, "" /* source */, lowerFSName, lowerOpts)
+ if err != nil {
+ return nil, nil, err
+ }
+ cu.Add(func() { lower.DecRef(ctx) })
+
+ // Configure overlay with both layers.
+ overlayOpts.GetFilesystemOptions.InternalData = overlay.FilesystemOptions{
+ UpperRoot: vfs.MakeVirtualDentry(upper, upper.Root()),
+ LowerRoots: []vfs.VirtualDentry{vfs.MakeVirtualDentry(lower, lower.Root())},
+ }
+ return &overlayOpts, cu.Release(), nil
+}
+
func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *config.Config, mns *vfs.MountNamespace, creds *auth.Credentials) error {
mounts, err := c.prepareMountsVFS2()
if err != nil {
@@ -225,8 +300,9 @@ func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *config.
if mnt != nil && mnt.ReadOnly() {
// Switch to ReadWrite while we setup submounts.
if err := c.k.VFS().SetMountReadOnly(mnt, false); err != nil {
- return fmt.Errorf("failed to set mount at %q readwrite: %v", submount.Destination, err)
+ return fmt.Errorf("failed to set mount at %q readwrite: %w", submount.Destination, err)
}
+ // Restore back to ReadOnly at the end.
defer func() {
if err := c.k.VFS().SetMountReadOnly(mnt, true); err != nil {
panic(fmt.Sprintf("failed to restore mount at %q back to readonly: %v", submount.Destination, err))
@@ -276,14 +352,7 @@ func (c *containerMounter) prepareMountsVFS2() ([]mountAndFD, error) {
}
func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *config.Config, mns *vfs.MountNamespace, creds *auth.Credentials, submount *mountAndFD) (*vfs.Mount, error) {
- root := mns.Root()
- defer root.DecRef(ctx)
- target := &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(submount.Destination),
- }
- fsName, opts, err := c.getMountNameAndOptionsVFS2(conf, submount)
+ fsName, opts, useOverlay, err := c.getMountNameAndOptionsVFS2(conf, submount)
if err != nil {
return nil, fmt.Errorf("mountOptions failed: %w", err)
}
@@ -292,8 +361,27 @@ func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *config.C
return nil, nil
}
- if err := c.k.VFS().MakeSyntheticMountpoint(ctx, submount.Destination, root, creds); err != nil {
- return nil, err
+ if err := c.makeMountPoint(ctx, creds, mns, submount.Destination); err != nil {
+ return nil, fmt.Errorf("creating mount point %q: %w", submount.Destination, err)
+ }
+
+ if useOverlay {
+ log.Infof("Adding overlay on top of mount %q", submount.Destination)
+ var cleanup func()
+ opts, cleanup, err = c.configureOverlay(ctx, creds, opts, fsName)
+ if err != nil {
+ return nil, fmt.Errorf("mounting volume with overlay at %q: %w", submount.Destination, err)
+ }
+ defer cleanup()
+ fsName = overlay.Name
+ }
+
+ root := mns.Root()
+ defer root.DecRef(ctx)
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(submount.Destination),
}
mnt, err := c.k.VFS().MountAt(ctx, creds, "", target, fsName, opts)
if err != nil {
@@ -305,8 +393,9 @@ func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *config.C
// getMountNameAndOptionsVFS2 retrieves the fsName, opts, and useOverlay values
// used for mounts.
-func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mountAndFD) (string, *vfs.MountOptions, error) {
+func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mountAndFD) (string, *vfs.MountOptions, bool, error) {
fsName := m.Type
+ useOverlay := false
var data []string
// Find filesystem name and FS specific data field.
@@ -321,7 +410,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
var err error
data, err = parseAndFilterOptions(m.Options, tmpfsAllowedData...)
if err != nil {
- return "", nil, err
+ return "", nil, false, err
}
case bind:
@@ -329,13 +418,16 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
if m.fd == 0 {
// Check that an FD was provided to fails fast. Technically FD=0 is valid,
// but unlikely to be correct in this context.
- return "", nil, fmt.Errorf("9P mount requires a connection FD")
+ return "", nil, false, fmt.Errorf("9P mount requires a connection FD")
}
data = p9MountData(m.fd, c.getMountAccessType(m.Mount), true /* vfs2 */)
+ // If configured, add overlay to all writable mounts.
+ useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
+
default:
log.Warningf("ignoring unknown filesystem type %q", m.Type)
- return "", nil, nil
+ return "", nil, false, nil
}
opts := &vfs.MountOptions{
@@ -360,11 +452,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
}
}
- if conf.Overlay {
- // All writes go to upper, be paranoid and make lower readonly.
- opts.ReadOnly = true
- }
- return fsName, opts, nil
+ return fsName, opts, useOverlay, nil
}
// mountTmpVFS2 mounts an internal tmpfs at '/tmp' if it's safe to do so.
@@ -467,13 +555,25 @@ func (c *containerMounter) mountSharedMasterVFS2(ctx context.Context, conf *conf
// Map mount type to filesystem name, and parse out the options that we are
// capable of dealing with.
mntFD := &mountAndFD{Mount: hint.mount}
- fsName, opts, err := c.getMountNameAndOptionsVFS2(conf, mntFD)
+ fsName, opts, useOverlay, err := c.getMountNameAndOptionsVFS2(conf, mntFD)
if err != nil {
return nil, err
}
if len(fsName) == 0 {
return nil, fmt.Errorf("mount type not supported %q", hint.mount.Type)
}
+
+ if useOverlay {
+ log.Infof("Adding overlay on top of shared mount %q", mntFD.Destination)
+ var cleanup func()
+ opts, cleanup, err = c.configureOverlay(ctx, creds, opts, fsName)
+ if err != nil {
+ return nil, fmt.Errorf("mounting shared volume with overlay at %q: %w", mntFD.Destination, err)
+ }
+ defer cleanup()
+ fsName = overlay.Name
+ }
+
return c.k.VFS().MountDisconnected(ctx, creds, "", fsName, opts)
}
@@ -484,7 +584,9 @@ func (c *containerMounter) mountSharedSubmountVFS2(ctx context.Context, conf *co
return nil, err
}
- _, opts, err := c.getMountNameAndOptionsVFS2(conf, &mountAndFD{Mount: mount})
+ // Ignore data and useOverlay because these were already applied to
+ // the master mount.
+ _, opts, _, err := c.getMountNameAndOptionsVFS2(conf, &mountAndFD{Mount: mount})
if err != nil {
return nil, err
}
@@ -496,18 +598,39 @@ func (c *containerMounter) mountSharedSubmountVFS2(ctx context.Context, conf *co
root := mns.Root()
defer root.DecRef(ctx)
- if err := c.k.VFS().MakeSyntheticMountpoint(ctx, mount.Destination, root, creds); err != nil {
- return nil, err
- }
-
target := &vfs.PathOperation{
Root: root,
Start: root,
Path: fspath.Parse(mount.Destination),
}
+
+ if err := c.makeMountPoint(ctx, creds, mns, mount.Destination); err != nil {
+ return nil, fmt.Errorf("creating mount point %q: %w", mount.Destination, err)
+ }
+
if err := c.k.VFS().ConnectMountAt(ctx, creds, newMnt, target); err != nil {
return nil, err
}
log.Infof("Mounted %q type shared bind to %q", mount.Destination, source.name)
return newMnt, nil
}
+
+func (c *containerMounter) makeMountPoint(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, dest string) error {
+ root := mns.Root()
+ defer root.DecRef(ctx)
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(dest),
+ }
+ // First check if mount point exists. When overlay is enabled, gofer doesn't
+ // allow changes to the FS, making MakeSytheticMountpoint() ineffective
+ // because MkdirAt fails with EROFS even if file exists.
+ vd, err := c.k.VFS().GetDentryAt(ctx, creds, target, &vfs.GetDentryOptions{})
+ if err == nil {
+ // File exists, we're done.
+ vd.DecRef(ctx)
+ return nil
+ }
+ return c.k.VFS().MakeSyntheticMountpoint(ctx, dest, root, creds)
+}
diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go
index 357f46517..cd419e1aa 100644
--- a/runsc/cmd/boot.go
+++ b/runsc/cmd/boot.go
@@ -168,7 +168,7 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// Get the spec from the specFD.
specFile := os.NewFile(uintptr(b.specFD), "spec file")
defer specFile.Close()
- spec, err := specutils.ReadSpecFromFile(b.bundleDir, specFile)
+ spec, err := specutils.ReadSpecFromFile(b.bundleDir, specFile, conf)
if err != nil {
Fatalf("reading spec: %v", err)
}
diff --git a/runsc/cmd/checkpoint.go b/runsc/cmd/checkpoint.go
index db46d509f..8fe0c427a 100644
--- a/runsc/cmd/checkpoint.go
+++ b/runsc/cmd/checkpoint.go
@@ -118,7 +118,7 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa
Fatalf("setting bundleDir")
}
- spec, err := specutils.ReadSpec(bundleDir)
+ spec, err := specutils.ReadSpec(bundleDir, conf)
if err != nil {
Fatalf("reading spec: %v", err)
}
diff --git a/runsc/cmd/create.go b/runsc/cmd/create.go
index 4d9085244..e76f7ba1d 100644
--- a/runsc/cmd/create.go
+++ b/runsc/cmd/create.go
@@ -91,7 +91,7 @@ func (c *Create) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}
if bundleDir == "" {
bundleDir = getwdOrDie()
}
- spec, err := specutils.ReadSpec(bundleDir)
+ spec, err := specutils.ReadSpec(bundleDir, conf)
if err != nil {
return Errorf("reading spec: %v", err)
}
diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go
index 600876a27..775ed4b43 100644
--- a/runsc/cmd/exec.go
+++ b/runsc/cmd/exec.go
@@ -220,7 +220,7 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi
cmd.Stderr = os.Stderr
// If the console control socket file is provided, then create a new
- // pty master/slave pair and set the TTY on the sandbox process.
+ // pty master/replica pair and set the TTY on the sandbox process.
if ex.consoleSocket != "" {
// Create a new TTY pair and send the master on the provided socket.
tty, err := console.NewWithSocket(ex.consoleSocket)
@@ -229,7 +229,7 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi
}
defer tty.Close()
- // Set stdio to the new TTY slave.
+ // Set stdio to the new TTY replica.
cmd.Stdin = tty
cmd.Stdout = tty
cmd.Stderr = tty
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 7da02c3af..371fcc0ae 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -62,9 +62,8 @@ type Gofer struct {
applyCaps bool
setUpRoot bool
- panicOnWrite bool
- specFD int
- mountsFD int
+ specFD int
+ mountsFD int
}
// Name implements subcommands.Command.
@@ -87,7 +86,6 @@ func (g *Gofer) SetFlags(f *flag.FlagSet) {
f.StringVar(&g.bundleDir, "bundle", "", "path to the root of the bundle directory, defaults to the current directory")
f.Var(&g.ioFDs, "io-fds", "list of FDs to connect 9P servers. They must follow this order: root first, then mounts as defined in the spec")
f.BoolVar(&g.applyCaps, "apply-caps", true, "if true, apply capabilities to restrict what the Gofer process can do")
- f.BoolVar(&g.panicOnWrite, "panic-on-write", false, "if true, panics on attempts to write to RO mounts. RW mounts are unnaffected")
f.BoolVar(&g.setUpRoot, "setup-root", true, "if true, set up an empty root for the process")
f.IntVar(&g.specFD, "spec-fd", -1, "required fd with the container spec")
f.IntVar(&g.mountsFD, "mounts-fd", -1, "mountsFD is the file descriptor to write list of mounts after they have been resolved (direct paths, no symlinks).")
@@ -100,15 +98,15 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
return subcommands.ExitUsageError
}
+ conf := args[0].(*config.Config)
+
specFile := os.NewFile(uintptr(g.specFD), "spec file")
defer specFile.Close()
- spec, err := specutils.ReadSpecFromFile(g.bundleDir, specFile)
+ spec, err := specutils.ReadSpecFromFile(g.bundleDir, specFile, conf)
if err != nil {
Fatalf("reading spec: %v", err)
}
- conf := args[0].(*config.Config)
-
if g.setUpRoot {
if err := setupRootFS(spec, conf); err != nil {
Fatalf("Error setting up root FS: %v", err)
@@ -168,8 +166,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// Start with root mount, then add any other additional mount as needed.
ats := make([]p9.Attacher, 0, len(spec.Mounts)+1)
ap, err := fsgofer.NewAttachPoint("/", fsgofer.Config{
- ROMount: spec.Root.Readonly || conf.Overlay,
- PanicOnWrite: g.panicOnWrite,
+ ROMount: spec.Root.Readonly || conf.Overlay,
})
if err != nil {
Fatalf("creating attach point: %v", err)
@@ -181,9 +178,8 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
for _, m := range spec.Mounts {
if specutils.Is9PMount(m) {
cfg := fsgofer.Config{
- ROMount: isReadonlyMount(m.Options) || conf.Overlay,
- PanicOnWrite: g.panicOnWrite,
- HostUDS: conf.FSGoferHostUDS,
+ ROMount: isReadonlyMount(m.Options) || conf.Overlay,
+ HostUDS: conf.FSGoferHostUDS,
}
ap, err := fsgofer.NewAttachPoint(m.Destination, cfg)
if err != nil {
@@ -316,6 +312,7 @@ func setupRootFS(spec *specs.Spec, conf *config.Config) error {
if err != nil {
return fmt.Errorf("resolving symlinks to %q: %v", spec.Process.Cwd, err)
}
+ log.Infof("Create working directory %q if needed", spec.Process.Cwd)
if err := os.MkdirAll(dst, 0755); err != nil {
return fmt.Errorf("creating working directory %q: %v", spec.Process.Cwd, err)
}
diff --git a/runsc/cmd/restore.go b/runsc/cmd/restore.go
index b16975804..096ec814c 100644
--- a/runsc/cmd/restore.go
+++ b/runsc/cmd/restore.go
@@ -88,7 +88,7 @@ func (r *Restore) Execute(_ context.Context, f *flag.FlagSet, args ...interface{
if bundleDir == "" {
bundleDir = getwdOrDie()
}
- spec, err := specutils.ReadSpec(bundleDir)
+ spec, err := specutils.ReadSpec(bundleDir, conf)
if err != nil {
return Errorf("reading spec: %v", err)
}
diff --git a/runsc/cmd/run.go b/runsc/cmd/run.go
index 1161de67a..c48cbe4cd 100644
--- a/runsc/cmd/run.go
+++ b/runsc/cmd/run.go
@@ -75,7 +75,7 @@ func (r *Run) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s
if bundleDir == "" {
bundleDir = getwdOrDie()
}
- spec, err := specutils.ReadSpec(bundleDir)
+ spec, err := specutils.ReadSpec(bundleDir, conf)
if err != nil {
return Errorf("reading spec: %v", err)
}
diff --git a/runsc/config/config.go b/runsc/config/config.go
index bca27ebf1..f30f79f68 100644
--- a/runsc/config/config.go
+++ b/runsc/config/config.go
@@ -157,6 +157,12 @@ type Config struct {
// Enables FUSE usage.
FUSE bool `flag:"fuse"`
+ // Allows overriding of flags in OCI annotations.
+ AllowFlagOverride bool `flag:"allow-flag-override"`
+
+ // Enables seccomp inside the sandbox.
+ OCISeccomp bool `flag:"oci-seccomp"`
+
// TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in
// tests. It allows runsc to start the sandbox process as the current
// user, and without chrooting the sandbox process. This can be
diff --git a/runsc/config/config_test.go b/runsc/config/config_test.go
index af7867a2a..fb162b7eb 100644
--- a/runsc/config/config_test.go
+++ b/runsc/config/config_test.go
@@ -183,3 +183,90 @@ func TestValidationFail(t *testing.T) {
})
}
}
+
+func TestOverride(t *testing.T) {
+ c, err := NewFromFlags()
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.AllowFlagOverride = true
+
+ t.Run("string", func(t *testing.T) {
+ c.RootDir = "foobar"
+ if err := c.Override("root", "bar"); err != nil {
+ t.Fatalf("Override(root, bar) failed: %v", err)
+ }
+ defer setDefault("root")
+ if c.RootDir != "bar" {
+ t.Errorf("Override(root, bar) didn't work: %+v", c)
+ }
+ })
+
+ t.Run("bool", func(t *testing.T) {
+ c.Debug = true
+ if err := c.Override("debug", "false"); err != nil {
+ t.Fatalf("Override(debug, false) failed: %v", err)
+ }
+ defer setDefault("debug")
+ if c.Debug {
+ t.Errorf("Override(debug, false) didn't work: %+v", c)
+ }
+ })
+
+ t.Run("enum", func(t *testing.T) {
+ c.FileAccess = FileAccessShared
+ if err := c.Override("file-access", "exclusive"); err != nil {
+ t.Fatalf("Override(file-access, exclusive) failed: %v", err)
+ }
+ defer setDefault("file-access")
+ if c.FileAccess != FileAccessExclusive {
+ t.Errorf("Override(file-access, exclusive) didn't work: %+v", c)
+ }
+ })
+}
+
+func TestOverrideDisabled(t *testing.T) {
+ c, err := NewFromFlags()
+ if err != nil {
+ t.Fatal(err)
+ }
+ const errMsg = "flag override disabled"
+ if err := c.Override("root", "path"); err == nil || !strings.Contains(err.Error(), errMsg) {
+ t.Errorf("Override() wrong error: %v", err)
+ }
+}
+
+func TestOverrideError(t *testing.T) {
+ c, err := NewFromFlags()
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.AllowFlagOverride = true
+ for _, tc := range []struct {
+ name string
+ value string
+ error string
+ }{
+ {
+ name: "invalid",
+ value: "valid",
+ error: `flag "invalid" not found`,
+ },
+ {
+ name: "debug",
+ value: "invalid",
+ error: "error setting flag debug",
+ },
+ {
+ name: "file-access",
+ value: "invalid",
+ error: "invalid file access type",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ if err := c.Override(tc.name, tc.value); err == nil || !strings.Contains(err.Error(), tc.error) {
+ t.Errorf("Override(%q, %q) wrong error: %v", tc.name, tc.value, err)
+ }
+ })
+ }
+}
diff --git a/runsc/config/flags.go b/runsc/config/flags.go
index 488a4b9fb..a5f25cfa2 100644
--- a/runsc/config/flags.go
+++ b/runsc/config/flags.go
@@ -48,6 +48,7 @@ func RegisterFlags() {
flag.Bool("log-packets", false, "enable network packet logging.")
flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.")
flag.Bool("alsologtostderr", false, "send log messages to stderr.")
+ flag.Bool("allow-flag-override", false, "allow OCI annotations (dev.gvisor.flag.<name>) to override flags for debugging.")
// Debugging flags: strace related
flag.Bool("strace", false, "enable strace.")
@@ -62,6 +63,7 @@ func RegisterFlags() {
flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.")
flag.Var(leakModePtr(refs.NoLeakChecking), "ref-leak-mode", "sets reference leak check mode: disabled (default), log-names, log-traces.")
flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)")
+ flag.Bool("oci-seccomp", false, "Enables loading OCI seccomp filters inside the sandbox.")
// Flags that control sandbox runtime behavior: FS related.
flag.Var(fileAccessTypePtr(FileAccessExclusive), "file-access", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.")
@@ -149,6 +151,41 @@ func (c *Config) ToFlags() []string {
return rv
}
+// Override writes a new value to a flag.
+func (c *Config) Override(name string, value string) error {
+ if !c.AllowFlagOverride {
+ return fmt.Errorf("flag override disabled, use --allow-flag-override to enable it")
+ }
+
+ obj := reflect.ValueOf(c).Elem()
+ st := obj.Type()
+ for i := 0; i < st.NumField(); i++ {
+ f := st.Field(i)
+ fieldName, ok := f.Tag.Lookup("flag")
+ if !ok || fieldName != name {
+ // Not a flag field, or flag name doesn't match.
+ continue
+ }
+ fl := flag.CommandLine.Lookup(name)
+ if fl == nil {
+ // Flag must exist if there is a field match above.
+ panic(fmt.Sprintf("Flag %q not found", name))
+ }
+
+ // Use flag to convert the string value to the underlying flag type, using
+ // the same rules as the command-line for consistency.
+ if err := fl.Value.Set(value); err != nil {
+ return fmt.Errorf("error setting flag %s=%q: %w", name, value, err)
+ }
+ x := reflect.ValueOf(flag.Get(fl.Value))
+ obj.Field(i).Set(x)
+
+ // Validates the config again to ensure it's left in a consistent state.
+ return c.validate()
+ }
+ return fmt.Errorf("flag %q not found. Cannot set it to %q", name, value)
+}
+
func getVal(field reflect.Value) string {
if str, ok := field.Addr().Interface().(fmt.Stringer); ok {
return str.String()
diff --git a/runsc/console/console.go b/runsc/console/console.go
index 64b23639a..dbb88e117 100644
--- a/runsc/console/console.go
+++ b/runsc/console/console.go
@@ -24,11 +24,11 @@ import (
"golang.org/x/sys/unix"
)
-// NewWithSocket creates pty master/slave pair, sends the master FD over the given
-// socket, and returns the slave.
+// NewWithSocket creates pty master/replica pair, sends the master FD over the given
+// socket, and returns the replica.
func NewWithSocket(socketPath string) (*os.File, error) {
- // Create a new pty master and slave.
- ptyMaster, ptySlave, err := pty.Open()
+ // Create a new pty master and replica.
+ ptyMaster, ptyReplica, err := pty.Open()
if err != nil {
return nil, fmt.Errorf("opening pty: %v", err)
}
@@ -37,18 +37,18 @@ func NewWithSocket(socketPath string) (*os.File, error) {
// Get a connection to the socket path.
conn, err := net.Dial("unix", socketPath)
if err != nil {
- ptySlave.Close()
+ ptyReplica.Close()
return nil, fmt.Errorf("dialing socket %q: %v", socketPath, err)
}
defer conn.Close()
uc, ok := conn.(*net.UnixConn)
if !ok {
- ptySlave.Close()
+ ptyReplica.Close()
return nil, fmt.Errorf("connection is not a UnixConn: %T", conn)
}
socket, err := uc.File()
if err != nil {
- ptySlave.Close()
+ ptyReplica.Close()
return nil, fmt.Errorf("getting file for unix socket %v: %v", uc, err)
}
defer socket.Close()
@@ -56,8 +56,8 @@ func NewWithSocket(socketPath string) (*os.File, error) {
// Send the master FD over the connection.
msg := unix.UnixRights(int(ptyMaster.Fd()))
if err := unix.Sendmsg(int(socket.Fd()), []byte("pty-master"), msg, nil, 0); err != nil {
- ptySlave.Close()
+ ptyReplica.Close()
return nil, fmt.Errorf("sending console over unix socket %q: %v", socketPath, err)
}
- return ptySlave, nil
+ return ptyReplica, nil
}
diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go
index 995d4e267..4228399b8 100644
--- a/runsc/container/console_test.go
+++ b/runsc/container/console_test.go
@@ -185,14 +185,14 @@ func TestJobControlSignalExec(t *testing.T) {
t.Fatalf("error starting container: %v", err)
}
- // Create a pty master/slave. The slave will be passed to the exec
+ // Create a pty master/replica. The replica will be passed to the exec
// process.
- ptyMaster, ptySlave, err := pty.Open()
+ ptyMaster, ptyReplica, err := pty.Open()
if err != nil {
t.Fatalf("error opening pty: %v", err)
}
defer ptyMaster.Close()
- defer ptySlave.Close()
+ defer ptyReplica.Close()
// Exec bash and attach a terminal. Note that occasionally /bin/sh
// may be a different shell or have a different configuration (such
@@ -203,9 +203,9 @@ func TestJobControlSignalExec(t *testing.T) {
// Don't let bash execute from profile or rc files, otherwise
// our PID counts get messed up.
Argv: []string{"/bin/bash", "--noprofile", "--norc"},
- // Pass the pty slave as FD 0, 1, and 2.
+ // Pass the pty replica as FD 0, 1, and 2.
FilePayload: urpc.FilePayload{
- Files: []*os.File{ptySlave, ptySlave, ptySlave},
+ Files: []*os.File{ptyReplica, ptyReplica, ptyReplica},
},
StdioIsPty: true,
}
diff --git a/runsc/container/container.go b/runsc/container/container.go
index 6e1d6a568..63478ba8c 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -902,9 +902,6 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *config.Config, bu
}
args = append(args, "gofer", "--bundle", bundleDir)
- if conf.Overlay {
- args = append(args, "--panic-on-write=true")
- }
// Open the spec file to donate to the sandbox.
specFile, err := specutils.OpenSpec(bundleDir)
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index 6082068c7..548c68087 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -41,6 +41,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/pkg/urpc"
"gvisor.dev/gvisor/runsc/boot/platforms"
"gvisor.dev/gvisor/runsc/config"
"gvisor.dev/gvisor/runsc/specutils"
@@ -292,22 +293,20 @@ var (
func configs(t *testing.T, opts ...configOption) map[string]*config.Config {
// Always load the default config.
cs := make(map[string]*config.Config)
+ testutil.TestConfig(t)
for _, o := range opts {
+ c := testutil.TestConfig(t)
switch o {
case overlay:
- c := testutil.TestConfig(t)
c.Overlay = true
cs["overlay"] = c
case ptrace:
- c := testutil.TestConfig(t)
c.Platform = platforms.Ptrace
cs["ptrace"] = c
case kvm:
- c := testutil.TestConfig(t)
c.Platform = platforms.KVM
cs["kvm"] = c
case nonExclusiveFS:
- c := testutil.TestConfig(t)
c.FileAccess = config.FileAccessShared
cs["non-exclusive"] = c
default:
@@ -318,22 +317,12 @@ func configs(t *testing.T, opts ...configOption) map[string]*config.Config {
}
func configsWithVFS2(t *testing.T, opts ...configOption) map[string]*config.Config {
- vfs1 := configs(t, opts...)
-
- var optsVFS2 []configOption
- for _, opt := range opts {
- // TODO(gvisor.dev/issue/1487): Enable overlay tests.
- if opt != overlay {
- optsVFS2 = append(optsVFS2, opt)
- }
- }
-
- for key, value := range configs(t, optsVFS2...) {
+ all := configs(t, opts...)
+ for key, value := range configs(t, opts...) {
value.VFS2 = true
- vfs1[key+"VFS2"] = value
+ all[key+"VFS2"] = value
}
-
- return vfs1
+ return all
}
// TestLifecycle tests the basic Create/Start/Signal/Destroy container lifecycle.
@@ -512,7 +501,7 @@ func TestExePath(t *testing.T) {
t.Fatalf("error making directory: %v", err)
}
- for name, conf := range configsWithVFS2(t, overlay) {
+ for name, conf := range configsWithVFS2(t, all...) {
t.Run(name, func(t *testing.T) {
for _, test := range []struct {
path string
@@ -837,7 +826,7 @@ func TestExecProcList(t *testing.T) {
// TestKillPid verifies that we can signal individual exec'd processes.
func TestKillPid(t *testing.T) {
- for name, conf := range configsWithVFS2(t, overlay) {
+ for name, conf := range configsWithVFS2(t, all...) {
t.Run(name, func(t *testing.T) {
app, err := testutil.FindFile("test/cmd/test_app/test_app")
if err != nil {
@@ -1470,7 +1459,7 @@ func TestRunNonRoot(t *testing.T) {
// TestMountNewDir checks that runsc will create destination directory if it
// doesn't exit.
func TestMountNewDir(t *testing.T) {
- for name, conf := range configsWithVFS2(t, overlay) {
+ for name, conf := range configsWithVFS2(t, all...) {
t.Run(name, func(t *testing.T) {
root, err := ioutil.TempDir(testutil.TmpDir(), "root")
if err != nil {
@@ -1490,6 +1479,8 @@ func TestMountNewDir(t *testing.T) {
Source: srcDir,
Type: "bind",
})
+ // Extra points for creating the mount with a readonly root.
+ spec.Root.Readonly = true
if err := run(spec, conf); err != nil {
t.Fatalf("error running sandbox: %v", err)
@@ -1499,17 +1490,17 @@ func TestMountNewDir(t *testing.T) {
}
func TestReadonlyRoot(t *testing.T) {
- for name, conf := range configsWithVFS2(t, overlay) {
+ for name, conf := range configsWithVFS2(t, all...) {
t.Run(name, func(t *testing.T) {
- spec := testutil.NewSpecWithArgs("/bin/touch", "/foo")
+ spec := testutil.NewSpecWithArgs("sleep", "100")
spec.Root.Readonly = true
+
_, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
defer cleanup()
- // Create, start and wait for the container.
args := Args{
ID: testutil.RandomContainerID(),
Spec: spec,
@@ -1524,12 +1515,82 @@ func TestReadonlyRoot(t *testing.T) {
t.Fatalf("error starting container: %v", err)
}
- ws, err := c.Wait()
+ // Read mounts to check that root is readonly.
+ out, ws, err := executeCombinedOutput(c, "/bin/sh", "-c", "mount | grep ' / '")
+ if err != nil || ws != 0 {
+ t.Fatalf("exec failed, ws: %v, err: %v", ws, err)
+ }
+ t.Logf("root mount: %q", out)
+ if !strings.Contains(string(out), "(ro)") {
+ t.Errorf("root not mounted readonly: %q", out)
+ }
+
+ // Check that file cannot be created.
+ ws, err = execute(c, "/bin/touch", "/foo")
if err != nil {
- t.Fatalf("error waiting on container: %v", err)
+ t.Fatalf("touch file in ro mount: %v", err)
}
if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM {
- t.Fatalf("container failed, waitStatus: %v", ws)
+ t.Fatalf("wrong waitStatus: %v", ws)
+ }
+ })
+ }
+}
+
+func TestReadonlyMount(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "ro-mount")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ spec := testutil.NewSpecWithArgs("sleep", "100")
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: dir,
+ Source: dir,
+ Type: "bind",
+ Options: []string{"ro"},
+ })
+ spec.Root.Readonly = false
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Read mounts to check that volume is readonly.
+ cmd := fmt.Sprintf("mount | grep ' %s '", dir)
+ out, ws, err := executeCombinedOutput(c, "/bin/sh", "-c", cmd)
+ if err != nil || ws != 0 {
+ t.Fatalf("exec failed, ws: %v, err: %v", ws, err)
+ }
+ t.Logf("mount: %q", out)
+ if !strings.Contains(string(out), "(ro)") {
+ t.Errorf("volume not mounted readonly: %q", out)
+ }
+
+ // Check that file cannot be created.
+ ws, err = execute(c, "/bin/touch", path.Join(dir, "file"))
+ if err != nil {
+ t.Fatalf("touch file in ro mount: %v", err)
+ }
+ if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM {
+ t.Fatalf("wrong WaitStatus: %v", ws)
}
})
}
@@ -1616,54 +1677,6 @@ func TestUIDMap(t *testing.T) {
}
}
-func TestReadonlyMount(t *testing.T) {
- for name, conf := range configsWithVFS2(t, overlay) {
- t.Run(name, func(t *testing.T) {
- dir, err := ioutil.TempDir(testutil.TmpDir(), "ro-mount")
- spec := testutil.NewSpecWithArgs("/bin/touch", path.Join(dir, "file"))
- if err != nil {
- t.Fatalf("ioutil.TempDir() failed: %v", err)
- }
- spec.Mounts = append(spec.Mounts, specs.Mount{
- Destination: dir,
- Source: dir,
- Type: "bind",
- Options: []string{"ro"},
- })
- spec.Root.Readonly = false
-
- _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer cleanup()
-
- // Create, start and wait for the container.
- args := Args{
- ID: testutil.RandomContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- c, err := New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer c.Destroy()
- if err := c.Start(conf); err != nil {
- t.Fatalf("error starting container: %v", err)
- }
-
- ws, err := c.Wait()
- if err != nil {
- t.Fatalf("error waiting on container: %v", err)
- }
- if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM {
- t.Fatalf("container failed, waitStatus: %v", ws)
- }
- })
- }
-}
-
// TestAbbreviatedIDs checks that runsc supports using abbreviated container
// IDs in place of full IDs.
func TestAbbreviatedIDs(t *testing.T) {
@@ -1830,8 +1843,9 @@ func TestUserLog(t *testing.T) {
t.Fatal("error finding test_app:", err)
}
- // sched_rr_get_interval = 148 - not implemented in gvisor.
- spec := testutil.NewSpecWithArgs(app, "syscall", "--syscall=148")
+ // sched_rr_get_interval - not implemented in gvisor.
+ num := strconv.Itoa(syscall.SYS_SCHED_RR_GET_INTERVAL)
+ spec := testutil.NewSpecWithArgs(app, "syscall", "--syscall="+num)
conf := testutil.TestConfig(t)
_, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
@@ -2013,7 +2027,7 @@ func doDestroyStartingTest(t *testing.T, vfs2 bool) {
}
func TestCreateWorkingDir(t *testing.T) {
- for name, conf := range configsWithVFS2(t, overlay) {
+ for name, conf := range configsWithVFS2(t, all...) {
t.Run(name, func(t *testing.T) {
tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "cwd-create")
if err != nil {
@@ -2116,27 +2130,19 @@ func TestMountPropagation(t *testing.T) {
// Check that mount didn't propagate to private mount.
privFile := filepath.Join(priv, "mnt", "file")
- execArgs := &control.ExecArgs{
- Filename: "/usr/bin/test",
- Argv: []string{"test", "!", "-f", privFile},
- }
- if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 {
+ if ws, err := execute(cont, "/usr/bin/test", "!", "-f", privFile); err != nil || ws != 0 {
t.Fatalf("exec: test ! -f %q, ws: %v, err: %v", privFile, ws, err)
}
// Check that mount propagated to slave mount.
slaveFile := filepath.Join(slave, "mnt", "file")
- execArgs = &control.ExecArgs{
- Filename: "/usr/bin/test",
- Argv: []string{"test", "-f", slaveFile},
- }
- if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 {
+ if ws, err := execute(cont, "/usr/bin/test", "-f", slaveFile); err != nil || ws != 0 {
t.Fatalf("exec: test -f %q, ws: %v, err: %v", privFile, ws, err)
}
}
func TestMountSymlink(t *testing.T) {
- for name, conf := range configsWithVFS2(t, overlay) {
+ for name, conf := range configsWithVFS2(t, all...) {
t.Run(name, func(t *testing.T) {
dir, err := ioutil.TempDir(testutil.TmpDir(), "mount-symlink")
if err != nil {
@@ -2196,11 +2202,7 @@ func TestMountSymlink(t *testing.T) {
// Check that symlink was resolved and mount was created where the symlink
// is pointing to.
file := path.Join(target, "file")
- execArgs := &control.ExecArgs{
- Filename: "/usr/bin/test",
- Argv: []string{"test", "-f", file},
- }
- if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 {
+ if ws, err := execute(cont, "/usr/bin/test", "-f", file); err != nil || ws != 0 {
t.Fatalf("exec: test -f %q, ws: %v, err: %v", file, ws, err)
}
})
@@ -2326,6 +2328,35 @@ func TestTTYField(t *testing.T) {
}
}
+func execute(cont *Container, name string, arg ...string) (syscall.WaitStatus, error) {
+ args := &control.ExecArgs{
+ Filename: name,
+ Argv: append([]string{name}, arg...),
+ }
+ return cont.executeSync(args)
+}
+
+func executeCombinedOutput(cont *Container, name string, arg ...string) ([]byte, syscall.WaitStatus, error) {
+ r, w, err := os.Pipe()
+ if err != nil {
+ return nil, 0, err
+ }
+ defer r.Close()
+
+ args := &control.ExecArgs{
+ Filename: name,
+ Argv: append([]string{name}, arg...),
+ FilePayload: urpc.FilePayload{Files: []*os.File{os.Stdin, w, w}},
+ }
+ ws, err := cont.executeSync(args)
+ w.Close()
+ if err != nil {
+ return nil, 0, err
+ }
+ out, err := ioutil.ReadAll(r)
+ return out, ws, err
+}
+
// executeSync synchronously executes a new process.
func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) {
pid, err := cont.Execute(args)
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
index 1beea123f..952215ec1 100644
--- a/runsc/container/multi_container_test.go
+++ b/runsc/container/multi_container_test.go
@@ -169,7 +169,7 @@ func TestMultiContainerSanity(t *testing.T) {
// TestMultiPIDNS checks that it is possible to run 2 dead-simple
// containers in the same sandbox with different pidns.
func TestMultiPIDNS(t *testing.T) {
- for name, conf := range configs(t, all...) {
+ for name, conf := range configsWithVFS2(t, all...) {
t.Run(name, func(t *testing.T) {
rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
@@ -214,7 +214,7 @@ func TestMultiPIDNS(t *testing.T) {
// TestMultiPIDNSPath checks the pidns path.
func TestMultiPIDNSPath(t *testing.T) {
- for name, conf := range configs(t, all...) {
+ for name, conf := range configsWithVFS2(t, all...) {
t.Run(name, func(t *testing.T) {
rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
@@ -580,7 +580,7 @@ func TestMultiContainerDestroy(t *testing.T) {
t.Fatal("error finding test_app:", err)
}
- for name, conf := range configs(t, all...) {
+ for name, conf := range configsWithVFS2(t, all...) {
t.Run(name, func(t *testing.T) {
rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
@@ -1252,8 +1252,7 @@ func TestMultiContainerSharedMountReadonly(t *testing.T) {
// Test that shared pod mounts continue to work after container is restarted.
func TestMultiContainerSharedMountRestart(t *testing.T) {
- //TODO(gvisor.dev/issue/1487): This is failing with VFS2.
- for name, conf := range configs(t, all...) {
+ for name, conf := range configsWithVFS2(t, all...) {
t.Run(name, func(t *testing.T) {
rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
@@ -1360,7 +1359,7 @@ func TestMultiContainerSharedMountRestart(t *testing.T) {
}
// Test that unsupported pod mounts options are ignored when matching master and
-// slave mounts.
+// replica mounts.
func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) {
for name, conf := range configsWithVFS2(t, all...) {
t.Run(name, func(t *testing.T) {
@@ -1518,8 +1517,7 @@ func TestMultiContainerGoferKilled(t *testing.T) {
}
// Check that container isn't running anymore.
- args := &control.ExecArgs{Argv: []string{"/bin/true"}}
- if _, err := c.executeSync(args); err == nil {
+ if _, err := execute(c, "/bin/true"); err == nil {
t.Fatalf("Container %q was not stopped after gofer death", c.ID)
}
@@ -1534,8 +1532,7 @@ func TestMultiContainerGoferKilled(t *testing.T) {
if err := waitForProcessList(c, pl); err != nil {
t.Errorf("Container %q was affected by another container: %v", c.ID, err)
}
- args := &control.ExecArgs{Argv: []string{"/bin/true"}}
- if _, err := c.executeSync(args); err != nil {
+ if _, err := execute(c, "/bin/true"); err != nil {
t.Fatalf("Container %q was affected by another container: %v", c.ID, err)
}
}
@@ -1557,8 +1554,7 @@ func TestMultiContainerGoferKilled(t *testing.T) {
// Check that entire sandbox isn't running anymore.
for _, c := range containers {
- args := &control.ExecArgs{Argv: []string{"/bin/true"}}
- if _, err := c.executeSync(args); err == nil {
+ if _, err := execute(c, "/bin/true"); err == nil {
t.Fatalf("Container %q was not stopped after gofer death", c.ID)
}
}
@@ -1720,12 +1716,11 @@ func TestMultiContainerHomeEnvDir(t *testing.T) {
homeDirs[name] = homeFile
}
- // We will sleep in the root container in order to ensure that
- // the root container doesn't terminate before sub containers can be
- // created.
+ // We will sleep in the root container in order to ensure that the root
+ //container doesn't terminate before sub containers can be created.
rootCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s; sleep 1000", homeDirs["root"].Name())}
subCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["sub"].Name())}
- execCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["exec"].Name())}
+ execCmd := fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["exec"].Name())
// Setup the containers, a root container and sub container.
specConfig, ids := createSpecs(rootCmd, subCmd)
@@ -1736,9 +1731,8 @@ func TestMultiContainerHomeEnvDir(t *testing.T) {
defer cleanup()
// Exec into the root container synchronously.
- args := &control.ExecArgs{Argv: execCmd}
- if _, err := containers[0].executeSync(args); err != nil {
- t.Errorf("error executing %+v: %v", args, err)
+ if _, err := execute(containers[0], "/bin/sh", "-c", execCmd); err != nil {
+ t.Errorf("error executing %+v: %v", execCmd, err)
}
// Wait for the subcontainer to finish.
diff --git a/runsc/container/shared_volume_test.go b/runsc/container/shared_volume_test.go
index 4ea8fefee..cb5bffb89 100644
--- a/runsc/container/shared_volume_test.go
+++ b/runsc/container/shared_volume_test.go
@@ -168,11 +168,7 @@ func TestSharedVolume(t *testing.T) {
func checkFile(c *Container, filename string, want []byte) error {
cpy := filename + ".copy"
- argsCp := &control.ExecArgs{
- Filename: "/bin/cp",
- Argv: []string{"cp", "-f", filename, cpy},
- }
- if _, err := c.executeSync(argsCp); err != nil {
+ if _, err := execute(c, "/bin/cp", "-f", filename, cpy); err != nil {
return fmt.Errorf("unexpected error copying file %q to %q: %v", filename, cpy, err)
}
got, err := ioutil.ReadFile(cpy)
@@ -235,11 +231,7 @@ func TestSharedVolumeFile(t *testing.T) {
}
// Append to file inside the container and check that content is not lost.
- argsAppend := &control.ExecArgs{
- Filename: "/bin/bash",
- Argv: []string{"bash", "-c", "echo -n sandbox- >> " + filename},
- }
- if _, err := c.executeSync(argsAppend); err != nil {
+ if _, err := execute(c, "/bin/bash", "-c", "echo -n sandbox- >> "+filename); err != nil {
t.Fatalf("unexpected error appending file %q: %v", filename, err)
}
want = []byte("host-sandbox-")
diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go
index 88814b83c..39b8a0b1e 100644
--- a/runsc/fsgofer/filter/config.go
+++ b/runsc/fsgofer/filter/config.go
@@ -27,62 +27,51 @@ import (
var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_ACCEPT: {},
syscall.SYS_CLOCK_GETTIME: {},
- syscall.SYS_CLONE: []seccomp.Rule{
- {
- seccomp.AllowValue(
- syscall.CLONE_VM |
- syscall.CLONE_FS |
- syscall.CLONE_FILES |
- syscall.CLONE_SIGHAND |
- syscall.CLONE_SYSVSEM |
- syscall.CLONE_THREAD),
- },
- },
- syscall.SYS_CLOSE: {},
- syscall.SYS_DUP: {},
- syscall.SYS_EPOLL_CTL: {},
+ syscall.SYS_CLOSE: {},
+ syscall.SYS_DUP: {},
+ syscall.SYS_EPOLL_CTL: {},
syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(0),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(0),
},
},
syscall.SYS_EVENTFD2: []seccomp.Rule{
{
- seccomp.AllowValue(0),
- seccomp.AllowValue(0),
+ seccomp.EqualTo(0),
+ seccomp.EqualTo(0),
},
},
syscall.SYS_EXIT: {},
syscall.SYS_EXIT_GROUP: {},
syscall.SYS_FALLOCATE: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowValue(0),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(0),
},
},
syscall.SYS_FCHMOD: {},
syscall.SYS_FCHOWNAT: {},
syscall.SYS_FCNTL: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.F_GETFL),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.F_GETFL),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.F_SETFL),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.F_SETFL),
},
{
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.F_GETFD),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.F_GETFD),
},
// Used by flipcall.PacketWindowAllocator.Init().
{
- seccomp.AllowAny{},
- seccomp.AllowValue(unix.F_ADD_SEALS),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(unix.F_ADD_SEALS),
},
},
syscall.SYS_FSTAT: {},
@@ -91,31 +80,31 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_FTRUNCATE: {},
syscall.SYS_FUTEX: {
seccomp.Rule{
- seccomp.AllowAny{},
- seccomp.AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(0),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(0),
},
seccomp.Rule{
- seccomp.AllowAny{},
- seccomp.AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(0),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(0),
},
// Non-private futex used for flipcall.
seccomp.Rule{
- seccomp.AllowAny{},
- seccomp.AllowValue(linux.FUTEX_WAIT),
- seccomp.AllowAny{},
- seccomp.AllowAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(linux.FUTEX_WAIT),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
},
seccomp.Rule{
- seccomp.AllowAny{},
- seccomp.AllowValue(linux.FUTEX_WAKE),
- seccomp.AllowAny{},
- seccomp.AllowAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(linux.FUTEX_WAKE),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
},
},
syscall.SYS_GETDENTS64: {},
@@ -137,28 +126,28 @@ var allowedSyscalls = seccomp.SyscallRules{
// TODO(b/148688965): Remove once this is gone from Go.
syscall.SYS_MLOCK: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowValue(4096),
+ seccomp.MatchAny{},
+ seccomp.EqualTo(4096),
},
},
syscall.SYS_MMAP: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MAP_SHARED),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MAP_SHARED),
},
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS),
},
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED),
},
},
syscall.SYS_MPROTECT: {},
@@ -172,14 +161,14 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_READLINKAT: {},
syscall.SYS_RECVMSG: []seccomp.Rule{
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC),
},
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK),
},
},
syscall.SYS_RENAMEAT: {},
@@ -190,33 +179,33 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_SENDMSG: []seccomp.Rule{
// Used by fdchannel.Endpoint.SendFD().
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(0),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(0),
},
// Used by unet.SocketWriter.WriteVec().
{
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL),
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL),
},
},
syscall.SYS_SHUTDOWN: []seccomp.Rule{
- {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)},
+ {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RDWR)},
},
syscall.SYS_SIGALTSTACK: {},
// Used by fdchannel.NewConnectedSockets().
syscall.SYS_SOCKETPAIR: {
{
- seccomp.AllowValue(syscall.AF_UNIX),
- seccomp.AllowValue(syscall.SOCK_SEQPACKET | syscall.SOCK_CLOEXEC),
- seccomp.AllowValue(0),
+ seccomp.EqualTo(syscall.AF_UNIX),
+ seccomp.EqualTo(syscall.SOCK_SEQPACKET | syscall.SOCK_CLOEXEC),
+ seccomp.EqualTo(0),
},
},
syscall.SYS_SYMLINKAT: {},
syscall.SYS_TGKILL: []seccomp.Rule{
{
- seccomp.AllowValue(uint64(os.Getpid())),
+ seccomp.EqualTo(uint64(os.Getpid())),
},
},
syscall.SYS_UNLINKAT: {},
@@ -227,24 +216,24 @@ var allowedSyscalls = seccomp.SyscallRules{
var udsSyscalls = seccomp.SyscallRules{
syscall.SYS_SOCKET: []seccomp.Rule{
{
- seccomp.AllowValue(syscall.AF_UNIX),
- seccomp.AllowValue(syscall.SOCK_STREAM),
- seccomp.AllowValue(0),
+ seccomp.EqualTo(syscall.AF_UNIX),
+ seccomp.EqualTo(syscall.SOCK_STREAM),
+ seccomp.EqualTo(0),
},
{
- seccomp.AllowValue(syscall.AF_UNIX),
- seccomp.AllowValue(syscall.SOCK_DGRAM),
- seccomp.AllowValue(0),
+ seccomp.EqualTo(syscall.AF_UNIX),
+ seccomp.EqualTo(syscall.SOCK_DGRAM),
+ seccomp.EqualTo(0),
},
{
- seccomp.AllowValue(syscall.AF_UNIX),
- seccomp.AllowValue(syscall.SOCK_SEQPACKET),
- seccomp.AllowValue(0),
+ seccomp.EqualTo(syscall.AF_UNIX),
+ seccomp.EqualTo(syscall.SOCK_SEQPACKET),
+ seccomp.EqualTo(0),
},
},
syscall.SYS_CONNECT: []seccomp.Rule{
{
- seccomp.AllowAny{},
+ seccomp.MatchAny{},
},
},
}
diff --git a/runsc/fsgofer/filter/config_amd64.go b/runsc/fsgofer/filter/config_amd64.go
index a4b28cb8b..686753d96 100644
--- a/runsc/fsgofer/filter/config_amd64.go
+++ b/runsc/fsgofer/filter/config_amd64.go
@@ -25,8 +25,41 @@ import (
func init() {
allowedSyscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{
- {seccomp.AllowValue(linux.ARCH_GET_FS)},
- {seccomp.AllowValue(linux.ARCH_SET_FS)},
+ // TODO(b/168828518): No longer used in Go 1.16+.
+ {seccomp.EqualTo(linux.ARCH_SET_FS)},
+ }
+
+ allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{
+ // parent_tidptr and child_tidptr are always 0 because neither
+ // CLONE_PARENT_SETTID nor CLONE_CHILD_SETTID are used.
+ {
+ seccomp.EqualTo(
+ syscall.CLONE_VM |
+ syscall.CLONE_FS |
+ syscall.CLONE_FILES |
+ syscall.CLONE_SETTLS |
+ syscall.CLONE_SIGHAND |
+ syscall.CLONE_SYSVSEM |
+ syscall.CLONE_THREAD),
+ seccomp.MatchAny{}, // newsp
+ seccomp.EqualTo(0), // parent_tidptr
+ seccomp.EqualTo(0), // child_tidptr
+ seccomp.MatchAny{}, // tls
+ },
+ {
+ // TODO(b/168828518): No longer used in Go 1.16+ (on amd64).
+ seccomp.EqualTo(
+ syscall.CLONE_VM |
+ syscall.CLONE_FS |
+ syscall.CLONE_FILES |
+ syscall.CLONE_SIGHAND |
+ syscall.CLONE_SYSVSEM |
+ syscall.CLONE_THREAD),
+ seccomp.MatchAny{}, // newsp
+ seccomp.EqualTo(0), // parent_tidptr
+ seccomp.EqualTo(0), // child_tidptr
+ seccomp.MatchAny{}, // tls
+ },
}
allowedSyscalls[syscall.SYS_NEWFSTATAT] = []seccomp.Rule{}
diff --git a/runsc/fsgofer/filter/config_arm64.go b/runsc/fsgofer/filter/config_arm64.go
index d2697deb7..ff0cf77a0 100644
--- a/runsc/fsgofer/filter/config_arm64.go
+++ b/runsc/fsgofer/filter/config_arm64.go
@@ -23,5 +23,26 @@ import (
)
func init() {
+ allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{
+ // parent_tidptr and child_tidptr are always 0 because neither
+ // CLONE_PARENT_SETTID nor CLONE_CHILD_SETTID are used.
+ {
+ seccomp.EqualTo(
+ syscall.CLONE_VM |
+ syscall.CLONE_FS |
+ syscall.CLONE_FILES |
+ syscall.CLONE_SIGHAND |
+ syscall.CLONE_SYSVSEM |
+ syscall.CLONE_THREAD),
+ seccomp.MatchAny{}, // newsp
+ // These arguments are left uninitialized by the Go
+ // runtime, so they may be anything (and are unused by
+ // the host).
+ seccomp.MatchAny{}, // parent_tidptr
+ seccomp.MatchAny{}, // tls
+ seccomp.MatchAny{}, // child_tidptr
+ },
+ }
+
allowedSyscalls[syscall.SYS_FSTATAT] = []seccomp.Rule{}
}
diff --git a/runsc/fsgofer/filter/extra_filters_race.go b/runsc/fsgofer/filter/extra_filters_race.go
index 885c92f7a..20a0732be 100644
--- a/runsc/fsgofer/filter/extra_filters_race.go
+++ b/runsc/fsgofer/filter/extra_filters_race.go
@@ -35,6 +35,7 @@ func instrumentationFilters() seccomp.SyscallRules {
syscall.SYS_MUNLOCK: {},
syscall.SYS_NANOSLEEP: {},
syscall.SYS_OPEN: {},
+ syscall.SYS_OPENAT: {},
syscall.SYS_SET_ROBUST_LIST: {},
// Used within glibc's malloc.
syscall.SYS_TIME: {},
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index b0788bd23..0b628c8ce 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -181,6 +181,8 @@ func (a *attachPoint) makeQID(stat unix.Stat_t) p9.QID {
// The few exceptions where it cannot be done are: utimensat on symlinks, and
// Connect() for the socket address.
type localFile struct {
+ p9.DisallowClientCalls
+
// attachPoint is the attachPoint that serves this localFile.
attachPoint *attachPoint
@@ -1179,9 +1181,6 @@ func extractErrno(err error) unix.Errno {
func (l *localFile) checkROMount() error {
if conf := l.attachPoint.conf; conf.ROMount {
- if conf.PanicOnWrite {
- panic("attempt to write to RO mount")
- }
return unix.EROFS
}
return nil
diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go
index 0e4945b3d..a84206686 100644
--- a/runsc/fsgofer/fsgofer_test.go
+++ b/runsc/fsgofer/fsgofer_test.go
@@ -553,29 +553,6 @@ func TestROMountChecks(t *testing.T) {
})
}
-func TestROMountPanics(t *testing.T) {
- conf := Config{ROMount: true, PanicOnWrite: true}
- uid := p9.UID(os.Getuid())
- gid := p9.GID(os.Getgid())
-
- runCustom(t, allTypes, []Config{conf}, func(t *testing.T, s state) {
- if s.fileType != unix.S_IFLNK {
- assertPanic(t, func() { s.file.Open(p9.WriteOnly) })
- }
- assertPanic(t, func() { s.file.Create("some_file", p9.ReadWrite, 0777, uid, gid) })
- assertPanic(t, func() { s.file.Mkdir("some_dir", 0777, uid, gid) })
- assertPanic(t, func() { s.file.RenameAt("some_file", s.file, "other_file") })
- assertPanic(t, func() { s.file.Symlink("some_place", "some_symlink", uid, gid) })
- assertPanic(t, func() { s.file.UnlinkAt("some_file", 0) })
- assertPanic(t, func() { s.file.Link(s.file, "some_link") })
- assertPanic(t, func() { s.file.Mknod("some-nod", 0777, 1, 2, uid, gid) })
-
- valid := p9.SetAttrMask{Size: true}
- attr := p9.SetAttr{Size: 0}
- assertPanic(t, func() { s.file.SetAttr(valid, attr) })
- })
-}
-
func TestWalkNotFound(t *testing.T) {
runCustom(t, []uint32{unix.S_IFDIR}, allConfs, func(t *testing.T, s state) {
if _, _, err := s.file.Walk([]string{"nobody-here"}); err != unix.ENOENT {
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index a339937fb..a8f4f64a5 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -478,10 +478,10 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn
cmd.Stderr = nil
// If the console control socket file is provided, then create a new
- // pty master/slave pair and set the TTY on the sandbox process.
+ // pty master/replica pair and set the TTY on the sandbox process.
if args.Spec.Process.Terminal && args.ConsoleSocket != "" {
// console.NewWithSocket will send the master on the given
- // socket, and return the slave.
+ // socket, and return the replica.
tty, err := console.NewWithSocket(args.ConsoleSocket)
if err != nil {
return fmt.Errorf("setting up console with socket %q: %v", args.ConsoleSocket, err)
diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD
index 43851a22f..679d8bc8e 100644
--- a/runsc/specutils/BUILD
+++ b/runsc/specutils/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/bits",
"//pkg/log",
"//pkg/sentry/kernel/auth",
+ "//runsc/config",
"@com_github_cenkalti_backoff//:go_default_library",
"@com_github_mohae_deepcopy//:go_default_library",
"@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
diff --git a/runsc/specutils/seccomp/BUILD b/runsc/specutils/seccomp/BUILD
new file mode 100644
index 000000000..3520f2d6d
--- /dev/null
+++ b/runsc/specutils/seccomp/BUILD
@@ -0,0 +1,34 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "seccomp",
+ srcs = [
+ "audit_amd64.go",
+ "audit_arm64.go",
+ "seccomp.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/bpf",
+ "//pkg/log",
+ "//pkg/seccomp",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/syscalls/linux",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ ],
+)
+
+go_test(
+ name = "seccomp_test",
+ size = "small",
+ srcs = ["seccomp_test.go"],
+ library = ":seccomp",
+ deps = [
+ "//pkg/binary",
+ "//pkg/bpf",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ ],
+)
diff --git a/runsc/specutils/seccomp/audit_amd64.go b/runsc/specutils/seccomp/audit_amd64.go
new file mode 100644
index 000000000..417cf4a7a
--- /dev/null
+++ b/runsc/specutils/seccomp/audit_amd64.go
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package seccomp
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+const (
+ nativeArchAuditNo = linux.AUDIT_ARCH_X86_64
+)
diff --git a/runsc/specutils/seccomp/audit_arm64.go b/runsc/specutils/seccomp/audit_arm64.go
new file mode 100644
index 000000000..b727ceff2
--- /dev/null
+++ b/runsc/specutils/seccomp/audit_arm64.go
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package seccomp
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+const (
+ nativeArchAuditNo = linux.AUDIT_ARCH_AARCH64
+)
diff --git a/runsc/specutils/seccomp/seccomp.go b/runsc/specutils/seccomp/seccomp.go
new file mode 100644
index 000000000..5932f7a41
--- /dev/null
+++ b/runsc/specutils/seccomp/seccomp.go
@@ -0,0 +1,229 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package seccomp implements some features of libseccomp in order to support
+// OCI.
+package seccomp
+
+import (
+ "fmt"
+ "syscall"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/seccomp"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+)
+
+var (
+ killThreadAction = linux.SECCOMP_RET_KILL_THREAD
+ trapAction = linux.SECCOMP_RET_TRAP
+ // runc always returns EPERM as the errorcode for SECCOMP_RET_ERRNO
+ errnoAction = linux.SECCOMP_RET_ERRNO.WithReturnCode(uint16(syscall.EPERM))
+ // runc always returns EPERM as the errorcode for SECCOMP_RET_TRACE
+ traceAction = linux.SECCOMP_RET_TRACE.WithReturnCode(uint16(syscall.EPERM))
+ allowAction = linux.SECCOMP_RET_ALLOW
+)
+
+// BuildProgram generates a bpf program based on the given OCI seccomp
+// config.
+func BuildProgram(s *specs.LinuxSeccomp) (bpf.Program, error) {
+ defaultAction, err := convertAction(s.DefaultAction)
+ if err != nil {
+ return bpf.Program{}, fmt.Errorf("secomp default action: %w", err)
+ }
+ ruleset, err := convertRules(s)
+ if err != nil {
+ return bpf.Program{}, fmt.Errorf("invalid seccomp rules: %w", err)
+ }
+
+ instrs, err := seccomp.BuildProgram(ruleset, defaultAction, killThreadAction)
+ if err != nil {
+ return bpf.Program{}, fmt.Errorf("building seccomp program: %w", err)
+ }
+
+ program, err := bpf.Compile(instrs)
+ if err != nil {
+ return bpf.Program{}, fmt.Errorf("compiling seccomp program: %w", err)
+ }
+
+ return program, nil
+}
+
+// lookupSyscallNo gets the syscall number for the syscall with the given name
+// for the given architecture.
+func lookupSyscallNo(arch uint32, name string) (uint32, error) {
+ var table *kernel.SyscallTable
+ switch arch {
+ case linux.AUDIT_ARCH_X86_64:
+ table = slinux.AMD64
+ case linux.AUDIT_ARCH_AARCH64:
+ table = slinux.ARM64
+ }
+ if table == nil {
+ return 0, fmt.Errorf("unsupported architecture: %d", arch)
+ }
+ n, err := table.LookupNo(name)
+ if err != nil {
+ return 0, err
+ }
+ return uint32(n), nil
+}
+
+// convertAction converts a LinuxSeccompAction to BPFAction
+func convertAction(act specs.LinuxSeccompAction) (linux.BPFAction, error) {
+ // TODO(gvisor.dev/issue/3124): Update specs package to include ActLog and ActKillProcess.
+ switch act {
+ case specs.ActKill:
+ return killThreadAction, nil
+ case specs.ActTrap:
+ return trapAction, nil
+ case specs.ActErrno:
+ return errnoAction, nil
+ case specs.ActTrace:
+ return traceAction, nil
+ case specs.ActAllow:
+ return allowAction, nil
+ default:
+ return 0, fmt.Errorf("invalid action: %v", act)
+ }
+}
+
+// convertRules converts OCI linux seccomp rules into RuleSets that can be used by
+// the seccomp package to build a seccomp program.
+func convertRules(s *specs.LinuxSeccomp) ([]seccomp.RuleSet, error) {
+ // NOTE: Architectures are only really relevant when calling 32bit syscalls
+ // on a 64bit system. Since we don't support that in gVisor anyway, we
+ // ignore Architectures and only test against the native architecture.
+
+ ruleset := []seccomp.RuleSet{}
+
+ for _, syscall := range s.Syscalls {
+ sysRules := seccomp.NewSyscallRules()
+
+ action, err := convertAction(syscall.Action)
+ if err != nil {
+ return nil, err
+ }
+
+ // Args
+ rules, err := convertArgs(syscall.Args)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, name := range syscall.Names {
+ syscallNo, err := lookupSyscallNo(nativeArchAuditNo, name)
+ if err != nil {
+ // If there is an error looking up the syscall number, assume it is
+ // not supported on this architecture and ignore it. This is, for
+ // better or worse, what runc does.
+ log.Warningf("OCI seccomp: ignoring syscall %q", name)
+ continue
+ }
+
+ for _, rule := range rules {
+ sysRules.AddRule(uintptr(syscallNo), rule)
+ }
+ }
+
+ ruleset = append(ruleset, seccomp.RuleSet{
+ Rules: sysRules,
+ Action: action,
+ })
+ }
+
+ return ruleset, nil
+}
+
+// convertArgs converts an OCI seccomp argument rule to a list of seccomp.Rule.
+func convertArgs(args []specs.LinuxSeccompArg) ([]seccomp.Rule, error) {
+ argCounts := make([]uint, 6)
+
+ for _, arg := range args {
+ if arg.Index > 6 {
+ return nil, fmt.Errorf("invalid index: %d", arg.Index)
+ }
+
+ argCounts[arg.Index]++
+ }
+
+ // NOTE: If multiple rules apply to the same argument (same index) the
+ // action is triggered if any one of the rules matches (OR). If not, then
+ // all rules much match in order to trigger the action (AND). This appears to
+ // be some kind of legacy behavior of runc that nevertheless needs to be
+ // supported to maintain compatibility.
+
+ hasMultipleArgs := false
+ for _, count := range argCounts {
+ if count > 1 {
+ hasMultipleArgs = true
+ break
+ }
+ }
+
+ if hasMultipleArgs {
+ rules := []seccomp.Rule{}
+
+ // Old runc behavior - do this for compatibility.
+ // Add rules as ORs by adding separate Rules.
+ for _, arg := range args {
+ rule := seccomp.Rule{nil, nil, nil, nil, nil, nil}
+
+ if err := convertRule(arg, &rule); err != nil {
+ return nil, err
+ }
+
+ rules = append(rules, rule)
+ }
+
+ return rules, nil
+ }
+
+ // Add rules as ANDs by adding to the same Rule.
+ rule := seccomp.Rule{nil, nil, nil, nil, nil, nil}
+ for _, arg := range args {
+ if err := convertRule(arg, &rule); err != nil {
+ return nil, err
+ }
+ }
+
+ return []seccomp.Rule{rule}, nil
+}
+
+// convertRule converts and adds the arg to a rule.
+func convertRule(arg specs.LinuxSeccompArg, rule *seccomp.Rule) error {
+ switch arg.Op {
+ case specs.OpEqualTo:
+ rule[arg.Index] = seccomp.EqualTo(arg.Value)
+ case specs.OpNotEqual:
+ rule[arg.Index] = seccomp.NotEqual(arg.Value)
+ case specs.OpGreaterThan:
+ rule[arg.Index] = seccomp.GreaterThan(arg.Value)
+ case specs.OpGreaterEqual:
+ rule[arg.Index] = seccomp.GreaterThanOrEqual(arg.Value)
+ case specs.OpLessThan:
+ rule[arg.Index] = seccomp.LessThan(arg.Value)
+ case specs.OpLessEqual:
+ rule[arg.Index] = seccomp.LessThanOrEqual(arg.Value)
+ case specs.OpMaskedEqual:
+ rule[arg.Index] = seccomp.MaskedEqual(uintptr(arg.Value), uintptr(arg.ValueTwo))
+ default:
+ return fmt.Errorf("unsupported operand: %q", arg.Op)
+ }
+ return nil
+}
diff --git a/runsc/specutils/seccomp/seccomp_test.go b/runsc/specutils/seccomp/seccomp_test.go
new file mode 100644
index 000000000..850c237ba
--- /dev/null
+++ b/runsc/specutils/seccomp/seccomp_test.go
@@ -0,0 +1,414 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccomp
+
+import (
+ "fmt"
+ "syscall"
+ "testing"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/bpf"
+)
+
+type seccompData struct {
+ nr uint32
+ arch uint32
+ instructionPointer uint64
+ args [6]uint64
+}
+
+// asInput converts a seccompData to a bpf.Input.
+func asInput(d seccompData) bpf.Input {
+ return bpf.InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian}
+}
+
+// testInput creates an Input struct with given seccomp input values.
+func testInput(arch uint32, syscallName string, args *[6]uint64) bpf.Input {
+ syscallNo, err := lookupSyscallNo(arch, syscallName)
+ if err != nil {
+ // Assume tests set valid syscall names.
+ panic(err)
+ }
+
+ if args == nil {
+ argArray := [6]uint64{0, 0, 0, 0, 0, 0}
+ args = &argArray
+ }
+
+ data := seccompData{
+ nr: syscallNo,
+ arch: arch,
+ args: *args,
+ }
+
+ return asInput(data)
+}
+
+// testCase holds a seccomp test case.
+type testCase struct {
+ name string
+ config specs.LinuxSeccomp
+ input bpf.Input
+ expected uint32
+}
+
+var (
+ // seccompTests is a list of speccomp test cases.
+ seccompTests = []testCase{
+ {
+ name: "default_allow",
+ config: specs.LinuxSeccomp{
+ DefaultAction: specs.ActAllow,
+ },
+ input: testInput(nativeArchAuditNo, "read", nil),
+ expected: uint32(allowAction),
+ },
+ {
+ name: "default_deny",
+ config: specs.LinuxSeccomp{
+ DefaultAction: specs.ActErrno,
+ },
+ input: testInput(nativeArchAuditNo, "read", nil),
+ expected: uint32(errnoAction),
+ },
+ {
+ name: "deny_arch",
+ config: specs.LinuxSeccomp{
+ DefaultAction: specs.ActAllow,
+ Syscalls: []specs.LinuxSyscall{
+ {
+ Names: []string{
+ "getcwd",
+ },
+ Action: specs.ActErrno,
+ },
+ },
+ },
+ // Syscall matches but the arch is AUDIT_ARCH_X86 so the return
+ // value is the bad arch action.
+ input: asInput(seccompData{nr: 183, arch: 0x40000003}), //
+ expected: uint32(killThreadAction),
+ },
+ {
+ name: "match_name_errno",
+ config: specs.LinuxSeccomp{
+ DefaultAction: specs.ActAllow,
+ Syscalls: []specs.LinuxSyscall{
+ {
+ Names: []string{
+ "getcwd",
+ "chmod",
+ },
+ Action: specs.ActErrno,
+ },
+ {
+ Names: []string{
+ "write",
+ },
+ Action: specs.ActTrace,
+ },
+ },
+ },
+ input: testInput(nativeArchAuditNo, "getcwd", nil),
+ expected: uint32(errnoAction),
+ },
+ {
+ name: "match_name_trace",
+ config: specs.LinuxSeccomp{
+ DefaultAction: specs.ActAllow,
+ Syscalls: []specs.LinuxSyscall{
+ {
+ Names: []string{
+ "getcwd",
+ "chmod",
+ },
+ Action: specs.ActErrno,
+ },
+ {
+ Names: []string{
+ "write",
+ },
+ Action: specs.ActTrace,
+ },
+ },
+ },
+ input: testInput(nativeArchAuditNo, "write", nil),
+ expected: uint32(traceAction),
+ },
+ {
+ name: "no_match_name_allow",
+ config: specs.LinuxSeccomp{
+ DefaultAction: specs.ActAllow,
+ Syscalls: []specs.LinuxSyscall{
+ {
+ Names: []string{
+ "getcwd",
+ "chmod",
+ },
+ Action: specs.ActErrno,
+ },
+ {
+ Names: []string{
+ "write",
+ },
+ Action: specs.ActTrace,
+ },
+ },
+ },
+ input: testInput(nativeArchAuditNo, "openat", nil),
+ expected: uint32(allowAction),
+ },
+ {
+ name: "simple_match_args",
+ config: specs.LinuxSeccomp{
+ DefaultAction: specs.ActAllow,
+ Syscalls: []specs.LinuxSyscall{
+ {
+ Names: []string{
+ "clone",
+ },
+ Args: []specs.LinuxSeccompArg{
+ {
+ Index: 0,
+ Value: syscall.CLONE_FS,
+ Op: specs.OpEqualTo,
+ },
+ },
+ Action: specs.ActErrno,
+ },
+ },
+ },
+ input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}),
+ expected: uint32(errnoAction),
+ },
+ {
+ name: "match_args_or",
+ config: specs.LinuxSeccomp{
+ DefaultAction: specs.ActAllow,
+ Syscalls: []specs.LinuxSyscall{
+ {
+ Names: []string{
+ "clone",
+ },
+ Args: []specs.LinuxSeccompArg{
+ {
+ Index: 0,
+ Value: syscall.CLONE_FS,
+ Op: specs.OpEqualTo,
+ },
+ {
+ Index: 0,
+ Value: syscall.CLONE_VM,
+ Op: specs.OpEqualTo,
+ },
+ },
+ Action: specs.ActErrno,
+ },
+ },
+ },
+ input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}),
+ expected: uint32(errnoAction),
+ },
+ {
+ name: "match_args_and",
+ config: specs.LinuxSeccomp{
+ DefaultAction: specs.ActAllow,
+ Syscalls: []specs.LinuxSyscall{
+ {
+ Names: []string{
+ "getsockopt",
+ },
+ Args: []specs.LinuxSeccompArg{
+ {
+ Index: 1,
+ Value: syscall.SOL_SOCKET,
+ Op: specs.OpEqualTo,
+ },
+ {
+ Index: 2,
+ Value: syscall.SO_PEERCRED,
+ Op: specs.OpEqualTo,
+ },
+ },
+ Action: specs.ActErrno,
+ },
+ },
+ },
+ input: testInput(nativeArchAuditNo, "getsockopt", &[6]uint64{0, syscall.SOL_SOCKET, syscall.SO_PEERCRED}),
+ expected: uint32(errnoAction),
+ },
+ {
+ name: "no_match_args_and",
+ config: specs.LinuxSeccomp{
+ DefaultAction: specs.ActAllow,
+ Syscalls: []specs.LinuxSyscall{
+ {
+ Names: []string{
+ "getsockopt",
+ },
+ Args: []specs.LinuxSeccompArg{
+ {
+ Index: 1,
+ Value: syscall.SOL_SOCKET,
+ Op: specs.OpEqualTo,
+ },
+ {
+ Index: 2,
+ Value: syscall.SO_PEERCRED,
+ Op: specs.OpEqualTo,
+ },
+ },
+ Action: specs.ActErrno,
+ },
+ },
+ },
+ input: testInput(nativeArchAuditNo, "getsockopt", &[6]uint64{0, syscall.SOL_SOCKET}),
+ expected: uint32(allowAction),
+ },
+ {
+ name: "Simple args (no match)",
+ config: specs.LinuxSeccomp{
+ DefaultAction: specs.ActAllow,
+ Syscalls: []specs.LinuxSyscall{
+ {
+ Names: []string{
+ "clone",
+ },
+ Args: []specs.LinuxSeccompArg{
+ {
+ Index: 0,
+ Value: syscall.CLONE_FS,
+ Op: specs.OpEqualTo,
+ },
+ },
+ Action: specs.ActErrno,
+ },
+ },
+ },
+ input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_VM}),
+ expected: uint32(allowAction),
+ },
+ {
+ name: "OpMaskedEqual (match)",
+ config: specs.LinuxSeccomp{
+ DefaultAction: specs.ActAllow,
+ Syscalls: []specs.LinuxSyscall{
+ {
+ Names: []string{
+ "clone",
+ },
+ Args: []specs.LinuxSeccompArg{
+ {
+ Index: 0,
+ Value: syscall.CLONE_FS,
+ ValueTwo: syscall.CLONE_FS,
+ Op: specs.OpMaskedEqual,
+ },
+ },
+ Action: specs.ActErrno,
+ },
+ },
+ },
+ input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS | syscall.CLONE_VM}),
+ expected: uint32(errnoAction),
+ },
+ {
+ name: "OpMaskedEqual (no match)",
+ config: specs.LinuxSeccomp{
+ DefaultAction: specs.ActAllow,
+ Syscalls: []specs.LinuxSyscall{
+ {
+ Names: []string{
+ "clone",
+ },
+ Args: []specs.LinuxSeccompArg{
+ {
+ Index: 0,
+ Value: syscall.CLONE_FS | syscall.CLONE_VM,
+ ValueTwo: syscall.CLONE_FS | syscall.CLONE_VM,
+ Op: specs.OpMaskedEqual,
+ },
+ },
+ Action: specs.ActErrno,
+ },
+ },
+ },
+ input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}),
+ expected: uint32(allowAction),
+ },
+ {
+ name: "OpMaskedEqual (clone)",
+ config: specs.LinuxSeccomp{
+ DefaultAction: specs.ActErrno,
+ Syscalls: []specs.LinuxSyscall{
+ {
+ Names: []string{
+ "clone",
+ },
+ // This comes from the Docker default seccomp
+ // profile for clone.
+ Args: []specs.LinuxSeccompArg{
+ {
+ Index: 0,
+ Value: 0x7e020000,
+ ValueTwo: 0x0,
+ Op: specs.OpMaskedEqual,
+ },
+ },
+ Action: specs.ActAllow,
+ },
+ },
+ },
+ input: testInput(nativeArchAuditNo, "clone", &[6]uint64{0x50f00}),
+ expected: uint32(allowAction),
+ },
+ }
+)
+
+// TestRunscSeccomp generates seccomp programs from OCI config and executes
+// them using runsc's library, comparing against expected results.
+func TestRunscSeccomp(t *testing.T) {
+ for _, tc := range seccompTests {
+ t.Run(tc.name, func(t *testing.T) {
+ runscProgram, err := BuildProgram(&tc.config)
+ if err != nil {
+ t.Fatalf("generating runsc BPF: %v", err)
+ }
+
+ if err := checkProgram(runscProgram, tc.input, tc.expected); err != nil {
+ t.Fatalf("running runsc BPF: %v", err)
+ }
+ })
+ }
+}
+
+// checkProgram runs the given program over the given input and checks the
+// result against the expected output.
+func checkProgram(p bpf.Program, in bpf.Input, expected uint32) error {
+ result, err := bpf.Exec(p, in)
+ if err != nil {
+ return err
+ }
+
+ if result != expected {
+ // Include a decoded version of the program in output for debugging purposes.
+ decoded, _ := bpf.DecodeProgram(p)
+ return fmt.Errorf("Unexpected result: got: %d, expected: %d\nBPF Program\n%s", result, expected, decoded)
+ }
+
+ return nil
+}
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index 5015c3a84..0392e3e83 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/runsc/config"
)
// ExePath must point to runsc binary, which is normally the same binary. It's
@@ -110,11 +111,6 @@ func ValidateSpec(spec *specs.Spec) error {
log.Warningf("noNewPrivileges ignored. PR_SET_NO_NEW_PRIVS is assumed to always be set.")
}
- // TODO(gvisor.dev/issue/510): Apply seccomp to application inside sandbox.
- if spec.Linux != nil && spec.Linux.Seccomp != nil {
- log.Warningf("Seccomp spec is being ignored")
- }
-
if spec.Linux != nil && spec.Linux.RootfsPropagation != "" {
if err := validateRootfsPropagation(spec.Linux.RootfsPropagation); err != nil {
return err
@@ -161,18 +157,18 @@ func OpenSpec(bundleDir string) (*os.File, error) {
// ReadSpec reads an OCI runtime spec from the given bundle directory.
// ReadSpec also normalizes all potential relative paths into absolute
// path, e.g. spec.Root.Path, mount.Source.
-func ReadSpec(bundleDir string) (*specs.Spec, error) {
+func ReadSpec(bundleDir string, conf *config.Config) (*specs.Spec, error) {
specFile, err := OpenSpec(bundleDir)
if err != nil {
return nil, fmt.Errorf("error opening spec file %q: %v", filepath.Join(bundleDir, "config.json"), err)
}
defer specFile.Close()
- return ReadSpecFromFile(bundleDir, specFile)
+ return ReadSpecFromFile(bundleDir, specFile, conf)
}
// ReadSpecFromFile reads an OCI runtime spec from the given File, and
// normalizes all relative paths into absolute by prepending the bundle dir.
-func ReadSpecFromFile(bundleDir string, specFile *os.File) (*specs.Spec, error) {
+func ReadSpecFromFile(bundleDir string, specFile *os.File, conf *config.Config) (*specs.Spec, error) {
if _, err := specFile.Seek(0, os.SEEK_SET); err != nil {
return nil, fmt.Errorf("error seeking to beginning of file %q: %v", specFile.Name(), err)
}
@@ -195,6 +191,20 @@ func ReadSpecFromFile(bundleDir string, specFile *os.File) (*specs.Spec, error)
m.Source = absPath(bundleDir, m.Source)
}
}
+
+ // Override flags using annotation to allow customization per sandbox
+ // instance.
+ for annotation, val := range spec.Annotations {
+ const flagPrefix = "dev.gvisor.flag."
+ if strings.HasPrefix(annotation, flagPrefix) {
+ name := annotation[len(flagPrefix):]
+ log.Infof("Overriding flag: %s=%q", name, val)
+ if err := conf.Override(name, val); err != nil {
+ return nil, err
+ }
+ }
+ }
+
return &spec, nil
}
diff --git a/scripts/common.sh b/scripts/common.sh
deleted file mode 100755
index 3ca699e4a..000000000
--- a/scripts/common.sh
+++ /dev/null
@@ -1,86 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -xeou pipefail
-
-# Get the path to the directory this script lives in.
-# If this script is being called with `source`, $0 will be the path of the
-# *sourcing* script, so we can't use `dirname $0` to find scripts in this
-# directory.
-if [[ -v BASH_SOURCE && "$0" != "$BASH_SOURCE" ]]; then
- declare -r script_dir="$(dirname "$BASH_SOURCE")"
-else
- declare -r script_dir="$(dirname "$0")"
-fi
-
-source "${script_dir}/common_build.sh"
-
-# Ensure it attempts to collect logs in all cases.
-trap collect_logs EXIT
-
-function set_runtime() {
- RUNTIME=${1:-runsc}
- RUNSC_BIN=/tmp/"${RUNTIME}"/runsc
- RUNSC_LOGS_DIR="$(dirname ${RUNSC_BIN})"/logs
- RUNSC_LOGS="${RUNSC_LOGS_DIR}"/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%
-}
-
-function test_runsc() {
- test --test_arg=--runtime=${RUNTIME} "$@"
-}
-
-function install_runsc_for_test() {
- local -r test_name=$1
- shift
- if [[ -z "${test_name}" ]]; then
- echo "Missing mandatory test name"
- exit 1
- fi
-
- # Add test to the name, so it doesn't conflict with other runtimes.
- set_runtime $(find_branch_name)_"${test_name}"
-
- # ${RUNSC_TEST_NAME} is set by tests (see dockerutil) to pass the test name
- # down to the runtime.
- install_runsc "${RUNTIME}" \
- --TESTONLY-test-name-env=RUNSC_TEST_NAME \
- --debug \
- --strace \
- --log-packets \
- "$@"
-}
-
-# Installs the runsc with given runtime name. set_runtime must have been called
-# to set runtime and logs location.
-function install_runsc() {
- local -r runtime=$1
- shift
-
- # Prepare the runtime binary.
- local -r output=$(build //runsc)
- mkdir -p "$(dirname ${RUNSC_BIN})"
- cp -f "${output}" "${RUNSC_BIN}"
- chmod 0755 "${RUNSC_BIN}"
-
- # Install the runtime.
- sudo "${RUNSC_BIN}" install --experimental=true --runtime="${runtime}" -- --debug-log "${RUNSC_LOGS}" "$@"
-
- # Clear old logs files that may exist.
- sudo rm -f "${RUNSC_LOGS_DIR}"/'*'
-
- # Restart docker to pick up the new runtime configuration.
- sudo systemctl restart docker
-}
diff --git a/scripts/common_build.sh b/scripts/common_build.sh
deleted file mode 100755
index d4a6c4908..000000000
--- a/scripts/common_build.sh
+++ /dev/null
@@ -1,116 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-which bazel
-bazel version
-
-# Switch into the workspace; only necessary if run with kokoro.
-if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then
- cd git/repo
-elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then
- cd github/repo
-fi
-
-# Set the standard bazel flags.
-declare -a BAZEL_FLAGS=(
- "--show_timestamps"
- "--test_output=errors"
- "--keep_going"
- "--verbose_failures=true"
-)
-# If running via kokoro, use the remote config.
-if [[ -v KOKORO_ARTIFACTS_DIR ]]; then
- BAZEL_FLAGS+=(
- "--config=remote"
- )
-fi
-declare -r BAZEL_FLAGS
-
-# Wrap bazel.
-function build() {
- bazel build "${BAZEL_FLAGS[@]}" "$@" 2>&1 \
- | tee /dev/fd/2 \
- | grep -E '^ bazel-bin/' \
- | awk '{ print $1; }'
-}
-
-function test() {
- bazel test "${BAZEL_FLAGS[@]}" "$@"
-}
-
-function run() {
- local binary=$1
- shift
- bazel run "${binary}" -- "$@"
-}
-
-function run_as_root() {
- local binary=$1
- shift
- bazel run --run_under="sudo" "${binary}" -- "$@"
-}
-
-function query() {
- bazel query "$@"
-}
-
-function collect_logs() {
- # Zip out everything into a convenient form.
- if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then
- # Merge results files of all shards for each test suite.
- for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do
- junitparser merge `find $d -name test.xml` $d/test.xml
- cat $d/shard_*_of_*/test.log > $d/test.log
- if ls -ld $d/shard_*_of_*/test.outputs 2>/dev/null; then
- zip -r -1 "$d/outputs.zip" $d/shard_*_of_*/test.outputs
- fi
- done
- find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf
- # Move test logs to Kokoro directory. tar is used to conveniently perform
- # renames while moving files.
- find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" |
- tar --create --files-from - --transform 's/test\./sponge_log./' |
- tar --extract --directory ${KOKORO_ARTIFACTS_DIR}
-
- # Collect sentry logs, if any.
- if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then
- # Check if the directory is empty or not (only the first line it needed).
- local -r logs=$(ls "${RUNSC_LOGS_DIR}" | head -n1)
- if [[ "${logs}" ]]; then
- local -r archive=runsc_logs_"${RUNTIME}".tar.gz
- if [[ -v KOKORO_BUILD_ARTIFACTS_SUBDIR ]]; then
- echo "runsc logs will be uploaded to:"
- echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp"
- echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}"
- fi
- time tar \
- --verbose \
- --create \
- --gzip \
- --file="${KOKORO_ARTIFACTS_DIR}/${archive}" \
- --directory "${RUNSC_LOGS_DIR}" \
- .
- fi
- fi
- fi
-}
-
-function find_branch_name() {
- git branch --show-current \
- || git rev-parse HEAD \
- || bazel info workspace \
- | xargs basename
-}
diff --git a/scripts/dev.sh b/scripts/dev.sh
deleted file mode 100755
index a9107f33e..000000000
--- a/scripts/dev.sh
+++ /dev/null
@@ -1,75 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-# common.sh sets '-x', but it's annoying to see so much output.
-set +x
-
-# Defaults
-declare -i REFRESH=0
-declare NAME=$(find_branch_name)
-
-while [[ $# -gt 0 ]]; do
- case "$1" in
- --refresh)
- REFRESH=1
- ;;
- --help)
- echo "Use this script to build and install runsc with Docker."
- echo
- echo "usage: $0 [--refresh] [runtime_name]"
- exit 1
- ;;
- *)
- NAME=$1
- ;;
- esac
- shift
-done
-
-set_runtime "${NAME}"
-echo
-echo "Using runtime=${RUNTIME}"
-echo
-
-echo Building runsc...
-# Build first and fail on error. $() prevents "set -e" from reporting errors.
-build //runsc
-declare OUTPUT="$(build //runsc)"
-
-if [[ ${REFRESH} -eq 0 ]]; then
- install_runsc "${RUNTIME}" --net-raw
- install_runsc "${RUNTIME}-d" --net-raw --debug --strace --log-packets
- install_runsc "${RUNTIME}-p" --net-raw --profile
-
- echo
- echo "Runtimes ${RUNTIME}, ${RUNTIME}-d (debug enabled), and ${RUNTIME}-p installed."
- echo "Use --runtime="${RUNTIME}" with your Docker command."
- echo " docker run --rm --runtime="${RUNTIME}" hello-world"
- echo
- echo "If you rebuild, use $0 --refresh."
-
-else
- mkdir -p "$(dirname ${RUNSC_BIN})"
- cp -f ${OUTPUT} "${RUNSC_BIN}"
- chmod a+rx "${RUNSC_BIN}"
-
- echo
- echo "Runtime ${RUNTIME} refreshed."
-fi
-
-echo "Logs are in: ${RUNSC_LOGS_DIR}"
diff --git a/scripts/do_tests.sh b/scripts/do_tests.sh
deleted file mode 100755
index a3a387c37..000000000
--- a/scripts/do_tests.sh
+++ /dev/null
@@ -1,27 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-# Build runsc.
-build //runsc
-
-# run runsc do without root privileges.
-run //runsc --rootless do true
-run //runsc --rootless --network=none do true
-
-# run runsc do with root privileges.
-run_as_root //runsc do true
diff --git a/scripts/docker_tests.sh b/scripts/docker_tests.sh
deleted file mode 100755
index 4f3867d05..000000000
--- a/scripts/docker_tests.sh
+++ /dev/null
@@ -1,25 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-make load-all-images
-
-install_runsc_for_test docker
-test_runsc //test/image:image_test //test/e2e:integration_test
-
-install_runsc_for_test docker --vfs2
-test_runsc //test/e2e:integration_test //test/image:image_test
diff --git a/scripts/go.sh b/scripts/go.sh
deleted file mode 100755
index 626ed8fa4..000000000
--- a/scripts/go.sh
+++ /dev/null
@@ -1,45 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-# Build the go path.
-build :gopath
-
-# Build the synthetic branch.
-tools/go_branch.sh
-
-# Checkout the new branch.
-git checkout go && git clean -f
-
-go version
-
-# Build everything.
-go build ./...
-
-# Push, if required.
-if [[ -v KOKORO_GO_PUSH ]] && [[ "${KOKORO_GO_PUSH}" == "true" ]]; then
- if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then
- git config --global credential.helper cache
- git credential approve <<EOF
-protocol=https
-host=github.com
-username=$(cat "${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}")
-password=x-oauth-basic
-EOF
- fi
- git push origin go:go
-fi
diff --git a/scripts/hostnet_tests.sh b/scripts/hostnet_tests.sh
deleted file mode 100755
index 992db50dd..000000000
--- a/scripts/hostnet_tests.sh
+++ /dev/null
@@ -1,23 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-make load-all-images
-
-# Install the runtime and perform basic tests.
-install_runsc_for_test hostnet --network=host
-test_runsc --test_arg=-checkpoint=false //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/iptables_tests.sh b/scripts/iptables_tests.sh
deleted file mode 100755
index 8299a7c8b..000000000
--- a/scripts/iptables_tests.sh
+++ /dev/null
@@ -1,26 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-make load-iptables
-
-# Needed by ip6tables.
-sudo modprobe ip6table_filter
-
-install_runsc_for_test iptables --net-raw
-test //test/iptables:iptables_test "--test_arg=--runtime=runc"
-test //test/iptables:iptables_test "--test_arg=--runtime=${RUNTIME}"
diff --git a/scripts/kvm_tests.sh b/scripts/kvm_tests.sh
deleted file mode 100755
index 619571c74..000000000
--- a/scripts/kvm_tests.sh
+++ /dev/null
@@ -1,30 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-make load-all-images
-
-# Ensure that KVM is loaded, and we can use it.
-(lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm
-sudo chmod a+rw /dev/kvm
-
-# Run all KVM platform tests (locally).
-run_as_root //pkg/sentry/platform/kvm:kvm_test
-
-# Install the KVM runtime and run all integration tests.
-install_runsc_for_test kvm --platform=kvm
-test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/make_tests.sh b/scripts/make_tests.sh
deleted file mode 100755
index dbf1bba77..000000000
--- a/scripts/make_tests.sh
+++ /dev/null
@@ -1,20 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-make runsc
-make bazel-shutdown
diff --git a/scripts/overlay_tests.sh b/scripts/overlay_tests.sh
deleted file mode 100755
index 448864953..000000000
--- a/scripts/overlay_tests.sh
+++ /dev/null
@@ -1,23 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-make load-all-images
-
-# Install the runtime and perform basic tests.
-install_runsc_for_test overlay --overlay
-test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/packetdrill_tests.sh b/scripts/packetdrill_tests.sh
deleted file mode 100755
index 1a8181ac8..000000000
--- a/scripts/packetdrill_tests.sh
+++ /dev/null
@@ -1,23 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-make load-packetdrill
-
-install_runsc_for_test runsc-d
-QUERY_RESULT=$(query "attr(tags, manual, tests(//test/packetdrill/...))")
-test_runsc $QUERY_RESULT
diff --git a/scripts/packetimpact_tests.sh b/scripts/packetimpact_tests.sh
deleted file mode 100755
index 77fb84bc3..000000000
--- a/scripts/packetimpact_tests.sh
+++ /dev/null
@@ -1,23 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-make load-packetimpact
-
-install_runsc_for_test runsc-d
-QUERY_RESULT=$(query "attr(tags, packetimpact, tests(//test/packetimpact/...))")
-test_runsc $QUERY_RESULT
diff --git a/scripts/root_tests.sh b/scripts/root_tests.sh
deleted file mode 100755
index 3eb735e62..000000000
--- a/scripts/root_tests.sh
+++ /dev/null
@@ -1,25 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-make load-all-images
-CONTAINERD_VERSION=1.3.4 make sudo TARGETS="tools/installers:containerd"
-make sudo TARGETS="tools/installers:shim"
-
-# Run the tests that require root.
-install_runsc_for_test root
-run_as_root //test/root:root_test --runtime=${RUNTIME}
diff --git a/scripts/runtime_tests.sh b/scripts/runtime_tests.sh
deleted file mode 100755
index 85e95d45d..000000000
--- a/scripts/runtime_tests.sh
+++ /dev/null
@@ -1,29 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-# Check that a runtime is provided.
-if [ ! -v RUNTIME_TEST_NAME ]; then
- echo "Must set $RUNTIME_TEST_NAME" >&2
- exit 1
-fi
-
-# Download language runtime image.
-make -C images/ "load-runtimes_${RUNTIME_TEST_NAME}"
-
-install_runsc_for_test runtimes
-test_runsc "//test/runtimes:${RUNTIME_TEST_NAME}"
diff --git a/scripts/simple_tests.sh b/scripts/simple_tests.sh
deleted file mode 100755
index 585216aae..000000000
--- a/scripts/simple_tests.sh
+++ /dev/null
@@ -1,20 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-# Run all simple tests (locally).
-test //pkg/... //runsc/... //tools/...
diff --git a/scripts/swgso_tests.sh b/scripts/swgso_tests.sh
deleted file mode 100755
index c67f2fe5c..000000000
--- a/scripts/swgso_tests.sh
+++ /dev/null
@@ -1,23 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-make load-all-images
-
-# Install the runtime and perform basic tests.
-install_runsc_for_test swgso --software-gso=true --gso=false
-test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/syscall_kvm_tests.sh b/scripts/syscall_kvm_tests.sh
deleted file mode 100755
index 0e5d86727..000000000
--- a/scripts/syscall_kvm_tests.sh
+++ /dev/null
@@ -1,20 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-# Run all ptrace-variants of the system call tests.
-test --test_tag_filters=runsc_kvm //test/syscalls/...
diff --git a/scripts/syscall_tests.sh b/scripts/syscall_tests.sh
deleted file mode 100755
index a131b2d50..000000000
--- a/scripts/syscall_tests.sh
+++ /dev/null
@@ -1,20 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-# Run all ptrace-variants of the system call tests.
-test --test_tag_filters=runsc_ptrace //test/syscalls/...
diff --git a/test/README.md b/test/README.md
index 02bbf42ff..15b0f4c33 100644
--- a/test/README.md
+++ b/test/README.md
@@ -24,11 +24,11 @@ also used to run these tests in `kokoro`.
To run image and integration tests, run:
-`./scripts/docker_tests.sh`
+`make docker-tests`
To run root tests, run:
-`./scripts/root_tests.sh`
+`make root-tests`
There are a few other interesting variations for image and integration tests:
diff --git a/test/benchmarks/base/size_test.go b/test/benchmarks/base/size_test.go
index 3c1364faf..7d3877459 100644
--- a/test/benchmarks/base/size_test.go
+++ b/test/benchmarks/base/size_test.go
@@ -105,6 +105,7 @@ func BenchmarkSizeNginx(b *testing.B) {
machine: machine,
port: port,
runOpts: runOpts,
+ cmd: []string{"nginx", "-c", "/etc/nginx/nginx_gofer.conf"},
})
defer cleanUpContainers(ctx, servers)
diff --git a/test/benchmarks/base/startup_test.go b/test/benchmarks/base/startup_test.go
index 4628a0a41..c36a544db 100644
--- a/test/benchmarks/base/startup_test.go
+++ b/test/benchmarks/base/startup_test.go
@@ -64,6 +64,7 @@ func BenchmarkStartupNginx(b *testing.B) {
machine: machine,
runOpts: runOpts,
port: 80,
+ cmd: []string{"nginx", "-c", "/etc/nginx/nginx_gofer.conf"},
})
}
@@ -123,8 +124,6 @@ func redisInstance(ctx context.Context, b *testing.B, machine harness.Machine) (
// runServerWorkload runs a server workload defined by 'runOpts' and 'cmd'.
// 'clientMachine' is used to connect to the server on 'serverMachine'.
func runServerWorkload(ctx context.Context, b *testing.B, args serverArgs) {
- b.Helper()
-
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := func() error {
diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD
index bd3f6245c..472b5c387 100644
--- a/test/benchmarks/network/BUILD
+++ b/test/benchmarks/network/BUILD
@@ -5,8 +5,15 @@ package(licenses = ["notice"])
go_library(
name = "network",
testonly = 1,
- srcs = ["network.go"],
- deps = ["//test/benchmarks/harness"],
+ srcs = [
+ "network.go",
+ "static_server.go",
+ ],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/benchmarks/harness",
+ "//test/benchmarks/tools",
+ ],
)
go_test(
diff --git a/test/benchmarks/network/httpd_test.go b/test/benchmarks/network/httpd_test.go
index 336e04c91..369ab326e 100644
--- a/test/benchmarks/network/httpd_test.go
+++ b/test/benchmarks/network/httpd_test.go
@@ -14,22 +14,19 @@
package network
import (
- "context"
"fmt"
"testing"
"gvisor.dev/gvisor/pkg/test/dockerutil"
- "gvisor.dev/gvisor/test/benchmarks/harness"
"gvisor.dev/gvisor/test/benchmarks/tools"
)
// see Dockerfile '//images/benchmarks/httpd'.
-var docs = map[string]string{
+var httpdDocs = map[string]string{
"notfound": "notfound",
"1Kb": "latin1k.txt",
"10Kb": "latin10k.txt",
"100Kb": "latin100k.txt",
- "1000Kb": "latin1000k.txt",
"1Mb": "latin1024k.txt",
"10Mb": "latin10240k.txt",
}
@@ -37,30 +34,17 @@ var docs = map[string]string{
// BenchmarkHttpdConcurrency iterates the concurrency argument and tests
// how well the runtime under test handles requests in parallel.
func BenchmarkHttpdConcurrency(b *testing.B) {
- // Grab a machine for the client and server.
- clientMachine, err := h.GetMachine()
- if err != nil {
- b.Fatalf("failed to get client: %v", err)
- }
- defer clientMachine.CleanUp()
-
- serverMachine, err := h.GetMachine()
- if err != nil {
- b.Fatalf("failed to get server: %v", err)
- }
- defer serverMachine.CleanUp()
-
// The test iterates over client concurrency, so set other parameters.
concurrency := []int{1, 25, 50, 100, 1000}
for _, c := range concurrency {
b.Run(fmt.Sprintf("%d", c), func(b *testing.B) {
hey := &tools.Hey{
- Requests: 10000,
+ Requests: c * b.N,
Concurrency: c,
- Doc: docs["10Kb"],
+ Doc: httpdDocs["10Kb"],
}
- runHttpd(b, clientMachine, serverMachine, hey, false /* reverse */)
+ runHttpd(b, hey, false /* reverse */)
})
}
}
@@ -77,57 +61,30 @@ func BenchmarkReverseHttpdDocSize(b *testing.B) {
benchmarkHttpdDocSize(b, true /* reverse */)
}
+// benchmarkHttpdDocSize iterates through all doc sizes, running subbenchmarks
+// for each size.
func benchmarkHttpdDocSize(b *testing.B, reverse bool) {
b.Helper()
-
- clientMachine, err := h.GetMachine()
- if err != nil {
- b.Fatalf("failed to get machine: %v", err)
- }
- defer clientMachine.CleanUp()
-
- serverMachine, err := h.GetMachine()
- if err != nil {
- b.Fatalf("failed to get machine: %v", err)
- }
- defer serverMachine.CleanUp()
-
- for name, filename := range docs {
+ for name, filename := range httpdDocs {
concurrency := []int{1, 25, 50, 100, 1000}
for _, c := range concurrency {
b.Run(fmt.Sprintf("%s_%d", name, c), func(b *testing.B) {
hey := &tools.Hey{
- Requests: 10000,
+ Requests: c * b.N,
Concurrency: c,
Doc: filename,
}
- runHttpd(b, clientMachine, serverMachine, hey, reverse)
+ runHttpd(b, hey, reverse)
})
}
}
}
-// runHttpd runs a single test run.
-func runHttpd(b *testing.B, clientMachine, serverMachine harness.Machine, hey *tools.Hey, reverse bool) {
- b.Helper()
-
- // Grab a container from the server.
- ctx := context.Background()
- var server *dockerutil.Container
- if reverse {
- server = serverMachine.GetNativeContainer(ctx, b)
- } else {
- server = serverMachine.GetContainer(ctx, b)
- }
-
- defer server.CleanUp(ctx)
-
- // Copy the docs to /tmp and serve from there.
- cmd := "mkdir -p /tmp/html; cp -r /local/* /tmp/html/.; apache2 -X"
+// runHttpd configures the static serving methods to run httpd.
+func runHttpd(b *testing.B, hey *tools.Hey, reverse bool) {
+ // httpd runs on port 80.
port := 80
-
- // Start the server.
- if err := server.Spawn(ctx, dockerutil.RunOpts{
+ httpdRunOpts := dockerutil.RunOpts{
Image: "benchmarks/httpd",
Ports: []int{port},
Env: []string{
@@ -138,44 +95,7 @@ func runHttpd(b *testing.B, clientMachine, serverMachine harness.Machine, hey *t
"APACHE_LOG_DIR=/tmp",
"APACHE_PID_FILE=/tmp/apache.pid",
},
- }, "sh", "-c", cmd); err != nil {
- b.Fatalf("failed to start server: %v", err)
- }
-
- ip, err := serverMachine.IPAddress()
- if err != nil {
- b.Fatalf("failed to find server ip: %v", err)
- }
-
- servingPort, err := server.FindPort(ctx, port)
- if err != nil {
- b.Fatalf("failed to find server port %d: %v", port, err)
- }
-
- // Check the server is serving.
- harness.WaitUntilServing(ctx, clientMachine, ip, servingPort)
-
- var client *dockerutil.Container
- // Grab a client.
- if reverse {
- client = clientMachine.GetContainer(ctx, b)
- } else {
- client = clientMachine.GetNativeContainer(ctx, b)
- }
- defer client.CleanUp(ctx)
-
- b.ResetTimer()
- server.RestartProfiles()
- for i := 0; i < b.N; i++ {
- out, err := client.Run(ctx, dockerutil.RunOpts{
- Image: "benchmarks/hey",
- }, hey.MakeCmd(ip, servingPort)...)
- if err != nil {
- b.Fatalf("run failed with: %v", err)
- }
-
- b.StopTimer()
- hey.Report(b, out)
- b.StartTimer()
}
+ httpdCmd := []string{"sh", "-c", "mkdir -p /tmp/html; cp -r /local/* /tmp/html/.; apache2 -X"}
+ runStaticServer(b, httpdRunOpts, httpdCmd, port, hey, reverse)
}
diff --git a/test/benchmarks/network/nginx_test.go b/test/benchmarks/network/nginx_test.go
index 2bf1a3624..9ec70369b 100644
--- a/test/benchmarks/network/nginx_test.go
+++ b/test/benchmarks/network/nginx_test.go
@@ -14,91 +14,97 @@
package network
import (
- "context"
"fmt"
"testing"
"gvisor.dev/gvisor/pkg/test/dockerutil"
- "gvisor.dev/gvisor/test/benchmarks/harness"
"gvisor.dev/gvisor/test/benchmarks/tools"
)
+// see Dockerfile '//images/benchmarks/nginx'.
+var nginxDocs = map[string]string{
+ "notfound": "notfound",
+ "1Kb": "latin1k.txt",
+ "10Kb": "latin10k.txt",
+ "100Kb": "latin100k.txt",
+ "1Mb": "latin1024k.txt",
+ "10Mb": "latin10240k.txt",
+}
+
// BenchmarkNginxConcurrency iterates the concurrency argument and tests
// how well the runtime under test handles requests in parallel.
-// TODO(gvisor.dev/issue/3536): Update with different doc sizes like Httpd.
func BenchmarkNginxConcurrency(b *testing.B) {
- // Grab a machine for the client and server.
- clientMachine, err := h.GetMachine()
- if err != nil {
- b.Fatalf("failed to get client: %v", err)
- }
- defer clientMachine.CleanUp()
-
- serverMachine, err := h.GetMachine()
- if err != nil {
- b.Fatalf("failed to get server: %v", err)
- }
- defer serverMachine.CleanUp()
-
- concurrency := []int{1, 5, 10, 25}
+ concurrency := []int{1, 25, 100, 1000}
for _, c := range concurrency {
- b.Run(fmt.Sprintf("%d", c), func(b *testing.B) {
- hey := &tools.Hey{
- Requests: 10000,
- Concurrency: c,
+ for _, tmpfs := range []bool{true, false} {
+ fs := "Gofer"
+ if tmpfs {
+ fs = "Tmpfs"
}
- runNginx(b, clientMachine, serverMachine, hey)
- })
+ name := fmt.Sprintf("%d_%s", c, fs)
+ b.Run(name, func(b *testing.B) {
+ hey := &tools.Hey{
+ Requests: c * b.N,
+ Concurrency: c,
+ Doc: nginxDocs["10kb"], // see Dockerfile '//images/benchmarks/nginx' and httpd_test.
+ }
+ runNginx(b, hey, false /* reverse */, tmpfs /* tmpfs */)
+ })
+ }
+
}
}
-// runHttpd runs a single test run.
-func runNginx(b *testing.B, clientMachine, serverMachine harness.Machine, hey *tools.Hey) {
- b.Helper()
+// BenchmarkNginxDocSize iterates over different sized payloads, testing how
+// well the runtime handles sending different payload sizes.
+func BenchmarkNginxDocSize(b *testing.B) {
+ benchmarkNginxDocSize(b, false /* reverse */, true /* tmpfs */)
+ benchmarkNginxDocSize(b, false /* reverse */, false /* tmpfs */)
+}
- // Grab a container from the server.
- ctx := context.Background()
- server := serverMachine.GetContainer(ctx, b)
- defer server.CleanUp(ctx)
+// BenchmarkReverseNginxDocSize iterates over different sized payloads, testing
+// how well the runtime handles receiving different payload sizes.
+func BenchmarkReverseNginxDocSize(b *testing.B) {
+ benchmarkNginxDocSize(b, true /* reverse */, true /* tmpfs */)
+}
- port := 80
- // Start the server.
- if err := server.Spawn(ctx,
- dockerutil.RunOpts{
- Image: "benchmarks/nginx",
- Ports: []int{port},
- }); err != nil {
- b.Fatalf("server failed to start: %v", err)
+// benchmarkNginxDocSize iterates through all doc sizes, running subbenchmarks
+// for each size.
+func benchmarkNginxDocSize(b *testing.B, reverse, tmpfs bool) {
+ for name, filename := range nginxDocs {
+ concurrency := []int{1, 25, 50, 100, 1000}
+ for _, c := range concurrency {
+ fs := "Gofer"
+ if tmpfs {
+ fs = "Tmpfs"
+ }
+ benchName := fmt.Sprintf("%s_%d_%s", name, c, fs)
+ b.Run(benchName, func(b *testing.B) {
+ hey := &tools.Hey{
+ Requests: c * b.N,
+ Concurrency: c,
+ Doc: filename,
+ }
+ runNginx(b, hey, reverse, tmpfs)
+ })
+ }
}
+}
- ip, err := serverMachine.IPAddress()
- if err != nil {
- b.Fatalf("failed to find server ip: %v", err)
+// runNginx configures the static serving methods to run httpd.
+func runNginx(b *testing.B, hey *tools.Hey, reverse, tmpfs bool) {
+ // nginx runs on port 80.
+ port := 80
+ nginxRunOpts := dockerutil.RunOpts{
+ Image: "benchmarks/nginx",
+ Ports: []int{port},
}
- servingPort, err := server.FindPort(ctx, port)
- if err != nil {
- b.Fatalf("failed to find server port %d: %v", port, err)
+ nginxCmd := []string{"nginx", "-c", "/etc/nginx/nginx_gofer.conf"}
+ if tmpfs {
+ nginxCmd = []string{"sh", "-c", "mkdir -p /tmp/html && cp -a /local/* /tmp/html && nginx -c /etc/nginx/nginx.conf"}
}
- // Check the server is serving.
- harness.WaitUntilServing(ctx, clientMachine, ip, servingPort)
-
- // Grab a client.
- client := clientMachine.GetNativeContainer(ctx, b)
- defer client.CleanUp(ctx)
-
- b.ResetTimer()
- server.RestartProfiles()
- for i := 0; i < b.N; i++ {
- out, err := client.Run(ctx, dockerutil.RunOpts{
- Image: "benchmarks/hey",
- }, hey.MakeCmd(ip, servingPort)...)
- if err != nil {
- b.Fatalf("run failed with: %v", err)
- }
- b.StopTimer()
- hey.Report(b, out)
- b.StartTimer()
- }
+ // Command copies nginxDocs to tmpfs serving directory and runs nginx.
+ runStaticServer(b, nginxRunOpts, nginxCmd, port, hey, reverse)
}
diff --git a/test/benchmarks/network/static_server.go b/test/benchmarks/network/static_server.go
new file mode 100644
index 000000000..e747a1395
--- /dev/null
+++ b/test/benchmarks/network/static_server.go
@@ -0,0 +1,87 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package network
+
+import (
+ "context"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+ "gvisor.dev/gvisor/test/benchmarks/tools"
+)
+
+// runStaticServer runs static serving workloads (httpd, nginx).
+func runStaticServer(b *testing.B, serverOpts dockerutil.RunOpts, serverCmd []string, port int, hey *tools.Hey, reverse bool) {
+ ctx := context.Background()
+
+ // Get two machines: a client and server.
+ clientMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer clientMachine.CleanUp()
+
+ serverMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer serverMachine.CleanUp()
+
+ // Make the containers. 'reverse=true' specifies that the client should use the
+ // runtime under test.
+ var client, server *dockerutil.Container
+ if reverse {
+ client = clientMachine.GetContainer(ctx, b)
+ server = serverMachine.GetNativeContainer(ctx, b)
+ } else {
+ client = clientMachine.GetNativeContainer(ctx, b)
+ server = serverMachine.GetContainer(ctx, b)
+ }
+ defer client.CleanUp(ctx)
+ defer server.CleanUp(ctx)
+
+ // Start the server.
+ if err := server.Spawn(ctx, serverOpts, serverCmd...); err != nil {
+ b.Fatalf("failed to start server: %v", err)
+ }
+
+ // Get its IP.
+ ip, err := serverMachine.IPAddress()
+ if err != nil {
+ b.Fatalf("failed to find server ip: %v", err)
+ }
+
+ // Get the published port.
+ servingPort, err := server.FindPort(ctx, port)
+ if err != nil {
+ b.Fatalf("failed to find server port %d: %v", port, err)
+ }
+
+ // Make sure the server is serving.
+ harness.WaitUntilServing(ctx, clientMachine, ip, servingPort)
+ b.ResetTimer()
+ server.RestartProfiles()
+ out, err := client.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/hey",
+ }, hey.MakeCmd(ip, servingPort)...)
+ if err != nil {
+ b.Fatalf("run failed with: %v", err)
+ }
+
+ b.StopTimer()
+ hey.Report(b, out)
+ b.StartTimer()
+}
diff --git a/test/benchmarks/tcp/tcp_proxy.go b/test/benchmarks/tcp/tcp_proxy.go
index 4b7ca7a14..5afe10f69 100644
--- a/test/benchmarks/tcp/tcp_proxy.go
+++ b/test/benchmarks/tcp/tcp_proxy.go
@@ -174,8 +174,8 @@ func newNetstackImpl(mode string) (impl, error) {
}
// Create a new network stack.
- netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}
- transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}
+ netProtos := []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}
+ transProtos := []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}
s := stack.New(stack.Options{
NetworkProtocols: netProtos,
TransportProtocols: transProtos,
@@ -228,19 +228,26 @@ func newNetstackImpl(mode string) (impl, error) {
})
// Set protocol options.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(*sack)); err != nil {
- return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %s", err)
+ {
+ opt := tcpip.TCPSACKEnabled(*sack)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
}
// Enable Receive Buffer Auto-Tuning.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(*moderateRecvBuf)); err != nil {
- return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPModerateReceiveBufferOption(*moderateRecvBuf)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
}
// Set Congestion Control to cubic if requested.
if *cubic {
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.CongestionControlOption("cubic")); err != nil {
- return nil, fmt.Errorf("SetTransportProtocolOption for CongestionControlOption(cubic) failed: %s", err)
+ opt := tcpip.CongestionControlOption("cubic")
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go
index 809244bab..8425abecb 100644
--- a/test/e2e/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -64,9 +64,10 @@ func TestLifeCycle(t *testing.T) {
defer d.CleanUp(ctx)
// Start the container.
+ port := 80
if err := d.Create(ctx, dockerutil.RunOpts{
Image: "basic/nginx",
- Ports: []int{80},
+ Ports: []int{port},
}); err != nil {
t.Fatalf("docker create failed: %v", err)
}
@@ -74,16 +75,15 @@ func TestLifeCycle(t *testing.T) {
t.Fatalf("docker start failed: %v", err)
}
- // Test that container is working.
- port, err := d.FindPort(ctx, 80)
+ ip, err := d.FindIP(ctx, false)
if err != nil {
- t.Fatalf("docker.FindPort(80) failed: %v", err)
+ t.Fatalf("docker.FindIP failed: %v", err)
}
- if err := testutil.WaitForHTTP(port, defaultWait); err != nil {
+ if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil {
t.Fatalf("WaitForHTTP() timeout: %v", err)
}
client := http.Client{Timeout: defaultWait}
- if err := httpRequestSucceeds(client, "localhost", port); err != nil {
+ if err := httpRequestSucceeds(client, ip.String(), port); err != nil {
t.Errorf("http request failed: %v", err)
}
@@ -105,27 +105,28 @@ func TestPauseResume(t *testing.T) {
defer d.CleanUp(ctx)
// Start the container.
+ port := 8080
if err := d.Spawn(ctx, dockerutil.RunOpts{
Image: "basic/python",
- Ports: []int{8080}, // See Dockerfile.
+ Ports: []int{port}, // See Dockerfile.
}); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- // Find where port 8080 is mapped to.
- port, err := d.FindPort(ctx, 8080)
+ // Find container IP address.
+ ip, err := d.FindIP(ctx, false)
if err != nil {
- t.Fatalf("docker.FindPort(8080) failed: %v", err)
+ t.Fatalf("docker.FindIP failed: %v", err)
}
// Wait until it's up and running.
- if err := testutil.WaitForHTTP(port, defaultWait); err != nil {
+ if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil {
t.Fatalf("WaitForHTTP() timeout: %v", err)
}
// Check that container is working.
client := http.Client{Timeout: defaultWait}
- if err := httpRequestSucceeds(client, "localhost", port); err != nil {
+ if err := httpRequestSucceeds(client, ip.String(), port); err != nil {
t.Error("http request failed:", err)
}
@@ -135,7 +136,7 @@ func TestPauseResume(t *testing.T) {
// Check if container is paused.
client = http.Client{Timeout: 10 * time.Millisecond} // Don't wait a minute.
- switch _, err := client.Get(fmt.Sprintf("http://localhost:%d", port)); v := err.(type) {
+ switch _, err := client.Get(fmt.Sprintf("http://%s:%d", ip.String(), port)); v := err.(type) {
case nil:
t.Errorf("http req expected to fail but it succeeded")
case net.Error:
@@ -151,13 +152,13 @@ func TestPauseResume(t *testing.T) {
}
// Wait until it's up and running.
- if err := testutil.WaitForHTTP(port, defaultWait); err != nil {
+ if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil {
t.Fatalf("WaitForHTTP() timeout: %v", err)
}
// Check if container is working again.
client = http.Client{Timeout: defaultWait}
- if err := httpRequestSucceeds(client, "localhost", port); err != nil {
+ if err := httpRequestSucceeds(client, ip.String(), port); err != nil {
t.Error("http request failed:", err)
}
}
@@ -179,9 +180,10 @@ func TestCheckpointRestore(t *testing.T) {
defer d.CleanUp(ctx)
// Start the container.
+ port := 8080
if err := d.Spawn(ctx, dockerutil.RunOpts{
Image: "basic/python",
- Ports: []int{8080}, // See Dockerfile.
+ Ports: []int{port}, // See Dockerfile.
}); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -199,20 +201,20 @@ func TestCheckpointRestore(t *testing.T) {
t.Fatalf("docker restore failed: %v", err)
}
- // Find where port 8080 is mapped to.
- port, err := d.FindPort(ctx, 8080)
+ // Find container IP address.
+ ip, err := d.FindIP(ctx, false)
if err != nil {
- t.Fatalf("docker.FindPort(8080) failed: %v", err)
+ t.Fatalf("docker.FindIP failed: %v", err)
}
// Wait until it's up and running.
- if err := testutil.WaitForHTTP(port, defaultWait); err != nil {
+ if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil {
t.Fatalf("WaitForHTTP() timeout: %v", err)
}
// Check if container is working again.
client := http.Client{Timeout: defaultWait}
- if err := httpRequestSucceeds(client, "localhost", port); err != nil {
+ if err := httpRequestSucceeds(client, ip.String(), port); err != nil {
t.Error("http request failed:", err)
}
}
diff --git a/test/fuse/BUILD b/test/fuse/BUILD
index 56157c96b..8e31fdd41 100644
--- a/test/fuse/BUILD
+++ b/test/fuse/BUILD
@@ -5,5 +5,69 @@ package(licenses = ["notice"])
syscall_test(
fuse = "True",
test = "//test/fuse/linux:stat_test",
- vfs2 = "True",
+)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:open_test",
+)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:release_test",
+)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:mknod_test",
+)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:symlink_test",
+)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:readlink_test",
+)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:mkdir_test",
+)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:read_test",
+)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:write_test",
+)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:rmdir_test",
+)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:readdir_test",
+)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:create_test",
+)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:unlink_test",
+)
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:setstat_test",
)
diff --git a/test/fuse/README.md b/test/fuse/README.md
index 734c3a4e3..65add57e2 100644
--- a/test/fuse/README.md
+++ b/test/fuse/README.md
@@ -1,65 +1,114 @@
# gVisor FUSE Test Suite
-This is an integration test suite for fuse(4) filesystem. It runs under both
-gVisor and Linux, and ensures compatibility between the two. This test suite is
-based on system calls test.
+This is an integration test suite for fuse(4) filesystem. It runs under gVisor
+sandbox container with VFS2 and FUSE function enabled.
-This document describes the framework of fuse integration test and the
-guidelines that should be followed when adding new fuse tests.
+This document describes the framework of FUSE integration test, how to use it,
+and the guidelines that should be followed when adding new testing features.
## Integration Test Framework
-Please refer to the figure below. `>` is entering the function, `<` is leaving
-the function, and `=` indicates sequentially entering and leaving.
+By inheriting the `FuseTest` class defined in `linux/fuse_base.h`, every test
+fixture can run in an environment with `mount_point_` mounted by a fake FUSE
+server. It creates a `socketpair(2)` to send and receive control commands and
+data between the client and the server. Because the FUSE server runs in the
+background thread, gTest cannot catch its assertion failure immediately. Thus,
+`TearDown()` function sends command to the FUSE server to check if all gTest
+assertion in the server are successful and all requests and preset responses are
+consumed.
+
+## Communication Diagram
+
+Diagram below describes how a testing thread communicates with the FUSE server
+to achieve integration test.
+
+For the following diagram, `>` means entering the function, `<` is leaving the
+function, and `=` indicates sequentially entering and leaving. Not necessarily
+follow exactly the below diagram due to the nature of a multi-threaded system,
+however, it is still helpful to know when the client waits for the server to
+complete a command and when the server awaits the next instruction.
```
- | Client (Test Main Process) | Server (FUSE Daemon)
- | |
- | >TEST_F() |
- | >SetUp() |
- | =MountFuse() |
- | >SetUpFuseServer() |
- | [create communication pipes] |
- | =fork() | =fork()
- | >WaitCompleted() |
- | [wait for MarkDone()] |
- | | =ConsumeFuseInit()
- | | =MarkDone()
- | <WaitCompleted() |
- | <SetUpFuseServer() |
- | <SetUp() |
- | >SetExpected() |
- | [construct expected reaction] |
- | | >FuseLoop()
- | | >ReceiveExpected()
- | | [wait data from pipe]
- | [write data to pipe] |
- | [wait for MarkDone()] |
- | | [save data to memory]
- | | =MarkDone()
- | <SetExpected() |
- | | <ReceiveExpected()
- | | >read()
- | | [wait for fs operation]
- | >[Do fs operation] |
- | [wait for fs response] |
- | | <read()
- | | =CompareRequest()
- | | =write() [write fs response]
- | <[Do fs operation] |
- | =[Test fs operation result] |
- | =[wait for MarkDone()] |
- | | =MarkDone()
- | >TearDown() |
- | =UnmountFuse() |
- | <TearDown() |
- | <TEST_F() |
+| Client (Testing Thread) | Server (FUSE Server Thread)
+| |
+| >TEST_F() |
+| >SetUp() |
+| =MountFuse() |
+| >SetUpFuseServer() |
+| [create communication socket]|
+| =fork() | =fork()
+| [wait server complete] |
+| | =ServerConsumeFuseInit()
+| | =ServerCompleteWith()
+| <SetUpFuseServer() |
+| <SetUp() |
+| [testing main] |
+| | >ServerFuseLoop()
+| | [poll on socket and fd]
+| >SetServerResponse() |
+| [write data to socket] |
+| [wait server complete] |
+| | [socket event occurs]
+| | >ServerHandleCommand()
+| | >ServerReceiveResponse()
+| | [read data from socket]
+| | [save data to memory]
+| | <ServerReceiveResponse()
+| | =ServerCompleteWith()
+| <SetServerResponse() |
+| | <ServerHandleCommand()
+| >[Do fs operation] |
+| [wait for fs response] |
+| | [fd event occurs]
+| | >ServerProcessFuseRequest()
+| | =[read fs request]
+| | =[save fs request to memory]
+| | =[write fs response]
+| <[Do fs operation] |
+| | <ServerProcessFuseRequest()
+| |
+| =[Test fs operation result] |
+| |
+| >GetServerActualRequest() |
+| [write data to socket] |
+| [wait data from server] |
+| | [socket event occurs]
+| | >ServerHandleCommand()
+| | >ServerSendReceivedRequest()
+| | [write data to socket]
+| [read data from socket] |
+| [wait server complete] |
+| | <ServerSendReceivedRequest()
+| | =ServerCompleteWith()
+| <GetServerActualRequest() |
+| | <ServerHandleCommand()
+| |
+| =[Test actual request] |
+| |
+| >TearDown() |
+| ... |
+| >GetServerNumUnsentResponses() |
+| [write data to socket] |
+| [wait server complete] |
+| | [socket event arrive]
+| | >ServerHandleCommand()
+| | >ServerSendData()
+| | [write data to socket]
+| | <ServerSendData()
+| | =ServerCompleteWith()
+| [read data from socket] |
+| [test if all succeeded] |
+| <GetServerNumUnsentResponses() |
+| | <ServerHandleCommand()
+| =UnmountFuse() |
+| <TearDown() |
+| <TEST_F() |
```
## Running the tests
-Based on syscall tests, fuse tests can run in different environments. To enable
-fuse testing environment, the test targets should be appended with `_fuse`.
+Based on syscall tests, FUSE tests generate targets only with vfs2 and fuse
+enabled. The corresponding targets end in `_fuse`.
For example, to run fuse test in `stat_test.cc`:
@@ -77,17 +126,16 @@ $ bazel test --test_tag_filters=fuse //test/fuse/...
1. Add test targets in `BUILD` and `linux/BUILD`.
2. Inherit your test from `FuseTest` base class. It allows you to:
- - Run a fake FUSE server in background during each test setup.
- - Create pipes for communication and provide utility functions.
- - Stop FUSE server after test completes.
-3. Customize your comparison function for request assessment in FUSE server.
-4. Add the mapping of the size of structs if you are working on new FUSE
- opcode.
- - Please update `FuseTest::GetPayloadSize()` for each new FUSE opcode.
-5. Build the expected request-response pair of your FUSE operation.
-6. Call `SetExpected()` function to inject the expected reaction.
-7. Check the response and/or errors.
-8. Finally call `WaitCompleted()` to ensure the FUSE server acts correctly.
+ - Fork a fake FUSE server in background during each test setup.
+ - Create a pair of sockets for communication and provide utility
+ functions.
+ - Stop FUSE server and check if error occurs in it after test completes.
+3. Build the expected opcode-response pairs of your FUSE operation.
+4. Call `SetServerResponse()` to preset the next expected opcode and response.
+5. Do real filesystem operations (FUSE is mounted at `mount_point_`).
+6. Check FUSE response and/or errors.
+7. Retrieve FUSE request by `GetServerActualRequest()`.
+8. Check if the request is as expected.
A few customized matchers used in syscalls test are encouraged to test the
outcome of filesystem operations. Such as:
@@ -101,3 +149,40 @@ SyscallFailsWithErrno(...)
Please refer to [test/syscalls/README.md](../syscalls/README.md) for further
details.
+
+## Writing a new FuseTestCmd
+
+A `FuseTestCmd` is a control protocol used in the communication between the
+testing thread and the FUSE server. Such commands are sent from the testing
+thread to the FUSE server to set up, control, or inspect the behavior of the
+FUSE server in response to a sequence of FUSE requests.
+
+The lifecycle of a command contains following steps:
+
+1. The testing thread sends a `FuseTestCmd` via socket and waits for
+ completion.
+2. The FUSE server receives the command and does corresponding action.
+3. (Optional) The testing thread reads data from socket.
+4. The FUSE server sends a success indicator via socket after processing.
+5. The testing thread gets the success signal and continues testing.
+
+The success indicator, i.e. `WaitServerComplete()`, is crucial at the end of
+each `FuseTestCmd` sent from the testing thread. Because we don't want to begin
+filesystem operation if the requests have not been completely set up. Also, to
+test FUSE interactions in a sequential manner, concurrent requests are not
+supported now.
+
+To add a new `FuseTestCmd`, one must comply with following format:
+
+1. Add a new `FuseTestCmd` enum class item defined in `linux/fuse_base.h`
+2. Add a `SetServerXXX()` or `GetServerXXX()` public function in `FuseTest`.
+ This is how the testing thread will call to send control message. Define how
+ many bytes you want to send along with the command and what you will expect
+ to receive. Finally it should block and wait for a success indicator from
+ the FUSE server.
+3. Add a handler logic in the switch condition of `ServerHandleCommand()`. Use
+ `ServerSendData()` or declare a new private function such as
+ `ServerReceiveXXX()` or `ServerSendXXX()`. It is mandatory to set it private
+ since only the FUSE server (forked from `FuseTest` base class) can call it.
+ This is the server part of the specific `FuseTestCmd` and the format of the
+ data should be consistent with what the client expects in the previous step.
diff --git a/test/fuse/linux/BUILD b/test/fuse/linux/BUILD
index 4871bb531..7673252ec 100644
--- a/test/fuse/linux/BUILD
+++ b/test/fuse/linux/BUILD
@@ -11,7 +11,134 @@ cc_binary(
srcs = ["stat_test.cc"],
deps = [
gtest,
+ ":fuse_fd_util",
+ "//test/util:cleanup",
+ "//test/util:fs_util",
+ "//test/util:fuse_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "open_test",
+ testonly = 1,
+ srcs = ["open_test.cc"],
+ deps = [
+ gtest,
":fuse_base",
+ "//test/util:fuse_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "release_test",
+ testonly = 1,
+ srcs = ["release_test.cc"],
+ deps = [
+ gtest,
+ ":fuse_base",
+ "//test/util:fuse_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "mknod_test",
+ testonly = 1,
+ srcs = ["mknod_test.cc"],
+ deps = [
+ gtest,
+ ":fuse_base",
+ "//test/util:fuse_util",
+ "//test/util:temp_umask",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "symlink_test",
+ testonly = 1,
+ srcs = ["symlink_test.cc"],
+ deps = [
+ gtest,
+ ":fuse_base",
+ "//test/util:fuse_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "readlink_test",
+ testonly = 1,
+ srcs = ["readlink_test.cc"],
+ deps = [
+ gtest,
+ ":fuse_base",
+ "//test/util:fuse_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "mkdir_test",
+ testonly = 1,
+ srcs = ["mkdir_test.cc"],
+ deps = [
+ gtest,
+ ":fuse_base",
+ "//test/util:fuse_util",
+ "//test/util:temp_umask",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "setstat_test",
+ testonly = 1,
+ srcs = ["setstat_test.cc"],
+ deps = [
+ gtest,
+ ":fuse_fd_util",
+ "//test/util:cleanup",
+ "//test/util:fs_util",
+ "//test/util:fuse_util",
+ "//test/util:temp_umask",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "rmdir_test",
+ testonly = 1,
+ srcs = ["rmdir_test.cc"],
+ deps = [
+ gtest,
+ ":fuse_base",
+ "//test/util:fs_util",
+ "//test/util:fuse_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "readdir_test",
+ testonly = 1,
+ srcs = ["readdir_test.cc"],
+ deps = [
+ gtest,
+ ":fuse_base",
+ "//test/util:fs_util",
+ "//test/util:fuse_util",
"//test/util:test_main",
"//test/util:test_util",
],
@@ -24,9 +151,80 @@ cc_library(
hdrs = ["fuse_base.h"],
deps = [
gtest,
+ "//test/util:fuse_util",
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_util",
"@com_google_absl//absl/strings:str_format",
],
)
+
+cc_library(
+ name = "fuse_fd_util",
+ testonly = 1,
+ srcs = ["fuse_fd_util.cc"],
+ hdrs = ["fuse_fd_util.h"],
+ deps = [
+ gtest,
+ ":fuse_base",
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "//test/util:fuse_util",
+ "//test/util:posix_error",
+ ],
+)
+
+cc_binary(
+ name = "read_test",
+ testonly = 1,
+ srcs = ["read_test.cc"],
+ deps = [
+ gtest,
+ ":fuse_base",
+ "//test/util:fuse_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "write_test",
+ testonly = 1,
+ srcs = ["write_test.cc"],
+ deps = [
+ gtest,
+ ":fuse_base",
+ "//test/util:fuse_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "create_test",
+ testonly = 1,
+ srcs = ["create_test.cc"],
+ deps = [
+ gtest,
+ ":fuse_base",
+ "//test/util:fs_util",
+ "//test/util:fuse_util",
+ "//test/util:temp_umask",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "unlink_test",
+ testonly = 1,
+ srcs = ["unlink_test.cc"],
+ deps = [
+ gtest,
+ ":fuse_base",
+ "//test/util:fuse_util",
+ "//test/util:temp_umask",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
diff --git a/test/fuse/linux/create_test.cc b/test/fuse/linux/create_test.cc
new file mode 100644
index 000000000..9a0219a58
--- /dev/null
+++ b/test/fuse/linux/create_test.cc
@@ -0,0 +1,128 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "test/fuse/linux/fuse_base.h"
+#include "test/util/fs_util.h"
+#include "test/util/fuse_util.h"
+#include "test/util/temp_umask.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class CreateTest : public FuseTest {
+ protected:
+ const std::string test_file_name_ = "test_file";
+ const mode_t mode = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO;
+};
+
+TEST_F(CreateTest, CreateFile) {
+ const std::string test_file_path =
+ JoinPath(mount_point_.path().c_str(), test_file_name_);
+
+ // Ensure the file doesn't exist.
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header),
+ .error = -ENOENT,
+ };
+ auto iov_out = FuseGenerateIovecs(out_header);
+ SetServerResponse(FUSE_LOOKUP, iov_out);
+
+ // creat(2) is equal to open(2) with open_flags O_CREAT | O_WRONLY | O_TRUNC.
+ const mode_t new_mask = S_IWGRP | S_IWOTH;
+ const int open_flags = O_CREAT | O_WRONLY | O_TRUNC;
+ out_header.error = 0;
+ out_header.len = sizeof(struct fuse_out_header) +
+ sizeof(struct fuse_entry_out) + sizeof(struct fuse_open_out);
+ struct fuse_entry_out entry_payload = DefaultEntryOut(mode & ~new_mask, 2);
+ struct fuse_open_out out_payload = {
+ .fh = 1,
+ .open_flags = open_flags,
+ };
+ iov_out = FuseGenerateIovecs(out_header, entry_payload, out_payload);
+ SetServerResponse(FUSE_CREATE, iov_out);
+
+ // kernfs generates a successive FUSE_OPEN after the file is created. Linux's
+ // fuse kernel module will not send this FUSE_OPEN after creat(2).
+ out_header.len =
+ sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out);
+ iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_OPEN, iov_out);
+
+ int fd;
+ TempUmask mask(new_mask);
+ EXPECT_THAT(fd = creat(test_file_path.c_str(), mode), SyscallSucceeds());
+ EXPECT_THAT(fcntl(fd, F_GETFL),
+ SyscallSucceedsWithValue(open_flags & O_ACCMODE));
+
+ struct fuse_in_header in_header;
+ struct fuse_create_in in_payload;
+ std::vector<char> name(test_file_name_.size() + 1);
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload, name);
+
+ // Skip the request of FUSE_LOOKUP.
+ SkipServerActualRequest();
+
+ // Get the first FUSE_CREATE.
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload) +
+ test_file_name_.size() + 1);
+ EXPECT_EQ(in_header.opcode, FUSE_CREATE);
+ EXPECT_EQ(in_payload.flags, open_flags);
+ EXPECT_EQ(in_payload.mode, mode & ~new_mask);
+ EXPECT_EQ(in_payload.umask, new_mask);
+ EXPECT_EQ(std::string(name.data()), test_file_name_);
+
+ // Get the successive FUSE_OPEN.
+ struct fuse_open_in in_payload_open;
+ iov_in = FuseGenerateIovecs(in_header, in_payload_open);
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload_open));
+ EXPECT_EQ(in_header.opcode, FUSE_OPEN);
+ EXPECT_EQ(in_payload_open.flags, open_flags & O_ACCMODE);
+
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+ // Skip the FUSE_RELEASE.
+ SkipServerActualRequest();
+}
+
+TEST_F(CreateTest, CreateFileAlreadyExists) {
+ const std::string test_file_path =
+ JoinPath(mount_point_.path().c_str(), test_file_name_);
+
+ const int open_flags = O_CREAT | O_EXCL;
+
+ SetServerInodeLookup(test_file_name_);
+
+ EXPECT_THAT(open(test_file_path.c_str(), mode, open_flags),
+ SyscallFailsWithErrno(EEXIST));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/fuse/linux/fuse_base.cc b/test/fuse/linux/fuse_base.cc
index 9c3124472..5b45804e1 100644
--- a/test/fuse/linux/fuse_base.cc
+++ b/test/fuse/linux/fuse_base.cc
@@ -16,17 +16,17 @@
#include <fcntl.h>
#include <linux/fuse.h>
-#include <string.h>
+#include <poll.h>
#include <sys/mount.h>
+#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/uio.h>
#include <unistd.h>
-#include <iostream>
-
#include "gtest/gtest.h"
#include "absl/strings/str_format.h"
+#include "test/util/fuse_util.h"
#include "test/util/posix_error.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -39,49 +39,123 @@ void FuseTest::SetUp() {
SetUpFuseServer();
}
-void FuseTest::TearDown() { UnmountFuse(); }
-
-// Since CompareRequest is running in background thread, gTest assertions and
-// expectations won't directly reflect the test result. However, the FUSE
-// background server still connects to the same standard I/O as testing main
-// thread. So EXPECT_XX can still be used to show different results. To
-// ensure failed testing result is observable, return false and the result
-// will be sent to test main thread via pipe.
-bool FuseTest::CompareRequest(void* expected_mem, size_t expected_len,
- void* real_mem, size_t real_len) {
- if (expected_len != real_len) return false;
- return memcmp(expected_mem, real_mem, expected_len) == 0;
+void FuseTest::TearDown() {
+ EXPECT_EQ(GetServerNumUnconsumedRequests(), 0);
+ EXPECT_EQ(GetServerNumUnsentResponses(), 0);
+ UnmountFuse();
}
-// SetExpected is called by the testing main thread to set expected request-
-// response pair of a single FUSE operation.
-void FuseTest::SetExpected(struct iovec* iov_in, int iov_in_cnt,
- struct iovec* iov_out, int iov_out_cnt) {
- EXPECT_THAT(RetryEINTR(writev)(set_expected_[1], iov_in, iov_in_cnt),
- SyscallSucceedsWithValue(::testing::Gt(0)));
- WaitCompleted();
+// Sends 3 parts of data to the FUSE server:
+// 1. The `kSetResponse` command
+// 2. The expected opcode
+// 3. The fake FUSE response
+// Then waits for the FUSE server to notify its completion.
+void FuseTest::SetServerResponse(uint32_t opcode,
+ std::vector<struct iovec>& iovecs) {
+ uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kSetResponse);
+ EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)),
+ SyscallSucceedsWithValue(sizeof(cmd)));
+
+ EXPECT_THAT(RetryEINTR(write)(sock_[0], &opcode, sizeof(opcode)),
+ SyscallSucceedsWithValue(sizeof(opcode)));
- EXPECT_THAT(RetryEINTR(writev)(set_expected_[1], iov_out, iov_out_cnt),
- SyscallSucceedsWithValue(::testing::Gt(0)));
- WaitCompleted();
+ EXPECT_THAT(RetryEINTR(writev)(sock_[0], iovecs.data(), iovecs.size()),
+ SyscallSucceeds());
+
+ WaitServerComplete();
}
-// WaitCompleted waits for the FUSE server to finish its job and check if it
+// Waits for the FUSE server to finish its blocking job and check if it
// completes without errors.
-void FuseTest::WaitCompleted() {
- char success;
- EXPECT_THAT(RetryEINTR(read)(done_[0], &success, sizeof(success)),
- SyscallSucceedsWithValue(1));
+void FuseTest::WaitServerComplete() {
+ uint32_t success;
+ EXPECT_THAT(RetryEINTR(read)(sock_[0], &success, sizeof(success)),
+ SyscallSucceedsWithValue(sizeof(success)));
+ ASSERT_EQ(success, 1);
+}
+
+// Sends the `kGetRequest` command to the FUSE server, then reads the next
+// request into iovec struct. The order of calling this function should be
+// the same as the one of SetServerResponse().
+void FuseTest::GetServerActualRequest(std::vector<struct iovec>& iovecs) {
+ uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kGetRequest);
+ EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)),
+ SyscallSucceedsWithValue(sizeof(cmd)));
+
+ EXPECT_THAT(RetryEINTR(readv)(sock_[0], iovecs.data(), iovecs.size()),
+ SyscallSucceeds());
+
+ WaitServerComplete();
+}
+
+// Sends a FuseTestCmd command to the FUSE server, reads from the socket, and
+// returns the corresponding data.
+uint32_t FuseTest::GetServerData(uint32_t cmd) {
+ uint32_t data;
+ EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)),
+ SyscallSucceedsWithValue(sizeof(cmd)));
+
+ EXPECT_THAT(RetryEINTR(read)(sock_[0], &data, sizeof(data)),
+ SyscallSucceedsWithValue(sizeof(data)));
+
+ WaitServerComplete();
+ return data;
+}
+
+uint32_t FuseTest::GetServerNumUnconsumedRequests() {
+ return GetServerData(
+ static_cast<uint32_t>(FuseTestCmd::kGetNumUnconsumedRequests));
+}
+
+uint32_t FuseTest::GetServerNumUnsentResponses() {
+ return GetServerData(
+ static_cast<uint32_t>(FuseTestCmd::kGetNumUnsentResponses));
+}
+
+uint32_t FuseTest::GetServerTotalReceivedBytes() {
+ return GetServerData(
+ static_cast<uint32_t>(FuseTestCmd::kGetTotalReceivedBytes));
+}
+
+// Sends the `kSkipRequest` command to the FUSE server, which would skip
+// current stored request data.
+void FuseTest::SkipServerActualRequest() {
+ uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kSkipRequest);
+ EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)),
+ SyscallSucceedsWithValue(sizeof(cmd)));
+
+ WaitServerComplete();
}
-void FuseTest::MountFuse() {
+// Sends the `kSetInodeLookup` command, expected mode, and the path of the
+// inode to create under the mount point.
+void FuseTest::SetServerInodeLookup(const std::string& path, mode_t mode,
+ uint64_t size) {
+ uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kSetInodeLookup);
+ EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)),
+ SyscallSucceedsWithValue(sizeof(cmd)));
+
+ EXPECT_THAT(RetryEINTR(write)(sock_[0], &mode, sizeof(mode)),
+ SyscallSucceedsWithValue(sizeof(mode)));
+
+ EXPECT_THAT(RetryEINTR(write)(sock_[0], &size, sizeof(size)),
+ SyscallSucceedsWithValue(sizeof(size)));
+
+ // Pad 1 byte for null-terminate c-string.
+ EXPECT_THAT(RetryEINTR(write)(sock_[0], path.c_str(), path.size() + 1),
+ SyscallSucceedsWithValue(path.size() + 1));
+
+ WaitServerComplete();
+}
+
+void FuseTest::MountFuse(const char* mountOpts) {
EXPECT_THAT(dev_fd_ = open("/dev/fuse", O_RDWR), SyscallSucceeds());
- std::string mount_opts = absl::StrFormat("fd=%d,%s", dev_fd_, kMountOpts);
+ std::string mount_opts = absl::StrFormat("fd=%d,%s", dev_fd_, mountOpts);
mount_point_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
EXPECT_THAT(mount("fuse", mount_point_.path().c_str(), "fuse",
MS_NODEV | MS_NOSUID, mount_opts.c_str()),
- SyscallSucceedsWithValue(0));
+ SyscallSucceeds());
}
void FuseTest::UnmountFuse() {
@@ -89,13 +163,13 @@ void FuseTest::UnmountFuse() {
// TODO(gvisor.dev/issue/3330): ensure the process is terminated successfully.
}
-// ConsumeFuseInit consumes the first FUSE request and returns the
-// corresponding PosixError.
-PosixError FuseTest::ConsumeFuseInit() {
+// Consumes the first FUSE request and returns the corresponding PosixError.
+PosixError FuseTest::ServerConsumeFuseInit(
+ const struct fuse_init_out* out_payload) {
+ std::vector<char> buf(FUSE_MIN_READ_BUFFER);
RETURN_ERROR_IF_SYSCALL_FAIL(
- RetryEINTR(read)(dev_fd_, buf_.data(), buf_.size()));
+ RetryEINTR(read)(dev_fd_, buf.data(), buf.size()));
- struct iovec iov_out[2];
struct fuse_out_header out_header = {
.len = sizeof(struct fuse_out_header) + sizeof(struct fuse_init_out),
.error = 0,
@@ -103,72 +177,74 @@ PosixError FuseTest::ConsumeFuseInit() {
};
// Returns a fake fuse_init_out with 7.0 version to avoid ECONNREFUSED
// error in the initialization of FUSE connection.
- struct fuse_init_out out_payload = {
- .major = 7,
- };
- iov_out[0].iov_len = sizeof(out_header);
- iov_out[0].iov_base = &out_header;
- iov_out[1].iov_len = sizeof(out_payload);
- iov_out[1].iov_base = &out_payload;
+ auto iov_out = FuseGenerateIovecs(
+ out_header, *const_cast<struct fuse_init_out*>(out_payload));
- RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(writev)(dev_fd_, iov_out, 2));
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ RetryEINTR(writev)(dev_fd_, iov_out.data(), iov_out.size()));
return NoError();
}
-// ReceiveExpected reads 1 pair of expected fuse request-response `iovec`s
-// from pipe and save them into member variables of this testing instance.
-void FuseTest::ReceiveExpected() {
- // Set expected fuse_in request.
- EXPECT_THAT(len_in_ = RetryEINTR(read)(set_expected_[0], mem_in_.data(),
- mem_in_.size()),
- SyscallSucceedsWithValue(::testing::Gt(0)));
- MarkDone(len_in_ > 0);
+// Reads 1 expected opcode and a fake response from socket and save them into
+// the serial buffer of this testing instance.
+void FuseTest::ServerReceiveResponse() {
+ ssize_t len;
+ uint32_t opcode;
+ std::vector<char> buf(FUSE_MIN_READ_BUFFER);
+ EXPECT_THAT(RetryEINTR(read)(sock_[1], &opcode, sizeof(opcode)),
+ SyscallSucceedsWithValue(sizeof(opcode)));
- // Set expected fuse_out response.
- EXPECT_THAT(len_out_ = RetryEINTR(read)(set_expected_[0], mem_out_.data(),
- mem_out_.size()),
- SyscallSucceedsWithValue(::testing::Gt(0)));
- MarkDone(len_out_ > 0);
+ EXPECT_THAT(len = RetryEINTR(read)(sock_[1], buf.data(), buf.size()),
+ SyscallSucceeds());
+
+ responses_.AddMemBlock(opcode, buf.data(), len);
}
-// MarkDone writes 1 byte of success indicator through pipe.
-void FuseTest::MarkDone(bool success) {
- char data = success ? 1 : 0;
- EXPECT_THAT(RetryEINTR(write)(done_[1], &data, sizeof(data)),
- SyscallSucceedsWithValue(1));
+// Writes 1 byte of success indicator through socket.
+void FuseTest::ServerCompleteWith(bool success) {
+ uint32_t data = success ? 1 : 0;
+ ServerSendData(data);
}
-// FuseLoop is the implementation of the fake FUSE server. Read from /dev/fuse,
-// compare the request by CompareRequest (use derived function if specified),
-// and write the expected response to /dev/fuse.
-void FuseTest::FuseLoop() {
- bool success = true;
- ssize_t len = 0;
+// ServerFuseLoop is the implementation of the fake FUSE server. Monitors 2
+// file descriptors: /dev/fuse and sock_[1]. Events from /dev/fuse are FUSE
+// requests and events from sock_[1] are FUSE testing commands, leading by
+// a FuseTestCmd data to indicate the command.
+void FuseTest::ServerFuseLoop() {
+ const int nfds = 2;
+ struct pollfd fds[nfds] = {
+ {
+ .fd = dev_fd_,
+ .events = POLL_IN | POLLHUP | POLLERR | POLLNVAL,
+ },
+ {
+ .fd = sock_[1],
+ .events = POLL_IN | POLLHUP | POLLERR | POLLNVAL,
+ },
+ };
+
while (true) {
- ReceiveExpected();
+ ASSERT_THAT(poll(fds, nfds, -1), SyscallSucceeds());
- EXPECT_THAT(len = RetryEINTR(read)(dev_fd_, buf_.data(), buf_.size()),
- SyscallSucceedsWithValue(len_in_));
- if (len != len_in_) success = false;
+ for (int fd_idx = 0; fd_idx < nfds; ++fd_idx) {
+ if (fds[fd_idx].revents == 0) continue;
- if (!CompareRequest(buf_.data(), len_in_, mem_in_.data(), len_in_)) {
- std::cerr << "the FUSE request is not expected" << std::endl;
- success = false;
+ ASSERT_EQ(fds[fd_idx].revents, POLL_IN);
+ if (fds[fd_idx].fd == sock_[1]) {
+ ServerHandleCommand();
+ } else if (fds[fd_idx].fd == dev_fd_) {
+ ServerProcessFuseRequest();
+ }
}
-
- EXPECT_THAT(len = RetryEINTR(write)(dev_fd_, mem_out_.data(), len_out_),
- SyscallSucceedsWithValue(len_out_));
- if (len != len_out_) success = false;
- MarkDone(success);
}
}
-// SetUpFuseServer creates 2 pipes. First is for testing client to send the
-// expected request-response pair, and the other acts as a checkpoint for the
-// FUSE server to notify the client that it can proceed.
-void FuseTest::SetUpFuseServer() {
- ASSERT_THAT(pipe(set_expected_), SyscallSucceedsWithValue(0));
- ASSERT_THAT(pipe(done_), SyscallSucceedsWithValue(0));
+// SetUpFuseServer creates 1 socketpair and fork the process. The parent thread
+// becomes testing thread and the child thread becomes the FUSE server running
+// in background. These 2 threads are connected via socketpair. sock_[0] is
+// opened in testing thread and sock_[1] is opened in the FUSE server.
+void FuseTest::SetUpFuseServer(const struct fuse_init_out* payload) {
+ ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_), SyscallSucceeds());
switch (fork()) {
case -1:
@@ -177,31 +253,194 @@ void FuseTest::SetUpFuseServer() {
case 0:
break;
default:
- ASSERT_THAT(close(set_expected_[0]), SyscallSucceedsWithValue(0));
- ASSERT_THAT(close(done_[1]), SyscallSucceedsWithValue(0));
- WaitCompleted();
+ ASSERT_THAT(close(sock_[1]), SyscallSucceeds());
+ WaitServerComplete();
return;
}
- ASSERT_THAT(close(set_expected_[1]), SyscallSucceedsWithValue(0));
- ASSERT_THAT(close(done_[0]), SyscallSucceedsWithValue(0));
-
- MarkDone(ConsumeFuseInit().ok());
-
- FuseLoop();
+ // Begin child thread, i.e. the FUSE server.
+ ASSERT_THAT(close(sock_[0]), SyscallSucceeds());
+ ServerCompleteWith(ServerConsumeFuseInit(payload).ok());
+ ServerFuseLoop();
_exit(0);
}
-// GetPayloadSize is a helper function to get the number of bytes of a
-// specific FUSE operation struct.
-size_t FuseTest::GetPayloadSize(uint32_t opcode, bool in) {
- switch (opcode) {
- case FUSE_INIT:
- return in ? sizeof(struct fuse_init_in) : sizeof(struct fuse_init_out);
+void FuseTest::ServerSendData(uint32_t data) {
+ EXPECT_THAT(RetryEINTR(write)(sock_[1], &data, sizeof(data)),
+ SyscallSucceedsWithValue(sizeof(data)));
+}
+
+// Reads FuseTestCmd sent from testing thread and routes to correct handler.
+// Since each command should be a blocking operation, a `ServerCompleteWith()`
+// is required after the switch keyword.
+void FuseTest::ServerHandleCommand() {
+ uint32_t cmd;
+ EXPECT_THAT(RetryEINTR(read)(sock_[1], &cmd, sizeof(cmd)),
+ SyscallSucceedsWithValue(sizeof(cmd)));
+
+ switch (static_cast<FuseTestCmd>(cmd)) {
+ case FuseTestCmd::kSetResponse:
+ ServerReceiveResponse();
+ break;
+ case FuseTestCmd::kSetInodeLookup:
+ ServerReceiveInodeLookup();
+ break;
+ case FuseTestCmd::kGetRequest:
+ ServerSendReceivedRequest();
+ break;
+ case FuseTestCmd::kGetTotalReceivedBytes:
+ ServerSendData(static_cast<uint32_t>(requests_.UsedBytes()));
+ break;
+ case FuseTestCmd::kGetNumUnconsumedRequests:
+ ServerSendData(static_cast<uint32_t>(requests_.RemainingBlocks()));
+ break;
+ case FuseTestCmd::kGetNumUnsentResponses:
+ ServerSendData(static_cast<uint32_t>(responses_.RemainingBlocks()));
+ break;
+ case FuseTestCmd::kSkipRequest:
+ ServerSkipReceivedRequest();
+ break;
default:
+ FAIL() << "Unknown FuseTestCmd " << cmd;
break;
}
- return 0;
+
+ ServerCompleteWith(!HasFailure());
+}
+
+// Reads the expected file mode and the path of one file. Crafts a basic
+// `fuse_entry_out` memory block and inserts into a map for future use.
+// The FUSE server will always return this response if a FUSE_LOOKUP
+// request with this specific path comes in.
+void FuseTest::ServerReceiveInodeLookup() {
+ mode_t mode;
+ uint64_t size;
+ std::vector<char> buf(FUSE_MIN_READ_BUFFER);
+
+ EXPECT_THAT(RetryEINTR(read)(sock_[1], &mode, sizeof(mode)),
+ SyscallSucceedsWithValue(sizeof(mode)));
+
+ EXPECT_THAT(RetryEINTR(read)(sock_[1], &size, sizeof(size)),
+ SyscallSucceedsWithValue(sizeof(size)));
+
+ EXPECT_THAT(RetryEINTR(read)(sock_[1], buf.data(), buf.size()),
+ SyscallSucceeds());
+
+ std::string path(buf.data());
+
+ uint32_t out_len =
+ sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out);
+ struct fuse_out_header out_header = {
+ .len = out_len,
+ .error = 0,
+ };
+ struct fuse_entry_out out_payload = DefaultEntryOut(mode, nodeid_);
+ // Since this is only used in test, nodeid_ is simply increased by 1 to
+ // comply with the unqiueness of different path.
+ ++nodeid_;
+
+ // Set the size.
+ out_payload.attr.size = size;
+
+ memcpy(buf.data(), &out_header, sizeof(out_header));
+ memcpy(buf.data() + sizeof(out_header), &out_payload, sizeof(out_payload));
+ lookups_.AddMemBlock(FUSE_LOOKUP, buf.data(), out_len);
+ lookup_map_[path] = lookups_.Next();
+}
+
+// Sends the received request pointed by current cursor and advances cursor.
+void FuseTest::ServerSendReceivedRequest() {
+ if (requests_.End()) {
+ FAIL() << "No more received request.";
+ return;
+ }
+ auto mem_block = requests_.Next();
+ EXPECT_THAT(
+ RetryEINTR(write)(sock_[1], requests_.DataAtOffset(mem_block.offset),
+ mem_block.len),
+ SyscallSucceedsWithValue(mem_block.len));
+}
+
+// Skip the request pointed by current cursor.
+void FuseTest::ServerSkipReceivedRequest() {
+ if (requests_.End()) {
+ FAIL() << "No more received request.";
+ return;
+ }
+ requests_.Next();
+}
+
+// Handles FUSE request. Reads request from /dev/fuse, checks if it has the
+// same opcode as expected, and responds with the saved fake FUSE response.
+// The FUSE request is copied to the serial buffer and can be retrieved one-
+// by-one by calling GetServerActualRequest from testing thread.
+void FuseTest::ServerProcessFuseRequest() {
+ ssize_t len;
+ std::vector<char> buf(FUSE_MIN_READ_BUFFER);
+
+ // Read FUSE request.
+ EXPECT_THAT(len = RetryEINTR(read)(dev_fd_, buf.data(), buf.size()),
+ SyscallSucceeds());
+ fuse_in_header* in_header = reinterpret_cast<fuse_in_header*>(buf.data());
+
+ // Check if this is a preset FUSE_LOOKUP path.
+ if (in_header->opcode == FUSE_LOOKUP) {
+ std::string path(buf.data() + sizeof(struct fuse_in_header));
+ auto it = lookup_map_.find(path);
+ if (it != lookup_map_.end()) {
+ // Matches a preset path. Reply with fake data and skip saving the
+ // request.
+ ServerRespondFuseSuccess(lookups_, it->second, in_header->unique);
+ return;
+ }
+ }
+
+ requests_.AddMemBlock(in_header->opcode, buf.data(), len);
+
+ if (in_header->opcode == FUSE_RELEASE || in_header->opcode == FUSE_RELEASEDIR)
+ return;
+ // Check if there is a corresponding response.
+ if (responses_.End()) {
+ GTEST_NONFATAL_FAILURE_("No more FUSE response is expected");
+ ServerRespondFuseError(in_header->unique);
+ return;
+ }
+ auto mem_block = responses_.Next();
+ if (in_header->opcode != mem_block.opcode) {
+ std::string message = absl::StrFormat("Expect opcode %d but got %d",
+ mem_block.opcode, in_header->opcode);
+ GTEST_NONFATAL_FAILURE_(message.c_str());
+ // We won't get correct response if opcode is not expected. Send error
+ // response here to avoid wrong parsing by VFS.
+ ServerRespondFuseError(in_header->unique);
+ return;
+ }
+
+ // Write FUSE response.
+ ServerRespondFuseSuccess(responses_, mem_block, in_header->unique);
+}
+
+void FuseTest::ServerRespondFuseSuccess(FuseMemBuffer& mem_buf,
+ const FuseMemBlock& block,
+ uint64_t unique) {
+ fuse_out_header* out_header =
+ reinterpret_cast<fuse_out_header*>(mem_buf.DataAtOffset(block.offset));
+
+ // Patch `unique` in fuse_out_header to avoid EINVAL caused by responding
+ // with an unknown `unique`.
+ out_header->unique = unique;
+ EXPECT_THAT(RetryEINTR(write)(dev_fd_, out_header, block.len),
+ SyscallSucceedsWithValue(block.len));
+}
+
+void FuseTest::ServerRespondFuseError(uint64_t unique) {
+ fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header),
+ .error = ENOSYS,
+ .unique = unique,
+ };
+ EXPECT_THAT(RetryEINTR(write)(dev_fd_, &out_header, sizeof(out_header)),
+ SyscallSucceedsWithValue(sizeof(out_header)));
}
} // namespace testing
diff --git a/test/fuse/linux/fuse_base.h b/test/fuse/linux/fuse_base.h
index 3a2f255a9..6ad296ca2 100644
--- a/test/fuse/linux/fuse_base.h
+++ b/test/fuse/linux/fuse_base.h
@@ -16,8 +16,12 @@
#define GVISOR_TEST_FUSE_FUSE_BASE_H_
#include <linux/fuse.h>
+#include <string.h>
+#include <sys/stat.h>
#include <sys/uio.h>
+#include <iostream>
+#include <unordered_map>
#include <vector>
#include "gtest/gtest.h"
@@ -29,68 +33,216 @@ namespace testing {
constexpr char kMountOpts[] = "rootmode=755,user_id=0,group_id=0";
-class FuseTest : public ::testing::Test {
+constexpr struct fuse_init_out kDefaultFUSEInitOutPayload = {.major = 7};
+
+// Internal commands used to communicate between testing thread and the FUSE
+// server. See test/fuse/README.md for further detail.
+enum class FuseTestCmd {
+ kSetResponse = 0,
+ kSetInodeLookup,
+ kGetRequest,
+ kGetNumUnconsumedRequests,
+ kGetNumUnsentResponses,
+ kGetTotalReceivedBytes,
+ kSkipRequest,
+};
+
+// Holds the information of a memory block in a serial buffer.
+struct FuseMemBlock {
+ uint32_t opcode;
+ size_t offset;
+ size_t len;
+};
+
+// A wrapper of a simple serial buffer that can be used with read(2) and
+// write(2). Contains a cursor to indicate accessing. This class is not thread-
+// safe and can only be used in single-thread version.
+class FuseMemBuffer {
public:
- FuseTest() {
- buf_.resize(FUSE_MIN_READ_BUFFER);
- mem_in_.resize(FUSE_MIN_READ_BUFFER);
- mem_out_.resize(FUSE_MIN_READ_BUFFER);
+ FuseMemBuffer() : cursor_(0) {
+ // To read from /dev/fuse, a buffer needs at least FUSE_MIN_READ_BUFFER
+ // bytes to avoid EINVAL. FuseMemBuffer holds memory that can accommodate
+ // a sequence of FUSE request/response, so it is initiated with double
+ // minimal requirement.
+ mem_.resize(FUSE_MIN_READ_BUFFER * 2);
}
- void SetUp() override;
- void TearDown() override;
- // CompareRequest is used by the FUSE server and should be implemented to
- // compare different FUSE operations. It compares the actual FUSE input
- // request with the expected one set by `SetExpected()`.
- virtual bool CompareRequest(void* expected_mem, size_t expected_len,
- void* real_mem, size_t real_len);
+ // Returns whether there is no memory block.
+ bool Empty() { return blocks_.empty(); }
+
+ // Returns if there is no more remaining memory blocks.
+ bool End() { return cursor_ == blocks_.size(); }
+
+ // Returns how many bytes that have been received.
+ size_t UsedBytes() {
+ return Empty() ? 0 : blocks_.back().offset + blocks_.back().len;
+ }
+
+ // Returns the available bytes remains in the serial buffer.
+ size_t AvailBytes() { return mem_.size() - UsedBytes(); }
+
+ // Appends a memory block information that starts at the tail of the serial
+ // buffer. /dev/fuse requires at least FUSE_MIN_READ_BUFFER bytes to read, or
+ // it will issue EINVAL. If it is not enough, just double the buffer length.
+ void AddMemBlock(uint32_t opcode, void* data, size_t len) {
+ if (AvailBytes() < FUSE_MIN_READ_BUFFER) {
+ mem_.resize(mem_.size() << 1);
+ }
+ size_t offset = UsedBytes();
+ memcpy(mem_.data() + offset, data, len);
+ blocks_.push_back(FuseMemBlock{opcode, offset, len});
+ }
+
+ // Returns the memory address at a specific offset. Used with read(2) or
+ // write(2).
+ char* DataAtOffset(size_t offset) { return mem_.data() + offset; }
+
+ // Returns current memory block pointed by the cursor and increase by 1.
+ FuseMemBlock Next() {
+ if (End()) {
+ std::cerr << "Buffer is already exhausted." << std::endl;
+ return FuseMemBlock{};
+ }
+ return blocks_[cursor_++];
+ }
+
+ // Returns the number of the blocks that has not been requested.
+ size_t RemainingBlocks() { return blocks_.size() - cursor_; }
+
+ private:
+ size_t cursor_;
+ std::vector<FuseMemBlock> blocks_;
+ std::vector<char> mem_;
+};
- // SetExpected is called by the testing main thread. Writes a request-
- // response pair into FUSE server's member variables via pipe.
- void SetExpected(struct iovec* iov_in, int iov_in_cnt, struct iovec* iov_out,
- int iov_out_cnt);
+// FuseTest base class is useful in FUSE integration test. Inherit this class
+// to automatically set up a fake FUSE server and use the member functions
+// to manipulate with it. Refer to test/fuse/README.md for detailed explanation.
+class FuseTest : public ::testing::Test {
+ public:
+ // nodeid_ is the ID of a fake inode. We starts from 2 since 1 is occupied by
+ // the mount point.
+ FuseTest() : nodeid_(2) {}
+ void SetUp() override;
+ void TearDown() override;
- // WaitCompleted waits for FUSE server to complete its processing. It
- // complains if the FUSE server responds failure during tests.
- void WaitCompleted();
+ // Called by the testing thread to set up a fake response for an expected
+ // opcode via socket. This can be used multiple times to define a sequence of
+ // expected FUSE reactions.
+ void SetServerResponse(uint32_t opcode, std::vector<struct iovec>& iovecs);
+
+ // Called by the testing thread to install a fake path under the mount point.
+ // e.g. a file under /mnt/dir/file and moint point is /mnt, then it will look
+ // up "dir/file" in this case.
+ //
+ // It sets a fixed response to the FUSE_LOOKUP requests issued with this
+ // path, pretending there is an inode and avoid ENOENT when testing. If mode
+ // is not given, it creates a regular file with mode 0600.
+ void SetServerInodeLookup(const std::string& path,
+ mode_t mode = S_IFREG | S_IRUSR | S_IWUSR,
+ uint64_t size = 512);
+
+ // Called by the testing thread to ask the FUSE server for its next received
+ // FUSE request. Be sure to use the corresponding struct of iovec to receive
+ // data from server.
+ void GetServerActualRequest(std::vector<struct iovec>& iovecs);
+
+ // Called by the testing thread to query the number of unconsumed requests in
+ // the requests_ serial buffer of the FUSE server. TearDown() ensures all
+ // FUSE requests received by the FUSE server were consumed by the testing
+ // thread.
+ uint32_t GetServerNumUnconsumedRequests();
+
+ // Called by the testing thread to query the number of unsent responses in
+ // the responses_ serial buffer of the FUSE server. TearDown() ensures all
+ // preset FUSE responses were sent out by the FUSE server.
+ uint32_t GetServerNumUnsentResponses();
+
+ // Called by the testing thread to ask the FUSE server for its total received
+ // bytes from /dev/fuse.
+ uint32_t GetServerTotalReceivedBytes();
+
+ // Called by the testing thread to ask the FUSE server to skip stored
+ // request data.
+ void SkipServerActualRequest();
protected:
TempPath mount_point_;
- private:
- void MountFuse();
+ // Opens /dev/fuse and inherit the file descriptor for the FUSE server.
+ void MountFuse(const char* mountOpts = kMountOpts);
+
+ // Creates a socketpair for communication and forks FUSE server.
+ void SetUpFuseServer(
+ const struct fuse_init_out* payload = &kDefaultFUSEInitOutPayload);
+
+ // Unmounts the mountpoint of the FUSE server.
void UnmountFuse();
- // ConsumeFuseInit is only used during FUSE server setup.
- PosixError ConsumeFuseInit();
+ private:
+ // Sends a FuseTestCmd and gets a uint32_t data from the FUSE server.
+ inline uint32_t GetServerData(uint32_t cmd);
+
+ // Waits for FUSE server to complete its processing. Complains if the FUSE
+ // server responds any failure during tests.
+ void WaitServerComplete();
- // ReceiveExpected is the FUSE server side's corresponding code of
- // `SetExpected()`. Save the request-response pair into its memory.
- void ReceiveExpected();
+ // The FUSE server stays here and waits next command or FUSE request until it
+ // is terminated.
+ void ServerFuseLoop();
- // MarkDone is used by the FUSE server to tell testing main if it's OK to
- // proceed next command.
- void MarkDone(bool success);
+ // Used by the FUSE server to tell testing thread if it is OK to proceed next
+ // command. Will be issued after processing each FuseTestCmd.
+ void ServerCompleteWith(bool success);
- // FuseLoop is where the FUSE server stay until it is terminated.
- void FuseLoop();
+ // Consumes the first FUSE request when mounting FUSE. Replies with a
+ // response with empty payload.
+ PosixError ServerConsumeFuseInit(const struct fuse_init_out* payload);
- // SetUpFuseServer creates 2 pipes for communication and forks FUSE server.
- void SetUpFuseServer();
+ // A command switch that dispatch different FuseTestCmd to its handler.
+ void ServerHandleCommand();
- // GetPayloadSize is a helper function to get the number of bytes of a
- // specific FUSE operation struct.
- size_t GetPayloadSize(uint32_t opcode, bool in);
+ // The FUSE server side's corresponding code of `SetServerResponse()`.
+ // Handles `kSetResponse` command. Saves the fake response into its output
+ // memory queue.
+ void ServerReceiveResponse();
+
+ // The FUSE server side's corresponding code of `SetServerInodeLookup()`.
+ // Handles `kSetInodeLookup` command. Receives an expected file mode and
+ // file path under the mount point.
+ void ServerReceiveInodeLookup();
+
+ // The FUSE server side's corresponding code of `GetServerActualRequest()`.
+ // Handles `kGetRequest` command. Sends the next received request pointed by
+ // the cursor.
+ void ServerSendReceivedRequest();
+
+ // Sends a uint32_t data via socket.
+ inline void ServerSendData(uint32_t data);
+
+ // The FUSE server side's corresponding code of `SkipServerActualRequest()`.
+ // Handles `kSkipRequest` command. Skip the request pointed by current cursor.
+ void ServerSkipReceivedRequest();
+
+ // Handles FUSE request sent to /dev/fuse by its saved responses.
+ void ServerProcessFuseRequest();
+
+ // Responds to FUSE request with a saved data.
+ void ServerRespondFuseSuccess(FuseMemBuffer& mem_buf,
+ const FuseMemBlock& block, uint64_t unique);
+
+ // Responds an error header to /dev/fuse when bad thing happens.
+ void ServerRespondFuseError(uint64_t unique);
int dev_fd_;
- int set_expected_[2];
- int done_[2];
-
- std::vector<char> buf_;
- std::vector<char> mem_in_;
- std::vector<char> mem_out_;
- ssize_t len_in_;
- ssize_t len_out_;
+ int sock_[2];
+
+ uint64_t nodeid_;
+ std::unordered_map<std::string, FuseMemBlock> lookup_map_;
+
+ FuseMemBuffer requests_;
+ FuseMemBuffer responses_;
+ FuseMemBuffer lookups_;
};
} // namespace testing
diff --git a/test/fuse/linux/fuse_fd_util.cc b/test/fuse/linux/fuse_fd_util.cc
new file mode 100644
index 000000000..30d1157bb
--- /dev/null
+++ b/test/fuse/linux/fuse_fd_util.cc
@@ -0,0 +1,61 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/fuse/linux/fuse_fd_util.h"
+
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/uio.h>
+
+#include <string>
+#include <vector>
+
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fuse_util.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<FileDescriptor> FuseFdTest::OpenPath(const std::string &path,
+ uint32_t flags, uint64_t fh) {
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out),
+ };
+ struct fuse_open_out out_payload = {
+ .fh = fh,
+ .open_flags = flags,
+ };
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_OPEN, iov_out);
+
+ auto res = Open(path.c_str(), flags);
+ if (res.ok()) {
+ SkipServerActualRequest();
+ }
+ return res;
+}
+
+Cleanup FuseFdTest::CloseFD(FileDescriptor &fd) {
+ return Cleanup([&] {
+ close(fd.release());
+ SkipServerActualRequest();
+ });
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/fuse/linux/fuse_fd_util.h b/test/fuse/linux/fuse_fd_util.h
new file mode 100644
index 000000000..066185c94
--- /dev/null
+++ b/test/fuse/linux/fuse_fd_util.h
@@ -0,0 +1,48 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_FUSE_FUSE_FD_UTIL_H_
+#define GVISOR_TEST_FUSE_FUSE_FD_UTIL_H_
+
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+
+#include <string>
+
+#include "test/fuse/linux/fuse_base.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+class FuseFdTest : public FuseTest {
+ public:
+ // Sets the FUSE server to respond to a FUSE_OPEN with corresponding flags and
+ // fh. Then does a real file system open on the absolute path to get an fd.
+ PosixErrorOr<FileDescriptor> OpenPath(const std::string &path,
+ uint32_t flags = O_RDONLY,
+ uint64_t fh = 1);
+
+ // Returns a cleanup object that closes the fd when it is destroyed. After
+ // the close is done, tells the FUSE server to skip this FUSE_RELEASE.
+ Cleanup CloseFD(FileDescriptor &fd);
+};
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_FUSE_FUSE_FD_UTIL_H_
diff --git a/test/fuse/linux/mkdir_test.cc b/test/fuse/linux/mkdir_test.cc
new file mode 100644
index 000000000..9647cb93f
--- /dev/null
+++ b/test/fuse/linux/mkdir_test.cc
@@ -0,0 +1,88 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/fuse/linux/fuse_base.h"
+#include "test/util/fuse_util.h"
+#include "test/util/temp_umask.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class MkdirTest : public FuseTest {
+ protected:
+ const std::string test_dir_ = "test_dir";
+ const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO;
+};
+
+TEST_F(MkdirTest, CreateDir) {
+ const std::string test_dir_path_ =
+ JoinPath(mount_point_.path().c_str(), test_dir_);
+ const mode_t new_umask = 0077;
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out),
+ };
+ struct fuse_entry_out out_payload = DefaultEntryOut(S_IFDIR | perms_, 5);
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_MKDIR, iov_out);
+ TempUmask mask(new_umask);
+ ASSERT_THAT(mkdir(test_dir_path_.c_str(), 0777), SyscallSucceeds());
+
+ struct fuse_in_header in_header;
+ struct fuse_mkdir_in in_payload;
+ std::vector<char> actual_dir(test_dir_.length() + 1);
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload, actual_dir);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_header.len,
+ sizeof(in_header) + sizeof(in_payload) + test_dir_.length() + 1);
+ EXPECT_EQ(in_header.opcode, FUSE_MKDIR);
+ EXPECT_EQ(in_payload.mode & 0777, perms_ & ~new_umask);
+ EXPECT_EQ(in_payload.umask, new_umask);
+ EXPECT_EQ(std::string(actual_dir.data()), test_dir_);
+}
+
+TEST_F(MkdirTest, FileTypeError) {
+ const std::string test_dir_path_ =
+ JoinPath(mount_point_.path().c_str(), test_dir_);
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out),
+ };
+ struct fuse_entry_out out_payload = DefaultEntryOut(S_IFREG | perms_, 5);
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_MKDIR, iov_out);
+ ASSERT_THAT(mkdir(test_dir_path_.c_str(), 0777), SyscallFailsWithErrno(EIO));
+ SkipServerActualRequest();
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/fuse/linux/mknod_test.cc b/test/fuse/linux/mknod_test.cc
new file mode 100644
index 000000000..74c74d76b
--- /dev/null
+++ b/test/fuse/linux/mknod_test.cc
@@ -0,0 +1,107 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/fuse/linux/fuse_base.h"
+#include "test/util/fuse_util.h"
+#include "test/util/temp_umask.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class MknodTest : public FuseTest {
+ protected:
+ const std::string test_file_ = "test_file";
+ const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO;
+};
+
+TEST_F(MknodTest, RegularFile) {
+ const std::string test_file_path =
+ JoinPath(mount_point_.path().c_str(), test_file_);
+ const mode_t new_umask = 0077;
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out),
+ };
+ struct fuse_entry_out out_payload = DefaultEntryOut(S_IFREG | perms_, 5);
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_MKNOD, iov_out);
+ TempUmask mask(new_umask);
+ ASSERT_THAT(mknod(test_file_path.c_str(), perms_, 0), SyscallSucceeds());
+
+ struct fuse_in_header in_header;
+ struct fuse_mknod_in in_payload;
+ std::vector<char> actual_file(test_file_.length() + 1);
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload, actual_file);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_header.len,
+ sizeof(in_header) + sizeof(in_payload) + test_file_.length() + 1);
+ EXPECT_EQ(in_header.opcode, FUSE_MKNOD);
+ EXPECT_EQ(in_payload.mode & 0777, perms_ & ~new_umask);
+ EXPECT_EQ(in_payload.umask, new_umask);
+ EXPECT_EQ(in_payload.rdev, 0);
+ EXPECT_EQ(std::string(actual_file.data()), test_file_);
+}
+
+TEST_F(MknodTest, FileTypeError) {
+ const std::string test_file_path =
+ JoinPath(mount_point_.path().c_str(), test_file_);
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out),
+ };
+ // server return directory instead of regular file should cause an error.
+ struct fuse_entry_out out_payload = DefaultEntryOut(S_IFDIR | perms_, 5);
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_MKNOD, iov_out);
+ ASSERT_THAT(mknod(test_file_path.c_str(), perms_, 0),
+ SyscallFailsWithErrno(EIO));
+ SkipServerActualRequest();
+}
+
+TEST_F(MknodTest, NodeIDError) {
+ const std::string test_file_path =
+ JoinPath(mount_point_.path().c_str(), test_file_);
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out),
+ };
+ struct fuse_entry_out out_payload =
+ DefaultEntryOut(S_IFREG | perms_, FUSE_ROOT_ID);
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_MKNOD, iov_out);
+ ASSERT_THAT(mknod(test_file_path.c_str(), perms_, 0),
+ SyscallFailsWithErrno(EIO));
+ SkipServerActualRequest();
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/fuse/linux/open_test.cc b/test/fuse/linux/open_test.cc
new file mode 100644
index 000000000..4b0c4a805
--- /dev/null
+++ b/test/fuse/linux/open_test.cc
@@ -0,0 +1,128 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "test/fuse/linux/fuse_base.h"
+#include "test/util/fuse_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class OpenTest : public FuseTest {
+ // OpenTest doesn't care the release request when close a fd,
+ // so doesn't check leftover requests when tearing down.
+ void TearDown() { UnmountFuse(); }
+
+ protected:
+ const std::string test_file_ = "test_file";
+ const mode_t regular_file_ = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO;
+
+ struct fuse_open_out out_payload_ = {
+ .fh = 1,
+ .open_flags = O_RDWR,
+ };
+};
+
+TEST_F(OpenTest, RegularFile) {
+ const std::string test_file_path =
+ JoinPath(mount_point_.path().c_str(), test_file_);
+ SetServerInodeLookup(test_file_, regular_file_);
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out),
+ };
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload_);
+ SetServerResponse(FUSE_OPEN, iov_out);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path.c_str(), O_RDWR));
+
+ struct fuse_in_header in_header;
+ struct fuse_open_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload));
+ EXPECT_EQ(in_header.opcode, FUSE_OPEN);
+ EXPECT_EQ(in_payload.flags, O_RDWR);
+ EXPECT_THAT(fcntl(fd.get(), F_GETFL), SyscallSucceedsWithValue(O_RDWR));
+}
+
+TEST_F(OpenTest, SetNoOpen) {
+ const std::string test_file_path =
+ JoinPath(mount_point_.path().c_str(), test_file_);
+ SetServerInodeLookup(test_file_, regular_file_);
+
+ // ENOSYS indicates open is not implemented.
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out),
+ .error = -ENOSYS,
+ };
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload_);
+ SetServerResponse(FUSE_OPEN, iov_out);
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path.c_str(), O_RDWR));
+ SkipServerActualRequest();
+
+ // check open doesn't send new request.
+ uint32_t recieved_before = GetServerTotalReceivedBytes();
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path.c_str(), O_RDWR));
+ EXPECT_EQ(GetServerTotalReceivedBytes(), recieved_before);
+}
+
+TEST_F(OpenTest, OpenFail) {
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out),
+ .error = -ENOENT,
+ };
+
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload_);
+ SetServerResponse(FUSE_OPENDIR, iov_out);
+ ASSERT_THAT(open(mount_point_.path().c_str(), O_RDWR),
+ SyscallFailsWithErrno(ENOENT));
+
+ struct fuse_in_header in_header;
+ struct fuse_open_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload));
+ EXPECT_EQ(in_header.opcode, FUSE_OPENDIR);
+ EXPECT_EQ(in_payload.flags, O_RDWR);
+}
+
+TEST_F(OpenTest, DirectoryFlagOnRegularFile) {
+ const std::string test_file_path =
+ JoinPath(mount_point_.path().c_str(), test_file_);
+
+ SetServerInodeLookup(test_file_, regular_file_);
+ ASSERT_THAT(open(test_file_path.c_str(), O_RDWR | O_DIRECTORY),
+ SyscallFailsWithErrno(ENOTDIR));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/fuse/linux/read_test.cc b/test/fuse/linux/read_test.cc
new file mode 100644
index 000000000..88fc299d8
--- /dev/null
+++ b/test/fuse/linux/read_test.cc
@@ -0,0 +1,390 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/fuse/linux/fuse_base.h"
+#include "test/util/fuse_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class ReadTest : public FuseTest {
+ void SetUp() override {
+ FuseTest::SetUp();
+ test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_);
+ }
+
+ // TearDown overrides the parent's function
+ // to skip checking the unconsumed release request at the end.
+ void TearDown() override { UnmountFuse(); }
+
+ protected:
+ const std::string test_file_ = "test_file";
+ const mode_t test_file_mode_ = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO;
+ const uint64_t test_fh_ = 1;
+ const uint32_t open_flag_ = O_RDWR;
+
+ std::string test_file_path_;
+
+ PosixErrorOr<FileDescriptor> OpenTestFile(const std::string &path,
+ uint64_t size = 512) {
+ SetServerInodeLookup(test_file_, test_file_mode_, size);
+
+ struct fuse_out_header out_header_open = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out),
+ };
+ struct fuse_open_out out_payload_open = {
+ .fh = test_fh_,
+ .open_flags = open_flag_,
+ };
+ auto iov_out_open = FuseGenerateIovecs(out_header_open, out_payload_open);
+ SetServerResponse(FUSE_OPEN, iov_out_open);
+
+ auto res = Open(path.c_str(), open_flag_);
+ if (res.ok()) {
+ SkipServerActualRequest();
+ }
+ return res;
+ }
+};
+
+class ReadTestSmallMaxRead : public ReadTest {
+ void SetUp() override {
+ MountFuse(mountOpts);
+ SetUpFuseServer();
+ test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_);
+ }
+
+ protected:
+ constexpr static char mountOpts[] =
+ "rootmode=755,user_id=0,group_id=0,max_read=4096";
+ // 4096 is hard-coded as the max_read in mount options.
+ const int size_fragment = 4096;
+};
+
+TEST_F(ReadTest, ReadWhole) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_));
+
+ // Prepare for the read.
+ const int n_read = 5;
+ std::vector<char> data(n_read);
+ RandomizeBuffer(data.data(), data.size());
+ struct fuse_out_header out_header_read = {
+ .len =
+ static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()),
+ };
+ auto iov_out_read = FuseGenerateIovecs(out_header_read, data);
+ SetServerResponse(FUSE_READ, iov_out_read);
+
+ // Read the whole "file".
+ std::vector<char> buf(n_read);
+ EXPECT_THAT(read(fd.get(), buf.data(), n_read),
+ SyscallSucceedsWithValue(n_read));
+
+ // Check the read request.
+ struct fuse_in_header in_header_read;
+ struct fuse_read_in in_payload_read;
+ auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_payload_read.fh, test_fh_);
+ EXPECT_EQ(in_header_read.len,
+ sizeof(in_header_read) + sizeof(in_payload_read));
+ EXPECT_EQ(in_header_read.opcode, FUSE_READ);
+ EXPECT_EQ(in_payload_read.offset, 0);
+ EXPECT_EQ(buf, data);
+}
+
+TEST_F(ReadTest, ReadPartial) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_));
+
+ // Prepare for the read.
+ const int n_data = 10;
+ std::vector<char> data(n_data);
+ RandomizeBuffer(data.data(), data.size());
+ // Note: due to read ahead, current read implementation will treat any
+ // response that is longer than requested as correct (i.e. not reach the EOF).
+ // Therefore, the test below should make sure the size to read does not exceed
+ // n_data.
+ struct fuse_out_header out_header_read = {
+ .len =
+ static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()),
+ };
+ auto iov_out_read = FuseGenerateIovecs(out_header_read, data);
+ struct fuse_in_header in_header_read;
+ struct fuse_read_in in_payload_read;
+ auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read);
+
+ std::vector<char> buf(n_data);
+
+ // Read 1 bytes.
+ SetServerResponse(FUSE_READ, iov_out_read);
+ EXPECT_THAT(read(fd.get(), buf.data(), 1), SyscallSucceedsWithValue(1));
+
+ // Check the 1-byte read request.
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_payload_read.fh, test_fh_);
+ EXPECT_EQ(in_header_read.len,
+ sizeof(in_header_read) + sizeof(in_payload_read));
+ EXPECT_EQ(in_header_read.opcode, FUSE_READ);
+ EXPECT_EQ(in_payload_read.offset, 0);
+
+ // Read 3 bytes.
+ SetServerResponse(FUSE_READ, iov_out_read);
+ EXPECT_THAT(read(fd.get(), buf.data(), 3), SyscallSucceedsWithValue(3));
+
+ // Check the 3-byte read request.
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_payload_read.fh, test_fh_);
+ EXPECT_EQ(in_payload_read.offset, 1);
+
+ // Read 5 bytes.
+ SetServerResponse(FUSE_READ, iov_out_read);
+ EXPECT_THAT(read(fd.get(), buf.data(), 5), SyscallSucceedsWithValue(5));
+
+ // Check the 5-byte read request.
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_payload_read.fh, test_fh_);
+ EXPECT_EQ(in_payload_read.offset, 4);
+}
+
+TEST_F(ReadTest, PRead) {
+ const int file_size = 512;
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, file_size));
+
+ // Prepare for the read.
+ const int n_read = 5;
+ std::vector<char> data(n_read);
+ RandomizeBuffer(data.data(), data.size());
+ struct fuse_out_header out_header_read = {
+ .len =
+ static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()),
+ };
+ auto iov_out_read = FuseGenerateIovecs(out_header_read, data);
+ SetServerResponse(FUSE_READ, iov_out_read);
+
+ // Read some bytes.
+ std::vector<char> buf(n_read);
+ const int offset_read = file_size >> 1;
+ EXPECT_THAT(pread(fd.get(), buf.data(), n_read, offset_read),
+ SyscallSucceedsWithValue(n_read));
+
+ // Check the read request.
+ struct fuse_in_header in_header_read;
+ struct fuse_read_in in_payload_read;
+ auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_payload_read.fh, test_fh_);
+ EXPECT_EQ(in_header_read.len,
+ sizeof(in_header_read) + sizeof(in_payload_read));
+ EXPECT_EQ(in_header_read.opcode, FUSE_READ);
+ EXPECT_EQ(in_payload_read.offset, offset_read);
+ EXPECT_EQ(buf, data);
+}
+
+TEST_F(ReadTest, ReadZero) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_));
+
+ // Issue the read.
+ std::vector<char> buf;
+ EXPECT_THAT(read(fd.get(), buf.data(), 0), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(ReadTest, ReadShort) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_));
+
+ // Prepare for the short read.
+ const int n_read = 5;
+ std::vector<char> data(n_read >> 1);
+ RandomizeBuffer(data.data(), data.size());
+ struct fuse_out_header out_header_read = {
+ .len =
+ static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()),
+ };
+ auto iov_out_read = FuseGenerateIovecs(out_header_read, data);
+ SetServerResponse(FUSE_READ, iov_out_read);
+
+ // Read the whole "file".
+ std::vector<char> buf(n_read);
+ EXPECT_THAT(read(fd.get(), buf.data(), n_read),
+ SyscallSucceedsWithValue(data.size()));
+
+ // Check the read request.
+ struct fuse_in_header in_header_read;
+ struct fuse_read_in in_payload_read;
+ auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_payload_read.fh, test_fh_);
+ EXPECT_EQ(in_header_read.len,
+ sizeof(in_header_read) + sizeof(in_payload_read));
+ EXPECT_EQ(in_header_read.opcode, FUSE_READ);
+ EXPECT_EQ(in_payload_read.offset, 0);
+ std::vector<char> short_buf(buf.begin(), buf.begin() + data.size());
+ EXPECT_EQ(short_buf, data);
+}
+
+TEST_F(ReadTest, ReadShortEOF) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_));
+
+ // Prepare for the short read.
+ struct fuse_out_header out_header_read = {
+ .len = static_cast<uint32_t>(sizeof(struct fuse_out_header)),
+ };
+ auto iov_out_read = FuseGenerateIovecs(out_header_read);
+ SetServerResponse(FUSE_READ, iov_out_read);
+
+ // Read the whole "file".
+ const int n_read = 10;
+ std::vector<char> buf(n_read);
+ EXPECT_THAT(read(fd.get(), buf.data(), n_read), SyscallSucceedsWithValue(0));
+
+ // Check the read request.
+ struct fuse_in_header in_header_read;
+ struct fuse_read_in in_payload_read;
+ auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_payload_read.fh, test_fh_);
+ EXPECT_EQ(in_header_read.len,
+ sizeof(in_header_read) + sizeof(in_payload_read));
+ EXPECT_EQ(in_header_read.opcode, FUSE_READ);
+ EXPECT_EQ(in_payload_read.offset, 0);
+}
+
+TEST_F(ReadTestSmallMaxRead, ReadSmallMaxRead) {
+ const int n_fragment = 10;
+ const int n_read = size_fragment * n_fragment;
+
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, n_read));
+
+ // Prepare for the read.
+ std::vector<char> data(size_fragment);
+ RandomizeBuffer(data.data(), data.size());
+ struct fuse_out_header out_header_read = {
+ .len =
+ static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()),
+ };
+ auto iov_out_read = FuseGenerateIovecs(out_header_read, data);
+
+ for (int i = 0; i < n_fragment; ++i) {
+ SetServerResponse(FUSE_READ, iov_out_read);
+ }
+
+ // Read the whole "file".
+ std::vector<char> buf(n_read);
+ EXPECT_THAT(read(fd.get(), buf.data(), n_read),
+ SyscallSucceedsWithValue(n_read));
+
+ ASSERT_EQ(GetServerNumUnsentResponses(), 0);
+ ASSERT_EQ(GetServerNumUnconsumedRequests(), n_fragment);
+
+ // Check each read segment.
+ struct fuse_in_header in_header_read;
+ struct fuse_read_in in_payload_read;
+ auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read);
+
+ for (int i = 0; i < n_fragment; ++i) {
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_payload_read.fh, test_fh_);
+ EXPECT_EQ(in_header_read.len,
+ sizeof(in_header_read) + sizeof(in_payload_read));
+ EXPECT_EQ(in_header_read.opcode, FUSE_READ);
+ EXPECT_EQ(in_payload_read.offset, i * size_fragment);
+ EXPECT_EQ(in_payload_read.size, size_fragment);
+
+ auto it = buf.begin() + i * size_fragment;
+ EXPECT_EQ(std::vector<char>(it, it + size_fragment), data);
+ }
+}
+
+TEST_F(ReadTestSmallMaxRead, ReadSmallMaxReadShort) {
+ const int n_fragment = 10;
+ const int n_read = size_fragment * n_fragment;
+
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, n_read));
+
+ // Prepare for the read.
+ std::vector<char> data(size_fragment);
+ RandomizeBuffer(data.data(), data.size());
+ struct fuse_out_header out_header_read = {
+ .len =
+ static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()),
+ };
+ auto iov_out_read = FuseGenerateIovecs(out_header_read, data);
+
+ for (int i = 0; i < n_fragment - 1; ++i) {
+ SetServerResponse(FUSE_READ, iov_out_read);
+ }
+
+ // The last fragment is a short read.
+ std::vector<char> half_data(data.begin(), data.begin() + (data.size() >> 1));
+ struct fuse_out_header out_header_read_short = {
+ .len = static_cast<uint32_t>(sizeof(struct fuse_out_header) +
+ half_data.size()),
+ };
+ auto iov_out_read_short =
+ FuseGenerateIovecs(out_header_read_short, half_data);
+ SetServerResponse(FUSE_READ, iov_out_read_short);
+
+ // Read the whole "file".
+ std::vector<char> buf(n_read);
+ EXPECT_THAT(read(fd.get(), buf.data(), n_read),
+ SyscallSucceedsWithValue(n_read - (data.size() >> 1)));
+
+ ASSERT_EQ(GetServerNumUnsentResponses(), 0);
+ ASSERT_EQ(GetServerNumUnconsumedRequests(), n_fragment);
+
+ // Check each read segment.
+ struct fuse_in_header in_header_read;
+ struct fuse_read_in in_payload_read;
+ auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read);
+
+ for (int i = 0; i < n_fragment; ++i) {
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_payload_read.fh, test_fh_);
+ EXPECT_EQ(in_header_read.len,
+ sizeof(in_header_read) + sizeof(in_payload_read));
+ EXPECT_EQ(in_header_read.opcode, FUSE_READ);
+ EXPECT_EQ(in_payload_read.offset, i * size_fragment);
+ EXPECT_EQ(in_payload_read.size, size_fragment);
+
+ auto it = buf.begin() + i * size_fragment;
+ if (i != n_fragment - 1) {
+ EXPECT_EQ(std::vector<char>(it, it + data.size()), data);
+ } else {
+ EXPECT_EQ(std::vector<char>(it, it + half_data.size()), half_data);
+ }
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/fuse/linux/readdir_test.cc b/test/fuse/linux/readdir_test.cc
new file mode 100644
index 000000000..2afb4b062
--- /dev/null
+++ b/test/fuse/linux/readdir_test.cc
@@ -0,0 +1,193 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <dirent.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <linux/unistd.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+
+#include "gtest/gtest.h"
+#include "test/fuse/linux/fuse_base.h"
+#include "test/util/fuse_util.h"
+#include "test/util/test_util.h"
+
+#define FUSE_NAME_OFFSET offsetof(struct fuse_dirent, name)
+#define FUSE_DIRENT_ALIGN(x) \
+ (((x) + sizeof(uint64_t) - 1) & ~(sizeof(uint64_t) - 1))
+#define FUSE_DIRENT_SIZE(d) FUSE_DIRENT_ALIGN(FUSE_NAME_OFFSET + (d)->namelen)
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class ReaddirTest : public FuseTest {
+ public:
+ void fill_fuse_dirent(char *buf, const char *name, uint64_t ino) {
+ size_t namelen = strlen(name);
+ size_t entlen = FUSE_NAME_OFFSET + namelen;
+ size_t entlen_padded = FUSE_DIRENT_ALIGN(entlen);
+ struct fuse_dirent *dirent;
+
+ dirent = reinterpret_cast<struct fuse_dirent *>(buf);
+ dirent->ino = ino;
+ dirent->namelen = namelen;
+ memcpy(dirent->name, name, namelen);
+ memset(dirent->name + namelen, 0, entlen_padded - entlen);
+ }
+
+ protected:
+ const std::string test_dir_name_ = "test_dir";
+};
+
+TEST_F(ReaddirTest, SingleEntry) {
+ const std::string test_dir_path =
+ JoinPath(mount_point_.path().c_str(), test_dir_name_);
+
+ const uint64_t ino_dir = 1024;
+ // We need to make sure the test dir is a directory that can be found.
+ mode_t expected_mode =
+ S_IFDIR | S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH;
+ struct fuse_attr dir_attr = {
+ .ino = ino_dir,
+ .size = 512,
+ .blocks = 4,
+ .mode = expected_mode,
+ .blksize = 4096,
+ };
+
+ // We need to make sure the test dir is a directory that can be found.
+ struct fuse_out_header lookup_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out),
+ };
+ struct fuse_entry_out lookup_payload = {
+ .nodeid = 1,
+ .entry_valid = true,
+ .attr_valid = true,
+ .attr = dir_attr,
+ };
+
+ struct fuse_out_header open_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out),
+ };
+ struct fuse_open_out open_payload = {
+ .fh = 1,
+ };
+ auto iov_out = FuseGenerateIovecs(lookup_header, lookup_payload);
+ SetServerResponse(FUSE_LOOKUP, iov_out);
+
+ iov_out = FuseGenerateIovecs(open_header, open_payload);
+ SetServerResponse(FUSE_OPENDIR, iov_out);
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_dir_path.c_str(), O_RDONLY));
+
+ // The open command makes two syscalls. Lookup the dir file and open.
+ // We don't need to inspect those headers in this test.
+ SkipServerActualRequest(); // LOOKUP.
+ SkipServerActualRequest(); // OPENDIR.
+
+ // Readdir test code.
+ std::string dot = ".";
+ std::string dot_dot = "..";
+ std::string test_file = "testFile";
+
+ // Figure out how many dirents to send over and allocate them appropriately.
+ // Each dirent has a dynamic name and a static metadata part. The dirent size
+ // is aligned to being a multiple of 8.
+ size_t dot_file_dirent_size =
+ FUSE_DIRENT_ALIGN(dot.length() + FUSE_NAME_OFFSET);
+ size_t dot_dot_file_dirent_size =
+ FUSE_DIRENT_ALIGN(dot_dot.length() + FUSE_NAME_OFFSET);
+ size_t test_file_dirent_size =
+ FUSE_DIRENT_ALIGN(test_file.length() + FUSE_NAME_OFFSET);
+
+ // Create an appropriately sized payload.
+ size_t readdir_payload_size =
+ test_file_dirent_size + dot_file_dirent_size + dot_dot_file_dirent_size;
+ std::vector<char> readdir_payload_vec(readdir_payload_size);
+ char *readdir_payload = readdir_payload_vec.data();
+
+ // Use fake ino for other directories.
+ fill_fuse_dirent(readdir_payload, dot.c_str(), ino_dir - 2);
+ fill_fuse_dirent(readdir_payload + dot_file_dirent_size, dot_dot.c_str(),
+ ino_dir - 1);
+ fill_fuse_dirent(
+ readdir_payload + dot_file_dirent_size + dot_dot_file_dirent_size,
+ test_file.c_str(), ino_dir);
+
+ struct fuse_out_header readdir_header = {
+ .len = uint32_t(sizeof(struct fuse_out_header) + readdir_payload_size),
+ };
+ struct fuse_out_header readdir_header_break = {
+ .len = uint32_t(sizeof(struct fuse_out_header)),
+ };
+
+ iov_out = FuseGenerateIovecs(readdir_header, readdir_payload_vec);
+ SetServerResponse(FUSE_READDIR, iov_out);
+
+ iov_out = FuseGenerateIovecs(readdir_header_break);
+ SetServerResponse(FUSE_READDIR, iov_out);
+
+ std::vector<char> buf(4090, 0);
+ int nread, off = 0, i = 0;
+ EXPECT_THAT(
+ nread = syscall(__NR_getdents64, fd.get(), buf.data(), buf.size()),
+ SyscallSucceeds());
+ for (; off < nread;) {
+ struct dirent64 *ent = (struct dirent64 *)(buf.data() + off);
+ off += ent->d_reclen;
+ switch (i++) {
+ case 0:
+ EXPECT_EQ(std::string(ent->d_name), dot);
+ break;
+ case 1:
+ EXPECT_EQ(std::string(ent->d_name), dot_dot);
+ break;
+ case 2:
+ EXPECT_EQ(std::string(ent->d_name), test_file);
+ break;
+ }
+ }
+
+ EXPECT_THAT(
+ nread = syscall(__NR_getdents64, fd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(0));
+
+ SkipServerActualRequest(); // READDIR.
+ SkipServerActualRequest(); // READDIR with no data.
+
+ // Clean up.
+ fd.reset(-1);
+
+ struct fuse_in_header in_header;
+ struct fuse_release_in in_payload;
+
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload));
+ EXPECT_EQ(in_header.opcode, FUSE_RELEASEDIR);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/fuse/linux/readlink_test.cc b/test/fuse/linux/readlink_test.cc
new file mode 100644
index 000000000..2cba8fc23
--- /dev/null
+++ b/test/fuse/linux/readlink_test.cc
@@ -0,0 +1,85 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/fuse/linux/fuse_base.h"
+#include "test/util/fuse_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class ReadlinkTest : public FuseTest {
+ protected:
+ const std::string test_file_ = "test_file_";
+ const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO;
+};
+
+TEST_F(ReadlinkTest, ReadSymLink) {
+ const std::string symlink_path =
+ JoinPath(mount_point_.path().c_str(), test_file_);
+ SetServerInodeLookup(test_file_, S_IFLNK | perms_);
+
+ struct fuse_out_header out_header = {
+ .len = static_cast<uint32_t>(sizeof(struct fuse_out_header)) +
+ static_cast<uint32_t>(test_file_.length()) + 1,
+ };
+ std::string link = test_file_;
+ auto iov_out = FuseGenerateIovecs(out_header, link);
+ SetServerResponse(FUSE_READLINK, iov_out);
+ const std::string actual_link =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadLink(symlink_path));
+
+ struct fuse_in_header in_header;
+ auto iov_in = FuseGenerateIovecs(in_header);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_header.len, sizeof(in_header));
+ EXPECT_EQ(in_header.opcode, FUSE_READLINK);
+ EXPECT_EQ(0, memcmp(actual_link.c_str(), link.data(), link.size()));
+
+ // next readlink should have link cached, so shouldn't have new request to
+ // server.
+ uint32_t recieved_before = GetServerTotalReceivedBytes();
+ ASSERT_NO_ERRNO(ReadLink(symlink_path));
+ EXPECT_EQ(GetServerTotalReceivedBytes(), recieved_before);
+}
+
+TEST_F(ReadlinkTest, NotSymlink) {
+ const std::string test_file_path =
+ JoinPath(mount_point_.path().c_str(), test_file_);
+ SetServerInodeLookup(test_file_, S_IFREG | perms_);
+
+ std::vector<char> buf(PATH_MAX + 1);
+ ASSERT_THAT(readlink(test_file_path.c_str(), buf.data(), PATH_MAX),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/fuse/linux/release_test.cc b/test/fuse/linux/release_test.cc
new file mode 100644
index 000000000..b5adb0870
--- /dev/null
+++ b/test/fuse/linux/release_test.cc
@@ -0,0 +1,74 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <sys/mount.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/fuse/linux/fuse_base.h"
+#include "test/util/fuse_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class ReleaseTest : public FuseTest {
+ protected:
+ const std::string test_file_ = "test_file";
+};
+
+TEST_F(ReleaseTest, RegularFile) {
+ const std::string test_file_path =
+ JoinPath(mount_point_.path().c_str(), test_file_);
+ SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO);
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out),
+ };
+ struct fuse_open_out out_payload = {
+ .fh = 1,
+ .open_flags = O_RDWR,
+ };
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_OPEN, iov_out);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path, O_RDWR));
+ SkipServerActualRequest();
+ ASSERT_THAT(close(fd.release()), SyscallSucceeds());
+
+ struct fuse_in_header in_header;
+ struct fuse_release_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload));
+ EXPECT_EQ(in_header.opcode, FUSE_RELEASE);
+ EXPECT_EQ(in_payload.flags, O_RDWR);
+ EXPECT_EQ(in_payload.fh, 1);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/fuse/linux/rmdir_test.cc b/test/fuse/linux/rmdir_test.cc
new file mode 100644
index 000000000..e3200e446
--- /dev/null
+++ b/test/fuse/linux/rmdir_test.cc
@@ -0,0 +1,100 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <sys/uio.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/fuse/linux/fuse_base.h"
+#include "test/util/fs_util.h"
+#include "test/util/fuse_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class RmDirTest : public FuseTest {
+ protected:
+ const std::string test_dir_name_ = "test_dir";
+ const std::string test_subdir_ = "test_subdir";
+ const mode_t test_dir_mode_ = S_IFDIR | S_IRWXU | S_IRWXG | S_IRWXO;
+};
+
+TEST_F(RmDirTest, NormalRmDir) {
+ const std::string test_dir_path_ =
+ JoinPath(mount_point_.path().c_str(), test_dir_name_);
+
+ SetServerInodeLookup(test_dir_name_, test_dir_mode_);
+
+ // RmDir code.
+ struct fuse_out_header rmdir_header = {
+ .len = sizeof(struct fuse_out_header),
+ };
+
+ auto iov_out = FuseGenerateIovecs(rmdir_header);
+ SetServerResponse(FUSE_RMDIR, iov_out);
+
+ ASSERT_THAT(rmdir(test_dir_path_.c_str()), SyscallSucceeds());
+
+ struct fuse_in_header in_header;
+ std::vector<char> actual_dirname(test_dir_name_.length() + 1);
+ auto iov_in = FuseGenerateIovecs(in_header, actual_dirname);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_header.len, sizeof(in_header) + test_dir_name_.length() + 1);
+ EXPECT_EQ(in_header.opcode, FUSE_RMDIR);
+ EXPECT_EQ(std::string(actual_dirname.data()), test_dir_name_);
+}
+
+TEST_F(RmDirTest, NormalRmDirSubdir) {
+ SetServerInodeLookup(test_subdir_, S_IFDIR | S_IRWXU | S_IRWXG | S_IRWXO);
+ const std::string test_dir_path_ =
+ JoinPath(mount_point_.path().c_str(), test_subdir_, test_dir_name_);
+ SetServerInodeLookup(test_dir_name_, test_dir_mode_);
+
+ // RmDir code.
+ struct fuse_out_header rmdir_header = {
+ .len = sizeof(struct fuse_out_header),
+ };
+
+ auto iov_out = FuseGenerateIovecs(rmdir_header);
+ SetServerResponse(FUSE_RMDIR, iov_out);
+
+ ASSERT_THAT(rmdir(test_dir_path_.c_str()), SyscallSucceeds());
+
+ struct fuse_in_header in_header;
+ std::vector<char> actual_dirname(test_dir_name_.length() + 1);
+ auto iov_in = FuseGenerateIovecs(in_header, actual_dirname);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_header.len, sizeof(in_header) + test_dir_name_.length() + 1);
+ EXPECT_EQ(in_header.opcode, FUSE_RMDIR);
+ EXPECT_EQ(std::string(actual_dirname.data()), test_dir_name_);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/fuse/linux/setstat_test.cc b/test/fuse/linux/setstat_test.cc
new file mode 100644
index 000000000..68301c775
--- /dev/null
+++ b/test/fuse/linux/setstat_test.cc
@@ -0,0 +1,338 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <sys/uio.h>
+#include <unistd.h>
+#include <utime.h>
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/fuse/linux/fuse_fd_util.h"
+#include "test/util/cleanup.h"
+#include "test/util/fs_util.h"
+#include "test/util/fuse_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class SetStatTest : public FuseFdTest {
+ public:
+ void SetUp() override {
+ FuseFdTest::SetUp();
+ test_dir_path_ = JoinPath(mount_point_.path(), test_dir_);
+ test_file_path_ = JoinPath(mount_point_.path(), test_file_);
+ }
+
+ protected:
+ const uint64_t fh = 23;
+ const std::string test_dir_ = "testdir";
+ const std::string test_file_ = "testfile";
+ const mode_t test_dir_mode_ = S_IFDIR | S_IRUSR | S_IWUSR | S_IXUSR;
+ const mode_t test_file_mode_ = S_IFREG | S_IRUSR | S_IWUSR | S_IXUSR;
+
+ std::string test_dir_path_;
+ std::string test_file_path_;
+};
+
+TEST_F(SetStatTest, ChmodDir) {
+ // Set up fixture.
+ SetServerInodeLookup(test_dir_, test_dir_mode_);
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out),
+ .error = 0,
+ };
+ mode_t set_mode = S_IRGRP | S_IWGRP | S_IXGRP;
+ struct fuse_attr_out out_payload = {
+ .attr = DefaultFuseAttr(set_mode, 2),
+ };
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_SETATTR, iov_out);
+
+ // Make syscall.
+ EXPECT_THAT(chmod(test_dir_path_.c_str(), set_mode), SyscallSucceeds());
+
+ // Check FUSE request.
+ struct fuse_in_header in_header;
+ struct fuse_setattr_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload));
+ EXPECT_EQ(in_header.opcode, FUSE_SETATTR);
+ EXPECT_EQ(in_header.uid, 0);
+ EXPECT_EQ(in_header.gid, 0);
+ EXPECT_EQ(in_payload.valid, FATTR_MODE);
+ EXPECT_EQ(in_payload.mode, S_IFDIR | set_mode);
+}
+
+TEST_F(SetStatTest, ChownDir) {
+ // Set up fixture.
+ SetServerInodeLookup(test_dir_, test_dir_mode_);
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out),
+ .error = 0,
+ };
+ struct fuse_attr_out out_payload = {
+ .attr = DefaultFuseAttr(test_dir_mode_, 2),
+ };
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_SETATTR, iov_out);
+
+ // Make syscall.
+ EXPECT_THAT(chown(test_dir_path_.c_str(), 1025, 1025), SyscallSucceeds());
+
+ // Check FUSE request.
+ struct fuse_in_header in_header;
+ struct fuse_setattr_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload));
+ EXPECT_EQ(in_header.opcode, FUSE_SETATTR);
+ EXPECT_EQ(in_header.uid, 0);
+ EXPECT_EQ(in_header.gid, 0);
+ EXPECT_EQ(in_payload.valid, FATTR_UID | FATTR_GID);
+ EXPECT_EQ(in_payload.uid, 1025);
+ EXPECT_EQ(in_payload.gid, 1025);
+}
+
+TEST_F(SetStatTest, TruncateFile) {
+ // Set up fixture.
+ SetServerInodeLookup(test_file_, test_file_mode_);
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out),
+ .error = 0,
+ };
+ struct fuse_attr_out out_payload = {
+ .attr = DefaultFuseAttr(S_IFREG | S_IRUSR | S_IWUSR, 2),
+ };
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_SETATTR, iov_out);
+
+ // Make syscall.
+ EXPECT_THAT(truncate(test_file_path_.c_str(), 321), SyscallSucceeds());
+
+ // Check FUSE request.
+ struct fuse_in_header in_header;
+ struct fuse_setattr_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload));
+ EXPECT_EQ(in_header.opcode, FUSE_SETATTR);
+ EXPECT_EQ(in_header.uid, 0);
+ EXPECT_EQ(in_header.gid, 0);
+ EXPECT_EQ(in_payload.valid, FATTR_SIZE);
+ EXPECT_EQ(in_payload.size, 321);
+}
+
+TEST_F(SetStatTest, UtimeFile) {
+ // Set up fixture.
+ SetServerInodeLookup(test_file_, test_file_mode_);
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out),
+ .error = 0,
+ };
+ struct fuse_attr_out out_payload = {
+ .attr = DefaultFuseAttr(S_IFREG | S_IRUSR | S_IWUSR, 2),
+ };
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_SETATTR, iov_out);
+
+ // Make syscall.
+ time_t expected_atime = 1597159766, expected_mtime = 1597159765;
+ struct utimbuf times = {
+ .actime = expected_atime,
+ .modtime = expected_mtime,
+ };
+ EXPECT_THAT(utime(test_file_path_.c_str(), &times), SyscallSucceeds());
+
+ // Check FUSE request.
+ struct fuse_in_header in_header;
+ struct fuse_setattr_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload));
+ EXPECT_EQ(in_header.opcode, FUSE_SETATTR);
+ EXPECT_EQ(in_header.uid, 0);
+ EXPECT_EQ(in_header.gid, 0);
+ EXPECT_EQ(in_payload.valid, FATTR_ATIME | FATTR_MTIME);
+ EXPECT_EQ(in_payload.atime, expected_atime);
+ EXPECT_EQ(in_payload.mtime, expected_mtime);
+}
+
+TEST_F(SetStatTest, UtimesFile) {
+ // Set up fixture.
+ SetServerInodeLookup(test_file_, test_file_mode_);
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out),
+ .error = 0,
+ };
+ struct fuse_attr_out out_payload = {
+ .attr = DefaultFuseAttr(test_file_mode_, 2),
+ };
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_SETATTR, iov_out);
+
+ // Make syscall.
+ struct timeval expected_times[2] = {
+ {
+ .tv_sec = 1597159766,
+ .tv_usec = 234945,
+ },
+ {
+ .tv_sec = 1597159765,
+ .tv_usec = 232341,
+ },
+ };
+ EXPECT_THAT(utimes(test_file_path_.c_str(), expected_times),
+ SyscallSucceeds());
+
+ // Check FUSE request.
+ struct fuse_in_header in_header;
+ struct fuse_setattr_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload));
+ EXPECT_EQ(in_header.opcode, FUSE_SETATTR);
+ EXPECT_EQ(in_header.uid, 0);
+ EXPECT_EQ(in_header.gid, 0);
+ EXPECT_EQ(in_payload.valid, FATTR_ATIME | FATTR_MTIME);
+ EXPECT_EQ(in_payload.atime, expected_times[0].tv_sec);
+ EXPECT_EQ(in_payload.atimensec, expected_times[0].tv_usec * 1000);
+ EXPECT_EQ(in_payload.mtime, expected_times[1].tv_sec);
+ EXPECT_EQ(in_payload.mtimensec, expected_times[1].tv_usec * 1000);
+}
+
+TEST_F(SetStatTest, FtruncateFile) {
+ // Set up fixture.
+ SetServerInodeLookup(test_file_, test_file_mode_);
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDWR, fh));
+ auto close_fd = CloseFD(fd);
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out),
+ .error = 0,
+ };
+ struct fuse_attr_out out_payload = {
+ .attr = DefaultFuseAttr(test_file_mode_, 2),
+ };
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_SETATTR, iov_out);
+
+ // Make syscall.
+ EXPECT_THAT(ftruncate(fd.get(), 321), SyscallSucceeds());
+
+ // Check FUSE request.
+ struct fuse_in_header in_header;
+ struct fuse_setattr_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload));
+ EXPECT_EQ(in_header.opcode, FUSE_SETATTR);
+ EXPECT_EQ(in_header.uid, 0);
+ EXPECT_EQ(in_header.gid, 0);
+ EXPECT_EQ(in_payload.valid, FATTR_SIZE | FATTR_FH);
+ EXPECT_EQ(in_payload.fh, fh);
+ EXPECT_EQ(in_payload.size, 321);
+}
+
+TEST_F(SetStatTest, FchmodFile) {
+ // Set up fixture.
+ SetServerInodeLookup(test_file_, test_file_mode_);
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDWR, fh));
+ auto close_fd = CloseFD(fd);
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out),
+ .error = 0,
+ };
+ mode_t set_mode = S_IROTH | S_IWOTH | S_IXOTH;
+ struct fuse_attr_out out_payload = {
+ .attr = DefaultFuseAttr(set_mode, 2),
+ };
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_SETATTR, iov_out);
+
+ // Make syscall.
+ EXPECT_THAT(fchmod(fd.get(), set_mode), SyscallSucceeds());
+
+ // Check FUSE request.
+ struct fuse_in_header in_header;
+ struct fuse_setattr_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload));
+ EXPECT_EQ(in_header.opcode, FUSE_SETATTR);
+ EXPECT_EQ(in_header.uid, 0);
+ EXPECT_EQ(in_header.gid, 0);
+ EXPECT_EQ(in_payload.valid, FATTR_MODE | FATTR_FH);
+ EXPECT_EQ(in_payload.fh, fh);
+ EXPECT_EQ(in_payload.mode, S_IFREG | set_mode);
+}
+
+TEST_F(SetStatTest, FchownFile) {
+ // Set up fixture.
+ SetServerInodeLookup(test_file_, test_file_mode_);
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDWR, fh));
+ auto close_fd = CloseFD(fd);
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out),
+ .error = 0,
+ };
+ struct fuse_attr_out out_payload = {
+ .attr = DefaultFuseAttr(S_IFREG | S_IRUSR | S_IWUSR | S_IXUSR, 2),
+ };
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_SETATTR, iov_out);
+
+ // Make syscall.
+ EXPECT_THAT(fchown(fd.get(), 1025, 1025), SyscallSucceeds());
+
+ // Check FUSE request.
+ struct fuse_in_header in_header;
+ struct fuse_setattr_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload));
+ EXPECT_EQ(in_header.opcode, FUSE_SETATTR);
+ EXPECT_EQ(in_header.uid, 0);
+ EXPECT_EQ(in_header.gid, 0);
+ EXPECT_EQ(in_payload.valid, FATTR_UID | FATTR_GID | FATTR_FH);
+ EXPECT_EQ(in_payload.fh, fh);
+ EXPECT_EQ(in_payload.uid, 1025);
+ EXPECT_EQ(in_payload.gid, 1025);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/fuse/linux/stat_test.cc b/test/fuse/linux/stat_test.cc
index 172e09867..6f032cac1 100644
--- a/test/fuse/linux/stat_test.cc
+++ b/test/fuse/linux/stat_test.cc
@@ -18,12 +18,16 @@
#include <sys/stat.h>
#include <sys/statfs.h>
#include <sys/types.h>
+#include <sys/uio.h>
#include <unistd.h>
#include <vector>
#include "gtest/gtest.h"
-#include "test/fuse/linux/fuse_base.h"
+#include "test/fuse/linux/fuse_fd_util.h"
+#include "test/util/cleanup.h"
+#include "test/util/fs_util.h"
+#include "test/util/fuse_util.h"
#include "test/util/test_util.h"
namespace gvisor {
@@ -31,89 +35,45 @@ namespace testing {
namespace {
-class StatTest : public FuseTest {
+class StatTest : public FuseFdTest {
public:
- bool CompareRequest(void* expected_mem, size_t expected_len, void* real_mem,
- size_t real_len) override {
- if (expected_len != real_len) return false;
- struct fuse_in_header* real_header =
- reinterpret_cast<fuse_in_header*>(real_mem);
-
- if (real_header->opcode != FUSE_GETATTR) {
- std::cerr << "expect header opcode " << FUSE_GETATTR << " but got "
- << real_header->opcode << std::endl;
- return false;
- }
- return true;
+ void SetUp() override {
+ FuseFdTest::SetUp();
+ test_file_path_ = JoinPath(mount_point_.path(), test_file_);
}
+ protected:
bool StatsAreEqual(struct stat expected, struct stat actual) {
- // device number will be dynamically allocated by kernel, we cannot know
- // in advance
+ // Device number will be dynamically allocated by kernel, we cannot know in
+ // advance.
actual.st_dev = expected.st_dev;
return memcmp(&expected, &actual, sizeof(struct stat)) == 0;
}
+
+ const std::string test_file_ = "testfile";
+ const mode_t expected_mode = S_IFREG | S_IRUSR | S_IWUSR;
+ const uint64_t fh = 23;
+
+ std::string test_file_path_;
};
TEST_F(StatTest, StatNormal) {
- struct iovec iov_in[2];
- struct iovec iov_out[2];
-
- struct fuse_in_header in_header = {
- .len = sizeof(struct fuse_in_header) + sizeof(struct fuse_getattr_in),
- .opcode = FUSE_GETATTR,
- .unique = 4,
- .nodeid = 1,
- .uid = 0,
- .gid = 0,
- .pid = 4,
- .padding = 0,
- };
- struct fuse_getattr_in in_payload = {0};
- iov_in[0].iov_len = sizeof(in_header);
- iov_in[0].iov_base = &in_header;
- iov_in[1].iov_len = sizeof(in_payload);
- iov_in[1].iov_base = &in_payload;
-
- mode_t expected_mode = S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH;
- struct timespec atime = {.tv_sec = 1595436289, .tv_nsec = 134150844};
- struct timespec mtime = {.tv_sec = 1595436290, .tv_nsec = 134150845};
- struct timespec ctime = {.tv_sec = 1595436291, .tv_nsec = 134150846};
+ // Set up fixture.
+ struct fuse_attr attr = DefaultFuseAttr(expected_mode, 1);
struct fuse_out_header out_header = {
.len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out),
- .error = 0,
- .unique = 4,
- };
- struct fuse_attr attr = {
- .ino = 1,
- .size = 512,
- .blocks = 4,
- .atime = static_cast<uint64_t>(atime.tv_sec),
- .mtime = static_cast<uint64_t>(mtime.tv_sec),
- .ctime = static_cast<uint64_t>(ctime.tv_sec),
- .atimensec = static_cast<uint32_t>(atime.tv_nsec),
- .mtimensec = static_cast<uint32_t>(mtime.tv_nsec),
- .ctimensec = static_cast<uint32_t>(ctime.tv_nsec),
- .mode = expected_mode,
- .nlink = 2,
- .uid = 1234,
- .gid = 4321,
- .rdev = 12,
- .blksize = 4096,
};
struct fuse_attr_out out_payload = {
.attr = attr,
};
- iov_out[0].iov_len = sizeof(out_header);
- iov_out[0].iov_base = &out_header;
- iov_out[1].iov_len = sizeof(out_payload);
- iov_out[1].iov_base = &out_payload;
-
- SetExpected(iov_in, 2, iov_out, 2);
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_GETATTR, iov_out);
+ // Make syscall.
struct stat stat_buf;
EXPECT_THAT(stat(mount_point_.path().c_str(), &stat_buf), SyscallSucceeds());
+ // Check filesystem operation result.
struct stat expected_stat = {
.st_ino = attr.ino,
.st_nlink = attr.nlink,
@@ -124,43 +84,133 @@ TEST_F(StatTest, StatNormal) {
.st_size = static_cast<off_t>(attr.size),
.st_blksize = attr.blksize,
.st_blocks = static_cast<blkcnt_t>(attr.blocks),
- .st_atim = atime,
- .st_mtim = mtime,
- .st_ctim = ctime,
+ .st_atim = (struct timespec){.tv_sec = static_cast<int>(attr.atime),
+ .tv_nsec = attr.atimensec},
+ .st_mtim = (struct timespec){.tv_sec = static_cast<int>(attr.mtime),
+ .tv_nsec = attr.mtimensec},
+ .st_ctim = (struct timespec){.tv_sec = static_cast<int>(attr.ctime),
+ .tv_nsec = attr.ctimensec},
};
EXPECT_TRUE(StatsAreEqual(stat_buf, expected_stat));
- WaitCompleted();
-}
-TEST_F(StatTest, StatNotFound) {
- struct iovec iov_in[2];
- struct iovec iov_out[2];
+ // Check FUSE request.
+ struct fuse_in_header in_header;
+ struct fuse_getattr_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
- struct fuse_in_header in_header = {
- .len = sizeof(struct fuse_in_header) + sizeof(struct fuse_getattr_in),
- .opcode = FUSE_GETATTR,
- .unique = 4,
- };
- struct fuse_getattr_in in_payload = {0};
- iov_in[0].iov_len = sizeof(in_header);
- iov_in[0].iov_base = &in_header;
- iov_in[1].iov_len = sizeof(in_payload);
- iov_in[1].iov_base = &in_payload;
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.opcode, FUSE_GETATTR);
+ EXPECT_EQ(in_payload.getattr_flags, 0);
+ EXPECT_EQ(in_payload.fh, 0);
+}
+TEST_F(StatTest, StatNotFound) {
+ // Set up fixture.
struct fuse_out_header out_header = {
.len = sizeof(struct fuse_out_header),
.error = -ENOENT,
- .unique = 4,
};
- iov_out[0].iov_len = sizeof(out_header);
- iov_out[0].iov_base = &out_header;
-
- SetExpected(iov_in, 2, iov_out, 1);
+ auto iov_out = FuseGenerateIovecs(out_header);
+ SetServerResponse(FUSE_GETATTR, iov_out);
+ // Make syscall.
struct stat stat_buf;
EXPECT_THAT(stat(mount_point_.path().c_str(), &stat_buf),
SyscallFailsWithErrno(ENOENT));
- WaitCompleted();
+
+ // Check FUSE request.
+ struct fuse_in_header in_header;
+ struct fuse_getattr_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.opcode, FUSE_GETATTR);
+ EXPECT_EQ(in_payload.getattr_flags, 0);
+ EXPECT_EQ(in_payload.fh, 0);
+}
+
+TEST_F(StatTest, FstatNormal) {
+ // Set up fixture.
+ SetServerInodeLookup(test_file_);
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDONLY, fh));
+ auto close_fd = CloseFD(fd);
+
+ struct fuse_attr attr = DefaultFuseAttr(expected_mode, 2);
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out),
+ };
+ struct fuse_attr_out out_payload = {
+ .attr = attr,
+ };
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_GETATTR, iov_out);
+
+ // Make syscall.
+ struct stat stat_buf;
+ EXPECT_THAT(fstat(fd.get(), &stat_buf), SyscallSucceeds());
+
+ // Check filesystem operation result.
+ struct stat expected_stat = {
+ .st_ino = attr.ino,
+ .st_nlink = attr.nlink,
+ .st_mode = expected_mode,
+ .st_uid = attr.uid,
+ .st_gid = attr.gid,
+ .st_rdev = attr.rdev,
+ .st_size = static_cast<off_t>(attr.size),
+ .st_blksize = attr.blksize,
+ .st_blocks = static_cast<blkcnt_t>(attr.blocks),
+ .st_atim = (struct timespec){.tv_sec = static_cast<int>(attr.atime),
+ .tv_nsec = attr.atimensec},
+ .st_mtim = (struct timespec){.tv_sec = static_cast<int>(attr.mtime),
+ .tv_nsec = attr.mtimensec},
+ .st_ctim = (struct timespec){.tv_sec = static_cast<int>(attr.ctime),
+ .tv_nsec = attr.ctimensec},
+ };
+ EXPECT_TRUE(StatsAreEqual(stat_buf, expected_stat));
+
+ // Check FUSE request.
+ struct fuse_in_header in_header;
+ struct fuse_getattr_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.opcode, FUSE_GETATTR);
+ EXPECT_EQ(in_payload.getattr_flags, 0);
+ EXPECT_EQ(in_payload.fh, 0);
+}
+
+TEST_F(StatTest, StatByFileHandle) {
+ // Set up fixture.
+ SetServerInodeLookup(test_file_, expected_mode, 0);
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDONLY, fh));
+ auto close_fd = CloseFD(fd);
+
+ struct fuse_attr attr = DefaultFuseAttr(expected_mode, 2, 0);
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out),
+ };
+ struct fuse_attr_out out_payload = {
+ .attr = attr,
+ };
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_GETATTR, iov_out);
+
+ // Make syscall.
+ std::vector<char> buf(1);
+ // Since this is an empty file, it won't issue FUSE_READ. But a FUSE_GETATTR
+ // will be issued before read completes.
+ EXPECT_THAT(read(fd.get(), buf.data(), buf.size()), SyscallSucceeds());
+
+ // Check FUSE request.
+ struct fuse_in_header in_header;
+ struct fuse_getattr_in in_payload;
+ auto iov_in = FuseGenerateIovecs(in_header, in_payload);
+
+ GetServerActualRequest(iov_in);
+ EXPECT_EQ(in_header.opcode, FUSE_GETATTR);
+ EXPECT_EQ(in_payload.getattr_flags, FUSE_GETATTR_FH);
+ EXPECT_EQ(in_payload.fh, fh);
}
} // namespace
diff --git a/test/fuse/linux/symlink_test.cc b/test/fuse/linux/symlink_test.cc
new file mode 100644
index 000000000..2c3a52987
--- /dev/null
+++ b/test/fuse/linux/symlink_test.cc
@@ -0,0 +1,88 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/fuse/linux/fuse_base.h"
+#include "test/util/fuse_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class SymlinkTest : public FuseTest {
+ protected:
+ const std::string target_file_ = "target_file_";
+ const std::string symlink_ = "symlink_";
+ const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO;
+};
+
+TEST_F(SymlinkTest, CreateSymLink) {
+ const std::string symlink_path =
+ JoinPath(mount_point_.path().c_str(), symlink_);
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out),
+ };
+ struct fuse_entry_out out_payload = DefaultEntryOut(S_IFLNK | perms_, 5);
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_SYMLINK, iov_out);
+ ASSERT_THAT(symlink(target_file_.c_str(), symlink_path.c_str()),
+ SyscallSucceeds());
+
+ struct fuse_in_header in_header;
+ std::vector<char> actual_target_file(target_file_.length() + 1);
+ std::vector<char> actual_symlink(symlink_.length() + 1);
+ auto iov_in =
+ FuseGenerateIovecs(in_header, actual_symlink, actual_target_file);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_header.len,
+ sizeof(in_header) + symlink_.length() + target_file_.length() + 2);
+ EXPECT_EQ(in_header.opcode, FUSE_SYMLINK);
+ EXPECT_EQ(std::string(actual_target_file.data()), target_file_);
+ EXPECT_EQ(std::string(actual_symlink.data()), symlink_);
+}
+
+TEST_F(SymlinkTest, FileTypeError) {
+ const std::string symlink_path =
+ JoinPath(mount_point_.path().c_str(), symlink_);
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out),
+ };
+ struct fuse_entry_out out_payload = DefaultEntryOut(S_IFREG | perms_, 5);
+ auto iov_out = FuseGenerateIovecs(out_header, out_payload);
+ SetServerResponse(FUSE_SYMLINK, iov_out);
+ ASSERT_THAT(symlink(target_file_.c_str(), symlink_path.c_str()),
+ SyscallFailsWithErrno(EIO));
+ SkipServerActualRequest();
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/fuse/linux/unlink_test.cc b/test/fuse/linux/unlink_test.cc
new file mode 100644
index 000000000..13efbf7c7
--- /dev/null
+++ b/test/fuse/linux/unlink_test.cc
@@ -0,0 +1,107 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <sys/mount.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/fuse/linux/fuse_base.h"
+#include "test/util/fuse_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class UnlinkTest : public FuseTest {
+ protected:
+ const std::string test_file_ = "test_file";
+ const std::string test_subdir_ = "test_subdir";
+};
+
+TEST_F(UnlinkTest, RegularFile) {
+ const std::string test_file_path =
+ JoinPath(mount_point_.path().c_str(), test_file_);
+ SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO);
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header),
+ };
+ auto iov_out = FuseGenerateIovecs(out_header);
+ SetServerResponse(FUSE_UNLINK, iov_out);
+
+ ASSERT_THAT(unlink(test_file_path.c_str()), SyscallSucceeds());
+ struct fuse_in_header in_header;
+ std::vector<char> unlinked_file(test_file_.length() + 1);
+ auto iov_in = FuseGenerateIovecs(in_header, unlinked_file);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_header.len, sizeof(in_header) + test_file_.length() + 1);
+ EXPECT_EQ(in_header.opcode, FUSE_UNLINK);
+ EXPECT_EQ(std::string(unlinked_file.data()), test_file_);
+}
+
+TEST_F(UnlinkTest, RegularFileSubDir) {
+ SetServerInodeLookup(test_subdir_, S_IFDIR | S_IRWXU | S_IRWXG | S_IRWXO);
+ const std::string test_file_path =
+ JoinPath(mount_point_.path().c_str(), test_subdir_, test_file_);
+ SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO);
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header),
+ };
+ auto iov_out = FuseGenerateIovecs(out_header);
+ SetServerResponse(FUSE_UNLINK, iov_out);
+
+ ASSERT_THAT(unlink(test_file_path.c_str()), SyscallSucceeds());
+ struct fuse_in_header in_header;
+ std::vector<char> unlinked_file(test_file_.length() + 1);
+ auto iov_in = FuseGenerateIovecs(in_header, unlinked_file);
+ GetServerActualRequest(iov_in);
+
+ EXPECT_EQ(in_header.len, sizeof(in_header) + test_file_.length() + 1);
+ EXPECT_EQ(in_header.opcode, FUSE_UNLINK);
+ EXPECT_EQ(std::string(unlinked_file.data()), test_file_);
+}
+
+TEST_F(UnlinkTest, NoFile) {
+ const std::string test_file_path =
+ JoinPath(mount_point_.path().c_str(), test_file_);
+ SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO);
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header),
+ .error = -ENOENT,
+ };
+ auto iov_out = FuseGenerateIovecs(out_header);
+ SetServerResponse(FUSE_UNLINK, iov_out);
+
+ ASSERT_THAT(unlink(test_file_path.c_str()), SyscallFailsWithErrno(ENOENT));
+ SkipServerActualRequest();
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/fuse/linux/write_test.cc b/test/fuse/linux/write_test.cc
new file mode 100644
index 000000000..1a62beb96
--- /dev/null
+++ b/test/fuse/linux/write_test.cc
@@ -0,0 +1,303 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/fuse/linux/fuse_base.h"
+#include "test/util/fuse_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class WriteTest : public FuseTest {
+ void SetUp() override {
+ FuseTest::SetUp();
+ test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_);
+ }
+
+ // TearDown overrides the parent's function
+ // to skip checking the unconsumed release request at the end.
+ void TearDown() override { UnmountFuse(); }
+
+ protected:
+ const std::string test_file_ = "test_file";
+ const mode_t test_file_mode_ = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO;
+ const uint64_t test_fh_ = 1;
+ const uint32_t open_flag_ = O_RDWR;
+
+ std::string test_file_path_;
+
+ PosixErrorOr<FileDescriptor> OpenTestFile(const std::string &path,
+ uint64_t size = 512) {
+ SetServerInodeLookup(test_file_, test_file_mode_, size);
+
+ struct fuse_out_header out_header_open = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out),
+ };
+ struct fuse_open_out out_payload_open = {
+ .fh = test_fh_,
+ .open_flags = open_flag_,
+ };
+ auto iov_out_open = FuseGenerateIovecs(out_header_open, out_payload_open);
+ SetServerResponse(FUSE_OPEN, iov_out_open);
+
+ auto res = Open(path.c_str(), open_flag_);
+ if (res.ok()) {
+ SkipServerActualRequest();
+ }
+ return res;
+ }
+};
+
+class WriteTestSmallMaxWrite : public WriteTest {
+ void SetUp() override {
+ MountFuse();
+ SetUpFuseServer(&fuse_init_payload);
+ test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_);
+ }
+
+ protected:
+ const static uint32_t max_write_ = 4096;
+ constexpr static struct fuse_init_out fuse_init_payload = {
+ .major = 7,
+ .max_write = max_write_,
+ };
+
+ const uint32_t size_fragment = max_write_;
+};
+
+TEST_F(WriteTest, WriteNormal) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_));
+
+ // Prepare for the write.
+ const int n_write = 10;
+ struct fuse_out_header out_header_write = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out),
+ };
+ struct fuse_write_out out_payload_write = {
+ .size = n_write,
+ };
+ auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write);
+ SetServerResponse(FUSE_WRITE, iov_out_write);
+
+ // Issue the write.
+ std::vector<char> buf(n_write);
+ RandomizeBuffer(buf.data(), buf.size());
+ EXPECT_THAT(write(fd.get(), buf.data(), n_write),
+ SyscallSucceedsWithValue(n_write));
+
+ // Check the write request.
+ struct fuse_in_header in_header_write;
+ struct fuse_write_in in_payload_write;
+ std::vector<char> payload_buf(n_write);
+ auto iov_in_write =
+ FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf);
+ GetServerActualRequest(iov_in_write);
+
+ EXPECT_EQ(in_payload_write.fh, test_fh_);
+ EXPECT_EQ(in_header_write.len,
+ sizeof(in_header_write) + sizeof(in_payload_write));
+ EXPECT_EQ(in_header_write.opcode, FUSE_WRITE);
+ EXPECT_EQ(in_payload_write.offset, 0);
+ EXPECT_EQ(in_payload_write.size, n_write);
+ EXPECT_EQ(buf, payload_buf);
+}
+
+TEST_F(WriteTest, WriteShort) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_));
+
+ // Prepare for the write.
+ const int n_write = 10, n_written = 5;
+ struct fuse_out_header out_header_write = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out),
+ };
+ struct fuse_write_out out_payload_write = {
+ .size = n_written,
+ };
+ auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write);
+ SetServerResponse(FUSE_WRITE, iov_out_write);
+
+ // Issue the write.
+ std::vector<char> buf(n_write);
+ RandomizeBuffer(buf.data(), buf.size());
+ EXPECT_THAT(write(fd.get(), buf.data(), n_write),
+ SyscallSucceedsWithValue(n_written));
+
+ // Check the write request.
+ struct fuse_in_header in_header_write;
+ struct fuse_write_in in_payload_write;
+ std::vector<char> payload_buf(n_write);
+ auto iov_in_write =
+ FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf);
+ GetServerActualRequest(iov_in_write);
+
+ EXPECT_EQ(in_payload_write.fh, test_fh_);
+ EXPECT_EQ(in_header_write.len,
+ sizeof(in_header_write) + sizeof(in_payload_write));
+ EXPECT_EQ(in_header_write.opcode, FUSE_WRITE);
+ EXPECT_EQ(in_payload_write.offset, 0);
+ EXPECT_EQ(in_payload_write.size, n_write);
+ EXPECT_EQ(buf, payload_buf);
+}
+
+TEST_F(WriteTest, WriteShortZero) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_));
+
+ // Prepare for the write.
+ const int n_write = 10;
+ struct fuse_out_header out_header_write = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out),
+ };
+ struct fuse_write_out out_payload_write = {
+ .size = 0,
+ };
+ auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write);
+ SetServerResponse(FUSE_WRITE, iov_out_write);
+
+ // Issue the write.
+ std::vector<char> buf(n_write);
+ RandomizeBuffer(buf.data(), buf.size());
+ EXPECT_THAT(write(fd.get(), buf.data(), n_write), SyscallFailsWithErrno(EIO));
+
+ // Check the write request.
+ struct fuse_in_header in_header_write;
+ struct fuse_write_in in_payload_write;
+ std::vector<char> payload_buf(n_write);
+ auto iov_in_write =
+ FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf);
+ GetServerActualRequest(iov_in_write);
+
+ EXPECT_EQ(in_payload_write.fh, test_fh_);
+ EXPECT_EQ(in_header_write.len,
+ sizeof(in_header_write) + sizeof(in_payload_write));
+ EXPECT_EQ(in_header_write.opcode, FUSE_WRITE);
+ EXPECT_EQ(in_payload_write.offset, 0);
+ EXPECT_EQ(in_payload_write.size, n_write);
+ EXPECT_EQ(buf, payload_buf);
+}
+
+TEST_F(WriteTest, WriteZero) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_));
+
+ // Issue the write.
+ std::vector<char> buf(0);
+ EXPECT_THAT(write(fd.get(), buf.data(), 0), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(WriteTest, PWrite) {
+ const int file_size = 512;
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, file_size));
+
+ // Prepare for the write.
+ const int n_write = 10;
+ struct fuse_out_header out_header_write = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out),
+ };
+ struct fuse_write_out out_payload_write = {
+ .size = n_write,
+ };
+ auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write);
+ SetServerResponse(FUSE_WRITE, iov_out_write);
+
+ // Issue the write.
+ std::vector<char> buf(n_write);
+ RandomizeBuffer(buf.data(), buf.size());
+ const int offset_write = file_size >> 1;
+ EXPECT_THAT(pwrite(fd.get(), buf.data(), n_write, offset_write),
+ SyscallSucceedsWithValue(n_write));
+
+ // Check the write request.
+ struct fuse_in_header in_header_write;
+ struct fuse_write_in in_payload_write;
+ std::vector<char> payload_buf(n_write);
+ auto iov_in_write =
+ FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf);
+ GetServerActualRequest(iov_in_write);
+
+ EXPECT_EQ(in_payload_write.fh, test_fh_);
+ EXPECT_EQ(in_header_write.len,
+ sizeof(in_header_write) + sizeof(in_payload_write));
+ EXPECT_EQ(in_header_write.opcode, FUSE_WRITE);
+ EXPECT_EQ(in_payload_write.offset, offset_write);
+ EXPECT_EQ(in_payload_write.size, n_write);
+ EXPECT_EQ(buf, payload_buf);
+}
+
+TEST_F(WriteTestSmallMaxWrite, WriteSmallMaxWrie) {
+ const int n_fragment = 10;
+ const int n_write = size_fragment * n_fragment;
+
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, n_write));
+
+ // Prepare for the write.
+ struct fuse_out_header out_header_write = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out),
+ };
+ struct fuse_write_out out_payload_write = {
+ .size = size_fragment,
+ };
+ auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write);
+
+ for (int i = 0; i < n_fragment; ++i) {
+ SetServerResponse(FUSE_WRITE, iov_out_write);
+ }
+
+ // Issue the write.
+ std::vector<char> buf(n_write);
+ RandomizeBuffer(buf.data(), buf.size());
+ EXPECT_THAT(write(fd.get(), buf.data(), n_write),
+ SyscallSucceedsWithValue(n_write));
+
+ ASSERT_EQ(GetServerNumUnsentResponses(), 0);
+ ASSERT_EQ(GetServerNumUnconsumedRequests(), n_fragment);
+
+ // Check the write request.
+ struct fuse_in_header in_header_write;
+ struct fuse_write_in in_payload_write;
+ std::vector<char> payload_buf(size_fragment);
+ auto iov_in_write =
+ FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf);
+
+ for (int i = 0; i < n_fragment; ++i) {
+ GetServerActualRequest(iov_in_write);
+
+ EXPECT_EQ(in_payload_write.fh, test_fh_);
+ EXPECT_EQ(in_header_write.len,
+ sizeof(in_header_write) + sizeof(in_payload_write));
+ EXPECT_EQ(in_header_write.opcode, FUSE_WRITE);
+ EXPECT_EQ(in_payload_write.offset, i * size_fragment);
+ EXPECT_EQ(in_payload_write.size, size_fragment);
+
+ auto it = buf.begin() + i * size_fragment;
+ EXPECT_EQ(std::vector<char>(it, it + size_fragment), payload_buf);
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/image/image_test.go b/test/image/image_test.go
index ac6186688..968e62f63 100644
--- a/test/image/image_test.go
+++ b/test/image/image_test.go
@@ -63,8 +63,8 @@ func TestHelloWorld(t *testing.T) {
}
}
-func runHTTPRequest(port int) error {
- url := fmt.Sprintf("http://localhost:%d/not-found", port)
+func runHTTPRequest(ip string, port int) error {
+ url := fmt.Sprintf("http://%s:%d/not-found", ip, port)
resp, err := http.Get(url)
if err != nil {
return fmt.Errorf("error reaching http server: %v", err)
@@ -73,7 +73,7 @@ func runHTTPRequest(port int) error {
return fmt.Errorf("Wrong response code, got: %d, want: %d", resp.StatusCode, want)
}
- url = fmt.Sprintf("http://localhost:%d/latin10k.txt", port)
+ url = fmt.Sprintf("http://%s:%d/latin10k.txt", ip, port)
resp, err = http.Get(url)
if err != nil {
return fmt.Errorf("Error reaching http server: %v", err)
@@ -95,13 +95,13 @@ func runHTTPRequest(port int) error {
return nil
}
-func testHTTPServer(t *testing.T, port int) {
+func testHTTPServer(t *testing.T, ip string, port int) {
const requests = 10
ch := make(chan error, requests)
for i := 0; i < requests; i++ {
go func() {
start := time.Now()
- err := runHTTPRequest(port)
+ err := runHTTPRequest(ip, port)
log.Printf("Response time %v: %v", time.Since(start).String(), err)
ch <- err
}()
@@ -110,7 +110,7 @@ func testHTTPServer(t *testing.T, port int) {
for i := 0; i < requests; i++ {
err := <-ch
if err != nil {
- t.Errorf("testHTTPServer(%d) failed: %v", port, err)
+ t.Errorf("testHTTPServer(%s, %d) failed: %v", ip, port, err)
}
}
}
@@ -121,27 +121,28 @@ func TestHttpd(t *testing.T) {
defer d.CleanUp(ctx)
// Start the container.
+ port := 80
opts := dockerutil.RunOpts{
Image: "basic/httpd",
- Ports: []int{80},
+ Ports: []int{port},
}
d.CopyFiles(&opts, "/usr/local/apache2/htdocs", "test/image/latin10k.txt")
if err := d.Spawn(ctx, opts); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- // Find where port 80 is mapped to.
- port, err := d.FindPort(ctx, 80)
+ // Find container IP address.
+ ip, err := d.FindIP(ctx, false)
if err != nil {
- t.Fatalf("FindPort(80) failed: %v", err)
+ t.Fatalf("docker.FindIP failed: %v", err)
}
// Wait until it's up and running.
- if err := testutil.WaitForHTTP(port, defaultWait); err != nil {
+ if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil {
t.Errorf("WaitForHTTP() timeout: %v", err)
}
- testHTTPServer(t, port)
+ testHTTPServer(t, ip.String(), port)
}
func TestNginx(t *testing.T) {
@@ -150,27 +151,28 @@ func TestNginx(t *testing.T) {
defer d.CleanUp(ctx)
// Start the container.
+ port := 80
opts := dockerutil.RunOpts{
Image: "basic/nginx",
- Ports: []int{80},
+ Ports: []int{port},
}
d.CopyFiles(&opts, "/usr/share/nginx/html", "test/image/latin10k.txt")
if err := d.Spawn(ctx, opts); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- // Find where port 80 is mapped to.
- port, err := d.FindPort(ctx, 80)
+ // Find container IP address.
+ ip, err := d.FindIP(ctx, false)
if err != nil {
- t.Fatalf("FindPort(80) failed: %v", err)
+ t.Fatalf("docker.FindIP failed: %v", err)
}
// Wait until it's up and running.
- if err := testutil.WaitForHTTP(port, defaultWait); err != nil {
+ if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil {
t.Errorf("WaitForHTTP() timeout: %v", err)
}
- testHTTPServer(t, port)
+ testHTTPServer(t, ip.String(), port)
}
func TestMysql(t *testing.T) {
@@ -218,26 +220,27 @@ func TestTomcat(t *testing.T) {
defer d.CleanUp(ctx)
// Start the server.
+ port := 8080
if err := d.Spawn(ctx, dockerutil.RunOpts{
Image: "basic/tomcat",
- Ports: []int{8080},
+ Ports: []int{port},
}); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- // Find where port 8080 is mapped to.
- port, err := d.FindPort(ctx, 8080)
+ // Find container IP address.
+ ip, err := d.FindIP(ctx, false)
if err != nil {
- t.Fatalf("FindPort(8080) failed: %v", err)
+ t.Fatalf("docker.FindIP failed: %v", err)
}
// Wait until it's up and running.
- if err := testutil.WaitForHTTP(port, defaultWait); err != nil {
+ if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil {
t.Fatalf("WaitForHTTP() timeout: %v", err)
}
// Ensure that content is being served.
- url := fmt.Sprintf("http://localhost:%d", port)
+ url := fmt.Sprintf("http://%s:%d", ip.String(), port)
resp, err := http.Get(url)
if err != nil {
t.Errorf("Error reaching http server: %v", err)
@@ -253,28 +256,29 @@ func TestRuby(t *testing.T) {
defer d.CleanUp(ctx)
// Execute the ruby workload.
+ port := 8080
opts := dockerutil.RunOpts{
Image: "basic/ruby",
- Ports: []int{8080},
+ Ports: []int{port},
}
d.CopyFiles(&opts, "/src", "test/image/ruby.rb", "test/image/ruby.sh")
if err := d.Spawn(ctx, opts, "/src/ruby.sh"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- // Find where port 8080 is mapped to.
- port, err := d.FindPort(ctx, 8080)
+ // Find container IP address.
+ ip, err := d.FindIP(ctx, false)
if err != nil {
- t.Fatalf("FindPort(8080) failed: %v", err)
+ t.Fatalf("docker.FindIP failed: %v", err)
}
// Wait until it's up and running, 'gem install' can take some time.
- if err := testutil.WaitForHTTP(port, time.Minute); err != nil {
+ if err := testutil.WaitForHTTP(ip.String(), port, time.Minute); err != nil {
t.Fatalf("WaitForHTTP() timeout: %v", err)
}
// Ensure that content is being served.
- url := fmt.Sprintf("http://localhost:%d", port)
+ url := fmt.Sprintf("http://%s:%d", ip.String(), port)
resp, err := http.Get(url)
if err != nil {
t.Errorf("error reaching http server: %v", err)
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index e2beb30d5..834f7615f 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -72,11 +72,6 @@ func iptablesTest(t *testing.T, test TestCase, ipv6 bool) {
d.CleanUp(context.Background())
}()
- // TODO(gvisor.dev/issue/170): Skipping IPv6 gVisor tests.
- if ipv6 && dockerutil.Runtime() != "runc" {
- t.Skip("gVisor ip6tables not yet implemented")
- }
-
// Create and start the container.
opts := dockerutil.RunOpts{
Image: "iptables",
@@ -314,11 +309,11 @@ func TestInputInvertDestination(t *testing.T) {
singleTest(t, FilterInputInvertDestination{})
}
-func TestOutputDestination(t *testing.T) {
+func TestFilterOutputDestination(t *testing.T) {
singleTest(t, FilterOutputDestination{})
}
-func TestOutputInvertDestination(t *testing.T) {
+func TestFilterOutputInvertDestination(t *testing.T) {
singleTest(t, FilterOutputInvertDestination{})
}
diff --git a/test/packetimpact/dut/BUILD b/test/packetimpact/dut/BUILD
index 3ce63c2c6..ccf1c735f 100644
--- a/test/packetimpact/dut/BUILD
+++ b/test/packetimpact/dut/BUILD
@@ -16,3 +16,13 @@ cc_binary(
"//test/packetimpact/proto:posix_server_cc_proto",
],
)
+
+cc_binary(
+ name = "posix_server_dynamic",
+ srcs = ["posix_server.cc"],
+ deps = [
+ grpcpp,
+ "//test/packetimpact/proto:posix_server_cc_grpc_proto",
+ "//test/packetimpact/proto:posix_server_cc_proto",
+ ],
+)
diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc
index de5b4be93..4de8540f6 100644
--- a/test/packetimpact/dut/posix_server.cc
+++ b/test/packetimpact/dut/posix_server.cc
@@ -21,6 +21,7 @@
#include <string.h>
#include <sys/socket.h>
#include <sys/types.h>
+#include <time.h>
#include <unistd.h>
#include <iostream>
@@ -307,9 +308,9 @@ class PosixImpl final : public posix_server::Posix::Service {
break;
}
case ::posix_server::SockOptVal::kTimeval: {
- timeval tv = {.tv_sec = static_cast<__time_t>(
+ timeval tv = {.tv_sec = static_cast<time_t>(
request->optval().timeval().seconds()),
- .tv_usec = static_cast<__suseconds_t>(
+ .tv_usec = static_cast<suseconds_t>(
request->optval().timeval().microseconds())};
response->set_ret(setsockopt(request->sockfd(), request->level(),
request->optname(), &tv, sizeof(tv)));
@@ -336,7 +337,7 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status::OK;
}
- ::grpc::Status Shutdown(grpc_impl::ServerContext *context,
+ ::grpc::Status Shutdown(grpc::ServerContext *context,
const ::posix_server::ShutdownRequest *request,
::posix_server::ShutdownResponse *response) override {
if (shutdown(request->fd(), request->how()) < 0) {
diff --git a/test/packetimpact/runner/BUILD b/test/packetimpact/runner/BUILD
index ff2be9b30..605dd4972 100644
--- a/test/packetimpact/runner/BUILD
+++ b/test/packetimpact/runner/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "bzl_library", "go_test")
+load("//tools:defs.bzl", "bzl_library", "go_library", "go_test")
package(
default_visibility = ["//test/packetimpact:__subpackages__"],
@@ -7,21 +7,31 @@ package(
go_test(
name = "packetimpact_test",
- srcs = ["packetimpact_test.go"],
+ srcs = [
+ "packetimpact_test.go",
+ ],
tags = [
# Not intended to be run directly.
"local",
"manual",
],
- deps = [
- "//pkg/test/dockerutil",
- "//test/packetimpact/netdevs",
- "@com_github_docker_docker//api/types/mount:go_default_library",
- ],
+ deps = [":runner"],
)
bzl_library(
name = "defs_bzl",
srcs = ["defs.bzl"],
- visibility = ["//visibility:private"],
+ visibility = ["//test/packetimpact:__subpackages__"],
+)
+
+go_library(
+ name = "runner",
+ testonly = True,
+ srcs = ["dut.go"],
+ visibility = ["//test/packetimpact:__subpackages__"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/packetimpact/netdevs",
+ "@com_github_docker_docker//api/types/mount:go_default_library",
+ ],
)
diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl
index d72c63fe6..f56d3c42e 100644
--- a/test/packetimpact/runner/defs.bzl
+++ b/test/packetimpact/runner/defs.bzl
@@ -23,8 +23,9 @@ def _packetimpact_test_impl(ctx):
transitive_files = []
if hasattr(ctx.attr._test_runner, "data_runfiles"):
transitive_files.append(ctx.attr._test_runner.data_runfiles.files)
+ files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server
runfiles = ctx.runfiles(
- files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server_binary,
+ files = files,
transitive_files = depset(transitive = transitive_files),
collect_default = True,
collect_data = True,
@@ -38,7 +39,7 @@ _packetimpact_test = rule(
cfg = "target",
default = ":packetimpact_test",
),
- "_posix_server_binary": attr.label(
+ "_posix_server": attr.label(
cfg = "target",
default = "//test/packetimpact/dut:posix_server",
),
diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go
new file mode 100644
index 000000000..59bb68eb1
--- /dev/null
+++ b/test/packetimpact/runner/dut.go
@@ -0,0 +1,442 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package runner starts docker containers and networking for a packetimpact test.
+package runner
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "math/rand"
+ "net"
+ "os"
+ "os/exec"
+ "path"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/docker/docker/api/types/mount"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/packetimpact/netdevs"
+)
+
+// stringList implements flag.Value.
+type stringList []string
+
+// String implements flag.Value.String.
+func (l *stringList) String() string {
+ return strings.Join(*l, ",")
+}
+
+// Set implements flag.Value.Set.
+func (l *stringList) Set(value string) error {
+ *l = append(*l, value)
+ return nil
+}
+
+var (
+ native = false
+ testbenchBinary = ""
+ tshark = false
+ extraTestArgs = stringList{}
+ expectFailure = false
+
+ // DutAddr is the IP addres for DUT.
+ DutAddr = net.IPv4(0, 0, 0, 10)
+ testbenchAddr = net.IPv4(0, 0, 0, 20)
+)
+
+// RegisterFlags defines flags and associates them with the package-level
+// exported variables above. It should be called by tests in their init
+// functions.
+func RegisterFlags(fs *flag.FlagSet) {
+ fs.BoolVar(&native, "native", false, "whether the test should be run natively")
+ fs.StringVar(&testbenchBinary, "testbench_binary", "", "path to the testbench binary")
+ fs.BoolVar(&tshark, "tshark", false, "use more verbose tshark in logs instead of tcpdump")
+ fs.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench")
+ fs.BoolVar(&expectFailure, "expect_failure", false, "expect that the test will fail when run")
+}
+
+// CtrlPort is the port that posix_server listens on.
+const CtrlPort = "40000"
+
+// logger implements testutil.Logger.
+//
+// Labels logs based on their source and formats multi-line logs.
+type logger string
+
+// Name implements testutil.Logger.Name.
+func (l logger) Name() string {
+ return string(l)
+}
+
+// Logf implements testutil.Logger.Logf.
+func (l logger) Logf(format string, args ...interface{}) {
+ lines := strings.Split(fmt.Sprintf(format, args...), "\n")
+ log.Printf("%s: %s", l, lines[0])
+ for _, line := range lines[1:] {
+ log.Printf("%*s %s", len(l), "", line)
+ }
+}
+
+// TestWithDUT runs a packetimpact test with the given information.
+func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Container) DUT, containerAddr net.IP) {
+ if testbenchBinary == "" {
+ t.Fatal("--testbench_binary is missing")
+ }
+ dockerutil.EnsureSupportedDockerVersion()
+
+ // Create the networks needed for the test. One control network is needed for
+ // the gRPC control packets and one test network on which to transmit the test
+ // packets.
+ ctrlNet := dockerutil.NewNetwork(ctx, logger("ctrlNet"))
+ testNet := dockerutil.NewNetwork(ctx, logger("testNet"))
+ for _, dn := range []*dockerutil.Network{ctrlNet, testNet} {
+ for {
+ if err := createDockerNetwork(ctx, dn); err != nil {
+ t.Log("creating docker network:", err)
+ const wait = 100 * time.Millisecond
+ t.Logf("sleeping %s and will try creating docker network again", wait)
+ // This can fail if another docker network claimed the same IP so we'll
+ // just try again.
+ time.Sleep(wait)
+ continue
+ }
+ break
+ }
+ dn := dn
+ t.Cleanup(func() {
+ if err := dn.Cleanup(ctx); err != nil {
+ t.Errorf("unable to cleanup container %s: %s", dn.Name, err)
+ }
+ })
+ // Sanity check.
+ if inspect, err := dn.Inspect(ctx); err != nil {
+ t.Fatalf("failed to inspect network %s: %v", dn.Name, err)
+ } else if inspect.Name != dn.Name {
+ t.Fatalf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name)
+ }
+ }
+
+ tmpDir, err := ioutil.TempDir("", "container-output")
+ if err != nil {
+ t.Fatal("creating temp dir:", err)
+ }
+ t.Cleanup(func() {
+ if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil {
+ t.Errorf("unable to copy container output files: %s", err)
+ }
+ if err := os.RemoveAll(tmpDir); err != nil {
+ t.Errorf("failed to remove tmpDir %s: %s", tmpDir, err)
+ }
+ })
+
+ const testOutputDir = "/tmp/testoutput"
+
+ // Create the Docker container for the DUT.
+ var dut *dockerutil.Container
+ if native {
+ dut = dockerutil.MakeNativeContainer(ctx, logger("dut"))
+ } else {
+ dut = dockerutil.MakeContainer(ctx, logger("dut"))
+ }
+ t.Cleanup(func() {
+ dut.CleanUp(ctx)
+ })
+
+ runOpts := dockerutil.RunOpts{
+ Image: "packetimpact",
+ CapAdd: []string{"NET_ADMIN"},
+ Mounts: []mount.Mount{{
+ Type: mount.TypeBind,
+ Source: tmpDir,
+ Target: testOutputDir,
+ ReadOnly: false,
+ }},
+ }
+
+ device := mkDevice(dut)
+ remoteIPv6, remoteMAC, dutDeviceID, dutTestNetDev := device.Prepare(ctx, t, runOpts, ctrlNet, testNet, containerAddr)
+
+ // Create the Docker container for the testbench.
+ testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench"))
+
+ tbb := path.Base(testbenchBinary)
+ containerTestbenchBinary := filepath.Join("/packetimpact", tbb)
+ testbench.CopyFiles(&runOpts, "/packetimpact", filepath.Join("test/packetimpact/tests", tbb))
+
+ // snifferNetDev is a network device on the test orchestrator that we will
+ // run sniffer (tcpdump or tshark) on and inject traffic to, not to be
+ // confused with the device on the DUT.
+ const snifferNetDev = "eth2"
+ // Run tcpdump in the test bench unbuffered, without DNS resolution, just on
+ // the interface with the test packets.
+ snifferArgs := []string{
+ "tcpdump",
+ "-S", "-vvv", "-U", "-n",
+ "-i", snifferNetDev,
+ "-w", testOutputDir + "/dump.pcap",
+ }
+ snifferRegex := "tcpdump: listening.*\n"
+ if tshark {
+ // Run tshark in the test bench unbuffered, without DNS resolution, just on
+ // the interface with the test packets.
+ snifferArgs = []string{
+ "tshark", "-V", "-l", "-n", "-i", snifferNetDev,
+ "-o", "tcp.check_checksum:TRUE",
+ "-o", "udp.check_checksum:TRUE",
+ }
+ snifferRegex = "Capturing on.*\n"
+ }
+
+ if err := StartContainer(
+ ctx,
+ runOpts,
+ testbench,
+ testbenchAddr,
+ []*dockerutil.Network{ctrlNet, testNet},
+ snifferArgs...,
+ ); err != nil {
+ t.Fatalf("failed to start docker container for testbench sniffer: %s", err)
+ }
+ // Kill so that it will flush output.
+ t.Cleanup(func() {
+ time.Sleep(1 * time.Second)
+ testbench.Exec(ctx, dockerutil.ExecOpts{}, "killall", snifferArgs[0])
+ })
+
+ if _, err := testbench.WaitForOutput(ctx, snifferRegex, 60*time.Second); err != nil {
+ t.Fatalf("sniffer on %s never listened: %s", dut.Name, err)
+ }
+
+ // When the Linux kernel receives a SYN-ACK for a SYN it didn't send, it
+ // will respond with an RST. In most packetimpact tests, the SYN is sent
+ // by the raw socket and the kernel knows nothing about the connection, this
+ // behavior will break lots of TCP related packetimpact tests. To prevent
+ // this, we can install the following iptables rules. The raw socket that
+ // packetimpact tests use will still be able to see everything.
+ for _, bin := range []string{"iptables", "ip6tables"} {
+ if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", snifferNetDev, "-p", "tcp", "-j", "DROP"); err != nil {
+ t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbench.Name, err, logs)
+ }
+ }
+
+ // FIXME(b/156449515): Some piece of the system has a race. The old
+ // bash script version had a sleep, so we have one too. The race should
+ // be fixed and this sleep removed.
+ time.Sleep(time.Second)
+
+ // Start a packetimpact test on the test bench. The packetimpact test sends
+ // and receives packets and also sends POSIX socket commands to the
+ // posix_server to be executed on the DUT.
+ testArgs := []string{containerTestbenchBinary}
+ testArgs = append(testArgs, extraTestArgs...)
+ testArgs = append(testArgs,
+ "--posix_server_ip", AddressInSubnet(DutAddr, *ctrlNet.Subnet).String(),
+ "--posix_server_port", CtrlPort,
+ "--remote_ipv4", AddressInSubnet(DutAddr, *testNet.Subnet).String(),
+ "--local_ipv4", AddressInSubnet(testbenchAddr, *testNet.Subnet).String(),
+ "--remote_ipv6", remoteIPv6.String(),
+ "--remote_mac", remoteMAC.String(),
+ "--remote_interface_id", fmt.Sprintf("%d", dutDeviceID),
+ "--local_device", snifferNetDev,
+ "--remote_device", dutTestNetDev,
+ fmt.Sprintf("--native=%t", native),
+ )
+ testbenchLogs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...)
+ if (err != nil) != expectFailure {
+ var dutLogs string
+ if logs, err := device.Logs(ctx); err != nil {
+ dutLogs = fmt.Sprintf("failed to fetch DUT logs: %s", err)
+ } else {
+ dutLogs = logs
+ }
+
+ t.Errorf(`test error: %v, expect failure: %t
+
+%s
+
+====== Begin of Testbench Logs ======
+
+%s
+
+====== End of Testbench Logs ======`,
+ err, expectFailure, dutLogs, testbenchLogs)
+ }
+}
+
+// DUT describes how to setup/teardown the dut for packetimpact tests.
+type DUT interface {
+ // Prepare prepares the dut, starts posix_server and returns the IPv6, MAC
+ // address, the interface ID, and the interface name for the testNet on DUT.
+ Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network, containerAddr net.IP) (net.IP, net.HardwareAddr, uint32, string)
+ // Logs retrieves the logs from the dut.
+ Logs(ctx context.Context) (string, error)
+}
+
+// DockerDUT describes a docker based DUT.
+type DockerDUT struct {
+ c *dockerutil.Container
+}
+
+// NewDockerDUT creates a docker based DUT.
+func NewDockerDUT(c *dockerutil.Container) DUT {
+ return &DockerDUT{
+ c: c,
+ }
+}
+
+// Prepare implements DUT.Prepare.
+func (dut *DockerDUT) Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network, containerAddr net.IP) (net.IP, net.HardwareAddr, uint32, string) {
+ const containerPosixServerBinary = "/packetimpact/posix_server"
+ dut.c.CopyFiles(&runOpts, "/packetimpact", "test/packetimpact/dut/posix_server")
+
+ if err := StartContainer(
+ ctx,
+ runOpts,
+ dut.c,
+ containerAddr,
+ []*dockerutil.Network{ctrlNet, testNet},
+ containerPosixServerBinary,
+ "--ip=0.0.0.0",
+ "--port="+CtrlPort,
+ ); err != nil {
+ t.Fatalf("failed to start docker container for DUT: %s", err)
+ }
+
+ if _, err := dut.c.WaitForOutput(ctx, "Server listening.*\n", 60*time.Second); err != nil {
+ t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.c.Name, err)
+ }
+
+ dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut.c, AddressInSubnet(containerAddr, *testNet.Subnet))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ remoteMAC := dutDeviceInfo.MAC
+ remoteIPv6 := dutDeviceInfo.IPv6Addr
+ // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if
+ // needed.
+ if remoteIPv6 == nil {
+ if _, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil {
+ t.Fatalf("unable to ip addr add on container %s: %s", dut.c.Name, err)
+ }
+ // Now try again, to make sure that it worked.
+ _, dutDeviceInfo, err = deviceByIP(ctx, dut.c, AddressInSubnet(containerAddr, *testNet.Subnet))
+ if err != nil {
+ t.Fatal(err)
+ }
+ remoteIPv6 = dutDeviceInfo.IPv6Addr
+ if remoteIPv6 == nil {
+ t.Fatalf("unable to set IPv6 address on container %s", dut.c.Name)
+ }
+ }
+ const testNetDev = "eth2"
+
+ return remoteIPv6, dutDeviceInfo.MAC, dutDeviceInfo.ID, testNetDev
+}
+
+// Logs implements DUT.Logs.
+func (dut *DockerDUT) Logs(ctx context.Context) (string, error) {
+ logs, err := dut.c.Logs(ctx)
+ if err != nil {
+ return "", err
+ }
+ return fmt.Sprintf(`====== Begin of DUT Logs ======
+
+%s
+
+====== End of DUT Logs ======`, logs), nil
+}
+
+// AddNetworks connects docker network with the container and assigns the specific IP.
+func AddNetworks(ctx context.Context, d *dockerutil.Container, addr net.IP, networks []*dockerutil.Network) error {
+ for _, dn := range networks {
+ ip := AddressInSubnet(addr, *dn.Subnet)
+ // Connect to the network with the specified IP address.
+ if err := dn.Connect(ctx, d, ip.String(), ""); err != nil {
+ return fmt.Errorf("unable to connect container %s to network %s: %w", d.Name, dn.Name, err)
+ }
+ }
+ return nil
+}
+
+// AddressInSubnet combines the subnet provided with the address and returns a
+// new address. The return address bits come from the subnet where the mask is 1
+// and from the ip address where the mask is 0.
+func AddressInSubnet(addr net.IP, subnet net.IPNet) net.IP {
+ var octets []byte
+ for i := 0; i < 4; i++ {
+ octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i])))
+ }
+ return net.IP(octets)
+}
+
+// deviceByIP finds a deviceInfo and device name from an IP address.
+func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) {
+ out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "show")
+ if err != nil {
+ return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w\n%s", d.Name, err, out)
+ }
+ devs, err := netdevs.ParseDevices(out)
+ if err != nil {
+ return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w\n%s", d.Name, err, out)
+ }
+ testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs)
+ if err != nil {
+ return "", netdevs.DeviceInfo{}, fmt.Errorf("can't find deviceInfo for container %s: %w", d.Name, err)
+ }
+ return testDevice, deviceInfo, nil
+}
+
+// createDockerNetwork makes a randomly-named network that will start with the
+// namePrefix. The network will be a random /24 subnet.
+func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error {
+ randSource := rand.NewSource(time.Now().UnixNano())
+ r1 := rand.New(randSource)
+ // Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24.
+ ip := net.IPv4(byte(r1.Intn(224-192)+192), byte(r1.Intn(256)), byte(r1.Intn(256)), 0)
+ n.Subnet = &net.IPNet{
+ IP: ip,
+ Mask: ip.DefaultMask(),
+ }
+ return n.Create(ctx)
+}
+
+// StartContainer will create a container instance from runOpts, connect it
+// with the specified docker networks and start executing the specified cmd.
+func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerutil.Container, containerAddr net.IP, ns []*dockerutil.Network, cmd ...string) error {
+ conf, hostconf, netconf := c.ConfigsFrom(runOpts, cmd...)
+ _ = netconf
+ hostconf.AutoRemove = true
+ hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"}
+
+ if err := c.CreateFrom(ctx, conf, hostconf, nil); err != nil {
+ return fmt.Errorf("unable to create container %s: %w", c.Name, err)
+ }
+
+ if err := AddNetworks(ctx, c, containerAddr, ns); err != nil {
+ return fmt.Errorf("unable to connect the container with the networks: %w", err)
+ }
+
+ if err := c.Start(ctx); err != nil {
+ return fmt.Errorf("unable to start container %s: %w", c.Name, err)
+ }
+ return nil
+}
diff --git a/test/packetimpact/runner/packetimpact_test.go b/test/packetimpact/runner/packetimpact_test.go
index cb9bfd5b7..c598bfc29 100644
--- a/test/packetimpact/runner/packetimpact_test.go
+++ b/test/packetimpact/runner/packetimpact_test.go
@@ -18,366 +18,15 @@ package packetimpact_test
import (
"context"
"flag"
- "fmt"
- "io/ioutil"
- "log"
- "math/rand"
- "net"
- "os"
- "os/exec"
- "path"
- "strings"
"testing"
- "time"
- "github.com/docker/docker/api/types/mount"
- "gvisor.dev/gvisor/pkg/test/dockerutil"
- "gvisor.dev/gvisor/test/packetimpact/netdevs"
+ "gvisor.dev/gvisor/test/packetimpact/runner"
)
-// stringList implements flag.Value.
-type stringList []string
-
-// String implements flag.Value.String.
-func (l *stringList) String() string {
- return strings.Join(*l, ",")
-}
-
-// Set implements flag.Value.Set.
-func (l *stringList) Set(value string) error {
- *l = append(*l, value)
- return nil
-}
-
-var (
- native = flag.Bool("native", false, "whether the test should be run natively")
- testbenchBinary = flag.String("testbench_binary", "", "path to the testbench binary")
- tshark = flag.Bool("tshark", false, "use more verbose tshark in logs instead of tcpdump")
- extraTestArgs = stringList{}
- expectFailure = flag.Bool("expect_failure", false, "expect that the test will fail when run")
-
- dutAddr = net.IPv4(0, 0, 0, 10)
- testbenchAddr = net.IPv4(0, 0, 0, 20)
-)
-
-const ctrlPort = "40000"
-
-// logger implements testutil.Logger.
-//
-// Labels logs based on their source and formats multi-line logs.
-type logger string
-
-// Name implements testutil.Logger.Name.
-func (l logger) Name() string {
- return string(l)
-}
-
-// Logf implements testutil.Logger.Logf.
-func (l logger) Logf(format string, args ...interface{}) {
- lines := strings.Split(fmt.Sprintf(format, args...), "\n")
- log.Printf("%s: %s", l, lines[0])
- for _, line := range lines[1:] {
- log.Printf("%*s %s", len(l), "", line)
- }
+func init() {
+ runner.RegisterFlags(flag.CommandLine)
}
func TestOne(t *testing.T) {
- flag.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench")
- flag.Parse()
- if *testbenchBinary == "" {
- t.Fatal("--testbench_binary is missing")
- }
- dockerutil.EnsureSupportedDockerVersion()
- ctx := context.Background()
-
- // Create the networks needed for the test. One control network is needed for
- // the gRPC control packets and one test network on which to transmit the test
- // packets.
- ctrlNet := dockerutil.NewNetwork(ctx, logger("ctrlNet"))
- testNet := dockerutil.NewNetwork(ctx, logger("testNet"))
- for _, dn := range []*dockerutil.Network{ctrlNet, testNet} {
- for {
- if err := createDockerNetwork(ctx, dn); err != nil {
- t.Log("creating docker network:", err)
- const wait = 100 * time.Millisecond
- t.Logf("sleeping %s and will try creating docker network again", wait)
- // This can fail if another docker network claimed the same IP so we'll
- // just try again.
- time.Sleep(wait)
- continue
- }
- break
- }
- defer func(dn *dockerutil.Network) {
- if err := dn.Cleanup(ctx); err != nil {
- t.Errorf("unable to cleanup container %s: %s", dn.Name, err)
- }
- }(dn)
- // Sanity check.
- inspect, err := dn.Inspect(ctx)
- if err != nil {
- t.Fatalf("failed to inspect network %s: %v", dn.Name, err)
- } else if inspect.Name != dn.Name {
- t.Fatalf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name)
- }
-
- }
-
- tmpDir, err := ioutil.TempDir("", "container-output")
- if err != nil {
- t.Fatal("creating temp dir:", err)
- }
- defer os.RemoveAll(tmpDir)
-
- const testOutputDir = "/tmp/testoutput"
-
- // Create the Docker container for the DUT.
- var dut *dockerutil.Container
- if *native {
- dut = dockerutil.MakeNativeContainer(ctx, logger("dut"))
- } else {
- dut = dockerutil.MakeContainer(ctx, logger("dut"))
- }
-
- runOpts := dockerutil.RunOpts{
- Image: "packetimpact",
- CapAdd: []string{"NET_ADMIN"},
- Mounts: []mount.Mount{mount.Mount{
- Type: mount.TypeBind,
- Source: tmpDir,
- Target: testOutputDir,
- ReadOnly: false,
- }},
- }
-
- const containerPosixServerBinary = "/packetimpact/posix_server"
- dut.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/dut/posix_server")
-
- conf, hostconf, _ := dut.ConfigsFrom(runOpts, containerPosixServerBinary, "--ip=0.0.0.0", "--port="+ctrlPort)
- hostconf.AutoRemove = true
- hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"}
-
- if err := dut.CreateFrom(ctx, conf, hostconf, nil); err != nil {
- t.Fatalf("unable to create container %s: %v", dut.Name, err)
- }
-
- defer dut.CleanUp(ctx)
-
- // Add ctrlNet as eth1 and testNet as eth2.
- const testNetDev = "eth2"
- if err := addNetworks(ctx, dut, dutAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil {
- t.Fatal(err)
- }
-
- if err := dut.Start(ctx); err != nil {
- t.Fatalf("unable to start container %s: %s", dut.Name, err)
- }
-
- if _, err := dut.WaitForOutput(ctx, "Server listening.*\n", 60*time.Second); err != nil {
- t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.Name, err)
- }
-
- dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet))
- if err != nil {
- t.Fatal(err)
- }
-
- remoteMAC := dutDeviceInfo.MAC
- remoteIPv6 := dutDeviceInfo.IPv6Addr
- // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if
- // needed.
- if remoteIPv6 == nil {
- if _, err := dut.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil {
- t.Fatalf("unable to ip addr add on container %s: %s", dut.Name, err)
- }
- // Now try again, to make sure that it worked.
- _, dutDeviceInfo, err = deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet))
- if err != nil {
- t.Fatal(err)
- }
- remoteIPv6 = dutDeviceInfo.IPv6Addr
- if remoteIPv6 == nil {
- t.Fatal("unable to set IPv6 address on container", dut.Name)
- }
- }
-
- // Create the Docker container for the testbench.
- testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench"))
-
- tbb := path.Base(*testbenchBinary)
- containerTestbenchBinary := "/packetimpact/" + tbb
- runOpts = dockerutil.RunOpts{
- Image: "packetimpact",
- CapAdd: []string{"NET_ADMIN"},
- Mounts: []mount.Mount{mount.Mount{
- Type: mount.TypeBind,
- Source: tmpDir,
- Target: testOutputDir,
- ReadOnly: false,
- }},
- }
- testbench.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/tests/"+tbb)
-
- // Run tcpdump in the test bench unbuffered, without DNS resolution, just on
- // the interface with the test packets.
- snifferArgs := []string{
- "tcpdump",
- "-S", "-vvv", "-U", "-n",
- "-i", testNetDev,
- "-w", testOutputDir + "/dump.pcap",
- }
- snifferRegex := "tcpdump: listening.*\n"
- if *tshark {
- // Run tshark in the test bench unbuffered, without DNS resolution, just on
- // the interface with the test packets.
- snifferArgs = []string{
- "tshark", "-V", "-l", "-n", "-i", testNetDev,
- "-o", "tcp.check_checksum:TRUE",
- "-o", "udp.check_checksum:TRUE",
- }
- snifferRegex = "Capturing on.*\n"
- }
-
- defer func() {
- if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil {
- t.Error("unable to copy container output files:", err)
- }
- }()
-
- conf, hostconf, _ = testbench.ConfigsFrom(runOpts, snifferArgs...)
- hostconf.AutoRemove = true
- hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"}
-
- if err := testbench.CreateFrom(ctx, conf, hostconf, nil); err != nil {
- t.Fatalf("unable to create container %s: %s", testbench.Name, err)
- }
- defer testbench.CleanUp(ctx)
-
- // Add ctrlNet as eth1 and testNet as eth2.
- if err := addNetworks(ctx, testbench, testbenchAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil {
- t.Fatal(err)
- }
-
- if err := testbench.Start(ctx); err != nil {
- t.Fatalf("unable to start container %s: %s", testbench.Name, err)
- }
-
- // Kill so that it will flush output.
- defer func() {
- time.Sleep(1 * time.Second)
- testbench.Exec(ctx, dockerutil.ExecOpts{}, "killall", snifferArgs[0])
- }()
-
- if _, err := testbench.WaitForOutput(ctx, snifferRegex, 60*time.Second); err != nil {
- t.Fatalf("sniffer on %s never listened: %s", dut.Name, err)
- }
-
- // Because the Linux kernel receives the SYN-ACK but didn't send the SYN it
- // will issue an RST. To prevent this IPtables can be used to filter out all
- // incoming packets. The raw socket that packetimpact tests use will still see
- // everything.
- for _, bin := range []string{"iptables", "ip6tables"} {
- if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", testNetDev, "-p", "tcp", "-j", "DROP"); err != nil {
- t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbench.Name, err, logs)
- }
- }
-
- // FIXME(b/156449515): Some piece of the system has a race. The old
- // bash script version had a sleep, so we have one too. The race should
- // be fixed and this sleep removed.
- time.Sleep(time.Second)
-
- // Start a packetimpact test on the test bench. The packetimpact test sends
- // and receives packets and also sends POSIX socket commands to the
- // posix_server to be executed on the DUT.
- testArgs := []string{containerTestbenchBinary}
- testArgs = append(testArgs, extraTestArgs...)
- testArgs = append(testArgs,
- "--posix_server_ip", addressInSubnet(dutAddr, *ctrlNet.Subnet).String(),
- "--posix_server_port", ctrlPort,
- "--remote_ipv4", addressInSubnet(dutAddr, *testNet.Subnet).String(),
- "--local_ipv4", addressInSubnet(testbenchAddr, *testNet.Subnet).String(),
- "--remote_ipv6", remoteIPv6.String(),
- "--remote_mac", remoteMAC.String(),
- "--remote_interface_id", fmt.Sprintf("%d", dutDeviceInfo.ID),
- "--device", testNetDev,
- fmt.Sprintf("--native=%t", *native),
- )
- testbenchLogs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...)
- if (err != nil) != *expectFailure {
- var dutLogs string
- if logs, err := dut.Logs(ctx); err != nil {
- dutLogs = fmt.Sprintf("failed to fetch DUT logs: %s", err)
- } else {
- dutLogs = logs
- }
-
- t.Errorf(`test error: %v, expect failure: %t
-
-====== Begin of DUT Logs ======
-
-%s
-
-====== End of DUT Logs ======
-
-====== Begin of Testbench Logs ======
-
-%s
-
-====== End of Testbench Logs ======`,
- err, *expectFailure, dutLogs, testbenchLogs)
- }
-}
-
-func addNetworks(ctx context.Context, d *dockerutil.Container, addr net.IP, networks []*dockerutil.Network) error {
- for _, dn := range networks {
- ip := addressInSubnet(addr, *dn.Subnet)
- // Connect to the network with the specified IP address.
- if err := dn.Connect(ctx, d, ip.String(), ""); err != nil {
- return fmt.Errorf("unable to connect container %s to network %s: %w", d.Name, dn.Name, err)
- }
- }
- return nil
-}
-
-// addressInSubnet combines the subnet provided with the address and returns a
-// new address. The return address bits come from the subnet where the mask is 1
-// and from the ip address where the mask is 0.
-func addressInSubnet(addr net.IP, subnet net.IPNet) net.IP {
- var octets []byte
- for i := 0; i < 4; i++ {
- octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i])))
- }
- return net.IP(octets)
-}
-
-// createDockerNetwork makes a randomly-named network that will start with the
-// namePrefix. The network will be a random /24 subnet.
-func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error {
- randSource := rand.NewSource(time.Now().UnixNano())
- r1 := rand.New(randSource)
- // Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24.
- ip := net.IPv4(byte(r1.Intn(224-192)+192), byte(r1.Intn(256)), byte(r1.Intn(256)), 0)
- n.Subnet = &net.IPNet{
- IP: ip,
- Mask: ip.DefaultMask(),
- }
- return n.Create(ctx)
-}
-
-// deviceByIP finds a deviceInfo and device name from an IP address.
-func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) {
- out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "show")
- if err != nil {
- return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w\n%s", d.Name, err, out)
- }
- devs, err := netdevs.ParseDevices(out)
- if err != nil {
- return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w\n%s", d.Name, err, out)
- }
- testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs)
- if err != nil {
- return "", netdevs.DeviceInfo{}, fmt.Errorf("can't find deviceInfo for container %s: %w", d.Name, err)
- }
- return testDevice, deviceInfo, nil
+ runner.TestWithDUT(context.Background(), t, runner.NewDockerDUT, runner.DutAddr)
}
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index 3af5f83fd..a90046f69 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -615,7 +615,7 @@ func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Du
if errs == nil {
return nil, fmt.Errorf("got no frames matching %v during %s", layers, timeout)
}
- return nil, fmt.Errorf("got no frames matching %v during %s: got %w", layers, timeout, errs)
+ return nil, fmt.Errorf("got frames %w want %v during %s", errs, layers, timeout)
}
if conn.match(layers, gotLayers) {
for i, s := range conn.layerStates {
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
index 57e822725..193bb2dc8 100644
--- a/test/packetimpact/testbench/rawsockets.go
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -139,7 +139,7 @@ type Injector struct {
func NewInjector(t *testing.T) (Injector, error) {
t.Helper()
- ifInfo, err := net.InterfaceByName(Device)
+ ifInfo, err := net.InterfaceByName(LocalDevice)
if err != nil {
return Injector{}, err
}
diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go
index e3629e1f3..0073a1361 100644
--- a/test/packetimpact/testbench/testbench.go
+++ b/test/packetimpact/testbench/testbench.go
@@ -29,8 +29,11 @@ import (
var (
// Native indicates that the test is being run natively.
Native = false
- // Device is the local device on the test network.
- Device = ""
+ // LocalDevice is the device that testbench uses to inject traffic.
+ LocalDevice = ""
+ // RemoteDevice is the device name on the DUT, individual tests can
+ // use the name to construct tests.
+ RemoteDevice = ""
// LocalIPv4 is the local IPv4 address on the test network.
LocalIPv4 = ""
@@ -80,7 +83,8 @@ func RegisterFlags(fs *flag.FlagSet) {
fs.StringVar(&RemoteIPv4, "remote_ipv4", RemoteIPv4, "remote IPv4 address for test packets")
fs.StringVar(&RemoteIPv6, "remote_ipv6", RemoteIPv6, "remote IPv6 address for test packets")
fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets")
- fs.StringVar(&Device, "device", Device, "local device for test packets")
+ fs.StringVar(&LocalDevice, "local_device", LocalDevice, "local device to inject traffic")
+ fs.StringVar(&RemoteDevice, "remote_device", RemoteDevice, "remote device on the DUT")
fs.BoolVar(&Native, "native", Native, "whether the test is running natively")
fs.Uint64Var(&RemoteInterfaceID, "remote_interface_id", RemoteInterfaceID, "remote interface ID for test packets")
}
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index 7a7152fa5..94731c64b 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -166,8 +166,8 @@ packetimpact_go_test(
)
packetimpact_go_test(
- name = "tcp_close_wait_ack",
- srcs = ["tcp_close_wait_ack_test.go"],
+ name = "tcp_unacc_seq_ack",
+ srcs = ["tcp_unacc_seq_ack_test.go"],
deps = [
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
@@ -260,6 +260,28 @@ packetimpact_go_test(
)
packetimpact_go_test(
+ name = "tcp_timewait_reset",
+ srcs = ["tcp_timewait_reset_test.go"],
+ # TODO(b/168523247): Fix netstack then remove the line below.
+ expect_netstack_failure = True,
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_queue_send_in_syn_sent",
+ srcs = ["tcp_queue_send_in_syn_sent_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
name = "icmpv6_param_problem",
srcs = ["icmpv6_param_problem_test.go"],
# TODO(b/153485026): Fix netstack then remove the line below.
@@ -318,3 +340,13 @@ packetimpact_go_test(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+packetimpact_go_test(
+ name = "tcp_rcv_buf_space",
+ srcs = ["tcp_rcv_buf_space_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go
deleted file mode 100644
index e6a96f214..000000000
--- a/test/packetimpact/tests/tcp_close_wait_ack_test.go
+++ /dev/null
@@ -1,109 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_close_wait_ack_test
-
-import (
- "flag"
- "fmt"
- "testing"
- "time"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/test/packetimpact/testbench"
-)
-
-func init() {
- testbench.RegisterFlags(flag.CommandLine)
-}
-
-func TestCloseWaitAck(t *testing.T) {
- for _, tt := range []struct {
- description string
- makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP
- seqNumOffset seqnum.Size
- expectAck bool
- }{
- {"OTW", generateOTWSeqSegment, 0, false},
- {"OTW", generateOTWSeqSegment, 1, true},
- {"OTW", generateOTWSeqSegment, 2, true},
- {"ACK", generateUnaccACKSegment, 0, false},
- {"ACK", generateUnaccACKSegment, 1, true},
- {"ACK", generateUnaccACKSegment, 2, true},
- } {
- t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
- dut := testbench.NewDUT(t)
- defer dut.TearDown()
- listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(t, listenFd)
- conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close(t)
-
- conn.Connect(t)
- acceptFd, _ := dut.Accept(t, listenFd)
-
- // Send a FIN to DUT to intiate the active close
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)})
- gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
- if err != nil {
- t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err)
- }
- windowSize := seqnum.Size(*gotTCP.WindowSize)
-
- // Send a segment with OTW Seq / unacc ACK and expect an ACK back
- conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")})
- gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
- if tt.expectAck && err != nil {
- t.Fatalf("expected an ack but got none: %s", err)
- }
- if !tt.expectAck && gotAck != nil {
- t.Fatalf("expected no ack but got one: %s", gotAck)
- }
-
- // Now let's verify DUT is indeed in CLOSE_WAIT
- dut.Close(t, acceptFd)
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil {
- t.Fatalf("expected DUT to send a FIN: %s", err)
- }
- // Ack the FIN from DUT
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
- // Send some extra data to DUT
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")})
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
- t.Fatalf("expected DUT to send an RST: %s", err)
- }
- })
- }
-}
-
-// generateOTWSeqSegment generates an segment with
-// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only
-// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the
-// receiver.
-func generateOTWSeqSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
- lastAcceptable := conn.LocalSeqNum(t).Add(windowSize)
- otwSeq := uint32(lastAcceptable.Add(seqNumOffset))
- return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)}
-}
-
-// generateUnaccACKSegment generates an segment with
-// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable
-// when seqNumOffset is 0, otherwise an ACK is expected from the receiver.
-func generateUnaccACKSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
- lastAcceptable := conn.RemoteSeqNum(t)
- unaccAck := uint32(lastAcceptable.Add(seqNumOffset))
- return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)}
-}
diff --git a/test/packetimpact/tests/tcp_linger_test.go b/test/packetimpact/tests/tcp_linger_test.go
index 913e49e06..b9a0409aa 100644
--- a/test/packetimpact/tests/tcp_linger_test.go
+++ b/test/packetimpact/tests/tcp_linger_test.go
@@ -251,3 +251,20 @@ func TestTCPLingerShutdownSendNonZeroTimeout(t *testing.T) {
})
}
}
+
+func TestTCPLingerNonEstablished(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ newFD := dut.Socket(t, unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_TCP)
+ dut.SetSockLingerOption(t, newFD, lingerDuration, true)
+
+ // As the socket is in the initial state, Close() should not linger
+ // and return immediately.
+ start := time.Now()
+ dut.CloseWithErrno(context.Background(), t, newFD)
+ diff := time.Since(start)
+
+ if diff > lingerDuration {
+ t.Errorf("expected close to return within %s, but returned after %s", lingerDuration, diff)
+ }
+ dut.TearDown()
+}
diff --git a/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go
new file mode 100644
index 000000000..0ec8fd748
--- /dev/null
+++ b/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go
@@ -0,0 +1,133 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_queue_send_in_syn_sent_test
+
+import (
+ "context"
+ "errors"
+ "flag"
+ "net"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestQueueSendInSynSent tests send behavior when the TCP state
+// is SYN-SENT.
+// It tests for 2 variants when in SYN_SENT state and:
+// (1) DUT blocks on send and complete handshake
+// (2) DUT blocks on send and receive a TCP RST.
+func TestQueueSendInSynSent(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ reset bool
+ }{
+ {description: "Complete handshake", reset: false},
+ {description: "Send RST", reset: true},
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+
+ socket, remotePort := dut.CreateBoundSocket(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4))
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+ dut.SetNonBlocking(t, socket, true)
+ if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) {
+ t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err)
+ }
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil {
+ t.Fatalf("expected a SYN from DUT, but got none: %s", err)
+ }
+ if _, err := dut.SendWithErrno(context.Background(), t, socket, sampleData, 0); err != syscall.Errno(unix.EWOULDBLOCK) {
+ t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err)
+ }
+
+ // Test blocking write.
+ dut.SetNonBlocking(t, socket, false)
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+ wg.Add(1)
+ var block sync.WaitGroup
+ block.Add(1)
+ go func() {
+ defer wg.Done()
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
+ defer cancel()
+
+ block.Done()
+ // Issue SEND call in SYN-SENT, this should be queued for
+ // process until the connection is established.
+ n, err := dut.SendWithErrno(ctx, t, socket, sampleData, 0)
+ if tt.reset {
+ if err != syscall.Errno(unix.ECONNREFUSED) {
+ t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err)
+ }
+ if n != -1 {
+ t.Errorf("expected return value %d, got %d", -1, n)
+ }
+ return
+ }
+ if n != int32(len(sampleData)) {
+ t.Errorf("failed to send on DUT: %s", err)
+ }
+ }()
+
+ // Wait for the goroutine to be scheduled and before it
+ // blocks on endpoint send.
+ block.Wait()
+ // The following sleep is used to prevent the connection
+ // from being established before we are blocked on send.
+ time.Sleep(100 * time.Millisecond)
+
+ if tt.reset {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)})
+ return
+ }
+
+ // Bring the connection to Established.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)})
+
+ // Expect the data from the DUT's enqueued send request.
+ //
+ // On Linux, this can be piggybacked with the ACK completing the
+ // handshake. On gVisor, getting such a piggyback is a bit more
+ // complicated because the actual data enqueuing occurs in the
+ // callers of endpoint Write.
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagPsh | header.TCPFlagAck)}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+
+ // Send sample payload and expect an ACK to ensure connection is still ESTABLISHED.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData})
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected an ACK from DUT, but got none: %s", err)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_rcv_buf_space_test.go b/test/packetimpact/tests/tcp_rcv_buf_space_test.go
new file mode 100644
index 000000000..cfbba1e8e
--- /dev/null
+++ b/test/packetimpact/tests/tcp_rcv_buf_space_test.go
@@ -0,0 +1,80 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_rcv_buf_space_test
+
+import (
+ "context"
+ "flag"
+ "syscall"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestReduceRecvBuf tests that a packet within window is still dropped
+// if the available buffer space drops below the size of the incoming
+// segment.
+func TestReduceRecvBuf(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
+
+ // Set a small receive buffer for the test.
+ const rcvBufSz = 4096
+ dut.SetSockOptInt(t, acceptFd, unix.SOL_SOCKET, unix.SO_RCVBUF, rcvBufSz)
+
+ // Retrieve the actual buffer.
+ bufSz := dut.GetSockOptInt(t, acceptFd, unix.SOL_SOCKET, unix.SO_RCVBUF)
+
+ // Generate a payload of 1 more than the actual buffer size used by the
+ // DUT.
+ sampleData := testbench.GenerateRandomPayload(t, int(bufSz)+1)
+ // Send and receive sample data to the dut.
+ const pktSize = 1400
+ for payload := sampleData; len(payload) != 0; {
+ payloadBytes := pktSize
+ if l := len(payload); l < payloadBytes {
+ payloadBytes = l
+ }
+
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, []testbench.Layer{&testbench.Payload{Bytes: payload[:payloadBytes]}}...)
+ payload = payload[payloadBytes:]
+ }
+
+ // First read should read < len(sampleData)
+ if ret, _, err := dut.RecvWithErrno(context.Background(), t, acceptFd, int32(len(sampleData)), 0); ret == -1 || int(ret) == len(sampleData) {
+ t.Fatalf("dut.RecvWithErrno(ctx, t, %d, %d, 0) = %d,_, %s", acceptFd, int32(len(sampleData)), ret, err)
+ }
+
+ // Second read should return EAGAIN as the last segment should have been
+ // dropped due to it exceeding the receive buffer space available in the
+ // socket.
+ if ret, got, err := dut.RecvWithErrno(context.Background(), t, acceptFd, int32(len(sampleData)), syscall.MSG_DONTWAIT); got != nil || ret != -1 || err != syscall.EAGAIN {
+ t.Fatalf("expected no packets but got: %s", got)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_timewait_reset_test.go b/test/packetimpact/tests/tcp_timewait_reset_test.go
new file mode 100644
index 000000000..2f76a6531
--- /dev/null
+++ b/test/packetimpact/tests/tcp_timewait_reset_test.go
@@ -0,0 +1,68 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_timewait_reset_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestTimeWaitReset tests handling of RST when in TIME_WAIT state.
+func TestTimeWaitReset(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/)
+ defer dut.Close(t, listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFD, _ := dut.Accept(t, listenFD)
+
+ // Trigger active close.
+ dut.Close(t, acceptFD)
+
+ _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected a FIN: %s", err)
+ }
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ // Send a FIN, DUT should transition to TIME_WAIT from FIN_WAIT2.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)})
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected an ACK for our FIN: %s", err)
+ }
+
+ // Send a RST, the DUT should transition to CLOSED from TIME_WAIT.
+ // This is the default Linux behavior, it can be changed to ignore RSTs via
+ // sysctl net.ipv4.tcp_rfc1337.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)})
+
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ // The DUT should reply with RST to our ACK as the state should have
+ // transitioned to CLOSED.
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ t.Fatalf("expected a RST: %s", err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_unacc_seq_ack_test.go b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go
new file mode 100644
index 000000000..d078bbf15
--- /dev/null
+++ b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go
@@ -0,0 +1,234 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_unacc_seq_ack_test
+
+import (
+ "flag"
+ "fmt"
+ "syscall"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestEstablishedUnaccSeqAck(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP
+ seqNumOffset seqnum.Size
+ expectAck bool
+ restoreSeq bool
+ }{
+ {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, expectAck: true, restoreSeq: true},
+ {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, expectAck: true, restoreSeq: true},
+ {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, expectAck: true, restoreSeq: true},
+ {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, expectAck: true, restoreSeq: false},
+ {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, expectAck: false, restoreSeq: true},
+ {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, expectAck: false, restoreSeq: true},
+ } {
+ t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/)
+ defer dut.Close(t, listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ dut.Accept(t, listenFD)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected ack %s", err)
+ }
+ windowSize := seqnum.Size(*gotTCP.WindowSize)
+
+ origSeq := *conn.LocalSeqNum(t)
+ // Send a segment with OTW Seq / unacc ACK.
+ conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), samplePayload)
+ if tt.restoreSeq {
+ // Restore the local sequence number to ensure that the incoming
+ // ACK matches the TCP layer state.
+ *conn.LocalSeqNum(t) = origSeq
+ }
+ gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if tt.expectAck && err != nil {
+ t.Fatalf("expected an ack but got none: %s", err)
+ }
+ if err == nil && !tt.expectAck && gotAck != nil {
+ t.Fatalf("expected no ack but got one: %s", gotAck)
+ }
+ })
+ }
+}
+
+func TestPassiveCloseUnaccSeqAck(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP
+ seqNumOffset seqnum.Size
+ expectAck bool
+ }{
+ {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, expectAck: false},
+ {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, expectAck: true},
+ {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, expectAck: true},
+ {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, expectAck: false},
+ {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, expectAck: true},
+ {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, expectAck: true},
+ } {
+ t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/)
+ defer dut.Close(t, listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFD, _ := dut.Accept(t, listenFD)
+
+ // Send a FIN to DUT to intiate the passive close.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)})
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err)
+ }
+ windowSize := seqnum.Size(*gotTCP.WindowSize)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ // Send a segment with OTW Seq / unacc ACK.
+ conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), samplePayload)
+ gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if tt.expectAck && err != nil {
+ t.Errorf("expected an ack but got none: %s", err)
+ }
+ if err == nil && !tt.expectAck && gotAck != nil {
+ t.Errorf("expected no ack but got one: %s", gotAck)
+ }
+
+ // Now let's verify DUT is indeed in CLOSE_WAIT
+ dut.Close(t, acceptFD)
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil {
+ t.Fatalf("expected DUT to send a FIN: %s", err)
+ }
+ // Ack the FIN from DUT
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ // Send some extra data to DUT
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, samplePayload)
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ t.Fatalf("expected DUT to send an RST: %s", err)
+ }
+ })
+ }
+}
+
+func TestActiveCloseUnaccpSeqAck(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP
+ seqNumOffset seqnum.Size
+ restoreSeq bool
+ }{
+ {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, restoreSeq: true},
+ {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, restoreSeq: true},
+ {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, restoreSeq: true},
+ {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, restoreSeq: false},
+ {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, restoreSeq: true},
+ {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, restoreSeq: true},
+ } {
+ t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/)
+ defer dut.Close(t, listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFD, _ := dut.Accept(t, listenFD)
+
+ // Trigger active close.
+ dut.Shutdown(t, acceptFD, syscall.SHUT_WR)
+
+ // Get to FIN_WAIT2
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected a FIN: %s", err)
+ }
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+
+ sendUnaccSeqAck := func(state string) {
+ t.Helper()
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ origSeq := *conn.LocalSeqNum(t)
+ // Send a segment with OTW Seq / unacc ACK.
+ conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, seqnum.Size(*gotTCP.WindowSize)), samplePayload)
+ if tt.restoreSeq {
+ // Restore the local sequence number to ensure that the
+ // incoming ACK matches the TCP layer state.
+ *conn.LocalSeqNum(t) = origSeq
+ }
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ t.Errorf("expected an ack in %s state, but got none: %s", state, err)
+ }
+ }
+
+ sendUnaccSeqAck("FIN_WAIT2")
+
+ // Send a FIN to DUT to get to TIME_WAIT
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)})
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected an ACK for our fin and DUT should enter TIME_WAIT: %s", err)
+ }
+
+ sendUnaccSeqAck("TIME_WAIT")
+ })
+ }
+}
+
+// generateOTWSeqSegment generates an segment with
+// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only
+// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the
+// receiver.
+func generateOTWSeqSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
+ lastAcceptable := conn.LocalSeqNum(t).Add(windowSize)
+ otwSeq := uint32(lastAcceptable.Add(seqNumOffset))
+ return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)}
+}
+
+// generateUnaccACKSegment generates an segment with
+// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable
+// when seqNumOffset is 0, otherwise an ACK is expected from the receiver.
+func generateUnaccACKSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
+ lastAcceptable := conn.RemoteSeqNum(t)
+ unaccAck := uint32(lastAcceptable.Add(seqNumOffset))
+ return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)}
+}
diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD
index b4e907826..dd1d2438c 100644
--- a/test/perf/linux/BUILD
+++ b/test/perf/linux/BUILD
@@ -354,3 +354,19 @@ cc_binary(
"//test/util:test_util",
],
)
+
+cc_binary(
+ name = "open_read_close_benchmark",
+ testonly = 1,
+ srcs = [
+ "open_read_close_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ ],
+)
diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc
index d8e81fa8c..9030eb356 100644
--- a/test/perf/linux/getdents_benchmark.cc
+++ b/test/perf/linux/getdents_benchmark.cc
@@ -105,7 +105,7 @@ void BM_GetdentsSameFD(benchmark::State& state) {
state.SetItemsProcessed(state.iterations());
}
-BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 16)->UseRealTime();
+BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 12)->UseRealTime();
// Creates a directory containing `files` files, and reads all the directory
// entries from the directory using a new FD each time.
diff --git a/test/perf/linux/open_read_close_benchmark.cc b/test/perf/linux/open_read_close_benchmark.cc
new file mode 100644
index 000000000..8b023a3d8
--- /dev/null
+++ b/test/perf/linux/open_read_close_benchmark.cc
@@ -0,0 +1,61 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_OpenReadClose(benchmark::State& state) {
+ const int size = state.range(0);
+ std::vector<TempPath> cache;
+ for (int i = 0; i < size; i++) {
+ auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), "some content", 0644));
+ cache.emplace_back(std::move(path));
+ }
+
+ char buf[1];
+ unsigned int seed = 1;
+ for (auto _ : state) {
+ const int chosen = rand_r(&seed) % size;
+ int fd = open(cache[chosen].path().c_str(), O_RDONLY);
+ TEST_CHECK(fd != -1);
+ TEST_CHECK(read(fd, buf, 1) == 1);
+ close(fd);
+ }
+}
+
+// Gofer dentry cache is 1000 by default. Go over it to force files to be closed
+// for real.
+BENCHMARK(BM_OpenReadClose)->Range(1000, 16384)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/root/root.go b/test/root/root.go
index 0f1d29faf..441fa5e2e 100644
--- a/test/root/root.go
+++ b/test/root/root.go
@@ -17,5 +17,5 @@
// docker, containerd, and crictl installed. To run these tests from the
// project root directory:
//
-// ./scripts/root_tests.sh
+// make root-tests
package root
diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD
index 066338ee3..22b526f59 100644
--- a/test/runtimes/BUILD
+++ b/test/runtimes/BUILD
@@ -5,7 +5,7 @@ package(licenses = ["notice"])
runtime_test(
name = "go1.12",
- exclude_file = "exclude_go1.12.csv",
+ exclude_file = "exclude/go1.12.csv",
lang = "go",
shard_count = 8,
)
@@ -13,28 +13,28 @@ runtime_test(
runtime_test(
name = "java11",
batch = 100,
- exclude_file = "exclude_java11.csv",
+ exclude_file = "exclude/java11.csv",
lang = "java",
shard_count = 16,
)
runtime_test(
name = "nodejs12.4.0",
- exclude_file = "exclude_nodejs12.4.0.csv",
+ exclude_file = "exclude/nodejs12.4.0.csv",
lang = "nodejs",
shard_count = 8,
)
runtime_test(
name = "php7.3.6",
- exclude_file = "exclude_php7.3.6.csv",
+ exclude_file = "exclude/php7.3.6.csv",
lang = "php",
shard_count = 8,
)
runtime_test(
name = "python3.7.3",
- exclude_file = "exclude_python3.7.3.csv",
+ exclude_file = "exclude/python3.7.3.csv",
lang = "python",
shard_count = 8,
)
diff --git a/test/runtimes/README.md b/test/runtimes/README.md
new file mode 100644
index 000000000..9dda1a728
--- /dev/null
+++ b/test/runtimes/README.md
@@ -0,0 +1,62 @@
+# gVisor Runtime Tests
+
+App Engine uses gvisor to sandbox application containers. The runtime tests aim
+to test `runsc` compatibility with these
+[standard runtimes](https://cloud.google.com/appengine/docs/standard/runtimes).
+The test itself runs the language-defined tests inside the sandboxed standard
+runtime container.
+
+Note: [Ruby runtime](https://cloud.google.com/appengine/docs/standard/ruby) is
+currently in beta mode and so we do not run tests for it yet.
+
+### Testing Locally
+
+To run runtime tests individually from a given runtime, use the following table.
+
+Language | Version | Download Image | Run Test(s)
+-------- | ------- | ------------------------------------------- | -----------
+Go | 1.12 | `make -C images load-runtimes_go1.12` | If the test name ends with `.go`, it is an on-disk test: <br> `docker run --runtime=runsc -it gvisor.dev/images/runtimes/go1.12 ( cd /usr/local/go/test ; go run run.go -v -- <TEST_NAME>... )` <br> Otherwise it is a tool test: <br> `docker run --runtime=runsc -it gvisor.dev/images/runtimes/go1.12 go tool dist test -v -no-rebuild ^TEST1$\|^TEST2$...`
+Java | 11 | `make -C images load-runtimes_java11` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/java11 jtreg -agentvm -dir:/root/test/jdk -noreport -timeoutFactor:20 -verbose:summary <TEST_NAME>...`
+NodeJS | 12.4.0 | `make -C images load-runtimes_nodejs12.4.0` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/nodejs12.4.0 python tools/test.py --timeout=180 <TEST_NAME>...`
+Php | 7.3.6 | `make -C images load-runtimes_php7.3.6` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/php7.3.6 make test "TESTS=<TEST_NAME>..."`
+Python | 3.7.3 | `make -C images load-runtimes_python3.7.3` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/python3.7.3 ./python -m test <TEST_NAME>...`
+
+To run an entire runtime test locally, use the following table.
+
+Note: java runtime test take 1+ hours with 16 cores.
+
+Language | Version | Running the test suite
+-------- | ------- | ----------------------------------------
+Go | 1.12 | `make go1.12-runtime-tests{_vfs2}`
+Java | 11 | `make java11-runtime-tests{_vfs2}`
+NodeJS | 12.4.0 | `make nodejs12.4.0-runtime-tests{_vfs2}`
+Php | 7.3.6 | `make php7.3.6-runtime-tests{_vfs2}`
+Python | 3.7.3 | `make python3.7.3-runtime-tests{_vfs2}`
+
+#### Clean Up
+
+Sometimes when runtime tests fail or when the testing container itself crashes
+unexpectedly, the containers are not removed or sometimes do not even exit. This
+can cause some docker commands like `docker system prune` to hang forever.
+
+Here are some helpful commands (should be executed in order):
+
+```bash
+docker ps -a # Lists all docker processes; useful when investigating hanging containers.
+docker kill $(docker ps -a -q) # Kills all running containers.
+docker rm $(docker ps -a -q) # Removes all exited containers.
+docker system prune # Remove unused data.
+```
+
+### Testing Infrastructure
+
+There are 3 components to this tests infrastructure:
+
+- [`runner`](runner) - This is the test entrypoint. This is the binary is
+ invoked by `bazel test`. The runner spawns the target runtime container
+ using `runsc` and then copies over the `proctor` binary into the container.
+- [`proctor`](proctor) - This binary acts as our agent inside the container
+ which communicates with the runner and actually executes tests.
+- [`exclude`](exclude) - Holds a CSV file for each language runtime containing
+ the full path of tests that should be excluded from running along with a
+ reason for exclusion.
diff --git a/test/runtimes/exclude_go1.12.csv b/test/runtimes/exclude/go1.12.csv
index 81e02cf64..81e02cf64 100644
--- a/test/runtimes/exclude_go1.12.csv
+++ b/test/runtimes/exclude/go1.12.csv
diff --git a/test/runtimes/exclude_java11.csv b/test/runtimes/exclude/java11.csv
index 997a29cad..f779df8d5 100644
--- a/test/runtimes/exclude_java11.csv
+++ b/test/runtimes/exclude/java11.csv
@@ -1,9 +1,11 @@
test name,bug id,comment
com/sun/crypto/provider/Cipher/PBE/PKCS12Cipher.java,,Fails in Docker
+com/sun/jdi/InvokeHangTest.java,https://bugs.openjdk.java.net/browse/JDK-8218463,
com/sun/jdi/NashornPopFrameTest.java,,
com/sun/jdi/ProcessAttachTest.java,,
com/sun/management/HotSpotDiagnosticMXBean/CheckOrigin.java,,Fails in Docker
com/sun/management/OperatingSystemMXBean/GetCommittedVirtualMemorySize.java,,
+com/sun/management/ThreadMXBean/ThreadCpuTimeArray.java,,Test assumes high CPU clock precision
com/sun/management/UnixOperatingSystemMXBean/GetMaxFileDescriptorCount.sh,,
com/sun/tools/attach/AttachSelf.java,,
com/sun/tools/attach/BasicTests.java,,
diff --git a/test/runtimes/exclude_nodejs12.4.0.csv b/test/runtimes/exclude/nodejs12.4.0.csv
index 5866ee56d..749fb9482 100644
--- a/test/runtimes/exclude_nodejs12.4.0.csv
+++ b/test/runtimes/exclude/nodejs12.4.0.csv
@@ -1,4 +1,5 @@
test name,bug id,comment
+async-hooks/test-statwatcher.js,https://github.com/nodejs/node/issues/21425,Check for fix inclusion in nodejs releases after 2020-03-29
benchmark/test-benchmark-fs.js,,
benchmark/test-benchmark-napi.js,,
doctool/test-make-doc.js,b/68848110,Expected to fail.
@@ -11,8 +12,9 @@ parallel/test-dgram-socket-buffer-size.js,b/68847921,
parallel/test-dns-channel-timeout.js,b/161893056,
parallel/test-fs-access.js,,
parallel/test-fs-watchfile.js,,Flaky - File already exists error
-parallel/test-fs-write-stream.js,,Flaky
-parallel/test-fs-write-stream-throw-type-error.js,b/110226209,
+parallel/test-fs-write-stream.js,b/166819807,Flaky
+parallel/test-fs-write-stream-double-close,b/166819807,Flaky
+parallel/test-fs-write-stream-throw-type-error.js,b/166819807,Flaky
parallel/test-http-writable-true-after-close.js,,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2
parallel/test-os.js,b/63997097,
parallel/test-net-server-listen-options.js,,Flaky - EADDRINUSE
diff --git a/test/runtimes/exclude_php7.3.6.csv b/test/runtimes/exclude/php7.3.6.csv
index 815a137c5..a73f3bcfb 100644
--- a/test/runtimes/exclude_php7.3.6.csv
+++ b/test/runtimes/exclude/php7.3.6.csv
@@ -32,6 +32,7 @@ ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,b/16289534
ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,b/162896223,
ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,,
ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,,
+ext/standard/tests/streams/proc_open_bug60120.phpt,,Flaky until php-src 3852a35fdbcb
ext/standard/tests/streams/proc_open_bug69900.phpt,,Flaky
ext/standard/tests/streams/stream_socket_sendto.phpt,,
ext/standard/tests/strings/007.phpt,,
diff --git a/test/runtimes/exclude_python3.7.3.csv b/test/runtimes/exclude/python3.7.3.csv
index 8760f8951..8760f8951 100644
--- a/test/runtimes/exclude_python3.7.3.csv
+++ b/test/runtimes/exclude/python3.7.3.csv
diff --git a/test/runtimes/proctor/BUILD b/test/runtimes/proctor/BUILD
index d1935cbe8..fdc6d3173 100644
--- a/test/runtimes/proctor/BUILD
+++ b/test/runtimes/proctor/BUILD
@@ -1,29 +1,11 @@
-load("//tools:defs.bzl", "go_binary", "go_test")
+load("//tools:defs.bzl", "go_binary")
package(licenses = ["notice"])
go_binary(
name = "proctor",
- srcs = [
- "go.go",
- "java.go",
- "nodejs.go",
- "php.go",
- "proctor.go",
- "python.go",
- ],
+ srcs = ["main.go"],
pure = True,
visibility = ["//test/runtimes:__pkg__"],
-)
-
-go_test(
- name = "proctor_test",
- size = "small",
- srcs = ["proctor_test.go"],
- library = ":proctor",
- nogo = False, # FIXME(gvisor.dev/issue/3374): Not working with all build systems.
- pure = True,
- deps = [
- "//pkg/test/testutil",
- ],
+ deps = ["//test/runtimes/proctor/lib"],
)
diff --git a/test/runtimes/proctor/lib/BUILD b/test/runtimes/proctor/lib/BUILD
new file mode 100644
index 000000000..0c8367dfe
--- /dev/null
+++ b/test/runtimes/proctor/lib/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "lib",
+ srcs = [
+ "go.go",
+ "java.go",
+ "lib.go",
+ "nodejs.go",
+ "php.go",
+ "python.go",
+ ],
+ visibility = ["//test/runtimes/proctor:__pkg__"],
+)
+
+go_test(
+ name = "lib_test",
+ size = "small",
+ srcs = ["lib_test.go"],
+ library = ":lib",
+ deps = ["//pkg/test/testutil"],
+)
diff --git a/test/runtimes/proctor/go.go b/test/runtimes/proctor/lib/go.go
index d0ae844e6..5c48fb60b 100644
--- a/test/runtimes/proctor/go.go
+++ b/test/runtimes/proctor/lib/go.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package main
+package lib
import (
"fmt"
@@ -59,7 +59,7 @@ func (goRunner) ListTests() ([]string, error) {
}
// Go tests on disk.
- diskSlice, err := search(goTestDir, goTestRegEx)
+ diskSlice, err := Search(goTestDir, goTestRegEx)
if err != nil {
return nil, err
}
diff --git a/test/runtimes/proctor/java.go b/test/runtimes/proctor/lib/java.go
index d456fa681..3105011ff 100644
--- a/test/runtimes/proctor/java.go
+++ b/test/runtimes/proctor/lib/java.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package main
+package lib
import (
"fmt"
diff --git a/test/runtimes/proctor/proctor.go b/test/runtimes/proctor/lib/lib.go
index 9e0642424..f2ba82498 100644
--- a/test/runtimes/proctor/proctor.go
+++ b/test/runtimes/proctor/lib/lib.go
@@ -12,20 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary proctor runs the test for a particular runtime. It is meant to be
-// included in Docker images for all runtime tests.
-package main
+// Package lib contains proctor functions.
+package lib
import (
- "flag"
"fmt"
- "log"
"os"
"os/exec"
"os/signal"
"path/filepath"
"regexp"
- "strings"
"syscall"
)
@@ -42,66 +38,8 @@ type TestRunner interface {
TestCmds(tests []string) []*exec.Cmd
}
-var (
- runtime = flag.String("runtime", "", "name of runtime")
- list = flag.Bool("list", false, "list all available tests")
- testNames = flag.String("tests", "", "run a subset of the available tests")
- pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children")
-)
-
-func main() {
- flag.Parse()
-
- if *pause {
- pauseAndReap()
- panic("pauseAndReap should never return")
- }
-
- if *runtime == "" {
- log.Fatalf("runtime flag must be provided")
- }
-
- tr, err := testRunnerForRuntime(*runtime)
- if err != nil {
- log.Fatalf("%v", err)
- }
-
- // List tests.
- if *list {
- tests, err := tr.ListTests()
- if err != nil {
- log.Fatalf("failed to list tests: %v", err)
- }
- for _, test := range tests {
- fmt.Println(test)
- }
- return
- }
-
- var tests []string
- if *testNames == "" {
- // Run every test.
- tests, err = tr.ListTests()
- if err != nil {
- log.Fatalf("failed to get all tests: %v", err)
- }
- } else {
- // Run subset of test.
- tests = strings.Split(*testNames, ",")
- }
-
- // Run tests.
- cmds := tr.TestCmds(tests)
- for _, cmd := range cmds {
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- log.Fatalf("FAIL: %v", err)
- }
- }
-}
-
-// testRunnerForRuntime returns a new TestRunner for the given runtime.
-func testRunnerForRuntime(runtime string) (TestRunner, error) {
+// TestRunnerForRuntime returns a new TestRunner for the given runtime.
+func TestRunnerForRuntime(runtime string) (TestRunner, error) {
switch runtime {
case "go":
return goRunner{}, nil
@@ -117,8 +55,8 @@ func testRunnerForRuntime(runtime string) (TestRunner, error) {
return nil, fmt.Errorf("invalid runtime %q", runtime)
}
-// pauseAndReap is like init. It runs forever and reaps any children.
-func pauseAndReap() {
+// PauseAndReap is like init. It runs forever and reaps any children.
+func PauseAndReap() {
// Get notified of any new children.
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGCHLD)
@@ -138,9 +76,9 @@ func pauseAndReap() {
}
}
-// search is a helper function to find tests in the given directory that match
+// Search is a helper function to find tests in the given directory that match
// the regex.
-func search(root string, testFilter *regexp.Regexp) ([]string, error) {
+func Search(root string, testFilter *regexp.Regexp) ([]string, error) {
var testSlice []string
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
diff --git a/test/runtimes/proctor/proctor_test.go b/test/runtimes/proctor/lib/lib_test.go
index 6ef2de085..1193d2e28 100644
--- a/test/runtimes/proctor/proctor_test.go
+++ b/test/runtimes/proctor/lib/lib_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package main
+package lib
import (
"io/ioutil"
@@ -47,7 +47,7 @@ func TestSearchEmptyDir(t *testing.T) {
var want []string
testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`)
- got, err := search(td, testFilter)
+ got, err := Search(td, testFilter)
if err != nil {
t.Errorf("search error: %v", err)
}
@@ -116,7 +116,7 @@ func TestSearch(t *testing.T) {
}
testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`)
- got, err := search(td, testFilter)
+ got, err := Search(td, testFilter)
if err != nil {
t.Errorf("search error: %v", err)
}
diff --git a/test/runtimes/proctor/nodejs.go b/test/runtimes/proctor/lib/nodejs.go
index dead5af4f..320597aa5 100644
--- a/test/runtimes/proctor/nodejs.go
+++ b/test/runtimes/proctor/lib/nodejs.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package main
+package lib
import (
"os/exec"
@@ -32,7 +32,7 @@ var _ TestRunner = nodejsRunner{}
// ListTests implements TestRunner.ListTests.
func (nodejsRunner) ListTests() ([]string, error) {
- testSlice, err := search(nodejsTestDir, nodejsTestRegEx)
+ testSlice, err := Search(nodejsTestDir, nodejsTestRegEx)
if err != nil {
return nil, err
}
diff --git a/test/runtimes/proctor/php.go b/test/runtimes/proctor/lib/php.go
index 6a83d64e3..b67a60a97 100644
--- a/test/runtimes/proctor/php.go
+++ b/test/runtimes/proctor/lib/php.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package main
+package lib
import (
"os/exec"
@@ -29,7 +29,7 @@ var _ TestRunner = phpRunner{}
// ListTests implements TestRunner.ListTests.
func (phpRunner) ListTests() ([]string, error) {
- testSlice, err := search(".", phpTestRegEx)
+ testSlice, err := Search(".", phpTestRegEx)
if err != nil {
return nil, err
}
diff --git a/test/runtimes/proctor/python.go b/test/runtimes/proctor/lib/python.go
index 7c598801b..429bfd850 100644
--- a/test/runtimes/proctor/python.go
+++ b/test/runtimes/proctor/lib/python.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package main
+package lib
import (
"fmt"
diff --git a/test/runtimes/proctor/main.go b/test/runtimes/proctor/main.go
new file mode 100644
index 000000000..e5607ac92
--- /dev/null
+++ b/test/runtimes/proctor/main.go
@@ -0,0 +1,85 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary proctor runs the test for a particular runtime. It is meant to be
+// included in Docker images for all runtime tests.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "log"
+ "os"
+ "strings"
+
+ "gvisor.dev/gvisor/test/runtimes/proctor/lib"
+)
+
+var (
+ runtime = flag.String("runtime", "", "name of runtime")
+ list = flag.Bool("list", false, "list all available tests")
+ testNames = flag.String("tests", "", "run a subset of the available tests")
+ pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children")
+)
+
+func main() {
+ flag.Parse()
+
+ if *pause {
+ lib.PauseAndReap()
+ panic("pauseAndReap should never return")
+ }
+
+ if *runtime == "" {
+ log.Fatalf("runtime flag must be provided")
+ }
+
+ tr, err := lib.TestRunnerForRuntime(*runtime)
+ if err != nil {
+ log.Fatalf("%v", err)
+ }
+
+ // List tests.
+ if *list {
+ tests, err := tr.ListTests()
+ if err != nil {
+ log.Fatalf("failed to list tests: %v", err)
+ }
+ for _, test := range tests {
+ fmt.Println(test)
+ }
+ return
+ }
+
+ var tests []string
+ if *testNames == "" {
+ // Run every test.
+ tests, err = tr.ListTests()
+ if err != nil {
+ log.Fatalf("failed to get all tests: %v", err)
+ }
+ } else {
+ // Run subset of test.
+ tests = strings.Split(*testNames, ",")
+ }
+
+ // Run tests.
+ cmds := tr.TestCmds(tests)
+ for _, cmd := range cmds {
+ cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
+ if err := cmd.Run(); err != nil {
+ log.Fatalf("FAIL: %v", err)
+ }
+ }
+}
diff --git a/test/runtimes/runner/BUILD b/test/runtimes/runner/BUILD
index dc0d5d5b4..70cc01594 100644
--- a/test/runtimes/runner/BUILD
+++ b/test/runtimes/runner/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_binary", "go_test")
+load("//tools:defs.bzl", "go_binary")
package(licenses = ["notice"])
@@ -7,16 +7,5 @@ go_binary(
testonly = 1,
srcs = ["main.go"],
visibility = ["//test/runtimes:__pkg__"],
- deps = [
- "//pkg/log",
- "//pkg/test/dockerutil",
- "//pkg/test/testutil",
- ],
-)
-
-go_test(
- name = "exclude_test",
- size = "small",
- srcs = ["exclude_test.go"],
- library = ":runner",
+ deps = ["//test/runtimes/runner/lib"],
)
diff --git a/test/runtimes/runner/lib/BUILD b/test/runtimes/runner/lib/BUILD
new file mode 100644
index 000000000..d308f41b0
--- /dev/null
+++ b/test/runtimes/runner/lib/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "lib",
+ testonly = 1,
+ srcs = ["lib.go"],
+ visibility = ["//test/runtimes/runner:__pkg__"],
+ deps = [
+ "//pkg/log",
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ ],
+)
+
+go_test(
+ name = "lib_test",
+ size = "small",
+ srcs = ["exclude_test.go"],
+ library = ":lib",
+)
diff --git a/test/runtimes/runner/exclude_test.go b/test/runtimes/runner/lib/exclude_test.go
index 67c2170c8..f996e895b 100644
--- a/test/runtimes/runner/exclude_test.go
+++ b/test/runtimes/runner/lib/exclude_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package main
+package lib
import (
"flag"
@@ -20,6 +20,8 @@ import (
"testing"
)
+var excludeFile = flag.String("exclude_file", "", "file to test (standard format)")
+
func TestMain(m *testing.M) {
flag.Parse()
os.Exit(m.Run())
@@ -27,7 +29,7 @@ func TestMain(m *testing.M) {
// Test that the exclude file parses without error.
func TestExcludelist(t *testing.T) {
- ex, err := getExcludes()
+ ex, err := getExcludes(*excludeFile)
if err != nil {
t.Fatalf("error parsing exclude file: %v", err)
}
diff --git a/test/runtimes/runner/lib/lib.go b/test/runtimes/runner/lib/lib.go
new file mode 100644
index 000000000..78285cb0e
--- /dev/null
+++ b/test/runtimes/runner/lib/lib.go
@@ -0,0 +1,185 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package lib provides utilities for runner.
+package lib
+
+import (
+ "context"
+ "encoding/csv"
+ "fmt"
+ "io"
+ "os"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// RunTests is a helper that is called by main. It exists so that we can run
+// defered functions before exiting. It returns an exit code that should be
+// passed to os.Exit.
+func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Duration) int {
+ // Get tests to exclude..
+ excludes, err := getExcludes(excludeFile)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error getting exclude list: %s\n", err.Error())
+ return 1
+ }
+
+ // Construct the shared docker instance.
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, testutil.DefaultLogger(lang))
+ defer d.CleanUp(ctx)
+
+ if err := testutil.TouchShardStatusFile(); err != nil {
+ fmt.Fprintf(os.Stderr, "error touching status shard file: %v\n", err)
+ return 1
+ }
+
+ // Get a slice of tests to run. This will also start a single Docker
+ // container that will be used to run each test. The final test will
+ // stop the Docker container.
+ tests, err := getTests(ctx, d, lang, image, batchSize, timeout, excludes)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%s\n", err.Error())
+ return 1
+ }
+
+ m := testing.MainStart(testDeps{}, tests, nil, nil)
+ return m.Run()
+}
+
+// getTests executes all tests as table tests.
+func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, batchSize int, timeout time.Duration, excludes map[string]struct{}) ([]testing.InternalTest, error) {
+ // Start the container.
+ opts := dockerutil.RunOpts{
+ Image: fmt.Sprintf("runtimes/%s", image),
+ }
+ d.CopyFiles(&opts, "/proctor", "test/runtimes/proctor/proctor")
+ if err := d.Spawn(ctx, opts, "/proctor/proctor", "--pause"); err != nil {
+ return nil, fmt.Errorf("docker run failed: %v", err)
+ }
+
+ // Get a list of all tests in the image.
+ list, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", lang, "--list")
+ if err != nil {
+ return nil, fmt.Errorf("docker exec failed: %v", err)
+ }
+
+ // Calculate a subset of tests to run corresponding to the current
+ // shard.
+ tests := strings.Fields(list)
+ sort.Strings(tests)
+ indices, err := testutil.TestIndicesForShard(len(tests))
+ if err != nil {
+ return nil, fmt.Errorf("TestsForShard() failed: %v", err)
+ }
+
+ var itests []testing.InternalTest
+ for i := 0; i < len(indices); i += batchSize {
+ var tcs []string
+ end := i + batchSize
+ if end > len(indices) {
+ end = len(indices)
+ }
+ for _, tc := range indices[i:end] {
+ // Add test if not excluded.
+ if _, ok := excludes[tests[tc]]; ok {
+ log.Infof("Skipping test case %s\n", tests[tc])
+ continue
+ }
+ tcs = append(tcs, tests[tc])
+ }
+ itests = append(itests, testing.InternalTest{
+ Name: strings.Join(tcs, ", "),
+ F: func(t *testing.T) {
+ var (
+ now = time.Now()
+ done = make(chan struct{})
+ output string
+ err error
+ )
+
+ go func() {
+ fmt.Printf("RUNNING the following in a batch\n%s\n", strings.Join(tcs, "\n"))
+ output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", lang, "--tests", strings.Join(tcs, ","))
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ if err == nil {
+ fmt.Printf("PASS: (%v)\n\n", time.Since(now))
+ return
+ }
+ t.Errorf("FAIL: (%v):\n%s\n", time.Since(now), output)
+ case <-time.After(timeout):
+ t.Errorf("TIMEOUT: (%v):\n%s\n", time.Since(now), output)
+ }
+ },
+ })
+ }
+
+ return itests, nil
+}
+
+// getBlacklist reads the exclude file and returns a set of test names to
+// exclude.
+func getExcludes(excludeFile string) (map[string]struct{}, error) {
+ excludes := make(map[string]struct{})
+ if excludeFile == "" {
+ return excludes, nil
+ }
+ f, err := os.Open(excludeFile)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ r := csv.NewReader(f)
+
+ // First line is header. Skip it.
+ if _, err := r.Read(); err != nil {
+ return nil, err
+ }
+
+ for {
+ record, err := r.Read()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return nil, err
+ }
+ excludes[record[0]] = struct{}{}
+ }
+ return excludes, nil
+}
+
+// testDeps implements testing.testDeps (an unexported interface), and is
+// required to use testing.MainStart.
+type testDeps struct{}
+
+func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil }
+func (f testDeps) StartCPUProfile(io.Writer) error { return nil }
+func (f testDeps) StopCPUProfile() {}
+func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil }
+func (f testDeps) ImportPath() string { return "" }
+func (f testDeps) StartTestLog(io.Writer) {}
+func (f testDeps) StopTestLog() error { return nil }
diff --git a/test/runtimes/runner/main.go b/test/runtimes/runner/main.go
index 948e7cf9c..ec79a22c2 100644
--- a/test/runtimes/runner/main.go
+++ b/test/runtimes/runner/main.go
@@ -16,20 +16,12 @@
package main
import (
- "context"
- "encoding/csv"
"flag"
"fmt"
- "io"
"os"
- "sort"
- "strings"
- "testing"
"time"
- "gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/test/dockerutil"
- "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/test/runtimes/runner/lib"
)
var (
@@ -37,169 +29,14 @@ var (
image = flag.String("image", "", "docker image with runtime tests")
excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment")
batchSize = flag.Int("batch", 50, "number of test cases run in one command")
+ timeout = flag.Duration("timeout", 90*time.Minute, "batch timeout")
)
-// Wait time for each test to run.
-const timeout = 90 * time.Minute
-
func main() {
flag.Parse()
if *lang == "" || *image == "" {
fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n")
os.Exit(1)
}
- os.Exit(runTests())
-}
-
-// runTests is a helper that is called by main. It exists so that we can run
-// defered functions before exiting. It returns an exit code that should be
-// passed to os.Exit.
-func runTests() int {
- // Get tests to exclude..
- excludes, err := getExcludes()
- if err != nil {
- fmt.Fprintf(os.Stderr, "Error getting exclude list: %s\n", err.Error())
- return 1
- }
-
- // Construct the shared docker instance.
- ctx := context.Background()
- d := dockerutil.MakeContainer(ctx, testutil.DefaultLogger(*lang))
- defer d.CleanUp(ctx)
-
- if err := testutil.TouchShardStatusFile(); err != nil {
- fmt.Fprintf(os.Stderr, "error touching status shard file: %v\n", err)
- return 1
- }
-
- // Get a slice of tests to run. This will also start a single Docker
- // container that will be used to run each test. The final test will
- // stop the Docker container.
- tests, err := getTests(ctx, d, excludes)
- if err != nil {
- fmt.Fprintf(os.Stderr, "%s\n", err.Error())
- return 1
- }
-
- m := testing.MainStart(testDeps{}, tests, nil, nil)
- return m.Run()
-}
-
-// getTests executes all tests as table tests.
-func getTests(ctx context.Context, d *dockerutil.Container, excludes map[string]struct{}) ([]testing.InternalTest, error) {
- // Start the container.
- opts := dockerutil.RunOpts{
- Image: fmt.Sprintf("runtimes/%s", *image),
- }
- d.CopyFiles(&opts, "/proctor", "test/runtimes/proctor/proctor")
- if err := d.Spawn(ctx, opts, "/proctor/proctor", "--pause"); err != nil {
- return nil, fmt.Errorf("docker run failed: %v", err)
- }
-
- // Get a list of all tests in the image.
- list, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", *lang, "--list")
- if err != nil {
- return nil, fmt.Errorf("docker exec failed: %v", err)
- }
-
- // Calculate a subset of tests to run corresponding to the current
- // shard.
- tests := strings.Fields(list)
- sort.Strings(tests)
- indices, err := testutil.TestIndicesForShard(len(tests))
- if err != nil {
- return nil, fmt.Errorf("TestsForShard() failed: %v", err)
- }
-
- var itests []testing.InternalTest
- for i := 0; i < len(indices); i += *batchSize {
- var tcs []string
- end := i + *batchSize
- if end > len(indices) {
- end = len(indices)
- }
- for _, tc := range indices[i:end] {
- // Add test if not excluded.
- if _, ok := excludes[tests[tc]]; ok {
- log.Infof("Skipping test case %s\n", tests[tc])
- continue
- }
- tcs = append(tcs, tests[tc])
- }
- itests = append(itests, testing.InternalTest{
- Name: strings.Join(tcs, ", "),
- F: func(t *testing.T) {
- var (
- now = time.Now()
- done = make(chan struct{})
- output string
- err error
- )
-
- go func() {
- fmt.Printf("RUNNING the following in a batch\n%s\n", strings.Join(tcs, "\n"))
- output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", *lang, "--tests", strings.Join(tcs, ","))
- close(done)
- }()
-
- select {
- case <-done:
- if err == nil {
- fmt.Printf("PASS: (%v)\n\n", time.Since(now))
- return
- }
- t.Errorf("FAIL: (%v):\n%s\n", time.Since(now), output)
- case <-time.After(timeout):
- t.Errorf("TIMEOUT: (%v):\n%s\n", time.Since(now), output)
- }
- },
- })
- }
-
- return itests, nil
+ os.Exit(lib.RunTests(*lang, *image, *excludeFile, *batchSize, *timeout))
}
-
-// getBlacklist reads the exclude file and returns a set of test names to
-// exclude.
-func getExcludes() (map[string]struct{}, error) {
- excludes := make(map[string]struct{})
- if *excludeFile == "" {
- return excludes, nil
- }
- f, err := os.Open(*excludeFile)
- if err != nil {
- return nil, err
- }
- defer f.Close()
-
- r := csv.NewReader(f)
-
- // First line is header. Skip it.
- if _, err := r.Read(); err != nil {
- return nil, err
- }
-
- for {
- record, err := r.Read()
- if err == io.EOF {
- break
- }
- if err != nil {
- return nil, err
- }
- excludes[record[0]] = struct{}{}
- }
- return excludes, nil
-}
-
-// testDeps implements testing.testDeps (an unexported interface), and is
-// required to use testing.MainStart.
-type testDeps struct{}
-
-func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil }
-func (f testDeps) StartCPUProfile(io.Writer) error { return nil }
-func (f testDeps) StopCPUProfile() {}
-func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil }
-func (f testDeps) ImportPath() string { return "" }
-func (f testDeps) StartTestLog(io.Writer) {}
-func (f testDeps) StopTestLog() error { return nil }
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index f949bc0e3..96a775456 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -238,7 +238,7 @@ syscall_test(
syscall_test(
size = "medium",
- add_overlay = False, # TODO(gvisor.dev/issue/317): enable when fixed.
+ add_overlay = True,
test = "//test/syscalls/linux:inotify_test",
)
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index de753fc4e..d9dbe2267 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1667,12 +1667,14 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/container:node_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
gtest,
"//test/util:memory_util",
"//test/util:posix_error",
+ "//test/util:proc_util",
"//test/util:temp_path",
"//test/util:test_util",
"//test/util:thread_util",
@@ -2412,6 +2414,7 @@ cc_library(
":socket_test_util",
"@com_google_absl//absl/memory",
gtest,
+ "//test/util:posix_error",
"//test/util:test_util",
],
alwayslink = 1,
@@ -2430,6 +2433,7 @@ cc_library(
":socket_netlink_route_util",
":socket_test_util",
"//test/util:capability_util",
+ "//test/util:cleanup",
gtest,
],
alwayslink = 1,
diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc
index cabc2b751..edd23e063 100644
--- a/test/syscalls/linux/fallocate.cc
+++ b/test/syscalls/linux/fallocate.cc
@@ -179,6 +179,12 @@ TEST_F(AllocateTest, FallocateOtherFDs) {
auto sock0 = FileDescriptor(socks[0]);
auto sock1 = FileDescriptor(socks[1]);
EXPECT_THAT(fallocate(sock0.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV));
+
+ int pipefds[2];
+ ASSERT_THAT(pipe(pipefds), SyscallSucceeds());
+ EXPECT_THAT(fallocate(pipefds[1], 0, 0, 10), SyscallFailsWithErrno(ESPIPE));
+ close(pipefds[0]);
+ close(pipefds[1]);
}
} // namespace
diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc
index a5c421118..e4392a450 100644
--- a/test/syscalls/linux/inotify.cc
+++ b/test/syscalls/linux/inotify.cc
@@ -465,7 +465,9 @@ TEST(Inotify, ConcurrentFileDeletionAndWatchRemoval) {
for (int i = 0; i < 100; ++i) {
FileDescriptor file_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_CREAT, S_IRUSR | S_IWUSR));
- file_fd.reset(); // Close before unlinking (although save is disabled).
+ // Close before unlinking (although S/R is disabled). Some filesystems
+ // cannot restore an open fd on an unlinked file.
+ file_fd.reset();
EXPECT_THAT(unlink(filename.c_str()), SyscallSucceeds());
}
};
@@ -1256,10 +1258,7 @@ TEST(Inotify, MknodGeneratesCreateEvent) {
InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
const TempPath file1(root.path() + "/file1");
- const int rc = mknod(file1.path().c_str(), S_IFREG, 0);
- // mknod(2) is only supported on tmpfs in the sandbox.
- SKIP_IF(IsRunningOnGvisor() && rc != 0);
- ASSERT_THAT(rc, SyscallSucceeds());
+ ASSERT_THAT(mknod(file1.path().c_str(), S_IFREG, 0), SyscallSucceeds());
const std::vector<Event> events =
ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
@@ -1289,6 +1288,10 @@ TEST(Inotify, SymlinkGeneratesCreateEvent) {
}
TEST(Inotify, LinkGeneratesAttribAndCreateEvents) {
+ // Inotify does not work properly with hard links in gofer and overlay fs.
+ SKIP_IF(IsRunningOnGvisor() &&
+ !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir())));
+
const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
const TempPath file1 =
ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
@@ -1301,11 +1304,8 @@ TEST(Inotify, LinkGeneratesAttribAndCreateEvents) {
const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE(
InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
- const int rc = link(file1.path().c_str(), link1.path().c_str());
- // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox.
- SKIP_IF(IsRunningOnGvisor() && rc != 0 &&
- (errno == EPERM || errno == ENOENT));
- ASSERT_THAT(rc, SyscallSucceeds());
+ ASSERT_THAT(link(file1.path().c_str(), link1.path().c_str()),
+ SyscallSucceeds());
const std::vector<Event> events =
ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
@@ -1334,68 +1334,70 @@ TEST(Inotify, UtimesGeneratesAttribEvent) {
}
TEST(Inotify, HardlinksReuseSameWatch) {
+ // Inotify does not work properly with hard links in gofer and overlay fs.
+ SKIP_IF(IsRunningOnGvisor() &&
+ !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir())));
+
const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
- TempPath file1 =
+ TempPath file =
ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
- TempPath link1(root.path() + "/link1");
- const int rc = link(file1.path().c_str(), link1.path().c_str());
- // link(2) is only supported on tmpfs in the sandbox.
- SKIP_IF(IsRunningOnGvisor() && rc != 0 &&
- (errno == EPERM || errno == ENOENT));
- ASSERT_THAT(rc, SyscallSucceeds());
+
+ TempPath file2(root.path() + "/file2");
+ ASSERT_THAT(link(file.path().c_str(), file2.path().c_str()),
+ SyscallSucceeds());
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
const int root_wd = ASSERT_NO_ERRNO_AND_VALUE(
InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
- const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE(
- InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
- const int link1_wd = ASSERT_NO_ERRNO_AND_VALUE(
- InotifyAddWatch(fd.get(), link1.path(), IN_ALL_EVENTS));
+ const int file_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file.path(), IN_ALL_EVENTS));
+ const int file2_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file2.path(), IN_ALL_EVENTS));
// The watch descriptors for watches on different links to the same file
// should be identical.
- EXPECT_NE(root_wd, file1_wd);
- EXPECT_EQ(file1_wd, link1_wd);
+ EXPECT_NE(root_wd, file_wd);
+ EXPECT_EQ(file_wd, file2_wd);
- FileDescriptor file1_fd =
- ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY));
+ FileDescriptor file_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY));
std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
ASSERT_THAT(events,
- AreUnordered({Event(IN_OPEN, root_wd, Basename(file1.path())),
- Event(IN_OPEN, file1_wd)}));
+ AreUnordered({Event(IN_OPEN, root_wd, Basename(file.path())),
+ Event(IN_OPEN, file_wd)}));
// For the next step, we want to ensure all fds to the file are closed. Do
// that now and drain the resulting events.
- file1_fd.reset();
+ file_fd.reset();
events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
ASSERT_THAT(
events,
- AreUnordered({Event(IN_CLOSE_WRITE, root_wd, Basename(file1.path())),
- Event(IN_CLOSE_WRITE, file1_wd)}));
+ AreUnordered({Event(IN_CLOSE_WRITE, root_wd, Basename(file.path())),
+ Event(IN_CLOSE_WRITE, file_wd)}));
// Try removing the link and let's see what events show up. Note that after
// this, we still have a link to the file so the watch shouldn't be
// automatically removed.
- const std::string link1_path = link1.reset();
+ const std::string file2_path = file2.reset();
events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
ASSERT_THAT(events,
- AreUnordered({Event(IN_ATTRIB, link1_wd),
- Event(IN_DELETE, root_wd, Basename(link1_path))}));
+ AreUnordered({Event(IN_ATTRIB, file2_wd),
+ Event(IN_DELETE, root_wd, Basename(file2_path))}));
// Now remove the other link. Since this is the last link to the file, the
// watch should be automatically removed.
- const std::string file1_path = file1.reset();
+ const std::string file_path = file.reset();
events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
ASSERT_THAT(
events,
- AreUnordered({Event(IN_ATTRIB, file1_wd), Event(IN_DELETE_SELF, file1_wd),
- Event(IN_IGNORED, file1_wd),
- Event(IN_DELETE, root_wd, Basename(file1_path))}));
+ AreUnordered({Event(IN_ATTRIB, file_wd), Event(IN_DELETE_SELF, file_wd),
+ Event(IN_IGNORED, file_wd),
+ Event(IN_DELETE, root_wd, Basename(file_path))}));
}
// Calling mkdir within "parent/child" should generate an event for child, but
@@ -1806,17 +1808,17 @@ TEST(Inotify, SpliceOnInotifyFD) {
// Watches on a parent should not be triggered by actions on a hard link to one
// of its children that has a different parent.
TEST(Inotify, LinkOnOtherParent) {
+ // Inotify does not work properly with hard links in gofer and overlay fs.
+ SKIP_IF(IsRunningOnGvisor() &&
+ !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir())));
+
const TempPath dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
const TempPath dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
const TempPath file =
ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path()));
std::string link_path = NewTempAbsPathInDir(dir2.path());
- const int rc = link(file.path().c_str(), link_path.c_str());
- // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox.
- SKIP_IF(IsRunningOnGvisor() && rc != 0 &&
- (errno == EPERM || errno == ENOENT));
- ASSERT_THAT(rc, SyscallSucceeds());
+ ASSERT_THAT(link(file.path().c_str(), link_path.c_str()), SyscallSucceeds());
const FileDescriptor inotify_fd =
ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
@@ -1825,13 +1827,18 @@ TEST(Inotify, LinkOnOtherParent) {
// Perform various actions on the link outside of dir1, which should trigger
// no inotify events.
- const FileDescriptor fd =
+ FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(link_path.c_str(), O_RDWR));
int val = 0;
ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds());
ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds());
ASSERT_THAT(ftruncate(fd.get(), 12345), SyscallSucceeds());
+
+ // Close before unlinking; some filesystems cannot restore an open fd on an
+ // unlinked file.
+ fd.reset();
ASSERT_THAT(unlink(link_path.c_str()), SyscallSucceeds());
+
const std::vector<Event> events =
ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
EXPECT_THAT(events, Are({}));
@@ -2055,21 +2062,21 @@ TEST(Inotify, ExcludeUnlinkDirectory_NoRandomSave) {
// We need to disable S/R because there are filesystems where we cannot re-open
// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
TEST(Inotify, ExcludeUnlinkMultipleChildren_NoRandomSave) {
- const DisableSave ds;
+ // Inotify does not work properly with hard links in gofer and overlay fs.
+ SKIP_IF(IsRunningOnGvisor() &&
+ !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir())));
// TODO(gvisor.dev/issue/1624): This test fails on VFS1.
SKIP_IF(IsRunningWithVFS1());
+ const DisableSave ds;
+
const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
const TempPath file =
ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
std::string path1 = file.path();
std::string path2 = NewTempAbsPathInDir(dir.path());
+ ASSERT_THAT(link(path1.c_str(), path2.c_str()), SyscallSucceeds());
- const int rc = link(path1.c_str(), path2.c_str());
- // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox.
- SKIP_IF(IsRunningOnGvisor() && rc != 0 &&
- (errno == EPERM || errno == ENOENT));
- ASSERT_THAT(rc, SyscallSucceeds());
const FileDescriptor fd1 =
ASSERT_NO_ERRNO_AND_VALUE(Open(path1.c_str(), O_RDWR));
const FileDescriptor fd2 =
@@ -2101,6 +2108,15 @@ TEST(Inotify, ExcludeUnlinkMultipleChildren_NoRandomSave) {
// We need to disable S/R because there are filesystems where we cannot re-open
// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) {
+ // TODO(gvisor.dev/issue/1624): Fails on VFS1.
+ SKIP_IF(IsRunningWithVFS1());
+
+ // NOTE(gvisor.dev/issue/3654): In the gofer filesystem, we do not allow
+ // setting attributes through an fd if the file at the open path has been
+ // deleted.
+ SKIP_IF(IsRunningOnGvisor() &&
+ !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir())));
+
const DisableSave ds;
const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
@@ -2110,18 +2126,6 @@ TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(file.path().c_str(), O_RDWR));
- // NOTE(b/157163751): Create another link before unlinking. This is needed for
- // the gofer filesystem in gVisor, where open fds will not work once the link
- // count hits zero. In VFS2, we end up skipping the gofer test anyway, because
- // hard links are not supported for gofer fs.
- if (IsRunningOnGvisor()) {
- std::string link_path = NewTempAbsPath();
- const int rc = link(file.path().c_str(), link_path.c_str());
- // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox.
- SKIP_IF(rc != 0 && (errno == EPERM || errno == ENOENT));
- ASSERT_THAT(rc, SyscallSucceeds());
- }
-
const FileDescriptor inotify_fd =
ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
diff --git a/test/syscalls/linux/ip6tables.cc b/test/syscalls/linux/ip6tables.cc
index 78e1fa09d..de0a1c114 100644
--- a/test/syscalls/linux/ip6tables.cc
+++ b/test/syscalls/linux/ip6tables.cc
@@ -82,13 +82,37 @@ TEST(IP6TablesBasic, GetEntriesErrorPrecedence) {
SyscallFailsWithErrno(EINVAL));
}
+TEST(IP6TablesBasic, GetRevision) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int sock;
+ ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW),
+ SyscallSucceeds());
+
+ struct xt_get_revision rev = {
+ .name = "REDIRECT",
+ .revision = 0,
+ };
+ socklen_t rev_len = sizeof(rev);
+
+ // Revision 0 exists.
+ EXPECT_THAT(
+ getsockopt(sock, SOL_IPV6, IP6T_SO_GET_REVISION_TARGET, &rev, &rev_len),
+ SyscallSucceeds());
+ EXPECT_EQ(rev.revision, 0);
+
+ // Revisions > 0 don't exist.
+ rev.revision = 1;
+ EXPECT_THAT(
+ getsockopt(sock, SOL_IPV6, IP6T_SO_GET_REVISION_TARGET, &rev, &rev_len),
+ SyscallFailsWithErrno(EPROTONOSUPPORT));
+}
+
// This tests the initial state of a machine with empty ip6tables via
// getsockopt(IP6T_SO_GET_INFO). We don't have a guarantee that the iptables are
// empty when running in native, but we can test that gVisor has the same
// initial state that a newly-booted Linux machine would have.
TEST(IP6TablesTest, InitialInfo) {
- // TODO(gvisor.dev/issue/3549): Enable for ip6tables.
- SKIP_IF(true);
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
FileDescriptor sock =
@@ -132,8 +156,6 @@ TEST(IP6TablesTest, InitialInfo) {
// are empty when running in native, but we can test that gVisor has the same
// initial state that a newly-booted Linux machine would have.
TEST(IP6TablesTest, InitialEntries) {
- // TODO(gvisor.dev/issue/3549): Enable for ip6tables.
- SKIP_IF(true);
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
FileDescriptor sock =
diff --git a/test/syscalls/linux/iptables.cc b/test/syscalls/linux/iptables.cc
index 83b6a164a..7ee10bbde 100644
--- a/test/syscalls/linux/iptables.cc
+++ b/test/syscalls/linux/iptables.cc
@@ -117,6 +117,32 @@ TEST(IPTablesBasic, OriginalDstErrors) {
SyscallFailsWithErrno(ENOTCONN));
}
+TEST(IPTablesBasic, GetRevision) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int sock;
+ ASSERT_THAT(sock = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP),
+ SyscallSucceeds());
+
+ struct xt_get_revision rev = {
+ .name = "REDIRECT",
+ .revision = 0,
+ };
+ socklen_t rev_len = sizeof(rev);
+
+ // Revision 0 exists.
+ EXPECT_THAT(
+ getsockopt(sock, SOL_IP, IPT_SO_GET_REVISION_TARGET, &rev, &rev_len),
+ SyscallSucceeds());
+ EXPECT_EQ(rev.revision, 0);
+
+ // Revisions > 0 don't exist.
+ rev.revision = 1;
+ EXPECT_THAT(
+ getsockopt(sock, SOL_IP, IPT_SO_GET_REVISION_TARGET, &rev, &rev_len),
+ SyscallFailsWithErrno(EPROTONOSUPPORT));
+}
+
// Fixture for iptables tests.
class IPTablesTest : public ::testing::Test {
protected:
diff --git a/test/syscalls/linux/kcov.cc b/test/syscalls/linux/kcov.cc
index f3c30444e..6afcb4e75 100644
--- a/test/syscalls/linux/kcov.cc
+++ b/test/syscalls/linux/kcov.cc
@@ -36,12 +36,13 @@ TEST(KcovTest, Kcov) {
constexpr int kSize = 4096;
constexpr int KCOV_INIT_TRACE = 0x80086301;
constexpr int KCOV_ENABLE = 0x6364;
+ constexpr int KCOV_DISABLE = 0x6365;
int fd;
ASSERT_THAT(fd = open("/sys/kernel/debug/kcov", O_RDWR),
AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT)));
- // Kcov not enabled.
+ // Kcov not available.
SKIP_IF(errno == ENOENT);
ASSERT_THAT(ioctl(fd, KCOV_INIT_TRACE, kSize), SyscallSucceeds());
@@ -62,6 +63,8 @@ TEST(KcovTest, Kcov) {
// Verify that PCs are in the standard kernel range.
EXPECT_GT(area[i], 0xffffffff7fffffffL);
}
+
+ ASSERT_THAT(ioctl(fd, KCOV_DISABLE, 0), SyscallSucceeds());
}
} // namespace
diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc
index 4036a9275..27758203d 100644
--- a/test/syscalls/linux/mkdir.cc
+++ b/test/syscalls/linux/mkdir.cc
@@ -82,6 +82,13 @@ TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) {
SyscallFailsWithErrno(EACCES));
}
+TEST_F(MkdirTest, MkdirAtEmptyPath) {
+ ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds());
+ auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dirname_, O_RDONLY | O_DIRECTORY, 0666));
+ EXPECT_THAT(mkdirat(fd.get(), "", 0777), SyscallFailsWithErrno(ENOENT));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/mknod.cc b/test/syscalls/linux/mknod.cc
index 2ba8c11b8..ae65d366b 100644
--- a/test/syscalls/linux/mknod.cc
+++ b/test/syscalls/linux/mknod.cc
@@ -105,11 +105,13 @@ TEST(MknodTest, UnimplementedTypesReturnError) {
}
TEST(MknodTest, Socket) {
+ SKIP_IF(IsRunningOnGvisor() && IsRunningWithVFS1());
+
ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds());
- SKIP_IF(IsRunningOnGvisor() && IsRunningWithVFS1());
+ auto filename = NewTempRelPath();
- ASSERT_THAT(mknod("./file0", S_IFSOCK | S_IRUSR | S_IWUSR, 0),
+ ASSERT_THAT(mknod(filename.c_str(), S_IFSOCK | S_IRUSR | S_IWUSR, 0),
SyscallSucceeds());
int sk;
@@ -117,9 +119,10 @@ TEST(MknodTest, Socket) {
FileDescriptor fd(sk);
struct sockaddr_un addr = {.sun_family = AF_UNIX};
- absl::SNPrintF(addr.sun_path, sizeof(addr.sun_path), "./file0");
+ absl::SNPrintF(addr.sun_path, sizeof(addr.sun_path), "%s", filename.c_str());
ASSERT_THAT(connect(sk, (struct sockaddr *)&addr, sizeof(addr)),
SyscallFailsWithErrno(ECONNREFUSED));
+ ASSERT_THAT(unlink(filename.c_str()), SyscallSucceeds());
}
TEST(MknodTest, Fifo) {
@@ -203,6 +206,14 @@ TEST(MknodTest, FifoTruncNoOp) {
EXPECT_THAT(ftruncate(wfd.get(), 0), SyscallFailsWithErrno(EINVAL));
}
+TEST(MknodTest, MknodAtEmptyPath) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY, 0666));
+ EXPECT_THAT(mknodat(fd.get(), "", S_IFREG | 0777, 0),
+ SyscallFailsWithErrno(ENOENT));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc
index 46b6f38db..3aab25b23 100644
--- a/test/syscalls/linux/mount.cc
+++ b/test/syscalls/linux/mount.cc
@@ -147,8 +147,15 @@ TEST(MountTest, UmountDetach) {
// Unmount the tmpfs.
mount.Release()();
- const struct stat after2 = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
- EXPECT_EQ(before.st_ino, after2.st_ino);
+ // Only check for inode number equality if the directory is not in overlayfs.
+ // If xino option is not enabled and if all overlayfs layers do not belong to
+ // the same filesystem then "the value of st_ino for directory objects may not
+ // be persistent and could change even while the overlay filesystem is
+ // mounted." -- Documentation/filesystems/overlayfs.txt
+ if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) {
+ const struct stat after2 = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
+ EXPECT_EQ(before.st_ino, after2.st_ino);
+ }
// Can still read file after unmounting.
std::vector<char> buf(sizeof(kContents));
@@ -213,8 +220,15 @@ TEST(MountTest, MountTmpfs) {
}
// Now that dir is unmounted again, we should have the old inode back.
- const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
- EXPECT_EQ(before.st_ino, after.st_ino);
+ // Only check for inode number equality if the directory is not in overlayfs.
+ // If xino option is not enabled and if all overlayfs layers do not belong to
+ // the same filesystem then "the value of st_ino for directory objects may not
+ // be persistent and could change even while the overlay filesystem is
+ // mounted." -- Documentation/filesystems/overlayfs.txt
+ if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) {
+ const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
+ EXPECT_EQ(before.st_ino, after.st_ino);
+ }
}
TEST(MountTest, MountTmpfsMagicValIgnored) {
diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc
index 51eacf3f2..78c36f98f 100644
--- a/test/syscalls/linux/open_create.cc
+++ b/test/syscalls/linux/open_create.cc
@@ -88,21 +88,21 @@ TEST(CreateTest, CreateExclusively) {
SyscallFailsWithErrno(EEXIST));
}
-TEST(CreateTeast, CreatWithOTrunc) {
+TEST(CreateTest, CreatWithOTrunc) {
std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd");
ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds());
ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC, 0666),
SyscallFailsWithErrno(EISDIR));
}
-TEST(CreateTeast, CreatDirWithOTruncAndReadOnly) {
+TEST(CreateTest, CreatDirWithOTruncAndReadOnly) {
std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd");
ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds());
ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666),
SyscallFailsWithErrno(EISDIR));
}
-TEST(CreateTeast, CreatFileWithOTruncAndReadOnly) {
+TEST(CreateTest, CreatFileWithOTruncAndReadOnly) {
std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile");
int dirfd;
ASSERT_THAT(dirfd = open(dirpath.c_str(), O_RDWR | O_CREAT, 0666),
@@ -149,6 +149,116 @@ TEST(CreateTest, OpenCreateROThenRW) {
EXPECT_THAT(WriteFd(fd2.get(), &c, 1), SyscallSucceedsWithValue(1));
}
+TEST(CreateTest, ChmodReadToWriteBetweenOpens_NoRandomSave) {
+ // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to
+ // override file read/write permissions. CAP_DAC_READ_SEARCH needs to be
+ // cleared for the same reason.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0400));
+
+ const FileDescriptor rfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ // Cannot restore after making permissions more restrictive.
+ const DisableSave ds;
+ ASSERT_THAT(fchmod(rfd.get(), 0200), SyscallSucceeds());
+
+ EXPECT_THAT(open(file.path().c_str(), O_RDONLY),
+ SyscallFailsWithErrno(EACCES));
+
+ const FileDescriptor wfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY));
+
+ char c = 'x';
+ EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1));
+ c = 0;
+ EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(c, 'x');
+}
+
+TEST(CreateTest, ChmodWriteToReadBetweenOpens_NoRandomSave) {
+ // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to
+ // override file read/write permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0200));
+
+ const FileDescriptor wfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY));
+
+ // Cannot restore after making permissions more restrictive.
+ const DisableSave ds;
+ ASSERT_THAT(fchmod(wfd.get(), 0400), SyscallSucceeds());
+
+ EXPECT_THAT(open(file.path().c_str(), O_WRONLY),
+ SyscallFailsWithErrno(EACCES));
+
+ const FileDescriptor rfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ char c = 'x';
+ EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1));
+ c = 0;
+ EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(c, 'x');
+}
+
+TEST(CreateTest, CreateWithReadFlagNotAllowedByMode_NoRandomSave) {
+ // The only time we can open a file with flags forbidden by its permissions
+ // is when we are creating the file. We cannot re-open with the same flags,
+ // so we cannot restore an fd obtained from such an operation.
+ const DisableSave ds;
+
+ // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to
+ // override file read/write permissions. CAP_DAC_READ_SEARCH needs to be
+ // cleared for the same reason.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ // Create and open a file with read flag but without read permissions.
+ const std::string path = NewTempAbsPath();
+ const FileDescriptor rfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CREAT | O_RDONLY, 0222));
+
+ EXPECT_THAT(open(path.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES));
+ const FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_WRONLY));
+
+ char c = 'x';
+ EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1));
+ c = 0;
+ EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(c, 'x');
+}
+
+TEST(CreateTest, CreateWithWriteFlagNotAllowedByMode_NoRandomSave) {
+ // The only time we can open a file with flags forbidden by its permissions
+ // is when we are creating the file. We cannot re-open with the same flags,
+ // so we cannot restore an fd obtained from such an operation.
+ const DisableSave ds;
+
+ // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to
+ // override file read/write permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+
+ // Create and open a file with write flag but without write permissions.
+ const std::string path = NewTempAbsPath();
+ const FileDescriptor wfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CREAT | O_WRONLY, 0444));
+
+ EXPECT_THAT(open(path.c_str(), O_WRONLY), SyscallFailsWithErrno(EACCES));
+ const FileDescriptor rfd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDONLY));
+
+ char c = 'x';
+ EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1));
+ c = 0;
+ EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(c, 'x');
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index f3c1d6bc9..b558e3a01 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -643,6 +643,27 @@ TEST_P(RawPacketTest, GetSocketDetachFilter) {
SyscallFailsWithErrno(ENOPROTOOPT));
}
+TEST_P(RawPacketTest, SetAndGetSocketLinger) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int level = SOL_SOCKET;
+ int type = SO_LINGER;
+
+ struct linger sl;
+ sl.l_onoff = 1;
+ sl.l_linger = 5;
+ ASSERT_THAT(setsockopt(s_, level, type, &sl, sizeof(sl)),
+ SyscallSucceedsWithValue(0));
+
+ struct linger got_linger = {};
+ socklen_t length = sizeof(sl);
+ ASSERT_THAT(getsockopt(s_, level, type, &got_linger, &length),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_EQ(length, sizeof(got_linger));
+ EXPECT_EQ(0, memcmp(&sl, &got_linger, length));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest,
::testing::Values(ETH_P_IP, ETH_P_ALL));
diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc
index 04c5161f5..f675dc430 100644
--- a/test/syscalls/linux/prctl.cc
+++ b/test/syscalls/linux/prctl.cc
@@ -153,7 +153,7 @@ TEST(PrctlTest, PDeathSig) {
// Enable tracing, then raise SIGSTOP and expect our parent to suppress
// it.
TEST_CHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) >= 0);
- raise(SIGSTOP);
+ TEST_CHECK(raise(SIGSTOP) == 0);
// Sleep until killed by our parent death signal. sleep(3) is
// async-signal-safe, absl::SleepFor isn't.
while (true) {
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index b73189e55..e8fcc4439 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -47,6 +47,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/container/node_hash_set.h"
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
@@ -63,6 +64,7 @@
#include "test/util/fs_util.h"
#include "test/util/memory_util.h"
#include "test/util/posix_error.h"
+#include "test/util/proc_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
@@ -672,6 +674,23 @@ TEST(ProcSelfMaps, Mprotect) {
3 * kPageSize, PROT_READ)));
}
+TEST(ProcSelfMaps, SharedAnon) {
+ const Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ, MAP_SHARED | MAP_ANONYMOUS));
+
+ const auto proc_self_maps =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps"));
+ for (const auto& line : absl::StrSplit(proc_self_maps, '\n')) {
+ const auto entry = ASSERT_NO_ERRNO_AND_VALUE(ParseProcMapsLine(line));
+ if (entry.start <= m.addr() && m.addr() < entry.end) {
+ // cf. proc(5), "/proc/[pid]/map_files/"
+ EXPECT_EQ(entry.filename, "/dev/zero (deleted)");
+ return;
+ }
+ }
+ FAIL() << "no maps entry containing mapping at " << m.ptr();
+}
+
TEST(ProcSelfFd, OpenFd) {
int pipe_fds[2];
ASSERT_THAT(pipe2(pipe_fds, O_CLOEXEC), SyscallSucceeds());
@@ -694,6 +713,30 @@ TEST(ProcSelfFd, OpenFd) {
ASSERT_THAT(close(pipe_fds[1]), SyscallSucceeds());
}
+static void CheckFdDirGetdentsDuplicates(const std::string& path) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path.c_str(), O_RDONLY | O_DIRECTORY));
+ // Open a FD whose value is supposed to be much larger than
+ // the number of FDs opened by current process.
+ auto newfd = fcntl(fd.get(), F_DUPFD, 1024);
+ EXPECT_GE(newfd, 1024);
+ auto fd_closer = Cleanup([newfd]() { close(newfd); });
+ auto fd_files = ASSERT_NO_ERRNO_AND_VALUE(ListDir(path.c_str(), false));
+ absl::node_hash_set<std::string> fd_files_dedup(fd_files.begin(),
+ fd_files.end());
+ EXPECT_EQ(fd_files.size(), fd_files_dedup.size());
+}
+
+// This is a regression test for gvisor.dev/issues/3894
+TEST(ProcSelfFd, GetdentsDuplicates) {
+ CheckFdDirGetdentsDuplicates("/proc/self/fd");
+}
+
+// This is a regression test for gvisor.dev/issues/3894
+TEST(ProcSelfFdInfo, GetdentsDuplicates) {
+ CheckFdDirGetdentsDuplicates("/proc/self/fdinfo");
+}
+
TEST(ProcSelfFdInfo, CorrectFds) {
// Make sure there is at least one open file.
auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
@@ -737,8 +780,12 @@ TEST(ProcSelfFdInfo, Flags) {
}
TEST(ProcSelfExe, Absolute) {
- auto exe = ASSERT_NO_ERRNO_AND_VALUE(
- ReadLink(absl::StrCat("/proc/", getpid(), "/exe")));
+ auto exe = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/exe"));
+ EXPECT_EQ(exe[0], '/');
+}
+
+TEST(ProcSelfCwd, Absolute) {
+ auto exe = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/cwd"));
EXPECT_EQ(exe[0], '/');
}
@@ -773,17 +820,12 @@ TEST(ProcCpuinfo, DeniesWriteNonRoot) {
constexpr int kNobody = 65534;
EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds());
EXPECT_THAT(open("/proc/cpuinfo", O_WRONLY), SyscallFailsWithErrno(EACCES));
- // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in
- // kernfs.
- if (!IsRunningOnGvisor() || IsRunningWithVFS1()) {
- EXPECT_THAT(truncate("/proc/cpuinfo", 123),
- SyscallFailsWithErrno(EACCES));
- }
+ EXPECT_THAT(truncate("/proc/cpuinfo", 123), SyscallFailsWithErrno(EACCES));
});
}
// With root privileges, it is possible to open /proc/cpuinfo with write mode,
-// but all write operations will return EIO.
+// but all write operations should fail.
TEST(ProcCpuinfo, DeniesWriteRoot) {
// VFS1 does not behave differently for root/non-root.
SKIP_IF(IsRunningWithVFS1());
@@ -792,16 +834,10 @@ TEST(ProcCpuinfo, DeniesWriteRoot) {
int fd;
EXPECT_THAT(fd = open("/proc/cpuinfo", O_WRONLY), SyscallSucceeds());
if (fd > 0) {
- EXPECT_THAT(write(fd, "x", 1), SyscallFailsWithErrno(EIO));
- EXPECT_THAT(pwrite(fd, "x", 1, 123), SyscallFailsWithErrno(EIO));
- }
- // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in
- // kernfs.
- if (!IsRunningOnGvisor() || IsRunningWithVFS1()) {
- if (fd > 0) {
- EXPECT_THAT(ftruncate(fd, 123), SyscallFailsWithErrno(EIO));
- }
- EXPECT_THAT(truncate("/proc/cpuinfo", 123), SyscallFailsWithErrno(EIO));
+ // Truncate is not tested--it may succeed on some kernels without doing
+ // anything.
+ EXPECT_THAT(write(fd, "x", 1), SyscallFails());
+ EXPECT_THAT(pwrite(fd, "x", 1, 123), SyscallFails());
}
}
@@ -1441,6 +1477,16 @@ TEST(ProcPidExe, Subprocess) {
EXPECT_EQ(actual, expected_absolute_path);
}
+// /proc/PID/cwd points to the correct directory.
+TEST(ProcPidCwd, Subprocess) {
+ auto want = ASSERT_NO_ERRNO_AND_VALUE(GetCWD());
+
+ char got[PATH_MAX + 1] = {};
+ ASSERT_THAT(ReadlinkWhileRunning("cwd", got, sizeof(got)),
+ SyscallSucceedsWithValue(Gt(0)));
+ EXPECT_EQ(got, want);
+}
+
// Test whether /proc/PID/ files can be read for a running process.
TEST(ProcPidFile, SubprocessRunning) {
char buf[1];
diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc
index 4fab097f4..23677e296 100644
--- a/test/syscalls/linux/proc_net.cc
+++ b/test/syscalls/linux/proc_net.cc
@@ -39,6 +39,7 @@ namespace testing {
namespace {
constexpr const char kProcNet[] = "/proc/net";
+constexpr const char kIpForward[] = "/proc/sys/net/ipv4/ip_forward";
TEST(ProcNetSymlinkTarget, FileMode) {
struct stat s;
@@ -515,6 +516,46 @@ TEST(ProcSysNetIpv4Recovery, CanReadAndWrite) {
SyscallSucceedsWithValue(sizeof(kMessage)));
EXPECT_EQ(strcmp(buf, "100\n"), 0);
}
+
+TEST(ProcSysNetIpv4IpForward, Exists) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kIpForward, O_RDONLY));
+}
+
+TEST(ProcSysNetIpv4IpForward, DefaultValueEqZero) {
+ // Test is only valid in sandbox. Not hermetic in native tests
+ // running on a arbitrary machine.
+ SKIP_IF(!IsRunningOnGvisor());
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kIpForward, O_RDONLY));
+
+ char buf = 101;
+ EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_EQ(buf, '0') << "unexpected ip_forward: " << buf;
+}
+
+TEST(ProcSysNetIpv4IpForward, CanReadAndWrite) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE))));
+
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kIpForward, O_RDWR));
+
+ char buf;
+ EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_TRUE(buf == '0' || buf == '1') << "unexpected ip_forward: " << buf;
+
+ // constexpr char to_write = '1';
+ char to_write = (buf == '1') ? '0' : '1';
+ EXPECT_THAT(PwriteFd(fd.get(), &to_write, sizeof(to_write), 0),
+ SyscallSucceedsWithValue(sizeof(to_write)));
+
+ buf = 0;
+ EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ EXPECT_EQ(buf, to_write);
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc
index 2e4ab6ca8..0b174e2be 100644
--- a/test/syscalls/linux/pty.cc
+++ b/test/syscalls/linux/pty.cc
@@ -71,7 +71,7 @@ constexpr absl::Duration kTimeout = absl::Seconds(20);
// The maximum line size in bytes returned per read from a pty file.
constexpr int kMaxLineSize = 4096;
-constexpr char kMainPath[] = "/dev/ptmx";
+constexpr char kMasterPath[] = "/dev/ptmx";
// glibc defines its own, different, version of struct termios. We care about
// what the kernel does, not glibc.
@@ -388,22 +388,22 @@ PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count,
TEST(PtyTrunc, Truncate) {
// Opening PTYs with O_TRUNC shouldn't cause an error, but calls to
// (f)truncate should.
- FileDescriptor main =
- ASSERT_NO_ERRNO_AND_VALUE(Open(kMainPath, O_RDWR | O_TRUNC));
- int n = ASSERT_NO_ERRNO_AND_VALUE(ReplicaID(main));
+ FileDescriptor master =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(kMasterPath, O_RDWR | O_TRUNC));
+ int n = ASSERT_NO_ERRNO_AND_VALUE(ReplicaID(master));
std::string spath = absl::StrCat("/dev/pts/", n);
FileDescriptor replica =
ASSERT_NO_ERRNO_AND_VALUE(Open(spath, O_RDWR | O_NONBLOCK | O_TRUNC));
- EXPECT_THAT(truncate(kMainPath, 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(truncate(kMasterPath, 0), SyscallFailsWithErrno(EINVAL));
EXPECT_THAT(truncate(spath.c_str(), 0), SyscallFailsWithErrno(EINVAL));
- EXPECT_THAT(ftruncate(main.get(), 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(ftruncate(master.get(), 0), SyscallFailsWithErrno(EINVAL));
EXPECT_THAT(ftruncate(replica.get(), 0), SyscallFailsWithErrno(EINVAL));
}
-TEST(BasicPtyTest, StatUnopenedMain) {
+TEST(BasicPtyTest, StatUnopenedMaster) {
struct stat s;
- ASSERT_THAT(stat(kMainPath, &s), SyscallSucceeds());
+ ASSERT_THAT(stat(kMasterPath, &s), SyscallSucceeds());
EXPECT_EQ(s.st_rdev, makedev(TTYAUX_MAJOR, kPtmxMinor));
EXPECT_EQ(s.st_size, 0);
@@ -454,41 +454,41 @@ void ExpectReadable(const FileDescriptor& fd, int expected, char* buf) {
EXPECT_EQ(expected, n);
}
-TEST(BasicPtyTest, OpenMainReplica) {
- FileDescriptor main = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR));
- FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(main));
+TEST(BasicPtyTest, OpenMasterReplica) {
+ FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR));
+ FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master));
}
-// The replica entry in /dev/pts/ disappears when the main is closed, even if
+// The replica entry in /dev/pts/ disappears when the master is closed, even if
// the replica is still open.
-TEST(BasicPtyTest, ReplicaEntryGoneAfterMainClose) {
- FileDescriptor main = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR));
- FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(main));
+TEST(BasicPtyTest, ReplicaEntryGoneAfterMasterClose) {
+ FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR));
+ FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master));
// Get pty index.
int index = -1;
- ASSERT_THAT(ioctl(main.get(), TIOCGPTN, &index), SyscallSucceeds());
+ ASSERT_THAT(ioctl(master.get(), TIOCGPTN, &index), SyscallSucceeds());
std::string path = absl::StrCat("/dev/pts/", index);
struct stat st;
EXPECT_THAT(stat(path.c_str(), &st), SyscallSucceeds());
- main.reset();
+ master.reset();
EXPECT_THAT(stat(path.c_str(), &st), SyscallFailsWithErrno(ENOENT));
}
TEST(BasicPtyTest, Getdents) {
- FileDescriptor main1 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR));
+ FileDescriptor master1 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR));
int index1 = -1;
- ASSERT_THAT(ioctl(main1.get(), TIOCGPTN, &index1), SyscallSucceeds());
- FileDescriptor replica1 = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(main1));
+ ASSERT_THAT(ioctl(master1.get(), TIOCGPTN, &index1), SyscallSucceeds());
+ FileDescriptor replica1 = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master1));
- FileDescriptor main2 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR));
+ FileDescriptor master2 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR));
int index2 = -1;
- ASSERT_THAT(ioctl(main2.get(), TIOCGPTN, &index2), SyscallSucceeds());
- FileDescriptor replica2 = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(main2));
+ ASSERT_THAT(ioctl(master2.get(), TIOCGPTN, &index2), SyscallSucceeds());
+ FileDescriptor replica2 = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master2));
// The directory contains ptmx, index1, and index2. (Plus any additional PTYs
// unrelated to this test.)
@@ -498,9 +498,9 @@ TEST(BasicPtyTest, Getdents) {
EXPECT_THAT(contents, Contains(absl::StrCat(index1)));
EXPECT_THAT(contents, Contains(absl::StrCat(index2)));
- main2.reset();
+ master2.reset();
- // The directory contains ptmx and index1, but not index2 since the main is
+ // The directory contains ptmx and index1, but not index2 since the master is
// closed. (Plus any additional PTYs unrelated to this test.)
contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/dev/pts/", true));
@@ -519,8 +519,8 @@ TEST(BasicPtyTest, Getdents) {
class PtyTest : public ::testing::Test {
protected:
void SetUp() override {
- main_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK));
- replica_ = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(main_));
+ master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK));
+ replica_ = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master_));
}
void DisableCanonical() {
@@ -537,21 +537,22 @@ class PtyTest : public ::testing::Test {
EXPECT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds());
}
- // Main and replica ends of the PTY. Non-blocking.
- FileDescriptor main_;
+ // Master and replica ends of the PTY. Non-blocking.
+ FileDescriptor master_;
FileDescriptor replica_;
};
-// Main to replica sanity test.
-TEST_F(PtyTest, WriteMainToReplica) {
- // N.B. by default, the replica reads nothing until the main writes a newline.
+// Master to replica sanity test.
+TEST_F(PtyTest, WriteMasterToReplica) {
+ // N.B. by default, the replica reads nothing until the master writes a
+ // newline.
constexpr char kBuf[] = "hello\n";
- EXPECT_THAT(WriteFd(main_.get(), kBuf, sizeof(kBuf) - 1),
+ EXPECT_THAT(WriteFd(master_.get(), kBuf, sizeof(kBuf) - 1),
SyscallSucceedsWithValue(sizeof(kBuf) - 1));
- // Linux moves data from the main to the replica via async work scheduled via
- // tty_flip_buffer_push. Since it is asynchronous, the data may not be
+ // Linux moves data from the master to the replica via async work scheduled
+ // via tty_flip_buffer_push. Since it is asynchronous, the data may not be
// available for reading immediately. Instead we must poll and assert that it
// becomes available "soon".
@@ -561,63 +562,63 @@ TEST_F(PtyTest, WriteMainToReplica) {
EXPECT_EQ(memcmp(buf, kBuf, sizeof(kBuf)), 0);
}
-// Replica to main sanity test.
-TEST_F(PtyTest, WriteReplicaToMain) {
- // N.B. by default, the main reads nothing until the replica writes a newline,
- // and the main gets a carriage return.
+// Replica to master sanity test.
+TEST_F(PtyTest, WriteReplicaToMaster) {
+ // N.B. by default, the master reads nothing until the replica writes a
+ // newline, and the master gets a carriage return.
constexpr char kInput[] = "hello\n";
constexpr char kExpected[] = "hello\r\n";
EXPECT_THAT(WriteFd(replica_.get(), kInput, sizeof(kInput) - 1),
SyscallSucceedsWithValue(sizeof(kInput) - 1));
- // Linux moves data from the main to the replica via async work scheduled via
- // tty_flip_buffer_push. Since it is asynchronous, the data may not be
+ // Linux moves data from the master to the replica via async work scheduled
+ // via tty_flip_buffer_push. Since it is asynchronous, the data may not be
// available for reading immediately. Instead we must poll and assert that it
// becomes available "soon".
char buf[sizeof(kExpected)] = {};
- ExpectReadable(main_, sizeof(buf) - 1, buf);
+ ExpectReadable(master_, sizeof(buf) - 1, buf);
EXPECT_EQ(memcmp(buf, kExpected, sizeof(kExpected)), 0);
}
TEST_F(PtyTest, WriteInvalidUTF8) {
char c = 0xff;
- ASSERT_THAT(syscall(__NR_write, main_.get(), &c, sizeof(c)),
+ ASSERT_THAT(syscall(__NR_write, master_.get(), &c, sizeof(c)),
SyscallSucceedsWithValue(sizeof(c)));
}
-// Both the main and replica report the standard default termios settings.
+// Both the master and replica report the standard default termios settings.
//
-// Note that TCGETS on the main actually redirects to the replica (see comment
-// on MainTermiosUnchangable).
+// Note that TCGETS on the master actually redirects to the replica (see comment
+// on MasterTermiosUnchangable).
TEST_F(PtyTest, DefaultTermios) {
struct kernel_termios t = {};
EXPECT_THAT(ioctl(replica_.get(), TCGETS, &t), SyscallSucceeds());
EXPECT_EQ(t, DefaultTermios());
- EXPECT_THAT(ioctl(main_.get(), TCGETS, &t), SyscallSucceeds());
+ EXPECT_THAT(ioctl(master_.get(), TCGETS, &t), SyscallSucceeds());
EXPECT_EQ(t, DefaultTermios());
}
-// Changing termios from the main actually affects the replica.
+// Changing termios from the master actually affects the replica.
//
-// TCSETS on the main actually redirects to the replica (see comment on
-// MainTermiosUnchangable).
+// TCSETS on the master actually redirects to the replica (see comment on
+// MasterTermiosUnchangable).
TEST_F(PtyTest, TermiosAffectsReplica) {
- struct kernel_termios main_termios = {};
- EXPECT_THAT(ioctl(main_.get(), TCGETS, &main_termios), SyscallSucceeds());
- main_termios.c_lflag ^= ICANON;
- EXPECT_THAT(ioctl(main_.get(), TCSETS, &main_termios), SyscallSucceeds());
+ struct kernel_termios master_termios = {};
+ EXPECT_THAT(ioctl(master_.get(), TCGETS, &master_termios), SyscallSucceeds());
+ master_termios.c_lflag ^= ICANON;
+ EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds());
struct kernel_termios replica_termios = {};
EXPECT_THAT(ioctl(replica_.get(), TCGETS, &replica_termios),
SyscallSucceeds());
- EXPECT_EQ(main_termios, replica_termios);
+ EXPECT_EQ(master_termios, replica_termios);
}
-// The main end of the pty has termios:
+// The master end of the pty has termios:
//
// struct kernel_termios t = {
// .c_iflag = 0;
@@ -629,25 +630,25 @@ TEST_F(PtyTest, TermiosAffectsReplica) {
//
// (From drivers/tty/pty.c:unix98_pty_init)
//
-// All termios control ioctls on the main actually redirect to the replica
+// All termios control ioctls on the master actually redirect to the replica
// (drivers/tty/tty_ioctl.c:tty_mode_ioctl), making it impossible to change the
-// main termios.
+// master termios.
//
// Verify this by setting ICRNL (which rewrites input \r to \n) and verify that
-// it has no effect on the main.
-TEST_F(PtyTest, MainTermiosUnchangable) {
- struct kernel_termios main_termios = {};
- EXPECT_THAT(ioctl(main_.get(), TCGETS, &main_termios), SyscallSucceeds());
- main_termios.c_lflag |= ICRNL;
- EXPECT_THAT(ioctl(main_.get(), TCSETS, &main_termios), SyscallSucceeds());
+// it has no effect on the master.
+TEST_F(PtyTest, MasterTermiosUnchangable) {
+ struct kernel_termios master_termios = {};
+ EXPECT_THAT(ioctl(master_.get(), TCGETS, &master_termios), SyscallSucceeds());
+ master_termios.c_lflag |= ICRNL;
+ EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds());
char c = '\r';
ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1));
- ExpectReadable(main_, 1, &c);
+ ExpectReadable(master_, 1, &c);
EXPECT_EQ(c, '\r'); // ICRNL had no effect!
- ExpectFinished(main_);
+ ExpectFinished(master_);
}
// ICRNL rewrites input \r to \n.
@@ -658,7 +659,7 @@ TEST_F(PtyTest, TermiosICRNL) {
ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds());
char c = '\r';
- ASSERT_THAT(WriteFd(main_.get(), &c, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1));
ExpectReadable(replica_, 1, &c);
EXPECT_EQ(c, '\n');
@@ -678,7 +679,7 @@ TEST_F(PtyTest, TermiosONLCR) {
// Extra byte for NUL for EXPECT_STREQ.
char buf[3] = {};
- ExpectReadable(main_, 2, buf);
+ ExpectReadable(master_, 2, buf);
EXPECT_STREQ(buf, "\r\n");
ExpectFinished(replica_);
@@ -691,7 +692,7 @@ TEST_F(PtyTest, TermiosIGNCR) {
ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds());
char c = '\r';
- ASSERT_THAT(WriteFd(main_.get(), &c, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1));
// Nothing to read.
ASSERT_THAT(PollAndReadFd(replica_.get(), &c, 1, kTimeout),
@@ -725,18 +726,18 @@ TEST_F(PtyTest, TermiosPollReplica) {
absl::SleepFor(absl::Seconds(1));
char s[] = "foo\n";
- ASSERT_THAT(WriteFd(main_.get(), s, strlen(s) + 1), SyscallSucceeds());
+ ASSERT_THAT(WriteFd(master_.get(), s, strlen(s) + 1), SyscallSucceeds());
}
-// Test that we can successfully poll for readable data from the main.
-TEST_F(PtyTest, TermiosPollMain) {
+// Test that we can successfully poll for readable data from the master.
+TEST_F(PtyTest, TermiosPollMaster) {
struct kernel_termios t = DefaultTermios();
t.c_iflag |= IGNCR;
t.c_lflag &= ~ICANON; // for byte-by-byte reading.
- ASSERT_THAT(ioctl(main_.get(), TCSETS, &t), SyscallSucceeds());
+ ASSERT_THAT(ioctl(master_.get(), TCSETS, &t), SyscallSucceeds());
absl::Notification notify;
- int mfd = main_.get();
+ int mfd = master_.get();
ScopedThread th([mfd, &notify]() {
notify.Notify();
@@ -765,7 +766,7 @@ TEST_F(PtyTest, TermiosINLCR) {
ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds());
char c = '\n';
- ASSERT_THAT(WriteFd(main_.get(), &c, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1));
ExpectReadable(replica_, 1, &c);
EXPECT_EQ(c, '\r');
@@ -784,7 +785,7 @@ TEST_F(PtyTest, TermiosONOCR) {
ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1));
// Nothing to read.
- ASSERT_THAT(PollAndReadFd(main_.get(), &c, 1, kTimeout),
+ ASSERT_THAT(PollAndReadFd(master_.get(), &c, 1, kTimeout),
PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out")));
// This time the column is greater than 0, so we should be able to read the CR
@@ -795,17 +796,17 @@ TEST_F(PtyTest, TermiosONOCR) {
SyscallSucceedsWithValue(kInputSize));
char buf[kInputSize] = {};
- ExpectReadable(main_, kInputSize, buf);
+ ExpectReadable(master_, kInputSize, buf);
EXPECT_EQ(memcmp(buf, kInput, kInputSize), 0);
- ExpectFinished(main_);
+ ExpectFinished(master_);
// Terminal should be at column 0 again, so no CR can be read.
ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1));
// Nothing to read.
- ASSERT_THAT(PollAndReadFd(main_.get(), &c, 1, kTimeout),
+ ASSERT_THAT(PollAndReadFd(master_.get(), &c, 1, kTimeout),
PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out")));
}
@@ -819,10 +820,10 @@ TEST_F(PtyTest, TermiosOCRNL) {
char c = '\r';
ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1));
- ExpectReadable(main_, 1, &c);
+ ExpectReadable(master_, 1, &c);
EXPECT_EQ(c, '\n');
- ExpectFinished(main_);
+ ExpectFinished(master_);
}
// Tests that VEOL is disabled when we start, and that we can set it to enable
@@ -830,7 +831,7 @@ TEST_F(PtyTest, TermiosOCRNL) {
TEST_F(PtyTest, VEOLTermination) {
// Write a few bytes ending with '\0', and confirm that we can't read.
constexpr char kInput[] = "hello";
- ASSERT_THAT(WriteFd(main_.get(), kInput, sizeof(kInput)),
+ ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)),
SyscallSucceedsWithValue(sizeof(kInput)));
char buf[sizeof(kInput)] = {};
ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(kInput), kTimeout),
@@ -841,7 +842,7 @@ TEST_F(PtyTest, VEOLTermination) {
struct kernel_termios t = DefaultTermios();
t.c_cc[VEOL] = delim;
ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds());
- ASSERT_THAT(WriteFd(main_.get(), &delim, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1));
// Now we can read, as sending EOL caused the line to become available.
ExpectReadable(replica_, sizeof(kInput), buf);
@@ -861,7 +862,7 @@ TEST_F(PtyTest, CanonBigWrite) {
char input[kWriteLen];
memset(input, 'M', kWriteLen - 1);
input[kWriteLen - 1] = '\n';
- ASSERT_THAT(WriteFd(main_.get(), input, kWriteLen),
+ ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen),
SyscallSucceedsWithValue(kWriteLen));
// We can read the line.
@@ -877,7 +878,7 @@ TEST_F(PtyTest, SwitchCanonToNoncanon) {
// Write a few bytes without a terminating character, switch to noncanonical
// mode, and read them.
constexpr char kInput[] = "hello";
- ASSERT_THAT(WriteFd(main_.get(), kInput, sizeof(kInput)),
+ ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)),
SyscallSucceedsWithValue(sizeof(kInput)));
// Nothing available yet.
@@ -896,7 +897,7 @@ TEST_F(PtyTest, SwitchCanonToNoncanon) {
TEST_F(PtyTest, SwitchCanonToNonCanonNewline) {
// Write a few bytes with a terminating character.
constexpr char kInput[] = "hello\n";
- ASSERT_THAT(WriteFd(main_.get(), kInput, sizeof(kInput)),
+ ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)),
SyscallSucceedsWithValue(sizeof(kInput)));
DisableCanonical();
@@ -916,12 +917,12 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNewlineBig) {
constexpr int kWriteLen = 4100;
char input[kWriteLen];
memset(input, 'M', kWriteLen);
- ASSERT_THAT(WriteFd(main_.get(), input, kWriteLen),
+ ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen),
SyscallSucceedsWithValue(kWriteLen));
// Wait for the input queue to fill.
ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1));
constexpr char delim = '\n';
- ASSERT_THAT(WriteFd(main_.get(), &delim, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1));
EnableCanonical();
@@ -941,7 +942,7 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNoNewline) {
// Write a few bytes without a terminating character.
// mode, and read them.
constexpr char kInput[] = "hello";
- ASSERT_THAT(WriteFd(main_.get(), kInput, sizeof(kInput) - 1),
+ ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput) - 1),
SyscallSucceedsWithValue(sizeof(kInput) - 1));
ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(kInput) - 1));
@@ -963,7 +964,7 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNoNewlineBig) {
constexpr int kWriteLen = 4100;
char input[kWriteLen];
memset(input, 'M', kWriteLen);
- ASSERT_THAT(WriteFd(main_.get(), input, kWriteLen),
+ ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen),
SyscallSucceedsWithValue(kWriteLen));
ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1));
@@ -987,12 +988,12 @@ TEST_F(PtyTest, NoncanonBigWrite) {
for (int i = 0; i < kInputSize; i++) {
// This makes too many syscalls for save/restore.
const DisableSave ds;
- ASSERT_THAT(WriteFd(main_.get(), &kInput, sizeof(kInput)),
+ ASSERT_THAT(WriteFd(master_.get(), &kInput, sizeof(kInput)),
SyscallSucceedsWithValue(sizeof(kInput)));
}
// We should be able to read out everything. Sleep a bit so that Linux has a
- // chance to move data from the main to the replica.
+ // chance to move data from the master to the replica.
ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1));
for (int i = 0; i < kInputSize; i++) {
// This makes too many syscalls for save/restore.
@@ -1010,7 +1011,7 @@ TEST_F(PtyTest, NoncanonBigWrite) {
// Test newline.
TEST_F(PtyTest, TermiosICANONNewline) {
char input[3] = {'a', 'b', 'c'};
- ASSERT_THAT(WriteFd(main_.get(), input, sizeof(input)),
+ ASSERT_THAT(WriteFd(master_.get(), input, sizeof(input)),
SyscallSucceedsWithValue(sizeof(input)));
// Extra bytes for newline (written later) and NUL for EXPECT_STREQ.
@@ -1021,7 +1022,7 @@ TEST_F(PtyTest, TermiosICANONNewline) {
PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out")));
char delim = '\n';
- ASSERT_THAT(WriteFd(main_.get(), &delim, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1));
// Now it is available.
ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(input) + 1));
@@ -1036,7 +1037,7 @@ TEST_F(PtyTest, TermiosICANONNewline) {
// Test EOF (^D).
TEST_F(PtyTest, TermiosICANONEOF) {
char input[3] = {'a', 'b', 'c'};
- ASSERT_THAT(WriteFd(main_.get(), input, sizeof(input)),
+ ASSERT_THAT(WriteFd(master_.get(), input, sizeof(input)),
SyscallSucceedsWithValue(sizeof(input)));
// Extra byte for NUL for EXPECT_STREQ.
@@ -1046,7 +1047,7 @@ TEST_F(PtyTest, TermiosICANONEOF) {
ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(input), kTimeout),
PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out")));
char delim = ControlCharacter('D');
- ASSERT_THAT(WriteFd(main_.get(), &delim, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1));
// Now it is available. Note that ^D is not included.
ExpectReadable(replica_, sizeof(input), buf);
@@ -1069,10 +1070,10 @@ TEST_F(PtyTest, CanonDiscard) {
// This makes too many syscalls for save/restore.
const DisableSave ds;
for (int i = 0; i < kInputSize; i++) {
- ASSERT_THAT(WriteFd(main_.get(), &kInput, sizeof(kInput)),
+ ASSERT_THAT(WriteFd(master_.get(), &kInput, sizeof(kInput)),
SyscallSucceedsWithValue(sizeof(kInput)));
}
- ASSERT_THAT(WriteFd(main_.get(), &delim, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1));
}
// There should be multiple truncated lines available to read.
@@ -1091,9 +1092,9 @@ TEST_F(PtyTest, CanonMultiline) {
constexpr char kInput2[] = "BLUE\n";
// Write both lines.
- ASSERT_THAT(WriteFd(main_.get(), kInput1, sizeof(kInput1) - 1),
+ ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1),
SyscallSucceedsWithValue(sizeof(kInput1) - 1));
- ASSERT_THAT(WriteFd(main_.get(), kInput2, sizeof(kInput2) - 1),
+ ASSERT_THAT(WriteFd(master_.get(), kInput2, sizeof(kInput2) - 1),
SyscallSucceedsWithValue(sizeof(kInput2) - 1));
// Get the first line.
@@ -1117,9 +1118,9 @@ TEST_F(PtyTest, SwitchNoncanonToCanonMultiline) {
constexpr char kExpected[] = "GO\nBLUE\n";
// Write both lines.
- ASSERT_THAT(WriteFd(main_.get(), kInput1, sizeof(kInput1) - 1),
+ ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1),
SyscallSucceedsWithValue(sizeof(kInput1) - 1));
- ASSERT_THAT(WriteFd(main_.get(), kInput2, sizeof(kInput2) - 1),
+ ASSERT_THAT(WriteFd(master_.get(), kInput2, sizeof(kInput2) - 1),
SyscallSucceedsWithValue(sizeof(kInput2) - 1));
ASSERT_NO_ERRNO(
@@ -1140,7 +1141,7 @@ TEST_F(PtyTest, SwitchTwiceMultiline) {
// Write each line.
for (const std::string& input : kInputs) {
- ASSERT_THAT(WriteFd(main_.get(), input.c_str(), input.size()),
+ ASSERT_THAT(WriteFd(master_.get(), input.c_str(), input.size()),
SyscallSucceedsWithValue(input.size()));
}
@@ -1162,7 +1163,7 @@ TEST_F(PtyTest, SwitchTwiceMultiline) {
TEST_F(PtyTest, QueueSize) {
// Write the line.
constexpr char kInput1[] = "GO\n";
- ASSERT_THAT(WriteFd(main_.get(), kInput1, sizeof(kInput1) - 1),
+ ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1),
SyscallSucceedsWithValue(sizeof(kInput1) - 1));
ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(kInput1) - 1));
@@ -1170,7 +1171,7 @@ TEST_F(PtyTest, QueueSize) {
// readable size.
char input[kMaxLineSize];
memset(input, 'M', kMaxLineSize);
- ASSERT_THAT(WriteFd(main_.get(), input, kMaxLineSize),
+ ASSERT_THAT(WriteFd(master_.get(), input, kMaxLineSize),
SyscallSucceedsWithValue(kMaxLineSize));
int inputBufSize = ASSERT_NO_ERRNO_AND_VALUE(
WaitUntilReceived(replica_.get(), sizeof(kInput1) - 1));
@@ -1192,10 +1193,11 @@ TEST_F(PtyTest, PartialBadBuffer) {
// Leave only one free byte in the buffer.
char* bad_buffer = buf + kPageSize - 1;
- // Write to the main.
+ // Write to the master.
constexpr char kBuf[] = "hello\n";
constexpr size_t size = sizeof(kBuf) - 1;
- EXPECT_THAT(WriteFd(main_.get(), kBuf, size), SyscallSucceedsWithValue(size));
+ EXPECT_THAT(WriteFd(master_.get(), kBuf, size),
+ SyscallSucceedsWithValue(size));
// Read from the replica into bad_buffer.
ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), size));
@@ -1207,14 +1209,14 @@ TEST_F(PtyTest, PartialBadBuffer) {
TEST_F(PtyTest, SimpleEcho) {
constexpr char kInput[] = "Mr. Eko";
- EXPECT_THAT(WriteFd(main_.get(), kInput, strlen(kInput)),
+ EXPECT_THAT(WriteFd(master_.get(), kInput, strlen(kInput)),
SyscallSucceedsWithValue(strlen(kInput)));
char buf[100] = {};
- ExpectReadable(main_, strlen(kInput), buf);
+ ExpectReadable(master_, strlen(kInput), buf);
EXPECT_STREQ(buf, kInput);
- ExpectFinished(main_);
+ ExpectFinished(master_);
}
TEST_F(PtyTest, GetWindowSize) {
@@ -1231,16 +1233,17 @@ TEST_F(PtyTest, SetReplicaWindowSize) {
ASSERT_THAT(ioctl(replica_.get(), TIOCSWINSZ, &ws), SyscallSucceeds());
struct winsize retrieved_ws = {};
- ASSERT_THAT(ioctl(main_.get(), TIOCGWINSZ, &retrieved_ws), SyscallSucceeds());
+ ASSERT_THAT(ioctl(master_.get(), TIOCGWINSZ, &retrieved_ws),
+ SyscallSucceeds());
EXPECT_EQ(retrieved_ws.ws_row, kRows);
EXPECT_EQ(retrieved_ws.ws_col, kCols);
}
-TEST_F(PtyTest, SetMainWindowSize) {
+TEST_F(PtyTest, SetMasterWindowSize) {
constexpr uint16_t kRows = 343;
constexpr uint16_t kCols = 2401;
struct winsize ws = {.ws_row = kRows, .ws_col = kCols};
- ASSERT_THAT(ioctl(main_.get(), TIOCSWINSZ, &ws), SyscallSucceeds());
+ ASSERT_THAT(ioctl(master_.get(), TIOCSWINSZ, &ws), SyscallSucceeds());
struct winsize retrieved_ws = {};
ASSERT_THAT(ioctl(replica_.get(), TIOCGWINSZ, &retrieved_ws),
@@ -1252,8 +1255,8 @@ TEST_F(PtyTest, SetMainWindowSize) {
class JobControlTest : public ::testing::Test {
protected:
void SetUp() override {
- main_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK));
- replica_ = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(main_));
+ master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK));
+ replica_ = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master_));
// Make this a session leader, which also drops the controlling terminal.
// In the gVisor test environment, this test will be run as the session
@@ -1277,15 +1280,15 @@ class JobControlTest : public ::testing::Test {
return PosixError(wstatus, "process returned");
}
- // Main and replica ends of the PTY. Non-blocking.
- FileDescriptor main_;
+ // Master and replica ends of the PTY. Non-blocking.
+ FileDescriptor master_;
FileDescriptor replica_;
};
-TEST_F(JobControlTest, SetTTYMain) {
+TEST_F(JobControlTest, SetTTYMaster) {
auto res = RunInChild([=]() {
TEST_PCHECK(setsid() >= 0);
- TEST_PCHECK(!ioctl(main_.get(), TIOCSCTTY, 0));
+ TEST_PCHECK(!ioctl(master_.get(), TIOCSCTTY, 0));
});
ASSERT_NO_ERRNO(res);
}
@@ -1360,7 +1363,7 @@ TEST_F(JobControlTest, ReleaseWrongTTY) {
auto res = RunInChild([=]() {
TEST_PCHECK(setsid() >= 0);
TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0));
- TEST_PCHECK(ioctl(main_.get(), TIOCNOTTY) < 0 && errno == ENOTTY);
+ TEST_PCHECK(ioctl(master_.get(), TIOCNOTTY) < 0 && errno == ENOTTY);
});
ASSERT_NO_ERRNO(res);
}
diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc
index a534cf0bb..4ac648729 100644
--- a/test/syscalls/linux/pty_root.cc
+++ b/test/syscalls/linux/pty_root.cc
@@ -48,9 +48,9 @@ TEST(JobControlRootTest, StealTTY) {
ASSERT_THAT(setsid(), SyscallSucceeds());
}
- FileDescriptor main =
+ FileDescriptor master =
ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK));
- FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(main));
+ FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master));
// Make replica the controlling terminal.
ASSERT_THAT(ioctl(replica.get(), TIOCSCTTY, 0), SyscallSucceeds());
diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc
index 3de898df7..1b9dbc584 100644
--- a/test/syscalls/linux/raw_socket_icmp.cc
+++ b/test/syscalls/linux/raw_socket_icmp.cc
@@ -416,6 +416,28 @@ TEST_F(RawSocketICMPTest, BindConnectSendAndReceive) {
ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp));
}
+// Set and get SO_LINGER.
+TEST_F(RawSocketICMPTest, SetAndGetSocketLinger) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int level = SOL_SOCKET;
+ int type = SO_LINGER;
+
+ struct linger sl;
+ sl.l_onoff = 1;
+ sl.l_linger = 5;
+ ASSERT_THAT(setsockopt(s_, level, type, &sl, sizeof(sl)),
+ SyscallSucceedsWithValue(0));
+
+ struct linger got_linger = {};
+ socklen_t length = sizeof(sl);
+ ASSERT_THAT(getsockopt(s_, level, type, &got_linger, &length),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_EQ(length, sizeof(got_linger));
+ EXPECT_EQ(0, memcmp(&sl, &got_linger, length));
+}
+
void RawSocketICMPTest::ExpectICMPSuccess(const struct icmphdr& icmp) {
// We're going to receive both the echo request and reply, but the order is
// indeterminate.
diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc
index 833c0dc4f..5458f54ad 100644
--- a/test/syscalls/linux/rename.cc
+++ b/test/syscalls/linux/rename.cc
@@ -170,6 +170,9 @@ TEST(RenameTest, FileOverwritesFile) {
}
TEST(RenameTest, DirectoryOverwritesDirectoryLinkCount) {
+ // Directory link counts are synthetic on overlay filesystems.
+ SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir())));
+
auto parent1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
EXPECT_THAT(Links(parent1.path()), IsPosixErrorOkAndHolds(2));
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 425084228..11fcec443 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -1158,33 +1158,37 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) {
SyscallSucceeds());
ASSERT_THAT(close(conn_fd.release()), SyscallSucceeds());
- // TODO(gvisor.dev/issue/3780): Remove this.
if (IsRunningOnGvisor()) {
- // Wait for the RST to be observed.
+ // Gvisor packet procssing is asynchronous and can take a bit of time in
+ // some cases so we give it a bit of time to process the RST packet before
+ // calling accept.
+ //
+ // There is nothing to poll() on so we have no choice but to use a sleep
+ // here.
absl::SleepFor(absl::Milliseconds(100));
}
sockaddr_storage accept_addr;
socklen_t addrlen = sizeof(accept_addr);
- // TODO(gvisor.dev/issue/3780): Remove this.
- if (IsRunningOnGvisor()) {
- ASSERT_THAT(accept(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&accept_addr), &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
- return;
- }
-
- conn_fd = ASSERT_NO_ERRNO_AND_VALUE(Accept(
+ auto accept_fd = ASSERT_NO_ERRNO_AND_VALUE(Accept(
listen_fd.get(), reinterpret_cast<sockaddr*>(&accept_addr), &addrlen));
ASSERT_EQ(addrlen, listener.addr_len);
- int err;
- socklen_t optlen = sizeof(err);
- ASSERT_THAT(getsockopt(conn_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen),
- SyscallSucceeds());
- ASSERT_EQ(err, ECONNRESET);
- ASSERT_EQ(optlen, sizeof(err));
+ // TODO(gvisor.dev/issue/3812): Remove after SO_ERROR is fixed.
+ if (IsRunningOnGvisor()) {
+ char buf[10];
+ ASSERT_THAT(ReadFd(accept_fd.get(), buf, sizeof(buf)),
+ SyscallFailsWithErrno(ECONNRESET));
+ } else {
+ int err;
+ socklen_t optlen = sizeof(err);
+ ASSERT_THAT(
+ getsockopt(accept_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen),
+ SyscallSucceeds());
+ ASSERT_EQ(err, ECONNRESET);
+ ASSERT_EQ(optlen, sizeof(err));
+ }
}
// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
index 6e4ecd680..3f2c0fdf2 100644
--- a/test/syscalls/linux/socket_ip_udp_generic.cc
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -451,7 +451,7 @@ TEST_P(UDPSocketPairTest, TClassRecvMismatch) {
}
// Test the SO_LINGER option can be set/get on udp socket.
-TEST_P(UDPSocketPairTest, SoLingerFail) {
+TEST_P(UDPSocketPairTest, SetAndGetSocketLinger) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
int level = SOL_SOCKET;
int type = SO_LINGER;
@@ -469,15 +469,7 @@ TEST_P(UDPSocketPairTest, SoLingerFail) {
SyscallSucceedsWithValue(0));
ASSERT_EQ(length, sizeof(got_linger));
- // Linux returns the values which are set in the SetSockOpt for SO_LINGER.
- // In gVisor, we do not store the linger values for UDP as SO_LINGER for UDP
- // is a no-op.
- if (IsRunningOnGvisor()) {
- struct linger want_linger = {};
- EXPECT_EQ(0, memcmp(&want_linger, &got_linger, length));
- } else {
- EXPECT_EQ(0, memcmp(&sl, &got_linger, length));
- }
+ EXPECT_EQ(0, memcmp(&sl, &got_linger, length));
}
} // namespace testing
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
index 02ea05e22..a72c76c97 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -27,6 +27,7 @@
#include "absl/memory/memory.h"
#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/posix_error.h"
#include "test/util/test_util.h"
namespace gvisor {
@@ -73,9 +74,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
- MSG_DONTWAIT),
- SyscallFailsWithErrno(EAGAIN));
+ EXPECT_THAT(
+ RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ PosixErrorIs(EAGAIN, ::testing::_));
}
// Check that not setting a default send interface prevents multicast packets
@@ -207,8 +208,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -262,8 +264,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -317,8 +320,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -372,8 +376,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -431,8 +436,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -490,8 +496,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -545,8 +552,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -600,8 +608,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -659,9 +668,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf),
- MSG_DONTWAIT),
- SyscallFailsWithErrno(EAGAIN));
+ EXPECT_THAT(
+ RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ PosixErrorIs(EAGAIN, ::testing::_));
}
// Check that multicast works when the default send interface is configured by
@@ -717,9 +726,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf),
- MSG_DONTWAIT),
- SyscallFailsWithErrno(EAGAIN));
+ EXPECT_THAT(
+ RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ PosixErrorIs(EAGAIN, ::testing::_));
}
// Check that multicast works when the default send interface is configured by
@@ -775,8 +784,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -834,8 +844,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -907,9 +918,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
- MSG_DONTWAIT),
- SyscallFailsWithErrno(EAGAIN));
+ EXPECT_THAT(
+ RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ PosixErrorIs(EAGAIN, ::testing::_));
}
// Check that dropping a group membership prevents multicast packets from being
@@ -965,9 +976,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
- MSG_DONTWAIT),
- SyscallFailsWithErrno(EAGAIN));
+ EXPECT_THAT(
+ RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ PosixErrorIs(EAGAIN, ::testing::_));
}
TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfZero) {
@@ -1319,9 +1330,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) {
// Check that we received the multicast packet on both sockets.
for (auto& sockets : socket_pairs) {
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(
- RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf,
+ sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
}
@@ -1398,9 +1409,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) {
// Check that we received the multicast packet on both sockets.
for (auto& sockets : socket_pairs) {
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(
- RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf,
+ sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
}
@@ -1421,9 +1432,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) {
char recv_buf[sizeof(send_buf)] = {};
for (auto& sockets : socket_pairs) {
- ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf,
- sizeof(recv_buf), MSG_DONTWAIT),
- SyscallFailsWithErrno(EAGAIN));
+ ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf,
+ sizeof(recv_buf), 1 /*timeout*/),
+ PosixErrorIs(EAGAIN, ::testing::_));
}
}
}
@@ -1474,9 +1485,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
- MSG_DONTWAIT),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1518,9 +1529,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) {
// Check that we don't receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
- MSG_DONTWAIT),
- SyscallFailsWithErrno(EAGAIN));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ PosixErrorIs(EAGAIN, ::testing::_));
}
// Check that a socket can bind to a multicast address and still send out
@@ -1568,9 +1579,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) {
// Check that we received the packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
- MSG_DONTWAIT),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1615,9 +1626,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
- MSG_DONTWAIT),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1666,9 +1677,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) {
// Check that we received the packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
- MSG_DONTWAIT),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(
+ RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1726,17 +1737,17 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) {
// of the other sockets to have received it, but we will check that later.
char recv_buf[sizeof(send_buf)] = {};
EXPECT_THAT(
- RetryEINTR(recv)(last->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ RecvMsgTimeout(last->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(send_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
// Verify that no other messages were received.
for (auto& socket : sockets) {
char recv_buf[kMessageSize] = {};
- EXPECT_THAT(RetryEINTR(recv)(socket->get(), recv_buf, sizeof(recv_buf),
- MSG_DONTWAIT),
- SyscallFailsWithErrno(EAGAIN));
+ EXPECT_THAT(RecvMsgTimeout(socket->get(), recv_buf, sizeof(recv_buf),
+ 1 /*timeout*/),
+ PosixErrorIs(EAGAIN, ::testing::_));
}
}
@@ -2113,12 +2124,12 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) {
// balancing (REUSEPORT) instead of the most recently bound socket
// (REUSEADDR).
char recv_buf[kMessageSize] = {};
- EXPECT_THAT(RetryEINTR(recv)(receiver1->get(), recv_buf, sizeof(recv_buf),
- MSG_DONTWAIT),
- SyscallSucceedsWithValue(kMessageSize));
- EXPECT_THAT(RetryEINTR(recv)(receiver2->get(), recv_buf, sizeof(recv_buf),
- MSG_DONTWAIT),
- SyscallSucceedsWithValue(kMessageSize));
+ EXPECT_THAT(RecvMsgTimeout(receiver1->get(), recv_buf, sizeof(recv_buf),
+ 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(kMessageSize));
+ EXPECT_THAT(RecvMsgTimeout(receiver2->get(), recv_buf, sizeof(recv_buf),
+ 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(kMessageSize));
}
// Test that socket will receive packet info control message.
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
index 79eb48afa..49a0f06d9 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
@@ -15,10 +15,12 @@
#include "test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h"
#include <arpa/inet.h>
+#include <poll.h>
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_netlink_route_util.h"
#include "test/util/capability_util.h"
+#include "test/util/cleanup.h"
namespace gvisor {
namespace testing {
@@ -33,9 +35,23 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) {
// Add an IP address to the loopback interface.
Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
struct in_addr addr;
- EXPECT_EQ(1, inet_pton(AF_INET, "192.0.2.1", &addr));
- EXPECT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET,
+ ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.1", &addr));
+ ASSERT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET,
/*prefixlen=*/24, &addr, sizeof(addr)));
+ Cleanup defer_addr_removal = Cleanup(
+ [loopback_link = std::move(loopback_link), addr = std::move(addr)] {
+ if (IsRunningOnGvisor()) {
+ // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses
+ // via netlink is supported.
+ EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr, sizeof(addr)),
+ PosixErrorIs(EOPNOTSUPP, ::testing::_));
+ } else {
+ EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr,
+ sizeof(addr)));
+ }
+ });
auto snd_sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto rcv_sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -45,10 +61,10 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) {
TestAddress sender_addr("V4NotAssignd1");
sender_addr.addr.ss_family = AF_INET;
sender_addr.addr_len = sizeof(sockaddr_in);
- EXPECT_EQ(1, inet_pton(AF_INET, "192.0.2.2",
+ ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.2",
&(reinterpret_cast<sockaddr_in*>(&sender_addr.addr)
->sin_addr.s_addr)));
- EXPECT_THAT(
+ ASSERT_THAT(
bind(snd_sock->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
sender_addr.addr_len),
SyscallSucceeds());
@@ -58,10 +74,10 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) {
TestAddress receiver_addr("V4NotAssigned2");
receiver_addr.addr.ss_family = AF_INET;
receiver_addr.addr_len = sizeof(sockaddr_in);
- EXPECT_EQ(1, inet_pton(AF_INET, "192.0.2.254",
+ ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.254",
&(reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)
->sin_addr.s_addr)));
- EXPECT_THAT(
+ ASSERT_THAT(
bind(rcv_sock->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
receiver_addr.addr_len),
SyscallSucceeds());
@@ -70,10 +86,10 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) {
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
- EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+ ASSERT_EQ(receiver_addr_len, receiver_addr.addr_len);
char send_buf[kSendBufSize];
RandomizeBuffer(send_buf, kSendBufSize);
- EXPECT_THAT(
+ ASSERT_THAT(
RetryEINTR(sendto)(snd_sock->get(), send_buf, kSendBufSize, 0,
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
receiver_addr.addr_len),
@@ -83,7 +99,126 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) {
char recv_buf[kSendBufSize] = {};
ASSERT_THAT(RetryEINTR(recv)(rcv_sock->get(), recv_buf, kSendBufSize, 0),
SyscallSucceedsWithValue(kSendBufSize));
- EXPECT_EQ(0, memcmp(send_buf, recv_buf, kSendBufSize));
+ ASSERT_EQ(0, memcmp(send_buf, recv_buf, kSendBufSize));
+}
+
+// Tests that broadcast packets are delivered to all interested sockets
+// (wildcard and broadcast address specified sockets).
+//
+// Note, we cannot test the IPv4 Broadcast (255.255.255.255) because we do
+// not have a route to it.
+TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) {
+ constexpr uint16_t kPort = 9876;
+ // Wait up to 20 seconds for the data.
+ constexpr int kPollTimeoutMs = 20000;
+ // Number of sockets per socket type.
+ constexpr int kNumSocketsPerType = 2;
+
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ // Add an IP address to the loopback interface.
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
+ struct in_addr addr;
+ ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.1", &addr));
+ ASSERT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET,
+ 24 /* prefixlen */, &addr, sizeof(addr)));
+ Cleanup defer_addr_removal = Cleanup(
+ [loopback_link = std::move(loopback_link), addr = std::move(addr)] {
+ if (IsRunningOnGvisor()) {
+ // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses
+ // via netlink is supported.
+ EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr, sizeof(addr)),
+ PosixErrorIs(EOPNOTSUPP, ::testing::_));
+ } else {
+ EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr,
+ sizeof(addr)));
+ }
+ });
+
+ TestAddress broadcast_address("SubnetBroadcastAddress");
+ broadcast_address.addr.ss_family = AF_INET;
+ broadcast_address.addr_len = sizeof(sockaddr_in);
+ auto broadcast_address_in =
+ reinterpret_cast<sockaddr_in*>(&broadcast_address.addr);
+ ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.255",
+ &broadcast_address_in->sin_addr.s_addr));
+ broadcast_address_in->sin_port = htons(kPort);
+
+ TestAddress any_address = V4Any();
+ reinterpret_cast<sockaddr_in*>(&any_address.addr)->sin_port = htons(kPort);
+
+ // We create sockets bound to both the wildcard address and the broadcast
+ // address to make sure both of these types of "broadcast interested" sockets
+ // receive broadcast packets.
+ std::vector<std::unique_ptr<FileDescriptor>> socks;
+ for (bool bind_wildcard : {false, true}) {
+ // Create multiple sockets for each type of "broadcast interested"
+ // socket so we can test that all sockets receive the broadcast packet.
+ for (int i = 0; i < kNumSocketsPerType; i++) {
+ auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto idx = socks.size();
+
+ ASSERT_THAT(setsockopt(sock->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0))
+ << "socks[" << idx << "]";
+
+ ASSERT_THAT(setsockopt(sock->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0))
+ << "socks[" << idx << "]";
+
+ if (bind_wildcard) {
+ ASSERT_THAT(
+ bind(sock->get(), reinterpret_cast<sockaddr*>(&any_address.addr),
+ any_address.addr_len),
+ SyscallSucceeds())
+ << "socks[" << idx << "]";
+ } else {
+ ASSERT_THAT(bind(sock->get(),
+ reinterpret_cast<sockaddr*>(&broadcast_address.addr),
+ broadcast_address.addr_len),
+ SyscallSucceeds())
+ << "socks[" << idx << "]";
+ }
+
+ socks.push_back(std::move(sock));
+ }
+ }
+
+ char send_buf[kSendBufSize];
+ RandomizeBuffer(send_buf, kSendBufSize);
+
+ // Broadcasts from each socket should be received by every socket (including
+ // the sending socket).
+ for (int w = 0; w < socks.size(); w++) {
+ auto& w_sock = socks[w];
+ ASSERT_THAT(
+ RetryEINTR(sendto)(w_sock->get(), send_buf, kSendBufSize, 0,
+ reinterpret_cast<sockaddr*>(&broadcast_address.addr),
+ broadcast_address.addr_len),
+ SyscallSucceedsWithValue(kSendBufSize))
+ << "write socks[" << w << "]";
+
+ // Check that we received the packet on all sockets.
+ for (int r = 0; r < socks.size(); r++) {
+ auto& r_sock = socks[r];
+
+ struct pollfd poll_fd = {r_sock->get(), POLLIN, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1))
+ << "write socks[" << w << "] & read socks[" << r << "]";
+
+ char recv_buf[kSendBufSize] = {};
+ EXPECT_THAT(RetryEINTR(recv)(r_sock->get(), recv_buf, kSendBufSize, 0),
+ SyscallSucceedsWithValue(kSendBufSize))
+ << "write socks[" << w << "] & read socks[" << r << "]";
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, kSendBufSize))
+ << "write socks[" << w << "] & read socks[" << r << "]";
+ }
+ }
}
} // namespace testing
diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc
index bde1dbb4d..7a0bad4cb 100644
--- a/test/syscalls/linux/socket_netlink_route_util.cc
+++ b/test/syscalls/linux/socket_netlink_route_util.cc
@@ -26,6 +26,62 @@ namespace {
constexpr uint32_t kSeq = 12345;
+// Types of address modifications that may be performed on an interface.
+enum class LinkAddrModification {
+ kAdd,
+ kDelete,
+};
+
+// Populates |hdr| with appripriate values for the modification type.
+PosixError PopulateNlmsghdr(LinkAddrModification modification,
+ struct nlmsghdr* hdr) {
+ switch (modification) {
+ case LinkAddrModification::kAdd:
+ hdr->nlmsg_type = RTM_NEWADDR;
+ hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ return NoError();
+ case LinkAddrModification::kDelete:
+ hdr->nlmsg_type = RTM_DELADDR;
+ hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ return NoError();
+ }
+
+ return PosixError(EINVAL);
+}
+
+// Adds or removes the specified address from the specified interface.
+PosixError LinkModifyLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen,
+ LinkAddrModification modification) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifaddrmsg ifaddr;
+ char attrbuf[512];
+ };
+
+ struct request req = {};
+ PosixError err = PopulateNlmsghdr(modification, &req.hdr);
+ if (!err.ok()) {
+ return err;
+ }
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr));
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifaddr.ifa_index = index;
+ req.ifaddr.ifa_family = family;
+ req.ifaddr.ifa_prefixlen = prefixlen;
+
+ struct rtattr* rta = reinterpret_cast<struct rtattr*>(
+ reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len));
+ rta->rta_type = IFA_LOCAL;
+ rta->rta_len = RTA_LENGTH(addrlen);
+ req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen);
+ memcpy(RTA_DATA(rta), addr, addrlen);
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
} // namespace
PosixError DumpLinks(
@@ -84,31 +140,14 @@ PosixErrorOr<Link> LoopbackLink() {
PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
const void* addr, int addrlen) {
- ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
-
- struct request {
- struct nlmsghdr hdr;
- struct ifaddrmsg ifaddr;
- char attrbuf[512];
- };
-
- struct request req = {};
- req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr));
- req.hdr.nlmsg_type = RTM_NEWADDR;
- req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
- req.hdr.nlmsg_seq = kSeq;
- req.ifaddr.ifa_index = index;
- req.ifaddr.ifa_family = family;
- req.ifaddr.ifa_prefixlen = prefixlen;
-
- struct rtattr* rta = reinterpret_cast<struct rtattr*>(
- reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len));
- rta->rta_type = IFA_LOCAL;
- rta->rta_len = RTA_LENGTH(addrlen);
- req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen);
- memcpy(RTA_DATA(rta), addr, addrlen);
+ return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen,
+ LinkAddrModification::kAdd);
+}
- return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+PosixError LinkDelLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen) {
+ return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen,
+ LinkAddrModification::kDelete);
}
PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) {
diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h
index 149c4a7f6..e5badca70 100644
--- a/test/syscalls/linux/socket_netlink_route_util.h
+++ b/test/syscalls/linux/socket_netlink_route_util.h
@@ -43,6 +43,10 @@ PosixErrorOr<Link> LoopbackLink();
PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
const void* addr, int addrlen);
+// LinkDelLocalAddr removes IFA_LOCAL attribute on the interface.
+PosixError LinkDelLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen);
+
// LinkChangeFlags changes interface flags. E.g. IFF_UP.
PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change);
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
index 53b678e94..e11792309 100644
--- a/test/syscalls/linux/socket_test_util.cc
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -753,6 +753,20 @@ PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size) {
return ret;
}
+PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size,
+ int timeout) {
+ fd_set rfd;
+ struct timeval to = {.tv_sec = timeout, .tv_usec = 0};
+ FD_ZERO(&rfd);
+ FD_SET(sock, &rfd);
+
+ int ret;
+ RETURN_ERROR_IF_SYSCALL_FAIL(ret = select(1, &rfd, NULL, NULL, &to));
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ ret = RetryEINTR(recv)(sock, buf, buf_size, MSG_DONTWAIT));
+ return ret;
+}
+
void RecvNoData(int sock) {
char data = 0;
struct iovec iov;
diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h
index 734b48b96..468bc96e0 100644
--- a/test/syscalls/linux/socket_test_util.h
+++ b/test/syscalls/linux/socket_test_util.h
@@ -467,6 +467,10 @@ PosixError FreeAvailablePort(int port);
// SendMsg converts a buffer to an iovec and adds it to msg before sending it.
PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size);
+// RecvMsgTimeout calls select on sock with timeout and then calls recv on sock.
+PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size,
+ int timeout);
+
// RecvNoData checks that no data is receivable on sock.
void RecvNoData(int sock);
diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc
index 99e77b89e..1edcb15a7 100644
--- a/test/syscalls/linux/socket_unix_stream.cc
+++ b/test/syscalls/linux/socket_unix_stream.cc
@@ -103,6 +103,24 @@ TEST_P(StreamUnixSocketPairTest, Sendto) {
SyscallFailsWithErrno(EISCONN));
}
+TEST_P(StreamUnixSocketPairTest, SetAndGetSocketLinger) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct linger sl = {1, 5};
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)),
+ SyscallSucceedsWithValue(0));
+
+ struct linger got_linger = {};
+ socklen_t length = sizeof(sl);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER,
+ &got_linger, &length),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_EQ(length, sizeof(got_linger));
+ EXPECT_EQ(0, memcmp(&got_linger, &sl, length));
+}
+
INSTANTIATE_TEST_SUITE_P(
AllUnixDomainSockets, StreamUnixSocketPairTest,
::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>(
diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc
index 2503960f3..92260b1e1 100644
--- a/test/syscalls/linux/stat.cc
+++ b/test/syscalls/linux/stat.cc
@@ -97,6 +97,11 @@ TEST_F(StatTest, FstatatSymlink) {
}
TEST_F(StatTest, Nlinks) {
+ // Skip this test if we are testing overlayfs because overlayfs does not
+ // (intentionally) return the correct nlink value for directories.
+ // See fs/overlayfs/inode.c:ovl_getattr().
+ SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir())));
+
TempPath basedir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
// Directory is initially empty, it should contain 2 links (one from itself,
@@ -328,20 +333,23 @@ TEST_F(StatTest, LeadingDoubleSlash) {
// Test that a rename doesn't change the underlying file.
TEST_F(StatTest, StatDoesntChangeAfterRename) {
- const TempPath old_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath old_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
const TempPath new_path(NewTempAbsPath());
struct stat st_old = {};
struct stat st_new = {};
- ASSERT_THAT(stat(old_dir.path().c_str(), &st_old), SyscallSucceeds());
- ASSERT_THAT(rename(old_dir.path().c_str(), new_path.path().c_str()),
+ ASSERT_THAT(stat(old_file.path().c_str(), &st_old), SyscallSucceeds());
+ ASSERT_THAT(rename(old_file.path().c_str(), new_path.path().c_str()),
SyscallSucceeds());
ASSERT_THAT(stat(new_path.path().c_str(), &st_new), SyscallSucceeds());
EXPECT_EQ(st_old.st_nlink, st_new.st_nlink);
EXPECT_EQ(st_old.st_dev, st_new.st_dev);
- EXPECT_EQ(st_old.st_ino, st_new.st_ino);
+ // Overlay filesystems may synthesize directory inode numbers on the fly.
+ if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))) {
+ EXPECT_EQ(st_old.st_ino, st_new.st_ino);
+ }
EXPECT_EQ(st_old.st_mode, st_new.st_mode);
EXPECT_EQ(st_old.st_uid, st_new.st_uid);
EXPECT_EQ(st_old.st_gid, st_new.st_gid);
@@ -378,7 +386,9 @@ TEST_F(StatTest, LinkCountsWithRegularFileChild) {
// This test verifies that inodes remain around when there is an open fd
// after link count hits 0.
-TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoRandomSave) {
+//
+// It is marked NoSave because we don't support saving unlinked files.
+TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoSave) {
// Setting the enviornment variable GVISOR_GOFER_UNCACHED to any value
// will prevent this test from running, see the tmpfs lifecycle.
//
@@ -387,9 +397,6 @@ TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoRandomSave) {
const char* uncached_gofer = getenv("GVISOR_GOFER_UNCACHED");
SKIP_IF(uncached_gofer != nullptr);
- // We don't support saving unlinked files.
- const DisableSave ds;
-
const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
const TempPath child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
dir.path(), "hello", TempPath::kDefaultFileMode));
@@ -432,6 +439,11 @@ TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoRandomSave) {
// Test link counts with a directory as the child.
TEST_F(StatTest, LinkCountsWithDirChild) {
+ // Skip this test if we are testing overlayfs because overlayfs does not
+ // (intentionally) return the correct nlink value for directories.
+ // See fs/overlayfs/inode.c:ovl_getattr().
+ SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir())));
+
const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
// Before a child is added the two links are "." and the link from the parent.
diff --git a/test/syscalls/linux/statfs.cc b/test/syscalls/linux/statfs.cc
index 49f2f156c..f0fb166bd 100644
--- a/test/syscalls/linux/statfs.cc
+++ b/test/syscalls/linux/statfs.cc
@@ -27,10 +27,6 @@ namespace testing {
namespace {
-// From linux/magic.h. For some reason, not defined in the headers for some
-// build environments.
-#define OVERLAYFS_SUPER_MAGIC 0x794c7630
-
TEST(StatfsTest, CannotStatBadPath) {
auto temp_file = NewTempAbsPathInDir("/tmp");
@@ -43,9 +39,6 @@ TEST(StatfsTest, InternalTmpfs) {
struct statfs st;
EXPECT_THAT(statfs(temp_file.path().c_str(), &st), SyscallSucceeds());
- // Note: We could be an overlay or goferfs on some configurations.
- EXPECT_TRUE(st.f_type == TMPFS_MAGIC || st.f_type == OVERLAYFS_SUPER_MAGIC ||
- st.f_type == V9FS_MAGIC);
}
TEST(StatfsTest, InternalDevShm) {
diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc
index f4f6ff490..4d9eba7f0 100644
--- a/test/syscalls/linux/symlink.cc
+++ b/test/syscalls/linux/symlink.cc
@@ -327,6 +327,16 @@ TEST(SymlinkTest, FollowUpdatesATime) {
EXPECT_LT(st_before_follow.st_atime, st_after_follow.st_atime);
}
+TEST(SymlinkTest, SymlinkAtEmptyPath) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY, 0666));
+ EXPECT_THAT(symlinkat(file.path().c_str(), fd.get(), ""),
+ SyscallFailsWithErrno(ENOENT));
+}
+
class ParamSymlinkTest : public ::testing::TestWithParam<std::string> {};
// Test that creating an existing symlink with creat will create the target.
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index ab731db1d..e0981e28a 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -1643,6 +1643,36 @@ TEST_P(SimpleTcpSocketTest, GetSocketDetachFilter) {
SyscallFailsWithErrno(ENOPROTOOPT));
}
+TEST_P(SimpleTcpSocketTest, CloseNonConnectedLingerOption) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ constexpr int kLingerTimeout = 10; // Seconds.
+
+ // Set the SO_LINGER option.
+ struct linger sl = {
+ .l_onoff = 1,
+ .l_linger = kLingerTimeout,
+ };
+ ASSERT_THAT(setsockopt(s.get(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)),
+ SyscallSucceeds());
+
+ struct pollfd poll_fd = {
+ .fd = s.get(),
+ .events = POLLHUP,
+ };
+ constexpr int kPollTimeoutMs = 0;
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ auto const start_time = absl::Now();
+ EXPECT_THAT(close(s.release()), SyscallSucceeds());
+ auto const end_time = absl::Now();
+
+ // Close() should not linger and return immediately.
+ ASSERT_LT((end_time - start_time), absl::Seconds(kLingerTimeout));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));
diff --git a/test/syscalls/linux/truncate.cc b/test/syscalls/linux/truncate.cc
index c988c6380..bfc95ed38 100644
--- a/test/syscalls/linux/truncate.cc
+++ b/test/syscalls/linux/truncate.cc
@@ -196,6 +196,26 @@ TEST(TruncateTest, FtruncateNonWriteable) {
EXPECT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL));
}
+// ftruncate(2) should succeed as long as the file descriptor is writeable,
+// regardless of whether the file permissions allow writing.
+TEST(TruncateTest, FtruncateWithoutWritePermission_NoRandomSave) {
+ // Drop capabilities that allow us to override file permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+
+ // The only time we can open a file with flags forbidden by its permissions
+ // is when we are creating the file. We cannot re-open with the same flags,
+ // so we cannot restore an fd obtained from such an operation.
+ const DisableSave ds;
+ auto path = NewTempAbsPath();
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDWR | O_CREAT, 0444));
+
+ // In goferfs, ftruncate may be converted to a remote truncate operation that
+ // unavoidably requires write permission.
+ SKIP_IF(IsRunningOnGvisor() && !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(path)));
+ ASSERT_THAT(ftruncate(fd.get(), 100), SyscallSucceeds());
+}
+
TEST(TruncateTest, TruncateNonExist) {
EXPECT_THAT(truncate("/foo/bar", 0), SyscallFailsWithErrno(ENOENT));
}
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
index 97db2b321..1a7673317 100644
--- a/test/syscalls/linux/udp_socket.cc
+++ b/test/syscalls/linux/udp_socket.cc
@@ -14,6 +14,9 @@
#include <arpa/inet.h>
#include <fcntl.h>
+
+#include <ctime>
+
#ifdef __linux__
#include <linux/errqueue.h>
#include <linux/filter.h>
@@ -834,8 +837,9 @@ TEST_P(UdpSocketTest, ReceiveBeforeConnect) {
// Receive the data. It works because it was sent before the connect.
char received[sizeof(buf)];
- EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0),
- SyscallSucceedsWithValue(sizeof(received)));
+ EXPECT_THAT(
+ RecvMsgTimeout(bind_.get(), received, sizeof(received), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(sizeof(received)));
EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
// Send again. This time it should not be received.
@@ -924,7 +928,9 @@ TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) {
SyscallSucceedsWithValue(1));
// We should get the data even though read has been shutdown.
- EXPECT_THAT(recv(bind_.get(), received, 2, 0), SyscallSucceedsWithValue(2));
+ EXPECT_THAT(
+ RecvMsgTimeout(bind_.get(), received, 2 /*buf_size*/, 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(2));
// Because we read less than the entire packet length, since it's a packet
// based socket any subsequent reads should return EWOULDBLOCK.
@@ -1692,9 +1698,9 @@ TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) {
sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
SyscallSucceedsWithValue(buf.size()));
std::vector<char> received(buf.size());
- EXPECT_THAT(
- recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT),
- SyscallSucceedsWithValue(received.size()));
+ EXPECT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(),
+ 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(received.size()));
}
{
@@ -1708,9 +1714,9 @@ TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) {
SyscallSucceedsWithValue(buf.size()));
std::vector<char> received(buf.size());
- EXPECT_THAT(
- recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT),
- SyscallSucceedsWithValue(received.size()));
+ ASSERT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(),
+ 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(received.size()));
}
}
@@ -1779,9 +1785,9 @@ TEST_P(UdpSocketTest, RecvBufLimits) {
for (int i = 0; i < sent - 1; i++) {
// Receive the data.
std::vector<char> received(buf.size());
- EXPECT_THAT(
- recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT),
- SyscallSucceedsWithValue(received.size()));
+ EXPECT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(),
+ 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(received.size()));
EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0);
}
diff --git a/test/syscalls/linux/unlink.cc b/test/syscalls/linux/unlink.cc
index 2040375c9..061e2e0f1 100644
--- a/test/syscalls/linux/unlink.cc
+++ b/test/syscalls/linux/unlink.cc
@@ -208,6 +208,20 @@ TEST(RmdirTest, CanRemoveWithTrailingSlashes) {
ASSERT_THAT(rmdir(slashslash.c_str()), SyscallSucceeds());
}
+TEST(UnlinkTest, UnlinkAtEmptyPath) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666));
+ EXPECT_THAT(unlinkat(fd.get(), "", 0), SyscallFailsWithErrno(ENOENT));
+
+ auto dirInDir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path()));
+ auto dirFD = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(dirInDir.path(), O_RDONLY | O_DIRECTORY, 0666));
+ EXPECT_THAT(unlinkat(dirFD.get(), "", AT_REMOVEDIR),
+ SyscallFailsWithErrno(ENOENT));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/vdso_clock_gettime.cc b/test/syscalls/linux/vdso_clock_gettime.cc
index ce1899f45..2a8699a7b 100644
--- a/test/syscalls/linux/vdso_clock_gettime.cc
+++ b/test/syscalls/linux/vdso_clock_gettime.cc
@@ -38,8 +38,6 @@ std::string PrintClockId(::testing::TestParamInfo<clockid_t> info) {
switch (info.param) {
case CLOCK_MONOTONIC:
return "CLOCK_MONOTONIC";
- case CLOCK_REALTIME:
- return "CLOCK_REALTIME";
case CLOCK_BOOTTIME:
return "CLOCK_BOOTTIME";
default:
@@ -47,59 +45,36 @@ std::string PrintClockId(::testing::TestParamInfo<clockid_t> info) {
}
}
-class CorrectVDSOClockTest : public ::testing::TestWithParam<clockid_t> {};
+class MonotonicVDSOClockTest : public ::testing::TestWithParam<clockid_t> {};
-TEST_P(CorrectVDSOClockTest, IsCorrect) {
+TEST_P(MonotonicVDSOClockTest, IsCorrect) {
+ // The VDSO implementation of clock_gettime() uses the TSC. On KVM, sentry and
+ // application TSCs can be very desynchronized; see
+ // sentry/platform/kvm/kvm.vCPU.setSystemTime().
+ SKIP_IF(GvisorPlatform() == Platform::kKVM);
+
+ // Check that when we alternate readings from the clock_gettime syscall and
+ // the VDSO's implementation, we observe the combined sequence as being
+ // monotonic.
struct timespec tvdso, tsys;
absl::Time vdso_time, sys_time;
- uint64_t total_calls = 0;
-
- // It is expected that 82.5% of clock_gettime calls will be less than 100us
- // skewed from the system time.
- // Unfortunately this is not only influenced by the VDSO clock skew, but also
- // by arbitrary scheduling delays and the like. The test is therefore
- // regularly disabled.
- std::map<absl::Duration, std::tuple<double, uint64_t, uint64_t>> confidence =
- {
- {absl::Microseconds(100), std::make_tuple(0.825, 0, 0)},
- {absl::Microseconds(250), std::make_tuple(0.94, 0, 0)},
- {absl::Milliseconds(1), std::make_tuple(0.999, 0, 0)},
- };
-
- absl::Time start = absl::Now();
- while (absl::Now() < start + absl::Seconds(30)) {
- EXPECT_THAT(clock_gettime(GetParam(), &tvdso), SyscallSucceeds());
- EXPECT_THAT(syscall(__NR_clock_gettime, GetParam(), &tsys),
- SyscallSucceeds());
-
+ ASSERT_THAT(syscall(__NR_clock_gettime, GetParam(), &tsys),
+ SyscallSucceeds());
+ sys_time = absl::TimeFromTimespec(tsys);
+ auto end = absl::Now() + absl::Seconds(10);
+ while (absl::Now() < end) {
+ ASSERT_THAT(clock_gettime(GetParam(), &tvdso), SyscallSucceeds());
vdso_time = absl::TimeFromTimespec(tvdso);
-
- for (auto const& conf : confidence) {
- std::get<1>(confidence[conf.first]) +=
- (sys_time - vdso_time) < conf.first;
- }
-
+ EXPECT_LE(sys_time, vdso_time);
+ ASSERT_THAT(syscall(__NR_clock_gettime, GetParam(), &tsys),
+ SyscallSucceeds());
sys_time = absl::TimeFromTimespec(tsys);
-
- for (auto const& conf : confidence) {
- std::get<2>(confidence[conf.first]) +=
- (vdso_time - sys_time) < conf.first;
- }
-
- ++total_calls;
- }
-
- for (auto const& conf : confidence) {
- EXPECT_GE(std::get<1>(conf.second) / static_cast<double>(total_calls),
- std::get<0>(conf.second));
- EXPECT_GE(std::get<2>(conf.second) / static_cast<double>(total_calls),
- std::get<0>(conf.second));
+ EXPECT_LE(vdso_time, sys_time);
}
}
-INSTANTIATE_TEST_SUITE_P(ClockGettime, CorrectVDSOClockTest,
- ::testing::Values(CLOCK_MONOTONIC, CLOCK_REALTIME,
- CLOCK_BOOTTIME),
+INSTANTIATE_TEST_SUITE_P(ClockGettime, MonotonicVDSOClockTest,
+ ::testing::Values(CLOCK_MONOTONIC, CLOCK_BOOTTIME),
PrintClockId);
} // namespace
diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc
index 5510a87a0..bd3f829c4 100644
--- a/test/syscalls/linux/xattr.cc
+++ b/test/syscalls/linux/xattr.cc
@@ -232,7 +232,7 @@ TEST_F(XattrTest, XattrOnInvalidFileTypes) {
EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM));
}
-TEST_F(XattrTest, SetxattrSizeSmallerThanValue) {
+TEST_F(XattrTest, SetXattrSizeSmallerThanValue) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
std::vector<char> val = {'a', 'a'};
@@ -247,7 +247,7 @@ TEST_F(XattrTest, SetxattrSizeSmallerThanValue) {
EXPECT_EQ(buf, expected_buf);
}
-TEST_F(XattrTest, SetxattrZeroSize) {
+TEST_F(XattrTest, SetXattrZeroSize) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
char val = 'a';
@@ -259,7 +259,7 @@ TEST_F(XattrTest, SetxattrZeroSize) {
EXPECT_EQ(buf, '-');
}
-TEST_F(XattrTest, SetxattrSizeTooLarge) {
+TEST_F(XattrTest, SetXattrSizeTooLarge) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
@@ -274,7 +274,7 @@ TEST_F(XattrTest, SetxattrSizeTooLarge) {
EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
}
-TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) {
+TEST_F(XattrTest, SetXattrNullValueAndNonzeroSize) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 1, /*flags=*/0),
@@ -283,7 +283,7 @@ TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) {
EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
}
-TEST_F(XattrTest, SetxattrNullValueAndZeroSize) {
+TEST_F(XattrTest, SetXattrNullValueAndZeroSize) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
@@ -291,7 +291,7 @@ TEST_F(XattrTest, SetxattrNullValueAndZeroSize) {
EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0));
}
-TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) {
+TEST_F(XattrTest, SetXattrValueTooLargeButOKSize) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
std::vector<char> val(XATTR_SIZE_MAX + 1);
@@ -307,7 +307,7 @@ TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) {
EXPECT_EQ(buf, expected_buf);
}
-TEST_F(XattrTest, SetxattrReplaceWithSmaller) {
+TEST_F(XattrTest, SetXattrReplaceWithSmaller) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
std::vector<char> val = {'a', 'a'};
@@ -322,7 +322,7 @@ TEST_F(XattrTest, SetxattrReplaceWithSmaller) {
EXPECT_EQ(buf, expected_buf);
}
-TEST_F(XattrTest, SetxattrReplaceWithLarger) {
+TEST_F(XattrTest, SetXattrReplaceWithLarger) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
std::vector<char> val = {'a', 'a'};
@@ -336,7 +336,7 @@ TEST_F(XattrTest, SetxattrReplaceWithLarger) {
EXPECT_EQ(buf, val);
}
-TEST_F(XattrTest, SetxattrCreateFlag) {
+TEST_F(XattrTest, SetXattrCreateFlag) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE),
@@ -347,7 +347,7 @@ TEST_F(XattrTest, SetxattrCreateFlag) {
EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0));
}
-TEST_F(XattrTest, SetxattrReplaceFlag) {
+TEST_F(XattrTest, SetXattrReplaceFlag) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE),
@@ -359,14 +359,14 @@ TEST_F(XattrTest, SetxattrReplaceFlag) {
EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0));
}
-TEST_F(XattrTest, SetxattrInvalidFlags) {
+TEST_F(XattrTest, SetXattrInvalidFlags) {
const char* path = test_file_name_.c_str();
int invalid_flags = 0xff;
EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, invalid_flags),
SyscallFailsWithErrno(EINVAL));
}
-TEST_F(XattrTest, Getxattr) {
+TEST_F(XattrTest, GetXattr) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
int val = 1234;
@@ -378,7 +378,7 @@ TEST_F(XattrTest, Getxattr) {
EXPECT_EQ(buf, val);
}
-TEST_F(XattrTest, GetxattrSizeSmallerThanValue) {
+TEST_F(XattrTest, GetXattrSizeSmallerThanValue) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
std::vector<char> val = {'a', 'a'};
@@ -390,7 +390,7 @@ TEST_F(XattrTest, GetxattrSizeSmallerThanValue) {
EXPECT_EQ(buf, '-');
}
-TEST_F(XattrTest, GetxattrSizeLargerThanValue) {
+TEST_F(XattrTest, GetXattrSizeLargerThanValue) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
char val = 'a';
@@ -405,7 +405,7 @@ TEST_F(XattrTest, GetxattrSizeLargerThanValue) {
EXPECT_EQ(buf, expected_buf);
}
-TEST_F(XattrTest, GetxattrZeroSize) {
+TEST_F(XattrTest, GetXattrZeroSize) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
char val = 'a';
@@ -418,7 +418,7 @@ TEST_F(XattrTest, GetxattrZeroSize) {
EXPECT_EQ(buf, '-');
}
-TEST_F(XattrTest, GetxattrSizeTooLarge) {
+TEST_F(XattrTest, GetXattrSizeTooLarge) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
char val = 'a';
@@ -434,7 +434,7 @@ TEST_F(XattrTest, GetxattrSizeTooLarge) {
EXPECT_EQ(buf, expected_buf);
}
-TEST_F(XattrTest, GetxattrNullValue) {
+TEST_F(XattrTest, GetXattrNullValue) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
char val = 'a';
@@ -445,7 +445,7 @@ TEST_F(XattrTest, GetxattrNullValue) {
SyscallFailsWithErrno(EFAULT));
}
-TEST_F(XattrTest, GetxattrNullValueAndZeroSize) {
+TEST_F(XattrTest, GetXattrNullValueAndZeroSize) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
char val = 'a';
@@ -461,13 +461,13 @@ TEST_F(XattrTest, GetxattrNullValueAndZeroSize) {
EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(size));
}
-TEST_F(XattrTest, GetxattrNonexistentName) {
+TEST_F(XattrTest, GetXattrNonexistentName) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
}
-TEST_F(XattrTest, Listxattr) {
+TEST_F(XattrTest, ListXattr) {
const char* path = test_file_name_.c_str();
const std::string name = "user.test";
const std::string name2 = "user.test2";
@@ -493,7 +493,7 @@ TEST_F(XattrTest, Listxattr) {
EXPECT_EQ(got, expected);
}
-TEST_F(XattrTest, ListxattrNoXattrs) {
+TEST_F(XattrTest, ListXattrNoXattrs) {
const char* path = test_file_name_.c_str();
std::vector<char> list, expected;
@@ -501,13 +501,13 @@ TEST_F(XattrTest, ListxattrNoXattrs) {
SyscallSucceedsWithValue(0));
EXPECT_EQ(list, expected);
- // Listxattr should succeed if there are no attributes, even if the buffer
+ // ListXattr should succeed if there are no attributes, even if the buffer
// passed in is a nullptr.
EXPECT_THAT(listxattr(path, nullptr, sizeof(list)),
SyscallSucceedsWithValue(0));
}
-TEST_F(XattrTest, ListxattrNullBuffer) {
+TEST_F(XattrTest, ListXattrNullBuffer) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
@@ -516,7 +516,7 @@ TEST_F(XattrTest, ListxattrNullBuffer) {
SyscallFailsWithErrno(EFAULT));
}
-TEST_F(XattrTest, ListxattrSizeTooSmall) {
+TEST_F(XattrTest, ListXattrSizeTooSmall) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
@@ -526,7 +526,7 @@ TEST_F(XattrTest, ListxattrSizeTooSmall) {
SyscallFailsWithErrno(ERANGE));
}
-TEST_F(XattrTest, ListxattrZeroSize) {
+TEST_F(XattrTest, ListXattrZeroSize) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
@@ -615,12 +615,18 @@ TEST_F(XattrTest, TrustedNamespaceWithCapSysAdmin) {
SKIP_IF(IsRunningOnGvisor() &&
!ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(test_file_name_)));
- // Setting/Getting in the trusted namespace requires CAP_SYS_ADMIN.
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
-
const char* path = test_file_name_.c_str();
const char name[] = "trusted.test";
+ // Writing to the trusted.* xattr namespace requires CAP_SYS_ADMIN in the root
+ // user namespace. There's no easy way to check that, other than trying the
+ // operation and seeing what happens. We'll call removexattr because it's
+ // simplest.
+ if (removexattr(path, name) < 0) {
+ SKIP_IF(errno == EPERM);
+ FAIL() << "unexpected errno from removexattr: " << errno;
+ }
+
// Set.
char val = 'a';
size_t size = sizeof(val);
diff --git a/test/util/BUILD b/test/util/BUILD
index 2a17c33ee..fc5fb3a8d 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -46,6 +46,13 @@ cc_library(
)
cc_library(
+ name = "fuse_util",
+ testonly = 1,
+ srcs = ["fuse_util.cc"],
+ hdrs = ["fuse_util.h"],
+)
+
+cc_library(
name = "proc_util",
testonly = 1,
srcs = ["proc_util.cc"],
diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc
index 572675622..b16055dd8 100644
--- a/test/util/fs_util.cc
+++ b/test/util/fs_util.cc
@@ -649,5 +649,19 @@ PosixErrorOr<bool> IsTmpfs(const std::string& path) {
}
#endif // __linux__
+PosixErrorOr<bool> IsOverlayfs(const std::string& path) {
+ struct statfs stat;
+ if (statfs(path.c_str(), &stat)) {
+ if (errno == ENOENT) {
+ // Nothing at path, don't raise this as an error. Instead, just report no
+ // overlayfs at path.
+ return false;
+ }
+ return PosixError(errno,
+ absl::StrFormat("statfs(\"%s\", %#p)", path, &stat));
+ }
+ return stat.f_type == OVERLAYFS_SUPER_MAGIC;
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/util/fs_util.h b/test/util/fs_util.h
index 314637de0..c99cf5eb7 100644
--- a/test/util/fs_util.h
+++ b/test/util/fs_util.h
@@ -38,6 +38,10 @@ constexpr int kOLargeFile = 00400000;
#error "Unknown architecture"
#endif
+// From linux/magic.h. For some reason, not defined in the headers for some
+// build environments.
+#define OVERLAYFS_SUPER_MAGIC 0x794c7630
+
// Returns a status or the current working directory.
PosixErrorOr<std::string> GetCWD();
@@ -184,6 +188,9 @@ PosixErrorOr<std::string> ProcessExePath(int pid);
PosixErrorOr<bool> IsTmpfs(const std::string& path);
#endif // __linux__
+// IsOverlayfs returns true if the file at path is backed by overlayfs.
+PosixErrorOr<bool> IsOverlayfs(const std::string& path);
+
namespace internal {
// Not part of the public API.
std::string JoinPathImpl(std::initializer_list<absl::string_view> paths);
diff --git a/test/util/fuse_util.cc b/test/util/fuse_util.cc
new file mode 100644
index 000000000..027f8386c
--- /dev/null
+++ b/test/util/fuse_util.cc
@@ -0,0 +1,63 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/fuse_util.h"
+
+#include <sys/stat.h>
+#include <sys/types.h>
+
+#include <string>
+
+namespace gvisor {
+namespace testing {
+
+// Create a default FuseAttr struct with specified mode, inode, and size.
+fuse_attr DefaultFuseAttr(mode_t mode, uint64_t inode, uint64_t size) {
+ const int time_sec = 1595436289;
+ const int time_nsec = 134150844;
+ return (struct fuse_attr){
+ .ino = inode,
+ .size = size,
+ .blocks = 4,
+ .atime = time_sec,
+ .mtime = time_sec,
+ .ctime = time_sec,
+ .atimensec = time_nsec,
+ .mtimensec = time_nsec,
+ .ctimensec = time_nsec,
+ .mode = mode,
+ .nlink = 2,
+ .uid = 1234,
+ .gid = 4321,
+ .rdev = 12,
+ .blksize = 4096,
+ };
+}
+
+// Create response body with specified mode, nodeID, and size.
+fuse_entry_out DefaultEntryOut(mode_t mode, uint64_t node_id, uint64_t size) {
+ struct fuse_entry_out default_entry_out = {
+ .nodeid = node_id,
+ .generation = 0,
+ .entry_valid = 0,
+ .attr_valid = 0,
+ .entry_valid_nsec = 0,
+ .attr_valid_nsec = 0,
+ .attr = DefaultFuseAttr(mode, node_id, size),
+ };
+ return default_entry_out;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/fuse_util.h b/test/util/fuse_util.h
new file mode 100644
index 000000000..544fe1b38
--- /dev/null
+++ b/test/util/fuse_util.h
@@ -0,0 +1,75 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_FUSE_UTIL_H_
+#define GVISOR_TEST_UTIL_FUSE_UTIL_H_
+
+#include <linux/fuse.h>
+#include <sys/uio.h>
+
+#include <string>
+#include <vector>
+
+namespace gvisor {
+namespace testing {
+
+// The fundamental generation function with a single argument. If passed by
+// std::string or std::vector<char>, it will call specialized versions as
+// implemented below.
+template <typename T>
+std::vector<struct iovec> FuseGenerateIovecs(T &first) {
+ return {(struct iovec){.iov_base = &first, .iov_len = sizeof(first)}};
+}
+
+// If an argument is of type std::string, it must be used in read-only scenario.
+// Because we are setting up iovec, which contains the original address of a
+// data structure, we have to drop const qualification. Usually used with
+// variable-length payload data.
+template <typename T = std::string>
+std::vector<struct iovec> FuseGenerateIovecs(std::string &first) {
+ // Pad one byte for null-terminate c-string.
+ return {(struct iovec){.iov_base = const_cast<char *>(first.c_str()),
+ .iov_len = first.size() + 1}};
+}
+
+// If an argument is of type std::vector<char>, it must be used in write-only
+// scenario and the size of the variable must be greater than or equal to the
+// size of the expected data. Usually used with variable-length payload data.
+template <typename T = std::vector<char>>
+std::vector<struct iovec> FuseGenerateIovecs(std::vector<char> &first) {
+ return {(struct iovec){.iov_base = first.data(), .iov_len = first.size()}};
+}
+
+// A helper function to set up an array of iovec struct for testing purpose.
+// Use variadic class template to generalize different numbers and different
+// types of FUSE structs.
+template <typename T, typename... Types>
+std::vector<struct iovec> FuseGenerateIovecs(T &first, Types &...args) {
+ auto first_iovec = FuseGenerateIovecs(first);
+ auto iovecs = FuseGenerateIovecs(args...);
+ first_iovec.insert(std::end(first_iovec), std::begin(iovecs),
+ std::end(iovecs));
+ return first_iovec;
+}
+
+// Create a fuse_attr filled with the specified mode and inode.
+fuse_attr DefaultFuseAttr(mode_t mode, uint64_t inode, uint64_t size = 512);
+
+// Return a fuse_entry_out FUSE server response body.
+fuse_entry_out DefaultEntryOut(mode_t mode, uint64_t node_id,
+ uint64_t size = 512);
+
+} // namespace testing
+} // namespace gvisor
+#endif // GVISOR_TEST_UTIL_FUSE_UTIL_H_
diff --git a/test/util/pty_util.cc b/test/util/pty_util.cc
index 5fa622922..2cf0bea74 100644
--- a/test/util/pty_util.cc
+++ b/test/util/pty_util.cc
@@ -23,25 +23,25 @@
namespace gvisor {
namespace testing {
-PosixErrorOr<FileDescriptor> OpenReplica(const FileDescriptor& main) {
- PosixErrorOr<int> n = ReplicaID(main);
+PosixErrorOr<FileDescriptor> OpenReplica(const FileDescriptor& master) {
+ PosixErrorOr<int> n = ReplicaID(master);
if (!n.ok()) {
return PosixErrorOr<FileDescriptor>(n.error());
}
return Open(absl::StrCat("/dev/pts/", n.ValueOrDie()), O_RDWR | O_NONBLOCK);
}
-PosixErrorOr<int> ReplicaID(const FileDescriptor& main) {
+PosixErrorOr<int> ReplicaID(const FileDescriptor& master) {
// Get pty index.
int n;
- int ret = ioctl(main.get(), TIOCGPTN, &n);
+ int ret = ioctl(master.get(), TIOCGPTN, &n);
if (ret < 0) {
return PosixError(errno, "ioctl(TIOCGPTN) failed");
}
// Unlock pts.
int unlock = 0;
- ret = ioctl(main.get(), TIOCSPTLCK, &unlock);
+ ret = ioctl(master.get(), TIOCSPTLCK, &unlock);
if (ret < 0) {
return PosixError(errno, "ioctl(TIOSPTLCK) failed");
}
diff --git a/test/util/pty_util.h b/test/util/pty_util.h
index dff6adab5..ed7658868 100644
--- a/test/util/pty_util.h
+++ b/test/util/pty_util.h
@@ -21,11 +21,11 @@
namespace gvisor {
namespace testing {
-// Opens the replica end of the passed main as R/W and nonblocking.
-PosixErrorOr<FileDescriptor> OpenReplica(const FileDescriptor& main);
+// Opens the replica end of the passed master as R/W and nonblocking.
+PosixErrorOr<FileDescriptor> OpenReplica(const FileDescriptor& master);
-// Get the number of the replica end of the main.
-PosixErrorOr<int> ReplicaID(const FileDescriptor& main);
+// Get the number of the replica end of the master.
+PosixErrorOr<int> ReplicaID(const FileDescriptor& master);
} // namespace testing
} // namespace gvisor
diff --git a/tools/bazel.mk b/tools/bazel.mk
index 3e27af7d1..5e129b2ed 100644
--- a/tools/bazel.mk
+++ b/tools/bazel.mk
@@ -19,6 +19,7 @@ SHELL=/bin/bash -o pipefail
BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \
git rev-parse --abbrev-ref HEAD 2>/dev/null) | \
xargs -n 1 basename 2>/dev/null)
+BUILD_ROOT := $(CURDIR)/bazel-bin/
# Bazel container configuration (see below).
USER ?= gvisor
@@ -31,6 +32,7 @@ DOCKER_PRIVILEGED ?= --privileged
BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/)
GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/)
DOCKER_SOCKET := /var/run/docker.sock
+DOCKER_CONFIG := /etc/docker/daemon.json
# Bazel flags.
BAZEL := bazel $(STARTUP_OPTIONS)
@@ -56,6 +58,9 @@ endif
# Add docker passthrough options.
ifneq ($(DOCKER_PRIVILEGED),)
FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)"
+# TODO(gvisor.dev/issue/1624): Remove docker config volume. This is required
+# temporarily for checking VFS1 vs VFS2 by some tests.
+FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_CONFIG):$(DOCKER_CONFIG)"
FULL_DOCKER_RUN_OPTIONS += $(DOCKER_PRIVILEGED)
FULL_DOCKER_EXEC_OPTIONS += $(DOCKER_PRIVILEGED)
DOCKER_GROUP := $(shell stat -c '%g' $(DOCKER_SOCKET))
@@ -127,7 +132,7 @@ bazel-server-start: bazel-image ## Starts the bazel server.
--workdir "$(CURDIR)" \
$(FULL_DOCKER_RUN_OPTIONS) \
$(BUILDER_IMAGE) \
- sh -c "tail -f --pid=\$$($(BAZEL) info server_pid)"
+ sh -c "tail -f --pid=\$$($(BAZEL) info server_pid) /dev/null"
.PHONY: bazel-server-start
bazel-shutdown: ## Shuts down a running bazel server.
diff --git a/tools/defs.bzl b/tools/defs.bzl
index 290d564f2..60c9c9d8c 100644
--- a/tools/defs.bzl
+++ b/tools/defs.bzl
@@ -214,7 +214,10 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
for (suffix, _) in marshal_sets.items():
_go_test(
name = name + suffix + "_abi_autogen_test",
- srcs = [name + suffix + "_abi_autogen_test.go"],
+ srcs = [
+ name + suffix + "_abi_autogen_test.go",
+ name + suffix + "_abi_autogen_unconditional_test.go",
+ ],
library = ":" + name,
deps = marshal_test_deps,
**kwargs
diff --git a/tools/github/BUILD b/tools/github/BUILD
new file mode 100644
index 000000000..aad088d13
--- /dev/null
+++ b/tools/github/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "github",
+ srcs = ["main.go"],
+ nogo = False,
+ deps = [
+ "//tools/github/nogo",
+ "//tools/github/reviver",
+ "@com_github_google_go_github_v28//github:go_default_library",
+ "@org_golang_x_oauth2//:go_default_library",
+ ],
+)
diff --git a/tools/github/main.go b/tools/github/main.go
new file mode 100644
index 000000000..7a74dc033
--- /dev/null
+++ b/tools/github/main.go
@@ -0,0 +1,162 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary github is the entry point for GitHub utilities.
+package main
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "strings"
+
+ "github.com/google/go-github/github"
+ "golang.org/x/oauth2"
+ "gvisor.dev/gvisor/tools/github/nogo"
+ "gvisor.dev/gvisor/tools/github/reviver"
+)
+
+var (
+ owner string
+ repo string
+ tokenFile string
+ path string
+ commit string
+ dryRun bool
+)
+
+// Keep the options simple for now. Supports only a single path and repo.
+func init() {
+ flag.StringVar(&owner, "owner", "", "GitHub project org/owner (required, except nogo dry-run)")
+ flag.StringVar(&repo, "repo", "", "GitHub repo (required, except nogo dry-run)")
+ flag.StringVar(&tokenFile, "oauth-token-file", "", "file containing the GitHub token (or GITHUB_TOKEN is set)")
+ flag.StringVar(&path, "path", ".", "path to scan (required for revive and nogo)")
+ flag.StringVar(&commit, "commit", "", "commit to associated (required for nogo, except dry-run)")
+ flag.BoolVar(&dryRun, "dry-run", false, "just print changes to be made")
+}
+
+func main() {
+ // Set defaults from the environment.
+ repository := os.Getenv("GITHUB_REPOSITORY")
+ if parts := strings.SplitN(repository, "/", 2); len(parts) == 2 {
+ owner = parts[0]
+ repo = parts[1]
+ }
+
+ // Parse flags.
+ flag.Usage = func() {
+ fmt.Fprintf(flag.CommandLine.Output(), "usage: %s [options] <command>\n", os.Args[0])
+ fmt.Fprintf(flag.CommandLine.Output(), "commands: revive, nogo\n")
+ flag.PrintDefaults()
+ }
+ flag.Parse()
+ args := flag.Args()
+ if len(args) != 1 {
+ fmt.Fprintf(flag.CommandLine.Output(), "extra arguments: %s\n", strings.Join(args[1:], ", "))
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ // Check for mandatory parameters.
+ command := args[0]
+ if len(owner) == 0 && (command != "nogo" || !dryRun) {
+ fmt.Fprintln(flag.CommandLine.Output(), "missing --owner option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+ if len(repo) == 0 && (command != "nogo" || !dryRun) {
+ fmt.Fprintln(flag.CommandLine.Output(), "missing --repo option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+ if len(path) == 0 {
+ fmt.Fprintln(flag.CommandLine.Output(), "missing --path option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ // The access token may be passed as a file so it doesn't show up in
+ // command line arguments. It also may be provided through the
+ // environment to faciliate use through GitHub's CI system.
+ token := os.Getenv("GITHUB_TOKEN")
+ if len(tokenFile) != 0 {
+ bytes, err := ioutil.ReadFile(tokenFile)
+ if err != nil {
+ fmt.Println(err.Error())
+ os.Exit(1)
+ }
+ token = string(bytes)
+ }
+ var client *github.Client
+ if len(token) == 0 {
+ // Client is unauthenticated.
+ client = github.NewClient(nil)
+ } else {
+ // Using the above token.
+ ts := oauth2.StaticTokenSource(
+ &oauth2.Token{AccessToken: token},
+ )
+ tc := oauth2.NewClient(context.Background(), ts)
+ client = github.NewClient(tc)
+ }
+
+ switch command {
+ case "revive":
+ // Load existing GitHub bugs.
+ bugger, err := reviver.NewGitHubBugger(client, owner, repo, dryRun)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error getting github issues: %v\n", err)
+ os.Exit(1)
+ }
+ // Scan the provided path.
+ rev := reviver.New([]string{path}, []reviver.Bugger{bugger})
+ if errs := rev.Run(); len(errs) > 0 {
+ fmt.Fprintf(os.Stderr, "Encountered %d errors:\n", len(errs))
+ for _, err := range errs {
+ fmt.Fprintf(os.Stderr, "\t%v\n", err)
+ }
+ os.Exit(1)
+ }
+ case "nogo":
+ // Did we get a commit? Try to extract one.
+ if len(commit) == 0 && !dryRun {
+ cmd := exec.Command("git", "rev-parse", "HEAD")
+ revBytes, err := cmd.Output()
+ if err != nil {
+ fmt.Fprintf(flag.CommandLine.Output(), "missing --commit option, unable to infer: %v\n", err)
+ flag.Usage()
+ os.Exit(1)
+ }
+ commit = strings.TrimSpace(string(revBytes))
+ }
+ // Scan all findings.
+ poster := nogo.NewFindingsPoster(client, owner, repo, commit, dryRun)
+ if err := poster.Walk(path); err != nil {
+ fmt.Fprintln(os.Stderr, "Error finding nogo findings:", err)
+ os.Exit(1)
+ }
+ // Post to GitHub.
+ if err := poster.Post(); err != nil {
+ fmt.Fprintln(os.Stderr, "Error posting nogo findings:", err)
+ }
+ default:
+ // Not a known command.
+ fmt.Fprintf(flag.CommandLine.Output(), "unknown command: %s\n", command)
+ flag.Usage()
+ os.Exit(1)
+ }
+}
diff --git a/tools/github/nogo/BUILD b/tools/github/nogo/BUILD
new file mode 100644
index 000000000..0633eaf19
--- /dev/null
+++ b/tools/github/nogo/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "nogo",
+ srcs = ["nogo.go"],
+ nogo = False,
+ visibility = [
+ "//tools/github:__subpackages__",
+ ],
+ deps = [
+ "//tools/nogo/util",
+ "@com_github_google_go_github_v28//github:go_default_library",
+ ],
+)
diff --git a/tools/github/nogo/nogo.go b/tools/github/nogo/nogo.go
new file mode 100644
index 000000000..b70dfe63b
--- /dev/null
+++ b/tools/github/nogo/nogo.go
@@ -0,0 +1,126 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package nogo provides nogo-related utilities.
+package nogo
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/google/go-github/github"
+ "gvisor.dev/gvisor/tools/nogo/util"
+)
+
+// FindingsPoster is a simple wrapper around the GitHub api.
+type FindingsPoster struct {
+ owner string
+ repo string
+ commit string
+ dryRun bool
+ startTime time.Time
+
+ findings map[util.Finding]struct{}
+ client *github.Client
+}
+
+// NewFindingsPoster returns a object that can post findings.
+func NewFindingsPoster(client *github.Client, owner, repo, commit string, dryRun bool) *FindingsPoster {
+ return &FindingsPoster{
+ owner: owner,
+ repo: repo,
+ commit: commit,
+ dryRun: dryRun,
+ startTime: time.Now(),
+ findings: make(map[util.Finding]struct{}),
+ client: client,
+ }
+}
+
+// Walk walks the given path tree for findings files.
+func (p *FindingsPoster) Walk(path string) error {
+ return filepath.Walk(path, func(filename string, info os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+ // Skip any directories or files not ending in .findings.
+ if !strings.HasSuffix(filename, ".findings") || info.IsDir() {
+ return nil
+ }
+ findings, err := util.ExtractFindingsFromFile(filename)
+ if err != nil {
+ return err
+ }
+ // Add all findings to the list. We use a map to ensure
+ // that each finding is unique.
+ for _, finding := range findings {
+ p.findings[finding] = struct{}{}
+ }
+ return nil
+ })
+}
+
+// Post posts all results to the GitHub API as a check run.
+func (p *FindingsPoster) Post() error {
+ // Just show results?
+ if p.dryRun {
+ for finding, _ := range p.findings {
+ // Pretty print, so that this is useful for debugging.
+ fmt.Printf("%s: (%s+%d) %s\n", finding.Category, finding.Path, finding.Line, finding.Message)
+ }
+ return nil
+ }
+
+ // Construct the message.
+ title := "nogo"
+ count := len(p.findings)
+ status := "completed"
+ conclusion := "success"
+ if count > 0 {
+ conclusion = "failure" // Contains errors.
+ }
+ summary := fmt.Sprintf("%d findings.", count)
+ opts := github.CreateCheckRunOptions{
+ Name: title,
+ HeadSHA: p.commit,
+ Status: &status,
+ Conclusion: &conclusion,
+ StartedAt: &github.Timestamp{p.startTime},
+ CompletedAt: &github.Timestamp{time.Now()},
+ Output: &github.CheckRunOutput{
+ Title: &title,
+ Summary: &summary,
+ AnnotationsCount: &count,
+ },
+ }
+ annotationLevel := "failure" // Always.
+ for finding, _ := range p.findings {
+ opts.Output.Annotations = append(opts.Output.Annotations, &github.CheckRunAnnotation{
+ Path: &finding.Path,
+ StartLine: &finding.Line,
+ EndLine: &finding.Line,
+ Message: &finding.Message,
+ Title: &finding.Category,
+ AnnotationLevel: &annotationLevel,
+ })
+ }
+
+ // Post to GitHub.
+ _, _, err := p.client.Checks.CreateCheckRun(context.Background(), p.owner, p.repo, opts)
+ return err
+}
diff --git a/tools/github/reviver/BUILD b/tools/github/reviver/BUILD
new file mode 100644
index 000000000..7d78480a7
--- /dev/null
+++ b/tools/github/reviver/BUILD
@@ -0,0 +1,27 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "reviver",
+ srcs = [
+ "github.go",
+ "reviver.go",
+ ],
+ nogo = False,
+ visibility = [
+ "//tools/github:__subpackages__",
+ ],
+ deps = ["@com_github_google_go_github_v28//github:go_default_library"],
+)
+
+go_test(
+ name = "reviver_test",
+ size = "small",
+ srcs = [
+ "github_test.go",
+ "reviver_test.go",
+ ],
+ library = ":reviver",
+ nogo = False,
+)
diff --git a/tools/issue_reviver/github/github.go b/tools/github/reviver/github.go
index 8ffd7e606..a95df0fb6 100644
--- a/tools/issue_reviver/github/github.go
+++ b/tools/github/reviver/github.go
@@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package github implements reviver.Bugger interface on top of Github issues.
-package github
+package reviver
import (
"context"
@@ -23,12 +22,10 @@ import (
"time"
"github.com/google/go-github/github"
- "golang.org/x/oauth2"
- "gvisor.dev/gvisor/tools/issue_reviver/reviver"
)
-// Bugger implements reviver.Bugger interface for github issues.
-type Bugger struct {
+// GitHubBugger implements Bugger interface for github issues.
+type GitHubBugger struct {
owner string
repo string
dryRun bool
@@ -37,36 +34,25 @@ type Bugger struct {
issues map[int]*github.Issue
}
-// NewBugger creates a new Bugger.
-func NewBugger(token, owner, repo string, dryRun bool) (*Bugger, error) {
- b := &Bugger{
+// NewGitHubBugger creates a new GitHubBugger.
+func NewGitHubBugger(client *github.Client, owner, repo string, dryRun bool) (*GitHubBugger, error) {
+ b := &GitHubBugger{
owner: owner,
repo: repo,
dryRun: dryRun,
issues: map[int]*github.Issue{},
+ client: client,
}
- if err := b.load(token); err != nil {
+ if err := b.load(); err != nil {
return nil, err
}
return b, nil
}
-func (b *Bugger) load(token string) error {
- ctx := context.Background()
- if len(token) == 0 {
- fmt.Print("No OAUTH token provided, using unauthenticated account.\n")
- b.client = github.NewClient(nil)
- } else {
- ts := oauth2.StaticTokenSource(
- &oauth2.Token{AccessToken: token},
- )
- tc := oauth2.NewClient(ctx, ts)
- b.client = github.NewClient(tc)
- }
-
+func (b *GitHubBugger) load() error {
err := processAllPages(func(listOpts github.ListOptions) (*github.Response, error) {
opts := &github.IssueListByRepoOptions{State: "open", ListOptions: listOpts}
- tmps, resp, err := b.client.Issues.ListByRepo(ctx, b.owner, b.repo, opts)
+ tmps, resp, err := b.client.Issues.ListByRepo(context.Background(), b.owner, b.repo, opts)
if err != nil {
return resp, err
}
@@ -83,8 +69,8 @@ func (b *Bugger) load(token string) error {
return nil
}
-// Activate implements reviver.Bugger.
-func (b *Bugger) Activate(todo *reviver.Todo) (bool, error) {
+// Activate implements Bugger.Activate.
+func (b *GitHubBugger) Activate(todo *Todo) (bool, error) {
id, err := parseIssueNo(todo.Issue)
if err != nil {
return true, err
diff --git a/tools/issue_reviver/github/github_test.go b/tools/github/reviver/github_test.go
index a78b230ef..5df7e3624 100644
--- a/tools/issue_reviver/github/github_test.go
+++ b/tools/github/reviver/github_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package github
+package reviver
import (
"testing"
diff --git a/tools/issue_reviver/reviver/reviver.go b/tools/github/reviver/reviver.go
index 2af7f0d59..2af7f0d59 100644
--- a/tools/issue_reviver/reviver/reviver.go
+++ b/tools/github/reviver/reviver.go
diff --git a/tools/issue_reviver/reviver/reviver_test.go b/tools/github/reviver/reviver_test.go
index a9fb1f9f1..a9fb1f9f1 100644
--- a/tools/issue_reviver/reviver/reviver_test.go
+++ b/tools/github/reviver/reviver_test.go
diff --git a/tools/go_generics/go_merge/main.go b/tools/go_generics/go_merge/main.go
index f6a331123..e0345500f 100644
--- a/tools/go_generics/go_merge/main.go
+++ b/tools/go_generics/go_merge/main.go
@@ -77,6 +77,7 @@ func main() {
// Create a new declaration slice with all imports at the top, merging any
// redundant imports.
imports := make(map[string]*ast.ImportSpec)
+ var importNames []string // Keep imports in the original order to get deterministic output.
var anonImports []*ast.ImportSpec
for _, d := range f.Decls {
if g, ok := d.(*ast.GenDecl); ok && g.Tok == token.IMPORT {
@@ -98,6 +99,7 @@ func main() {
}
} else {
imports[n] = i
+ importNames = append(importNames, n)
}
}
}
@@ -112,8 +114,8 @@ func main() {
Lparen: token.NoPos + 1,
Specs: make([]ast.Spec, 0, l),
}
- for _, i := range imports {
- d.Specs = append(d.Specs, i)
+ for _, i := range importNames {
+ d.Specs = append(d.Specs, imports[i])
}
for _, i := range anonImports {
d.Specs = append(d.Specs, i)
diff --git a/tools/go_generics/imports.go b/tools/go_generics/imports.go
index 148dc7216..90d3aa1e0 100644
--- a/tools/go_generics/imports.go
+++ b/tools/go_generics/imports.go
@@ -21,6 +21,7 @@ import (
"go/format"
"go/parser"
"go/token"
+ "sort"
"strconv"
"gvisor.dev/gvisor/tools/go_generics/globals"
@@ -132,10 +133,17 @@ func updateImports(maps []mapValue, imports mapValue) (ast.Decl, error) {
if len(importsUsed) == 0 {
return nil, nil
}
+ var names []string
+ for n := range importsUsed {
+ names = append(names, n)
+ }
+ // Sort the new imports for deterministic build outputs.
+ sort.Strings(names)
// Create spec array for each new import.
specs := make([]ast.Spec, 0, len(importsUsed))
- for _, i := range importsUsed {
+ for _, n := range names {
+ i := importsUsed[n]
specs = append(specs, &ast.ImportSpec{
Name: &ast.Ident{Name: i.newName},
Path: &ast.BasicLit{Value: i.path},
diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md
index 68d759083..d8045c295 100644
--- a/tools/go_marshal/README.md
+++ b/tools/go_marshal/README.md
@@ -3,18 +3,19 @@ This package implements the go_marshal utility.
# Overview
`go_marshal` is a code generation utility similar to `go_stateify` for
-automatically generating code to marshal go data structures to memory.
+marshalling go data structures to and from memory.
`go_marshal` attempts to improve on `binary.Write` and the sentry's
-`binary.Marshal` by moving the go runtime reflection necessary to marshal a
-struct to compile-time.
+`binary.Marshal` by moving the expensive use of reflection from runtime to
+compile-time.
`go_marshal` automatically generates implementations for `marshal.Marshallable`
-and `safemem.{Reader,Writer}`. Data structures that require custom serialization
-will have manual implementations for these interfaces.
+interface. Data structures that require custom serialization can be accomodated
+through a manual implementation this interface.
Data structures can be flagged for code generation by adding a struct-level
-comment `// +marshal`.
+comment `// +marshal`. For additional details and options, see the documentation
+for the `marshal.Marshallable` interface.
# Usage
@@ -74,7 +75,7 @@ intended for ABI structs, which have these additional restrictions:
dependent native pointer size.
- Fields must either be a primitive integer type (`byte`,
- `[u]int{8,16,32,64}`), or of a type that implements abi.Marshallable.
+ `[u]int{8,16,32,64}`), or of a type that implements `marshal.Marshallable`.
- `int` and `uint` fields are not allowed. Use an explicitly-sized numeric
type.
@@ -112,3 +113,18 @@ The following are some guidelines for modifying the `go_marshal` tool:
- No runtime reflection in the code generated for the marshallable interface.
The entire point of the tool is to avoid runtime reflection. The generated
tests may use reflection.
+
+## Debugging
+
+To enable debugging output from the go-marshal tool, use one of the following
+options, depending on how go-marshal is being invoked:
+
+- Pass `--define gomarshal=verbose` to the bazel command. Note that this can
+ generate a lot of output depending on what's being compiled, as this will
+ enable debugging for all packages built by the command.
+
+- Set `marshal_debug = True` on the top-level `go_library` BUILD rule.
+
+- Set `debug = True` on the `go_marshal` BUILD rule.
+
+- Pass `-debug` to the go-marshal tool invocation.
diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl
index 323e33882..f44f83eab 100644
--- a/tools/go_marshal/defs.bzl
+++ b/tools/go_marshal/defs.bzl
@@ -4,11 +4,13 @@ def _go_marshal_impl(ctx):
"""Execute the go_marshal tool."""
output = ctx.outputs.lib
output_test = ctx.outputs.test
+ output_test_unconditional = ctx.outputs.test_unconditional
# Run the marshal command.
args = ["-output=%s" % output.path]
- args += ["-pkg=%s" % ctx.attr.package]
- args += ["-output_test=%s" % output_test.path]
+ args.append("-pkg=%s" % ctx.attr.package)
+ args.append("-output_test=%s" % output_test.path)
+ args.append("-output_test_unconditional=%s" % output_test_unconditional.path)
if ctx.attr.debug:
args += ["-debug"]
@@ -18,7 +20,7 @@ def _go_marshal_impl(ctx):
args += [f.path for f in src.files.to_list()]
ctx.actions.run(
inputs = ctx.files.srcs,
- outputs = [output, output_test],
+ outputs = [output, output_test, output_test_unconditional],
mnemonic = "GoMarshal",
progress_message = "go_marshal: %s" % ctx.label,
arguments = args,
@@ -48,6 +50,7 @@ go_marshal = rule(
outputs = {
"lib": "%{name}_unsafe.go",
"test": "%{name}_test.go",
+ "test_unconditional": "%{name}_unconditional_test.go",
},
)
@@ -56,7 +59,7 @@ marshal_deps = [
"//pkg/gohacks",
"//pkg/safecopy",
"//pkg/usermem",
- "//tools/go_marshal/marshal",
+ "//pkg/marshal",
]
# marshal_test_deps are required by test targets.
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index 19bcd4e6a..56fbcb5d2 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -38,8 +38,8 @@ import (
// All recievers are single letters, so we don't allow import aliases to be a
// single letter.
var badIdents = []string{
- "addr", "blk", "buf", "dst", "dsts", "count", "err", "hdr", "idx", "inner",
- "length", "limit", "ptr", "size", "src", "srcs", "task", "val",
+ "addr", "blk", "buf", "cc", "dst", "dsts", "count", "err", "hdr", "idx",
+ "inner", "length", "limit", "ptr", "size", "src", "srcs", "val",
// All single-letter identifiers.
}
@@ -68,6 +68,8 @@ type Generator struct {
output *os.File
// Output file to write generated tests.
outputTest *os.File
+ // Output file to write unconditionally generated tests.
+ outputTestUC *os.File
// Package name for the generated file.
pkg string
// Set of extra packages to import in the generated file.
@@ -75,7 +77,7 @@ type Generator struct {
}
// NewGenerator creates a new code Generator.
-func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*Generator, error) {
+func NewGenerator(srcs []string, out, outTest, outTestUnconditional, pkg string, imports []string) (*Generator, error) {
f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err)
@@ -84,12 +86,17 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G
if err != nil {
return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err)
}
+ fTestUC, err := os.OpenFile(outTestUnconditional, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return nil, fmt.Errorf("Couldn't open unconditional test output file %q: %v", out, err)
+ }
g := Generator{
- inputs: srcs,
- output: f,
- outputTest: fTest,
- pkg: pkg,
- imports: newImportTable(),
+ inputs: srcs,
+ output: f,
+ outputTest: fTest,
+ outputTestUC: fTestUC,
+ pkg: pkg,
+ imports: newImportTable(),
}
for _, i := range imports {
// All imports on the extra imports list are unconditionally marked as
@@ -107,7 +114,7 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G
g.imports.add("gvisor.dev/gvisor/pkg/gohacks")
g.imports.add("gvisor.dev/gvisor/pkg/safecopy")
g.imports.add("gvisor.dev/gvisor/pkg/usermem")
- g.imports.add("gvisor.dev/gvisor/tools/go_marshal/marshal")
+ g.imports.add("gvisor.dev/gvisor/pkg/marshal")
return &g, nil
}
@@ -454,6 +461,46 @@ func (g *Generator) Run() error {
// source file.
func (g *Generator) writeTests(ts []*testGenerator) error {
var b sourceBuffer
+
+ // Write the unconditional test file. This file is always compiled,
+ // regardless of what build tags were specified on the original input
+ // files. We use this file to guarantee we never end up with an empty test
+ // file, as that causes the build to fail with "no tests/benchmarks/examples
+ // found".
+ //
+ // There's no easy way to determine ahead of time if we'll end up with an
+ // empty build file since build constraints can arbitrarily cause some of
+ // the original types to be not defined. We also have no way to tell bazel
+ // to omit the entire test suite since the output files are already defined
+ // before go-marshal is called.
+ b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n")
+ b.emit("package %s\n\n", g.pkg)
+ b.emit("func Example() {\n")
+ b.inIndent(func() {
+ b.emit("// This example is intentionally empty, and ensures this package contains at\n")
+ b.emit("// least one testable entity. go-marshal is forced to emit a test package if the\n")
+ b.emit("// input package is marked marshallable, but emitting no testable entities \n")
+ b.emit("// results in a build failure.\n")
+ })
+ b.emit("}\n")
+ if err := b.write(g.outputTestUC); err != nil {
+ return err
+ }
+
+ // Now generate the real test file that contains the real types we
+ // processed. These need to be conditionally compiled according to the build
+ // tags, as the original types may not be defined under all build
+ // configurations.
+
+ b.reset()
+ b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n")
+
+ // Emit build tags.
+ if t := tags.Aggregate(g.inputs); len(t) > 0 {
+ b.emit(strings.Join(t.Lines(), "\n"))
+ b.emit("\n\n")
+ }
+
b.emit("package %s\n\n", g.pkg)
if err := b.write(g.outputTest); err != nil {
return err
@@ -470,26 +517,6 @@ func (g *Generator) writeTests(ts []*testGenerator) error {
}
// Write test functions.
-
- // If we didn't generate any Marshallable implementations, we can't just
- // emit an empty test file, since that causes the build to fail with "no
- // tests/benchmarks/examples found". Unfortunately we can't signal bazel to
- // omit the entire package since the outputs are already defined before
- // go-marshal is called. If we'd otherwise emit an empty test suite, emit an
- // empty example instead.
- if len(ts) == 0 {
- b.reset()
- b.emit("func Example() {\n")
- b.inIndent(func() {
- b.emit("// This example is intentionally empty to ensure this file contains at least\n")
- b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n")
- b.emit("// is marked marshallable, but emitting a test file with no entities results\n")
- b.emit("// in a build failure.\n")
- })
- b.emit("}\n")
- return b.write(g.outputTest)
- }
-
for _, t := range ts {
if err := t.write(g.outputTest); err != nil {
return err
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index cf76b5241..36447b86b 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -43,8 +43,8 @@ type interfaceGenerator struct {
// of t's interfaces.
ms map[string]struct{}
- // as records embedded fields in t that are potentially not packed. The key
- // is the accessor for the field.
+ // as records fields in t that are potentially not packed. The key is the
+ // accessor for the field.
as map[string]struct{}
}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
index 72ef03a22..7525b52da 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
@@ -102,11 +102,11 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *as
g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
- g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
g.emitKeepAlive(g.r)
g.emit("return length, err\n")
})
@@ -114,19 +114,19 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *as
g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r)
})
g.emit("}\n\n")
g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
- g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n")
g.emitKeepAlive(g.r)
g.emit("return length, err\n")
})
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
index 39f654ea8..7edaf666c 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
@@ -154,11 +154,11 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident)
g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
- g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
g.emitKeepAlive(g.r)
g.emit("return length, err\n")
})
@@ -166,19 +166,19 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident)
g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r)
})
g.emit("}\n\n")
g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
- g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n")
g.emitKeepAlive(g.r)
g.emit("return length, err\n")
})
@@ -211,7 +211,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id
g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType)
g.emit("//go:nosplit\n")
- g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType)
+ g.emit("func Copy%sIn(cc marshal.CopyContext, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType)
g.inIndent(func() {
g.emit("count := len(dst)\n")
g.emit("if count == 0 {\n")
@@ -223,7 +223,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id
g.emitCastSliceToByteSlice("&dst", "buf", "size * count")
- g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n")
g.emitKeepAlive("dst")
g.emit("return length, err\n")
})
@@ -231,7 +231,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id
g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType)
g.emit("//go:nosplit\n")
- g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType)
+ g.emit("func Copy%sOut(cc marshal.CopyContext, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType)
g.inIndent(func() {
g.emit("count := len(src)\n")
g.emit("if count == 0 {\n")
@@ -243,7 +243,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id
g.emitCastSliceToByteSlice("&src", "buf", "size * count")
- g.emit("length, err := task.CopyOutBytes(addr, buf) // escapes: okay.\n")
+ g.emit("length, err := cc.CopyOutBytes(addr, buf) // escapes: okay.\n")
g.emitKeepAlive("src")
g.emit("return length, err\n")
})
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
index 44fbb425c..fe76d3785 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
@@ -20,6 +20,7 @@ package gomarshal
import (
"fmt"
"go/ast"
+ "sort"
"strings"
)
@@ -40,6 +41,8 @@ func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
for accessor, _ := range g.as {
cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
}
+ // Sort expressions for determinstic build outputs.
+ sort.Strings(cs)
return strings.Join(cs, " && "), true
}
@@ -48,12 +51,6 @@ func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
// later.
func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) {
forEachStructField(st, func(f *ast.Field) {
- if len(f.Names) == 0 {
- g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
- }
- })
-
- forEachStructField(st, func(f *ast.Field) {
fieldDispatcher{
primitive: func(_, t *ast.Ident) {
g.validatePrimitiveNewtype(t)
@@ -98,7 +95,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
var dynamicSizeTerms []string
forEachStructField(st, fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
+ primitive: func(_, t *ast.Ident) {
if size, dynamic := g.scalarSize(t); !dynamic {
primitiveSize += size
} else {
@@ -106,13 +103,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
}
},
- selector: func(n, tX, tSel *ast.Ident) {
+ selector: func(_, tX, tSel *ast.Ident) {
tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
g.recordUsedImport(tX.Name)
g.recordUsedMarshallable(tName)
dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
},
- array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ array: func(_ *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
lenExpr := g.arrayLenExpr(a)
if size, dynamic := g.scalarSize(t); !dynamic {
dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr))
@@ -323,13 +320,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.emit("//go:nosplit\n")
g.recordUsedImport("marshal")
g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
fallback := func() {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
- g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
+ g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r)
- g.emit("return task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ g.emit("return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
}
if thisPacked {
g.recordUsedImport("reflect")
@@ -343,7 +340,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
// Fast serialization.
g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
- g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
g.emitKeepAlive(g.r)
g.emit("return length, err\n")
} else {
@@ -356,9 +353,9 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.emit("//go:nosplit\n")
g.recordUsedImport("marshal")
g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r)
})
g.emit("}\n\n")
@@ -366,12 +363,12 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.emit("//go:nosplit\n")
g.recordUsedImport("marshal")
g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
fallback := func() {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
- g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
- g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
+ g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n")
g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n")
g.emit("// partially unmarshalled struct.\n")
g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r)
@@ -389,7 +386,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
// Fast deserialization.
g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
- g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n")
g.emitKeepAlive(g.r)
g.emit("return length, err\n")
} else {
@@ -442,7 +439,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType,
g.recordUsedImport("usermem")
g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName())
- g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName())
+ g.emit("func Copy%sIn(cc marshal.CopyContext, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName())
g.inIndent(func() {
g.emit("count := len(dst)\n")
g.emit("if count == 0 {\n")
@@ -454,8 +451,8 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType,
fallback := func() {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
- g.emit("buf := task.CopyScratchBuffer(size * count)\n")
- g.emit("length, err := task.CopyInBytes(addr, buf)\n\n")
+ g.emit("buf := cc.CopyScratchBuffer(size * count)\n")
+ g.emit("length, err := cc.CopyInBytes(addr, buf)\n\n")
g.emit("// Unmarshal as much as possible, even on error. First handle full objects.\n")
g.emit("limit := length/size\n")
@@ -489,7 +486,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType,
// Fast deserialization.
g.emitCastSliceToByteSlice("&dst", "buf", "size * count")
- g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("length, err := cc.CopyInBytes(addr, buf)\n")
g.emitKeepAlive("dst")
g.emit("return length, err\n")
} else {
@@ -499,7 +496,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType,
g.emit("}\n\n")
g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName())
- g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName())
+ g.emit("func Copy%sOut(cc marshal.CopyContext, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName())
g.inIndent(func() {
g.emit("count := len(src)\n")
g.emit("if count == 0 {\n")
@@ -511,13 +508,13 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType,
fallback := func() {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
- g.emit("buf := task.CopyScratchBuffer(size * count)\n")
+ g.emit("buf := cc.CopyScratchBuffer(size * count)\n")
g.emit("for idx := 0; idx < count; idx++ {\n")
g.inIndent(func() {
g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n")
})
g.emit("}\n")
- g.emit("return task.CopyOutBytes(addr, buf)\n")
+ g.emit("return cc.CopyOutBytes(addr, buf)\n")
}
if thisPacked {
g.recordUsedImport("reflect")
@@ -531,7 +528,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType,
// Fast serialization.
g.emitCastSliceToByteSlice("&src", "buf", "size * count")
- g.emit("length, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("length, err := cc.CopyOutBytes(addr, buf)\n")
g.emitKeepAlive("src")
g.emit("return length, err\n")
} else {
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
index d94314302..6a42691cd 100644
--- a/tools/go_marshal/gomarshal/util.go
+++ b/tools/go_marshal/gomarshal/util.go
@@ -79,7 +79,7 @@ type fieldDispatcher struct {
}
// Precondition: All dispatch callbacks that will be invoked must be
-// provided. Embedded fields are not allowed, len(f.Names) >= 1.
+// provided.
func (fd fieldDispatcher) dispatch(f *ast.Field) {
// Each field declaration may actually be multiple declarations of the same
// type. For example, consider:
@@ -88,12 +88,24 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) {
// x, y, z int
// }
//
- // We invoke the call-backs once per such instance. Embedded fields are not
- // allowed, and results in a panic.
+ // We invoke the call-backs once per such instance.
+
+ // Handle embedded fields. Embedded fields have no names, but can be
+ // referenced by the type name.
if len(f.Names) < 1 {
- panic("Precondition not met: attempted to dispatch on embedded field")
+ switch v := f.Type.(type) {
+ case *ast.Ident:
+ fd.primitive(v, v)
+ case *ast.SelectorExpr:
+ fd.selector(v.Sel, v.X.(*ast.Ident), v.Sel)
+ default:
+ // Note: Arrays can't be embedded, which is handled here.
+ panic(fmt.Sprintf("Attempted to dispatch on embedded field of unsupported kind: %#v", f.Type))
+ }
+ return
}
+ // Non-embedded field.
for _, name := range f.Names {
switch v := f.Type.(type) {
case *ast.Ident:
diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go
index f74be5c29..6e4a3e8c4 100644
--- a/tools/go_marshal/main.go
+++ b/tools/go_marshal/main.go
@@ -31,10 +31,11 @@ import (
)
var (
- pkg = flag.String("pkg", "", "output package")
- output = flag.String("output", "", "output file")
- outputTest = flag.String("output_test", "", "output file for tests")
- imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code")
+ pkg = flag.String("pkg", "", "output package")
+ output = flag.String("output", "", "output file")
+ outputTest = flag.String("output_test", "", "output file for tests")
+ outputTestUnconditional = flag.String("output_test_unconditional", "", "output file for unconditional tests")
+ imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code")
)
func main() {
@@ -61,7 +62,7 @@ func main() {
// as an import.
extraImports = strings.Split(*imports, ",")
}
- g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, extraImports)
+ g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *outputTestUnconditional, *pkg, extraImports)
if err != nil {
panic(err)
}
diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD
index 3d989823a..4b27773c2 100644
--- a/tools/go_marshal/test/BUILD
+++ b/tools/go_marshal/test/BUILD
@@ -35,10 +35,10 @@ go_test(
srcs = ["marshal_test.go"],
deps = [
":test",
+ "//pkg/marshal",
"//pkg/syserror",
"//pkg/usermem",
"//tools/go_marshal/analysis",
- "//tools/go_marshal/marshal",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/tools/go_marshal/test/escape/BUILD b/tools/go_marshal/test/escape/BUILD
index f74e6ffae..2981ef196 100644
--- a/tools/go_marshal/test/escape/BUILD
+++ b/tools/go_marshal/test/escape/BUILD
@@ -7,8 +7,8 @@ go_library(
testonly = 1,
srcs = ["escape.go"],
deps = [
+ "//pkg/marshal",
"//pkg/usermem",
- "//tools/go_marshal/marshal",
"//tools/go_marshal/test",
],
)
diff --git a/tools/go_marshal/test/escape/escape.go b/tools/go_marshal/test/escape/escape.go
index 3a1a64e9c..7f62b0a2b 100644
--- a/tools/go_marshal/test/escape/escape.go
+++ b/tools/go_marshal/test/escape/escape.go
@@ -15,34 +15,34 @@
package escape
import (
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
"gvisor.dev/gvisor/tools/go_marshal/test"
)
-// dummyTask implements marshal.Task.
-type dummyTask struct {
+// dummyCopyContext implements marshal.CopyContext.
+type dummyCopyContext struct {
}
-func (*dummyTask) CopyScratchBuffer(size int) []byte {
+func (*dummyCopyContext) CopyScratchBuffer(size int) []byte {
return make([]byte, size)
}
-func (*dummyTask) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) {
+func (*dummyCopyContext) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) {
return len(b), nil
}
-func (*dummyTask) CopyInBytes(addr usermem.Addr, b []byte) (int, error) {
+func (*dummyCopyContext) CopyInBytes(addr usermem.Addr, b []byte) (int, error) {
return len(b), nil
}
-func (t *dummyTask) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) {
+func (t *dummyCopyContext) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) {
buf := t.CopyScratchBuffer(marshallable.SizeBytes())
marshallable.MarshalBytes(buf)
t.CopyOutBytes(addr, buf)
}
-func (t *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) {
+func (t *dummyCopyContext) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) {
buf := t.CopyScratchBuffer(marshallable.SizeBytes())
marshallable.MarshalUnsafe(buf)
t.CopyOutBytes(addr, buf)
@@ -50,14 +50,14 @@ func (t *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marsha
// +checkescape:all
//go:nosplit
-func doCopyIn(t *dummyTask) {
+func doCopyIn(t *dummyCopyContext) {
var stat test.Stat
stat.CopyIn(t, usermem.Addr(0xf000ba12))
}
// +checkescape:all
//go:nosplit
-func doCopyOut(t *dummyTask) {
+func doCopyOut(t *dummyCopyContext) {
var stat test.Stat
stat.CopyOut(t, usermem.Addr(0xf000ba12))
}
@@ -65,7 +65,7 @@ func doCopyOut(t *dummyTask) {
// +mustescape:builtin
// +mustescape:stack
//go:nosplit
-func doMarshalBytesDirect(t *dummyTask) {
+func doMarshalBytesDirect(t *dummyCopyContext) {
var stat test.Stat
buf := t.CopyScratchBuffer(stat.SizeBytes())
stat.MarshalBytes(buf)
@@ -75,7 +75,7 @@ func doMarshalBytesDirect(t *dummyTask) {
// +mustescape:builtin
// +mustescape:stack
//go:nosplit
-func doMarshalUnsafeDirect(t *dummyTask) {
+func doMarshalUnsafeDirect(t *dummyCopyContext) {
var stat test.Stat
buf := t.CopyScratchBuffer(stat.SizeBytes())
stat.MarshalUnsafe(buf)
@@ -85,7 +85,7 @@ func doMarshalUnsafeDirect(t *dummyTask) {
// +mustescape:local,heap
// +mustescape:stack
//go:nosplit
-func doMarshalBytesViaMarshallable(t *dummyTask) {
+func doMarshalBytesViaMarshallable(t *dummyCopyContext) {
var stat test.Stat
t.MarshalBytes(usermem.Addr(0xf000ba12), &stat)
}
@@ -93,7 +93,7 @@ func doMarshalBytesViaMarshallable(t *dummyTask) {
// +mustescape:local,heap
// +mustescape:stack
//go:nosplit
-func doMarshalUnsafeViaMarshallable(t *dummyTask) {
+func doMarshalUnsafeViaMarshallable(t *dummyCopyContext) {
var stat test.Stat
t.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat)
}
diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go
index 16829ee45..a00f9a684 100644
--- a/tools/go_marshal/test/marshal_test.go
+++ b/tools/go_marshal/test/marshal_test.go
@@ -27,22 +27,22 @@ import (
"unsafe"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/tools/go_marshal/analysis"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
"gvisor.dev/gvisor/tools/go_marshal/test"
)
var simulatedErr error = syserror.EFAULT
-// mockTask implements marshal.Task.
-type mockTask struct {
+// mockCopyContext implements marshal.CopyContext.
+type mockCopyContext struct {
taskMem usermem.BytesIO
}
// populate fills the task memory with the contents of val.
-func (t *mockTask) populate(val interface{}) {
+func (t *mockCopyContext) populate(val interface{}) {
var buf bytes.Buffer
// Use binary.Write so we aren't testing go-marshal against its own
// potentially buggy implementation.
@@ -52,7 +52,7 @@ func (t *mockTask) populate(val interface{}) {
t.taskMem.Bytes = buf.Bytes()
}
-func (t *mockTask) setLimit(n int) {
+func (t *mockCopyContext) setLimit(n int) {
if len(t.taskMem.Bytes) < n {
grown := make([]byte, n)
copy(grown, t.taskMem.Bytes)
@@ -62,22 +62,22 @@ func (t *mockTask) setLimit(n int) {
t.taskMem.Bytes = t.taskMem.Bytes[:n]
}
-// CopyScratchBuffer implements marshal.Task.CopyScratchBuffer.
-func (t *mockTask) CopyScratchBuffer(size int) []byte {
+// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
+func (t *mockCopyContext) CopyScratchBuffer(size int) []byte {
return make([]byte, size)
}
-// CopyOutBytes implements marshal.Task.CopyOutBytes. The implementation
+// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. The implementation
// completely ignores the target address and stores a copy of b in its
// internally buffer, overriding any previous contents.
-func (t *mockTask) CopyOutBytes(_ usermem.Addr, b []byte) (int, error) {
+func (t *mockCopyContext) CopyOutBytes(_ usermem.Addr, b []byte) (int, error) {
return t.taskMem.CopyOut(nil, 0, b, usermem.IOOpts{})
}
-// CopyInBytes implements marshal.Task.CopyInBytes. The implementation
+// CopyInBytes implements marshal.CopyContext.CopyInBytes. The implementation
// completely ignores the source address and always fills b from the begining of
// its internal buffer.
-func (t *mockTask) CopyInBytes(_ usermem.Addr, b []byte) (int, error) {
+func (t *mockCopyContext) CopyInBytes(_ usermem.Addr, b []byte) (int, error) {
return t.taskMem.CopyIn(nil, 0, b, usermem.IOOpts{})
}
@@ -171,11 +171,11 @@ func compareMemory(t *testing.T, expected, actual []byte, n int) {
// dst. The task signals an error at limit bytes during copy-in, which should
// result in a truncated unmarshalling.
func limitedCopyIn(t *testing.T, src, dst marshal.Marshallable, limit int) {
- var task mockTask
- task.populate(src)
- task.setLimit(limit)
+ var cc mockCopyContext
+ cc.populate(src)
+ cc.setLimit(limit)
- n, err := dst.CopyIn(&task, usermem.Addr(0))
+ n, err := dst.CopyIn(&cc, usermem.Addr(0))
if n != limit {
t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
}
@@ -202,10 +202,10 @@ func limitedCopyIn(t *testing.T, src, dst marshal.Marshallable, limit int) {
// limitedCopyOut marshals src to task memory. The task signals an error at
// limit bytes during copy-out, which should result in a truncated marshalling.
func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) {
- var task mockTask
- task.setLimit(limit)
+ var cc mockCopyContext
+ cc.setLimit(limit)
- n, err := src.CopyOut(&task, usermem.Addr(0))
+ n, err := src.CopyOut(&cc, usermem.Addr(0))
if n != limit {
t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n)
}
@@ -215,7 +215,7 @@ func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) {
expectedMem := unsafeMemory(src)
defer runtime.KeepAlive(src)
- actualMem := task.taskMem.Bytes
+ actualMem := cc.taskMem.Bytes
compareMemory(t, expectedMem, actualMem, n)
}
@@ -223,10 +223,10 @@ func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) {
// copyOutN marshals src to task memory, requesting the marshalling to be
// limited to limit bytes.
func copyOutN(t *testing.T, src marshal.Marshallable, limit int) {
- var task mockTask
- task.setLimit(limit)
+ var cc mockCopyContext
+ cc.setLimit(limit)
- n, err := src.CopyOutN(&task, usermem.Addr(0), limit)
+ n, err := src.CopyOutN(&cc, usermem.Addr(0), limit)
if err != nil {
t.Errorf("CopyOut returned unexpected error: %v", err)
}
@@ -236,7 +236,7 @@ func copyOutN(t *testing.T, src marshal.Marshallable, limit int) {
expectedMem := unsafeMemory(src)
defer runtime.KeepAlive(src)
- actualMem := task.taskMem.Bytes
+ actualMem := cc.taskMem.Bytes
t.Logf("Expected: %v + %v\n", expectedMem[:n], expectedMem[n:])
t.Logf("Actual : %v + %v\n", actualMem[:n], actualMem[n:])
@@ -303,20 +303,20 @@ func TestLimitedMarshalling(t *testing.T) {
func TestLimitedSliceMarshalling(t *testing.T) {
types := []struct {
arrayPtrType reflect.Type
- copySliceIn func(task marshal.Task, addr usermem.Addr, dstSlice interface{}) (int, error)
- copySliceOut func(task marshal.Task, addr usermem.Addr, srcSlice interface{}) (int, error)
+ copySliceIn func(cc marshal.CopyContext, addr usermem.Addr, dstSlice interface{}) (int, error)
+ copySliceOut func(cc marshal.CopyContext, addr usermem.Addr, srcSlice interface{}) (int, error)
unsafeMemory func(arrPtr interface{}) []byte
}{
// Packed types.
{
reflect.TypeOf((*[20]test.Stat)(nil)),
- func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) {
slice := dst.(*[20]test.Stat)[:]
- return test.CopyStatSliceIn(task, addr, slice)
+ return test.CopyStatSliceIn(cc, addr, slice)
},
- func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) {
slice := src.(*[20]test.Stat)[:]
- return test.CopyStatSliceOut(task, addr, slice)
+ return test.CopyStatSliceOut(cc, addr, slice)
},
func(a interface{}) []byte {
slice := a.(*[20]test.Stat)[:]
@@ -325,13 +325,13 @@ func TestLimitedSliceMarshalling(t *testing.T) {
},
{
reflect.TypeOf((*[1]test.Stat)(nil)),
- func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) {
slice := dst.(*[1]test.Stat)[:]
- return test.CopyStatSliceIn(task, addr, slice)
+ return test.CopyStatSliceIn(cc, addr, slice)
},
- func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) {
slice := src.(*[1]test.Stat)[:]
- return test.CopyStatSliceOut(task, addr, slice)
+ return test.CopyStatSliceOut(cc, addr, slice)
},
func(a interface{}) []byte {
slice := a.(*[1]test.Stat)[:]
@@ -340,13 +340,13 @@ func TestLimitedSliceMarshalling(t *testing.T) {
},
{
reflect.TypeOf((*[5]test.SignalSetAlias)(nil)),
- func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) {
slice := dst.(*[5]test.SignalSetAlias)[:]
- return test.CopySignalSetAliasSliceIn(task, addr, slice)
+ return test.CopySignalSetAliasSliceIn(cc, addr, slice)
},
- func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) {
slice := src.(*[5]test.SignalSetAlias)[:]
- return test.CopySignalSetAliasSliceOut(task, addr, slice)
+ return test.CopySignalSetAliasSliceOut(cc, addr, slice)
},
func(a interface{}) []byte {
slice := a.(*[5]test.SignalSetAlias)[:]
@@ -356,13 +356,13 @@ func TestLimitedSliceMarshalling(t *testing.T) {
// Non-packed types.
{
reflect.TypeOf((*[20]test.Type1)(nil)),
- func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) {
slice := dst.(*[20]test.Type1)[:]
- return test.CopyType1SliceIn(task, addr, slice)
+ return test.CopyType1SliceIn(cc, addr, slice)
},
- func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) {
slice := src.(*[20]test.Type1)[:]
- return test.CopyType1SliceOut(task, addr, slice)
+ return test.CopyType1SliceOut(cc, addr, slice)
},
func(a interface{}) []byte {
slice := a.(*[20]test.Type1)[:]
@@ -371,13 +371,13 @@ func TestLimitedSliceMarshalling(t *testing.T) {
},
{
reflect.TypeOf((*[1]test.Type1)(nil)),
- func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) {
slice := dst.(*[1]test.Type1)[:]
- return test.CopyType1SliceIn(task, addr, slice)
+ return test.CopyType1SliceIn(cc, addr, slice)
},
- func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) {
slice := src.(*[1]test.Type1)[:]
- return test.CopyType1SliceOut(task, addr, slice)
+ return test.CopyType1SliceOut(cc, addr, slice)
},
func(a interface{}) []byte {
slice := a.(*[1]test.Type1)[:]
@@ -386,13 +386,13 @@ func TestLimitedSliceMarshalling(t *testing.T) {
},
{
reflect.TypeOf((*[7]test.Type8)(nil)),
- func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) {
slice := dst.(*[7]test.Type8)[:]
- return test.CopyType8SliceIn(task, addr, slice)
+ return test.CopyType8SliceIn(cc, addr, slice)
},
- func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) {
slice := src.(*[7]test.Type8)[:]
- return test.CopyType8SliceOut(task, addr, slice)
+ return test.CopyType8SliceOut(cc, addr, slice)
},
func(a interface{}) []byte {
slice := a.(*[7]test.Type8)[:]
@@ -439,11 +439,11 @@ func TestLimitedSliceMarshalling(t *testing.T) {
limit += elem.SizeBytes() / 2
analysis.RandomizeValue(expected)
- var task mockTask
- task.populate(expected)
- task.setLimit(limit)
+ var cc mockCopyContext
+ cc.populate(expected)
+ cc.setLimit(limit)
- n, err := tt.copySliceIn(&task, usermem.Addr(0), actual)
+ n, err := tt.copySliceIn(&cc, usermem.Addr(0), actual)
if n != limit {
t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
}
@@ -493,11 +493,11 @@ func TestLimitedSliceMarshalling(t *testing.T) {
limit += elem.SizeBytes() / 2
analysis.RandomizeValue(expected)
- var task mockTask
- task.populate(expected)
- task.setLimit(limit)
+ var cc mockCopyContext
+ cc.populate(expected)
+ cc.setLimit(limit)
- n, err := tt.copySliceOut(&task, usermem.Addr(0), expected)
+ n, err := tt.copySliceOut(&cc, usermem.Addr(0), expected)
if n != limit {
t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
}
@@ -507,7 +507,7 @@ func TestLimitedSliceMarshalling(t *testing.T) {
expectedMem := tt.unsafeMemory(expected)
defer runtime.KeepAlive(expected)
- actualMem := task.taskMem.Bytes
+ actualMem := cc.taskMem.Bytes
compareMemory(t, expectedMem, actualMem, n)
})
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
index f75ca1b7f..d9e9f341b 100644
--- a/tools/go_marshal/test/test.go
+++ b/tools/go_marshal/test/test.go
@@ -174,3 +174,27 @@ type Type9 struct {
x int64
y [sizeA]int32
}
+
+// Type10Embed is a test data type which is be embedded into another type.
+//
+// +marshal
+type Type10Embed struct {
+ x int64
+}
+
+// Type10 is a test data type which contains an embedded struct.
+//
+// +marshal
+type Type10 struct {
+ Type10Embed
+ y int64
+}
+
+// Type11 is a test data type which contains an embedded struct from an external
+// package.
+//
+// +marshal
+type Type11 struct {
+ ex.External
+ y int64
+}
diff --git a/tools/issue_reviver/BUILD b/tools/issue_reviver/BUILD
deleted file mode 100644
index 35b0111ca..000000000
--- a/tools/issue_reviver/BUILD
+++ /dev/null
@@ -1,13 +0,0 @@
-load("//tools:defs.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "issue_reviver",
- srcs = ["main.go"],
- nogo = False,
- deps = [
- "//tools/issue_reviver/github",
- "//tools/issue_reviver/reviver",
- ],
-)
diff --git a/tools/issue_reviver/github/BUILD b/tools/issue_reviver/github/BUILD
deleted file mode 100644
index 555abd296..000000000
--- a/tools/issue_reviver/github/BUILD
+++ /dev/null
@@ -1,25 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "github",
- srcs = ["github.go"],
- nogo = False,
- visibility = [
- "//tools/issue_reviver:__subpackages__",
- ],
- deps = [
- "//tools/issue_reviver/reviver",
- "@com_github_google_go_github_v28//github:go_default_library",
- "@org_golang_x_oauth2//:go_default_library",
- ],
-)
-
-go_test(
- name = "github_test",
- size = "small",
- srcs = ["github_test.go"],
- library = ":github",
- nogo = False,
-)
diff --git a/tools/issue_reviver/main.go b/tools/issue_reviver/main.go
deleted file mode 100644
index 47c796b8a..000000000
--- a/tools/issue_reviver/main.go
+++ /dev/null
@@ -1,100 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package main is the entry point for issue_reviver.
-package main
-
-import (
- "flag"
- "fmt"
- "io/ioutil"
- "os"
- "strings"
-
- "gvisor.dev/gvisor/tools/issue_reviver/github"
- "gvisor.dev/gvisor/tools/issue_reviver/reviver"
-)
-
-var (
- owner string
- repo string
- tokenFile string
- path string
- dryRun bool
-)
-
-// Keep the options simple for now. Supports only a single path and repo.
-func init() {
- flag.StringVar(&owner, "owner", "", "Github project org/owner to look for issues")
- flag.StringVar(&repo, "repo", "", "Github repo to look for issues")
- flag.StringVar(&tokenFile, "oauth-token-file", "", "Path to file containing the OAUTH token to be used as credential to github")
- flag.StringVar(&path, "path", ".", "Path to scan for TODOs")
- flag.BoolVar(&dryRun, "dry-run", false, "If set to true, no changes are made to issues")
-}
-
-func main() {
- // Set defaults from the environment.
- repository := os.Getenv("GITHUB_REPOSITORY")
- if parts := strings.SplitN(repository, "/", 2); len(parts) == 2 {
- owner = parts[0]
- repo = parts[1]
- }
-
- // Parse flags.
- flag.Parse()
-
- // Check for mandatory parameters.
- if len(owner) == 0 {
- fmt.Println("missing --owner option.")
- flag.Usage()
- os.Exit(1)
- }
- if len(repo) == 0 {
- fmt.Println("missing --repo option.")
- flag.Usage()
- os.Exit(1)
- }
- if len(path) == 0 {
- fmt.Println("missing --path option.")
- flag.Usage()
- os.Exit(1)
- }
-
- // The access token may be passed as a file so it doesn't show up in
- // command line arguments. It also may be provided through the
- // environment to faciliate use through GitHub's CI system.
- token := os.Getenv("GITHUB_TOKEN")
- if len(tokenFile) != 0 {
- bytes, err := ioutil.ReadFile(tokenFile)
- if err != nil {
- fmt.Println(err.Error())
- os.Exit(1)
- }
- token = string(bytes)
- }
-
- bugger, err := github.NewBugger(token, owner, repo, dryRun)
- if err != nil {
- fmt.Fprintln(os.Stderr, "Error getting github issues:", err)
- os.Exit(1)
- }
- rev := reviver.New([]string{path}, []reviver.Bugger{bugger})
- if errs := rev.Run(); len(errs) > 0 {
- fmt.Fprintf(os.Stderr, "Encountered %d errors:\n", len(errs))
- for _, err := range errs {
- fmt.Fprintf(os.Stderr, "\t%v\n", err)
- }
- os.Exit(1)
- }
-}
diff --git a/tools/issue_reviver/reviver/BUILD b/tools/issue_reviver/reviver/BUILD
deleted file mode 100644
index d262932bd..000000000
--- a/tools/issue_reviver/reviver/BUILD
+++ /dev/null
@@ -1,18 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "reviver",
- srcs = ["reviver.go"],
- visibility = [
- "//tools/issue_reviver:__subpackages__",
- ],
-)
-
-go_test(
- name = "reviver_test",
- size = "small",
- srcs = ["reviver_test.go"],
- library = ":reviver",
-)
diff --git a/tools/make_apt.sh b/tools/make_apt.sh
index b47977ed5..13c5edd76 100755
--- a/tools/make_apt.sh
+++ b/tools/make_apt.sh
@@ -54,18 +54,22 @@ declare -r release="${root}/dists/${suite}"
mkdir -p "${release}"
# Create a temporary keyring, and ensure it is cleaned up.
+# Using separate homedir allows us to install apt repositories multiple times
+# using the same key. This is a limitation in GnuPG pre-2.1.
declare -r keyring=$(mktemp /tmp/keyringXXXXXX.gpg)
+declare -r homedir=$(mktemp -d /tmp/homedirXXXXXX)
+declare -r gpg_opts=("--no-default-keyring" "--secret-keyring" "${keyring}" "--homedir" "${homedir}")
cleanup() {
- rm -f "${keyring}"
+ rm -rf "${keyring}" "${homedir}"
}
trap cleanup EXIT
# We attempt the import twice because the first one will fail if the public key
# is not found. This isn't actually a failure for us, because we don't require
-# the public (this may be stored separately). The second import will succeed
+# the public key (this may be stored separately). The second import will succeed
# because, in reality, the first import succeeded and it's a no-op.
-gpg --no-default-keyring --keyring "${keyring}" --secret-keyring "${keyring}" --import "${private_key}" || \
- gpg --no-default-keyring --keyring "${keyring}" --secret-keyring "${keyring}" --import "${private_key}"
+gpg "${gpg_opts[@]}" --import "${private_key}" || \
+ gpg "${gpg_opts[@]}" --import "${private_key}"
# Copy the packages into the root.
for pkg in "$@"; do
@@ -100,7 +104,8 @@ for pkg in "$@"; do
cp -a "${pkg}" "${target}"
chmod 0644 "${target}"
if [[ "${ext}" == "deb" ]]; then
- dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${target}"
+ # We use [*] here to expand the gpg_opts array into a single shell-word.
+ dpkg-sig -g "${gpg_opts[*]}" --sign builder "${target}"
fi
done
@@ -135,5 +140,5 @@ rm "${release}"/apt.conf
# Sign the release.
declare -r digest_opts=("--digest-algo" "SHA512" "--cert-digest-algo" "SHA512")
(cd "${release}" && rm -f Release.gpg InRelease)
-(cd "${release}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign "${digest_opts[@]}" -o InRelease Release)
-(cd "${release}" && gpg --no-default-keyring --keyring "${keyring}" -abs "${digest_opts[@]}" -o Release.gpg Release)
+(cd "${release}" && gpg "${gpg_opts[@]}" --clearsign "${digest_opts[@]}" -o InRelease Release)
+(cd "${release}" && gpg "${gpg_opts[@]}" -abs "${digest_opts[@]}" -o Release.gpg Release)
diff --git a/tools/nogo/build.go b/tools/nogo/build.go
index 37947b5c3..39c2ae418 100644
--- a/tools/nogo/build.go
+++ b/tools/nogo/build.go
@@ -26,6 +26,9 @@ var (
// and should not have any special prefix applied.
internalPrefix = fmt.Sprintf("^")
+ // internalDefault is applied when no paths are provided.
+ internalDefault = fmt.Sprintf("%s/.*", notPath("external"))
+
// externalPrefix is external workspace packages.
externalPrefix = "^external/"
)
diff --git a/tools/nogo/matchers.go b/tools/nogo/matchers.go
index 57a250501..5c39be630 100644
--- a/tools/nogo/matchers.go
+++ b/tools/nogo/matchers.go
@@ -16,7 +16,6 @@ package nogo
import (
"go/token"
- "path/filepath"
"regexp"
"strings"
@@ -44,11 +43,30 @@ type pathRegexps struct {
func buildRegexps(prefix string, args ...string) []*regexp.Regexp {
result := make([]*regexp.Regexp, 0, len(args))
for _, arg := range args {
- result = append(result, regexp.MustCompile(filepath.Join(prefix, arg)))
+ result = append(result, regexp.MustCompile(prefix+arg))
}
return result
}
+// notPath works around the lack of backtracking.
+//
+// It is used to construct a regular expression for non-matching components.
+func notPath(name string) string {
+ sb := strings.Builder{}
+ sb.WriteString("(")
+ for i := range name {
+ if i > 0 {
+ sb.WriteString("|")
+ }
+ sb.WriteString(name[:i])
+ sb.WriteString("[^")
+ sb.WriteByte(name[i])
+ sb.WriteString("/][^/]*")
+ }
+ sb.WriteString(")")
+ return sb.String()
+}
+
// ShouldReport implements matcher.ShouldReport.
func (p *pathRegexps) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool {
fullPos := fs.Position(d.Pos).String()
@@ -79,7 +97,7 @@ func externalExcluded(paths ...string) *pathRegexps {
// internalMatches returns a path matcher for internal packages.
func internalMatches() *pathRegexps {
return &pathRegexps{
- expr: buildRegexps(internalPrefix, ".*"),
+ expr: buildRegexps(internalPrefix, internalDefault),
include: true,
}
}
diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go
index 40e48540d..120fdcff5 100644
--- a/tools/nogo/nogo.go
+++ b/tools/nogo/nogo.go
@@ -202,29 +202,41 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
config.Srcs[i] = path.Clean(config.Srcs[i])
}
- // Calculate the root directory.
- longestPrefix := path.Dir(config.Srcs[0])
- for _, file := range config.Srcs[1:] {
- for i := 0; i < len(file) && i < len(longestPrefix); i++ {
- if file[i] != longestPrefix[i] {
- // Truncate here; will stop the loop.
- longestPrefix = longestPrefix[:i]
- break
- }
+ // Calculate the root source directory. This is always a directory
+ // named 'src', of which we simply take the first we find. This is a
+ // bit fragile, but works for all currently known Go source
+ // configurations.
+ //
+ // Note that there may be extra files outside of the root source
+ // directory; we simply ignore those.
+ rootSrcPrefix := ""
+ for _, file := range config.Srcs {
+ const src = "/src/"
+ i := strings.Index(file, src)
+ if i == -1 {
+ // Superfluous file.
+ continue
}
- }
- if len(longestPrefix) > 0 && longestPrefix[len(longestPrefix)-1] != '/' {
- longestPrefix += "/"
+
+ // Index of first character after /src/.
+ i += len(src)
+ rootSrcPrefix = file[:i]
+ break
}
// Aggregate all files by directory.
packages := make(map[string]*packageConfig)
for _, file := range config.Srcs {
+ if !strings.HasPrefix(file, rootSrcPrefix) {
+ // Superflouous file.
+ continue
+ }
+
d := path.Dir(file)
- if len(longestPrefix) >= len(d) {
+ if len(rootSrcPrefix) >= len(d) {
continue // Not a file.
}
- pkg := path.Dir(file)[len(longestPrefix):]
+ pkg := d[len(rootSrcPrefix):]
// Skip cmd packages and obvious test files: see above.
if strings.HasPrefix(pkg, "cmd/") || strings.HasSuffix(file, "_test.go") {
continue
@@ -303,6 +315,11 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
checkOne(pkg)
}
+ // Sanity check.
+ if len(stdlibFacts) == 0 {
+ return nil, nil, fmt.Errorf("no stdlib facts found: misconfiguration?")
+ }
+
// Write out all findings.
factData, err := json.Marshal(stdlibFacts)
if err != nil {
diff --git a/tools/nogo/util/BUILD b/tools/nogo/util/BUILD
new file mode 100644
index 000000000..7ab340b51
--- /dev/null
+++ b/tools/nogo/util/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "util",
+ srcs = ["util.go"],
+ visibility = ["//visibility:public"],
+)
diff --git a/tools/nogo/util/util.go b/tools/nogo/util/util.go
new file mode 100644
index 000000000..919fec799
--- /dev/null
+++ b/tools/nogo/util/util.go
@@ -0,0 +1,85 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package util contains nogo-related utilities.
+package util
+
+import (
+ "fmt"
+ "io/ioutil"
+ "regexp"
+ "strconv"
+ "strings"
+)
+
+// findingRegexp is used to parse findings.
+var findingRegexp = regexp.MustCompile(`([a-zA-Z0-9_\/\.-]+): (-|([a-zA-Z0-9_\/\.-]+):([0-9]+)(:([0-9]+))?): (.*)`)
+
+const (
+ categoryIndex = 1
+ fullPathAndLineIndex = 2
+ fullPathIndex = 3
+ lineIndex = 4
+ messageIndex = 7
+)
+
+// Finding is a single finding.
+type Finding struct {
+ Category string
+ Path string
+ Line int
+ Message string
+}
+
+// ExtractFindingsFromFile loads findings from a file.
+func ExtractFindingsFromFile(filename string) ([]Finding, error) {
+ content, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return nil, err
+ }
+ return ExtractFindingsFromBytes(content)
+}
+
+// ExtractFindingsFromBytes loads findings from bytes.
+func ExtractFindingsFromBytes(content []byte) (findings []Finding, err error) {
+ lines := strings.Split(string(content), "\n")
+ for _, singleLine := range lines {
+ // Skip blank lines.
+ singleLine = strings.TrimSpace(singleLine)
+ if singleLine == "" {
+ continue
+ }
+ m := findingRegexp.FindStringSubmatch(singleLine)
+ if m == nil {
+ // We shouldn't see findings like this.
+ return findings, fmt.Errorf("poorly formated line: %v", singleLine)
+ }
+ if m[fullPathAndLineIndex] == "-" {
+ continue // No source file available.
+ }
+ // Cleanup the message.
+ message := m[messageIndex]
+ message = strings.Replace(message, " → ", "\n → ", -1)
+ message = strings.Replace(message, " or ", "\n or ", -1)
+ // Construct a new annotation.
+ lineNumber, _ := strconv.ParseUint(m[lineIndex], 10, 32)
+ findings = append(findings, Finding{
+ Category: m[categoryIndex],
+ Path: m[fullPathIndex],
+ Line: int(lineNumber),
+ Message: message,
+ })
+ }
+ return findings, nil
+}
diff --git a/website/BUILD b/website/BUILD
index 7b61d13c8..6d92d9103 100644
--- a/website/BUILD
+++ b/website/BUILD
@@ -157,6 +157,7 @@ docs(
"//g3doc/user_guide/quick_start:oci",
"//g3doc/user_guide/tutorials:cni",
"//g3doc/user_guide/tutorials:docker",
+ "//g3doc/user_guide/tutorials:docker_compose",
"//g3doc/user_guide/tutorials:kubernetes",
],
)
diff --git a/website/_config.yml b/website/_config.yml
index b08602970..20fbb3d2d 100644
--- a/website/_config.yml
+++ b/website/_config.yml
@@ -34,3 +34,6 @@ authors:
igudger:
name: Ian Gudger
email: igudger@google.com
+ fvoznika:
+ name: Fabricio Voznika
+ email: fvoznika@google.com
diff --git a/website/_sass/front.scss b/website/_sass/front.scss
index 0e4208f3c..f1b060560 100644
--- a/website/_sass/front.scss
+++ b/website/_sass/front.scss
@@ -1,5 +1,5 @@
.jumbotron {
- background-image: url(/assets/images/background.jpg);
+ background-image: url(/assets/images/background_1080p.jpg);
background-position: center;
background-repeat: no-repeat;
background-size: cover;
diff --git a/website/assets/images/2020-09-18-containing-a-real-vulnerability-figure1.png b/website/assets/images/2020-09-18-containing-a-real-vulnerability-figure1.png
new file mode 100644
index 000000000..c750f0851
--- /dev/null
+++ b/website/assets/images/2020-09-18-containing-a-real-vulnerability-figure1.png
Binary files differ
diff --git a/website/assets/images/background_1080p.jpg b/website/assets/images/background_1080p.jpg
new file mode 100644
index 000000000..d312595a6
--- /dev/null
+++ b/website/assets/images/background_1080p.jpg
Binary files differ
diff --git a/website/blog/2019-11-18-security-basics.md b/website/blog/2019-11-18-security-basics.md
index 2256ee9d5..b6cf57a77 100644
--- a/website/blog/2019-11-18-security-basics.md
+++ b/website/blog/2019-11-18-security-basics.md
@@ -279,8 +279,10 @@ weaknesses of each gVisor component.
We will also use it to introduce Google's Vulnerability Reward Program[^14], and
other ways the community can contribute to help make gVisor safe, fast and
stable.
+<br>
+<br>
-## Notes
+--------------------------------------------------------------------------------
[^1]: [https://en.wikipedia.org/wiki/Secure_by_design](https://en.wikipedia.org/wiki/Secure_by_design)
[^2]: [https://gvisor.dev/docs/architecture_guide](https://gvisor.dev/docs/architecture_guide/)
diff --git a/website/blog/2020-09-18-containing-a-real-vulnerability.md b/website/blog/2020-09-18-containing-a-real-vulnerability.md
new file mode 100644
index 000000000..c1b06a996
--- /dev/null
+++ b/website/blog/2020-09-18-containing-a-real-vulnerability.md
@@ -0,0 +1,223 @@
+# Containing a Real Vulnerability
+
+In the previous two posts we talked about gVisor's
+[security design principles](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/)
+as well as how those are applied in the
+[context of networking](https://gvisor.dev/blog/2020/04/02/gvisor-networking-security/).
+Recently, a new container escape vulnerability
+([CVE-2020-14386](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14386))
+was announced that ties these topics well together. gVisor is
+[not vulnerable](https://seclists.org/oss-sec/2020/q3/168) to this specific
+issue, but it provides an interesting case study to continue our exploration of
+gVisor's security. While gVisor is not immune to vulnerabilities,
+[we take several steps](https://gvisor.dev/security/) to minimize the impact and
+remediate if a vulnerability is found.
+
+## Escaping the Container
+
+First, let’s describe how the discovered vulnerability works. There are numerous
+ways one can send and receive bytes over the network with Linux. One of the most
+performant ways is to use a ring buffer, which is a memory region shared by the
+application and the kernel. These rings are created by calling
+[setsockopt(2)](https://man7.org/linux/man-pages/man2/setsockopt.2.html) with
+[`PACKET_RX_RING`](https://man7.org/linux/man-pages/man7/packet.7.html) for
+receiving and
+[`PACKET_TX_RING`](https://man7.org/linux/man-pages/man7/packet.7.html) for
+sending packets.
+
+The vulnerability is in the code that reads packets when `PACKET_RX_RING` is
+enabled. There is another option
+([`PACKET_RESERVE`](https://man7.org/linux/man-pages/man7/packet.7.html)) that
+asks the kernel to leave some space in the ring buffer before each packet for
+anything the application needs, e.g. control structures. When a packet is
+received, the kernel calculates where to copy the packet to, taking the amount
+reserved before each packet into consideration. If the amount reserved is large,
+the kernel performed an incorrect calculation which could cause an overflow
+leading to an out-of-bounds write of up to 10 bytes, controlled by the attacker.
+The data in the write is easily controlled using the loopback to send a crafted
+packet and receiving it using a `PACKET_RX_RING` with a carefully selected
+`PACKET_RESERVE` size.
+
+```c
+static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
+ struct packet_type *pt, struct net_device *orig_dev)
+{
+// ...
+ if (sk->sk_type == SOCK_DGRAM) {
+ macoff = netoff = TPACKET_ALIGN(po->tp_hdrlen) + 16 +
+ po->tp_reserve;
+ } else {
+ unsigned int maclen = skb_network_offset(skb);
+ // tp_reserve is unsigned int, netoff is unsigned short. Addition can overflow netoff
+ netoff = TPACKET_ALIGN(po->tp_hdrlen +
+ (maclen < 16 ? 16 : maclen)) +
+ po->tp_reserve;
+ if (po->has_vnet_hdr) {
+ netoff += sizeof(struct virtio_net_hdr);
+ do_vnet = true;
+ }
+ // Attacker controls netoff and can make macoff be smaller than sizeof(struct virtio_net_hdr)
+ macoff = netoff - maclen;
+ }
+// ...
+ // "macoff - sizeof(struct virtio_net_hdr)" can be negative, resulting in a pointer before h.raw
+ if (do_vnet &&
+ virtio_net_hdr_from_skb(skb, h.raw + macoff -
+ sizeof(struct virtio_net_hdr),
+ vio_le(), true, 0)) {
+// ...
+```
+
+The [`CAP_NET_RAW`](https://man7.org/linux/man-pages/man7/capabilities.7.html)
+capability is required to create the socket above. However, in order to support
+common debugging tools like `ping` and `tcpdump`, Docker containers, including
+those created for Kubernetes, are given `CAP_NET_RAW` by default and thus may be
+able to trigger this vulnerability to elevate privileges and escape the
+container.
+
+Next, we are going to explore why this vulnerability doesn’t work in gVisor, and
+how gVisor could prevent the escape even if a similar vulnerability existed
+inside gVisor’s kernel.
+
+## Default Protections
+
+gVisor does not implement `PACKET_RX_RING`, but **does** support raw sockets
+which are required for `PACKET_RX_RING`. Raw sockets are a controversial feature
+to support in a sandbox environment. While it allows great customizations for
+essential tools like `ping`, it may allow packets to be written to the network
+without any validation. In general, allowing an untrusted application to write
+crafted packets to the network is a questionable idea and a historical source of
+vulnerabilities. With that in mind, if `CAP_NET_RAW` is enabled by default, it
+would not be _secure by default_ to run untrusted applications.
+
+After multiple discussions when raw sockets were first implemented, we decided
+to disable raw sockets by default, **even if `CAP_NET_RAW` is given to the
+application**. Instead, enabling raw sockets in gVisor requires the admin to set
+`--net-raw` flag to runsc when configuring the runtime, in addition to requiring
+the `CAP_NET_RAW` capability in the application. It comes at the expense that
+some tools may not work out of the box, but as part of our
+[secure-by-default](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/#secure-by-default)
+principle, we felt that it was important for the “less secure” configuration to
+be explicit.
+
+Since this bug was due to an overflow in the specific Linux implementation of
+the packet ring, gVisor's raw socket implementation is not affected. However, if
+there were a vulnerability in gVisor, containers would not be allowed to exploit
+it by default.
+
+As an alternative way to implement this same constraint, Kubernetes allows
+[admission controllers](https://kubernetes.io/docs/reference/access-authn-authz/admission-controllers/)
+to be configured to customize requests. Cloud providers can use this to
+implement more stringent policies. For example, GKE implements an admission
+controller for gVisor that
+[removes `CAP_NET_RAW` from gVisor pods](https://cloud.google.com/kubernetes-engine/docs/concepts/sandbox-pods#capabilities)
+unless it has been explicitly set in the pod spec.
+
+## Isolated Kernel
+
+gVisor has its own application kernel, called the Sentry, that is distinct from
+the host kernel. Just like what you would expect from a kernel, gVisor has a
+memory management subsystem, virtual file system, and a full network stack. The
+host network is only used as a transport to carry packets in and out the
+sandbox[^1]. The loopback interface which is used in the exploit stays
+completely inside the sandbox, never reaching the host.
+
+Therefore, even if the Sentry was vulnerable to the attack, there would be two
+factors that would prevent a container escape from happening. First, the
+vulnerability would be limited to the Sentry, and the attacker would compromise
+only the application kernel, bound by a restricted set of
+[seccomp](https://en.wikipedia.org/wiki/Seccomp) filters, discussed more in
+depth below. Second, the Sentry is a distinct implementation of the API, written
+in Go, which provides bounds checking that would have likely prevented access
+past the bounds of the shared region (e.g. see
+[aio](https://cs.opensource.google/gvisor/gvisor/+/master:pkg/sentry/syscalls/linux/vfs2/aio.go;l=210;drc=a11061d78a58ed75b10606d1a770b035ed944b66?q=file:aio&ss=gvisor%2Fgvisor)
+or
+[kcov](https://cs.opensource.google/gvisor/gvisor/+/master:pkg/sentry/kernel/kcov.go;l=272?q=file:kcov&ss=gvisor%2Fgvisor),
+which have similar shared regions).
+
+Here, Kubernetes warrants slightly more explanation. gVisor makes pods the unit
+of isolation and a pod can run multiple containers. In other words, each pod is
+a gVisor instance, and each container is a set of processes running inside
+gVisor, isolated via Sentry-internal namespaces like regular containers inside a
+pod. If there were a vulnerability in gVisor, the privilege escalation would
+allow a container inside the pod to break out to other **containers inside the
+same pod**, but the container still **cannot break out of the pod**.
+
+## Defense in Depth
+
+gVisor follows a
+[common security principle used at Google](https://cloud.google.com/security/infrastructure/design/resources/google_infrastructure_whitepaper_fa.pdf)
+that the system should have two layers of protection, and those layers should
+require different compromises to be broken. We apply this principle by assuming
+that the Sentry (first layer of defense)
+[will be compromised and should not be trusted](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/#defense-in-depth).
+In order to protect the host kernel from a compromised Sentry, we wrap it around
+many security and isolations features to ensure only the minimal set of
+functionality from the host kernel is exposed.
+
+![Figure 1](/assets/images/2020-09-18-containing-a-real-vulnerability-figure1.png "Protection layers.")
+
+First, the sandbox runs inside a cgroup that can limit and throttle host
+resources being used. Second, the sandbox joins empty namespaces, including user
+and mount, to further isolate from the host. Next, it changes the process root
+to a read-only directory that contains only `/proc` and nothing else. Then, it
+executes with the unprivileged user/group
+[`nobody`](https://en.wikipedia.org/wiki/Nobody_\(username\)) with all
+capabilities stripped. Last and most importantly, a seccomp filter is added to
+tightly restrict what parts of the Linux syscall surface that gVisor is allowed
+to access. The allowed host surface is a far smaller set of syscalls than the
+Sentry implements for applications to use. Not only restricting the syscall
+being called, but also checking that arguments to these syscalls are within the
+expected set. Dangerous syscalls like <code>execve(2)</code>,
+<code>open(2)</code>, and <code>socket(2)</code> are prohibited, thus an
+attacker isn’t able to execute binaries or acquire new resources on the host.
+
+if there were a vulnerability in gVisor that allowed an attacker to execute code
+inside the Sentry, the attacker still has extremely limited privileges on the
+host. In fact, a compromised Sentry is much more restricted than a
+non-compromised regular container. For CVE-2020-14386 in particular, the attack
+would be blocked by more than one security layer: non-privileged user, no
+capability, and seccomp filters.
+
+Although the surface is drastically reduced, there is still a chance that there
+is a vulnerability in one of the allowed syscalls. That’s why it’s important to
+keep the surface small and carefully consider what syscalls are allowed. You can
+find the full set of allowed syscalls
+[here](https://cs.opensource.google/gvisor/gvisor/+/master:runsc/boot/filter/).
+
+Another possible attack vector is resources that are present in the Sentry, like
+open file descriptors. The Sentry has file descriptors that an attacker could
+potentially use, such as log files, platform files (e.g. `/dev/kvm`), an RPC
+endpoint that allows external communication with the Sentry, and a Netstack
+endpoint that connects the sandbox to the network. The Netstack endpoint in
+particular is a concern because it gives direct access to the network. It’s an
+`AF_PACKET` socket that allows arbitrary L2 packets to be written to the
+network. In the normal case, Netstack assembles packets that go out the network,
+giving the container control over only the payload. But if the Sentry is
+compromised, an attacker can craft packets to the network. In many ways this is
+similar to anyone sending random packets over the internet, but still this is a
+place where the host kernel surface exposed is larger than we would like it to
+be.
+
+## Conclusion
+
+Security comes with many tradeoffs that are often hard to make, such as the
+decision to disable raw sockets by default. However, these tradeoffs have served
+us well, and we've found them to have paid off over time. CVE-2020-14386 offers
+great insight into how multiple layers of protection can be effective against
+such an attack.
+
+We cannot guarantee that a container escape will never happen in gVisor, but we
+do our best to make it as hard as we possibly can.
+
+If you have not tried gVisor yet, it’s easier than you think. Just follow the
+steps [here](https://gvisor.dev/docs/user_guide/install/).
+<br>
+<br>
+
+--------------------------------------------------------------------------------
+
+[^1]: Those packets are eventually handled by the host, as it needs to route
+ them to local containers or send them out the NIC. The packet will be
+ handled by many switches, routers, proxies, servers, etc. along the way,
+ which may be subject to their own vulnerabilities.
diff --git a/website/blog/BUILD b/website/blog/BUILD
index 01c1f5a6e..865e403da 100644
--- a/website/blog/BUILD
+++ b/website/blog/BUILD
@@ -28,6 +28,16 @@ doc(
permalink = "/blog/2020/04/02/gvisor-networking-security/",
)
+doc(
+ name = "containing_a_real_vulnerability",
+ src = "2020-09-18-containing-a-real-vulnerability.md",
+ authors = [
+ "fvoznika",
+ ],
+ layout = "post",
+ permalink = "/blog/2020/09/18/containing-a-real-vulnerability/",
+)
+
docs(
name = "posts",
deps = [
diff --git a/website/css/main.scss b/website/css/main.scss
index 06106833f..4b3b7b500 100644
--- a/website/css/main.scss
+++ b/website/css/main.scss
@@ -1,5 +1,10 @@
-@import 'style.scss';
-@import 'front.scss';
-@import 'navbar.scss';
-@import 'sidebar.scss';
-@import 'footer.scss';
+// The main style sheet for gvisor.dev
+
+// NOTE: Do not include file extensions to import .sass and .css files seamlessly.
+@import "style";
+@import "front";
+@import "navbar";
+@import "sidebar";
+@import "footer";
+// syntax is generated by rougify.
+@import "syntax";
diff --git a/website/index.md b/website/index.md
index 84f877d49..c6cd477c2 100644
--- a/website/index.md
+++ b/website/index.md
@@ -5,7 +5,7 @@
<div class="col-md-6">
<p>gVisor is an <b>application kernel</b> for <b>containers</b> that provides efficient defense-in-depth anywhere.</p>
<p style="margin-top: 20px;">
- <a class="btn" href="/docs/user_guide/quick_start/docker/">Quick start&nbsp;<i class="fas fa-arrow-alt-circle-right ml-2"></i></a>
+ <a class="btn" href="/docs/user_guide/install/">Get started&nbsp;<i class="fas fa-arrow-alt-circle-right ml-2"></i></a>
<a class="btn" href="/docs/">Learn More&nbsp;<i class="fas fa-arrow-alt-circle-right ml-2"></i></a>
</p>
</div>