summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD5
-rw-r--r--pkg/abi/linux/aio.go60
-rw-r--r--pkg/abi/linux/dev.go4
-rw-r--r--pkg/abi/linux/fadvise.go24
-rw-r--r--pkg/abi/linux/fcntl.go2
-rw-r--r--pkg/abi/linux/file.go5
-rw-r--r--pkg/abi/linux/fuse.go248
-rw-r--r--pkg/abi/linux/futex.go18
-rw-r--r--pkg/abi/linux/ioctl.go27
-rw-r--r--pkg/abi/linux/ip.go10
-rw-r--r--pkg/abi/linux/netdevice.go4
-rw-r--r--pkg/abi/linux/netfilter.go154
-rw-r--r--pkg/abi/linux/netlink_route.go2
-rw-r--r--pkg/abi/linux/socket.go21
-rw-r--r--pkg/abi/linux/tcp.go1
-rw-r--r--pkg/bpf/interpreter_test.go2
-rw-r--r--pkg/buffer/safemem.go82
-rw-r--r--pkg/cleanup/BUILD17
-rw-r--r--pkg/cleanup/cleanup.go60
-rw-r--r--pkg/cleanup/cleanup_test.go66
-rw-r--r--pkg/compressio/compressio.go54
-rw-r--r--pkg/cpuid/cpuid_arm64.go5
-rw-r--r--pkg/cpuid/cpuid_x86.go7
-rw-r--r--pkg/flipcall/BUILD1
-rw-r--r--pkg/flipcall/flipcall.go2
-rw-r--r--pkg/flipcall/packet_window_mmap.go25
-rw-r--r--pkg/gohacks/BUILD1
-rw-r--r--pkg/ilist/list.go6
-rw-r--r--pkg/iovec/BUILD18
-rw-r--r--pkg/iovec/iovec.go75
-rw-r--r--pkg/iovec/iovec_test.go121
-rw-r--r--pkg/merkletree/BUILD16
-rw-r--r--pkg/merkletree/merkletree.go135
-rw-r--r--pkg/merkletree/merkletree_test.go122
-rw-r--r--pkg/p9/messages.go2
-rw-r--r--pkg/p9/p9.go13
-rw-r--r--pkg/p9/server.go4
-rw-r--r--pkg/procid/procid_amd64.s2
-rw-r--r--pkg/procid/procid_arm64.s2
-rw-r--r--pkg/seccomp/seccomp_rules.go4
-rw-r--r--pkg/sentry/arch/arch_aarch64.go33
-rw-r--r--pkg/sentry/arch/arch_amd64.go4
-rw-r--r--pkg/sentry/arch/arch_arm64.go4
-rw-r--r--pkg/sentry/arch/arch_x86.go16
-rw-r--r--pkg/sentry/control/BUILD5
-rw-r--r--pkg/sentry/control/logging.go4
-rw-r--r--pkg/sentry/control/proc.go128
-rw-r--r--pkg/sentry/device/device.go3
-rw-r--r--pkg/sentry/devices/memdev/full.go1
-rw-r--r--pkg/sentry/devices/memdev/null.go1
-rw-r--r--pkg/sentry/devices/memdev/random.go1
-rw-r--r--pkg/sentry/devices/memdev/zero.go1
-rw-r--r--pkg/sentry/devices/ttydev/BUILD16
-rw-r--r--pkg/sentry/devices/ttydev/ttydev.go91
-rw-r--r--pkg/sentry/devices/tundev/BUILD23
-rw-r--r--pkg/sentry/devices/tundev/tundev.go178
-rw-r--r--pkg/sentry/fdimport/fdimport.go5
-rw-r--r--pkg/sentry/fs/file.go11
-rw-r--r--pkg/sentry/fs/file_operations.go1
-rw-r--r--pkg/sentry/fs/filesystems.go14
-rw-r--r--pkg/sentry/fs/fs.go3
-rw-r--r--pkg/sentry/fs/fsutil/BUILD7
-rw-r--r--pkg/sentry/fs/fsutil/dirty_set.go7
-rw-r--r--pkg/sentry/fs/fsutil/file_range_set.go15
-rw-r--r--pkg/sentry/fs/fsutil/frame_ref_set.go10
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go5
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go19
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go25
-rw-r--r--pkg/sentry/fs/g3doc/.gitignore1
-rw-r--r--pkg/sentry/fs/g3doc/fuse.md263
-rw-r--r--pkg/sentry/fs/gofer/inode.go2
-rw-r--r--pkg/sentry/fs/host/BUILD1
-rw-r--r--pkg/sentry/fs/host/inode.go3
-rw-r--r--pkg/sentry/fs/host/socket.go10
-rw-r--r--pkg/sentry/fs/host/socket_iovec.go7
-rw-r--r--pkg/sentry/fs/host/tty.go6
-rw-r--r--pkg/sentry/fs/lock/lock.go41
-rw-r--r--pkg/sentry/fs/lock/lock_set_functions.go8
-rw-r--r--pkg/sentry/fs/lock/lock_test.go111
-rw-r--r--pkg/sentry/fs/mounts.go72
-rw-r--r--pkg/sentry/fs/user/BUILD8
-rw-r--r--pkg/sentry/fs/user/path.go170
-rw-r--r--pkg/sentry/fs/user/user.go2
-rw-r--r--pkg/sentry/fsimpl/devpts/BUILD1
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go4
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go21
-rw-r--r--pkg/sentry/fsimpl/devpts/slave.go23
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs.go16
-rw-r--r--pkg/sentry/fsimpl/eventfd/eventfd.go1
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD5
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_file.go9
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_test.go29
-rw-r--r--pkg/sentry/fsimpl/ext/dentry.go17
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go22
-rw-r--r--pkg/sentry/fsimpl/ext/extent_file.go5
-rw-r--r--pkg/sentry/fsimpl/ext/extent_test.go22
-rw-r--r--pkg/sentry/fsimpl/ext/file_description.go1
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go28
-rw-r--r--pkg/sentry/fsimpl/ext/regular_file.go29
-rw-r--r--pkg/sentry/fsimpl/ext/symlink.go14
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD63
-rw-r--r--pkg/sentry/fsimpl/fuse/connection.go437
-rw-r--r--pkg/sentry/fsimpl/fuse/dev.go397
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go428
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go228
-rw-r--r--pkg/sentry/fsimpl/fuse/init.go166
-rw-r--r--pkg/sentry/fsimpl/fuse/register.go42
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD4
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go15
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go307
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go349
-rw-r--r--pkg/sentry/fsimpl/gofer/handle.go5
-rw-r--r--pkg/sentry/fsimpl/gofer/host_named_pipe.go97
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go153
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go172
-rw-r--r--pkg/sentry/fsimpl/gofer/time.go15
-rw-r--r--pkg/sentry/fsimpl/host/BUILD3
-rw-r--r--pkg/sentry/fsimpl/host/host.go127
-rw-r--r--pkg/sentry/fsimpl/host/mmap.go21
-rw-r--r--pkg/sentry/fsimpl/host/socket.go9
-rw-r--r--pkg/sentry/fsimpl/host/socket_iovec.go7
-rw-r--r--pkg/sentry/fsimpl/host/tty.go17
-rw-r--r--pkg/sentry/fsimpl/host/util.go10
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD3
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go26
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go38
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go12
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go14
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go27
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go12
-rw-r--r--pkg/sentry/fsimpl/overlay/BUILD41
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go262
-rw-r--r--pkg/sentry/fsimpl/overlay/directory.go287
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go1364
-rw-r--r--pkg/sentry/fsimpl/overlay/non_directory.go266
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go627
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go7
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD1
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go10
-rw-r--r--pkg/sentry/fsimpl/proc/task.go8
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go9
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go86
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go8
-rw-r--r--pkg/sentry/fsimpl/signalfd/signalfd.go1
-rw-r--r--pkg/sentry/fsimpl/sys/BUILD2
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go8
-rw-r--r--pkg/sentry/fsimpl/sys/sys_test.go2
-rw-r--r--pkg/sentry/fsimpl/testutil/BUILD2
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go8
-rw-r--r--pkg/sentry/fsimpl/timerfd/timerfd.go1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/device_file.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go10
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go120
-rw-r--r--pkg/sentry/fsimpl/tmpfs/named_pipe.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go98
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file_test.go33
-rw-r--r--pkg/sentry/fsimpl/tmpfs/socket_file.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/symlink.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go216
-rw-r--r--pkg/sentry/kernel/BUILD3
-rw-r--r--pkg/sentry/kernel/aio.go81
-rw-r--r--pkg/sentry/kernel/auth/credentials.go28
-rw-r--r--pkg/sentry/kernel/context.go53
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go31
-rw-r--r--pkg/sentry/kernel/epoll/epoll_state.go3
-rw-r--r--pkg/sentry/kernel/fasync/BUILD1
-rw-r--r--pkg/sentry/kernel/fasync/fasync.go18
-rw-r--r--pkg/sentry/kernel/fd_table.go56
-rw-r--r--pkg/sentry/kernel/futex/futex.go8
-rw-r--r--pkg/sentry/kernel/kernel.go48
-rw-r--r--pkg/sentry/kernel/pipe/BUILD3
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go6
-rw-r--r--pkg/sentry/kernel/pipe/pipe_unsafe.go35
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go247
-rw-r--r--pkg/sentry/kernel/shm/BUILD1
-rw-r--r--pkg/sentry/kernel/shm/shm.go3
-rw-r--r--pkg/sentry/kernel/syslog.go9
-rw-r--r--pkg/sentry/kernel/task.go19
-rw-r--r--pkg/sentry/kernel/task_exec.go7
-rw-r--r--pkg/sentry/kernel/task_exit.go3
-rw-r--r--pkg/sentry/kernel/task_futex.go125
-rw-r--r--pkg/sentry/kernel/task_run.go17
-rw-r--r--pkg/sentry/kernel/task_work.go38
-rw-r--r--pkg/sentry/kernel/thread_group.go3
-rw-r--r--pkg/sentry/kernel/threads.go7
-rw-r--r--pkg/sentry/kernel/time/BUILD1
-rw-r--r--pkg/sentry/kernel/time/tcpip.go131
-rw-r--r--pkg/sentry/kernel/timekeeper.go9
-rw-r--r--pkg/sentry/kernel/vdso.go6
-rw-r--r--pkg/sentry/loader/BUILD4
-rw-r--r--pkg/sentry/loader/elf.go15
-rw-r--r--pkg/sentry/loader/loader.go21
-rw-r--r--pkg/sentry/loader/vdso.go61
-rw-r--r--pkg/sentry/memmap/BUILD14
-rw-r--r--pkg/sentry/memmap/memmap.go60
-rw-r--r--pkg/sentry/mm/BUILD4
-rw-r--r--pkg/sentry/mm/aio_context.go3
-rw-r--r--pkg/sentry/mm/mm.go10
-rw-r--r--pkg/sentry/mm/pma.go25
-rw-r--r--pkg/sentry/mm/special_mappable.go7
-rw-r--r--pkg/sentry/pgalloc/BUILD29
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go278
-rw-r--r--pkg/sentry/pgalloc/pgalloc_test.go206
-rw-r--r--pkg/sentry/pgalloc/save_restore.go13
-rw-r--r--pkg/sentry/platform/BUILD20
-rw-r--r--pkg/sentry/platform/kvm/BUILD4
-rw-r--r--pkg/sentry/platform/kvm/address_space.go76
-rw-r--r--pkg/sentry/platform/kvm/bluepill_allocator.go (renamed from pkg/sentry/platform/kvm/allocator.go)52
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.go12
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go32
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go15
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go34
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go70
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64_test.go51
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64.go15
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64_unsafe.go4
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go18
-rw-r--r--pkg/sentry/platform/kvm/kvm_const_arm64.go18
-rw-r--r--pkg/sentry/platform/kvm/kvm_test.go24
-rw-r--r--pkg/sentry/platform/kvm/machine.go52
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go34
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go2
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go80
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go37
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go4
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.go4
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.s12
-rw-r--r--pkg/sentry/platform/platform.go50
-rw-r--r--pkg/sentry/platform/ptrace/BUILD1
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go3
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_unsafe.go2
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s36
-rw-r--r--pkg/sentry/platform/ring0/kernel.go24
-rw-r--r--pkg/sentry/platform/ring0/kernel_amd64.go12
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go8
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.go6
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.s86
-rw-r--r--pkg/sentry/platform/ring0/pagetables/allocator.go11
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables.go8
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/BUILD3
-rw-r--r--pkg/sentry/socket/hostinet/socket.go17
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go21
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go177
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go2
-rw-r--r--pkg/sentry/socket/netfilter/targets.go2
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go36
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go36
-rw-r--r--pkg/sentry/socket/netlink/BUILD3
-rw-r--r--pkg/sentry/socket/netlink/socket.go14
-rw-r--r--pkg/sentry/socket/netlink/socket_vfs2.go18
-rw-r--r--pkg/sentry/socket/netstack/BUILD4
-rw-r--r--pkg/sentry/socket/netstack/netstack.go470
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go29
-rw-r--r--pkg/sentry/socket/netstack/stack.go33
-rw-r--r--pkg/sentry/socket/socket.go4
-rw-r--r--pkg/sentry/socket/unix/BUILD2
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go5
-rw-r--r--pkg/sentry/socket/unix/unix.go42
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go20
-rw-r--r--pkg/sentry/strace/epoll.go10
-rw-r--r--pkg/sentry/strace/socket.go1
-rw-r--r--pkg/sentry/syscalls/linux/BUILD2
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go169
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go91
-rw-r--r--pkg/sentry/syscalls/linux/sys_futex.go48
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go17
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go6
-rw-r--r--pkg/sentry/syscalls/linux/sys_xattr.go12
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD11
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/aio.go216
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go187
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go16
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/inotify.go137
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go72
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/lock.go64
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mount.go150
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/read_write.go69
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go56
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go17
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go486
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sync.go42
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go132
-rw-r--r--pkg/sentry/time/muldiv_arm64.s3
-rw-r--r--pkg/sentry/time/parameters.go12
-rw-r--r--pkg/sentry/time/parameters_test.go15
-rw-r--r--pkg/sentry/vfs/BUILD16
-rw-r--r--pkg/sentry/vfs/README.md6
-rw-r--r--pkg/sentry/vfs/anonfs.go15
-rw-r--r--pkg/sentry/vfs/dentry.go53
-rw-r--r--pkg/sentry/vfs/epoll.go5
-rw-r--r--pkg/sentry/vfs/file_description.go151
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go86
-rw-r--r--pkg/sentry/vfs/file_description_impl_util_test.go9
-rw-r--r--pkg/sentry/vfs/filesystem.go2
-rw-r--r--pkg/sentry/vfs/g3doc/inotify.md210
-rw-r--r--pkg/sentry/vfs/genericfstree/genericfstree.go3
-rw-r--r--pkg/sentry/vfs/inotify.go774
-rw-r--r--pkg/sentry/vfs/lock.go (renamed from pkg/sentry/vfs/lock/lock.go)43
-rw-r--r--pkg/sentry/vfs/lock/BUILD13
-rw-r--r--pkg/sentry/vfs/mount.go102
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go2
-rw-r--r--pkg/sentry/vfs/options.go21
-rw-r--r--pkg/sentry/vfs/permissions.go53
-rw-r--r--pkg/sentry/vfs/vfs.go13
-rw-r--r--pkg/sentry/watchdog/watchdog.go14
-rw-r--r--pkg/shim/runsc/BUILD16
-rw-r--r--pkg/shim/runsc/runsc.go514
-rw-r--r--pkg/shim/runsc/utils.go44
-rw-r--r--pkg/shim/v1/proc/BUILD36
-rw-r--r--pkg/shim/v1/proc/deleted_state.go49
-rw-r--r--pkg/shim/v1/proc/exec.go281
-rw-r--r--pkg/shim/v1/proc/exec_state.go154
-rw-r--r--pkg/shim/v1/proc/init.go460
-rw-r--r--pkg/shim/v1/proc/init_state.go182
-rw-r--r--pkg/shim/v1/proc/io.go162
-rw-r--r--pkg/shim/v1/proc/process.go37
-rw-r--r--pkg/shim/v1/proc/types.go69
-rw-r--r--pkg/shim/v1/proc/utils.go90
-rw-r--r--pkg/shim/v1/shim/BUILD40
-rw-r--r--pkg/shim/v1/shim/api.go28
-rw-r--r--pkg/shim/v1/shim/platform.go106
-rw-r--r--pkg/shim/v1/shim/service.go573
-rw-r--r--pkg/shim/v1/utils/BUILD27
-rw-r--r--pkg/shim/v1/utils/annotations.go25
-rw-r--r--pkg/shim/v1/utils/utils.go56
-rw-r--r--pkg/shim/v1/utils/volumes.go155
-rw-r--r--pkg/shim/v1/utils/volumes_test.go308
-rw-r--r--pkg/shim/v2/BUILD43
-rw-r--r--pkg/shim/v2/api.go22
-rw-r--r--pkg/shim/v2/epoll.go129
-rw-r--r--pkg/shim/v2/options/BUILD11
-rw-r--r--pkg/shim/v2/options/options.go33
-rw-r--r--pkg/shim/v2/runtimeoptions/BUILD20
-rw-r--r--pkg/shim/v2/runtimeoptions/runtimeoptions.go27
-rw-r--r--pkg/shim/v2/runtimeoptions/runtimeoptions.proto25
-rw-r--r--pkg/shim/v2/service.go824
-rw-r--r--pkg/shim/v2/service_linux.go108
-rw-r--r--pkg/sleep/BUILD1
-rw-r--r--pkg/sleep/sleep_test.go20
-rw-r--r--pkg/sleep/sleep_unsafe.go9
-rw-r--r--pkg/state/BUILD68
-rw-r--r--pkg/state/README.md158
-rw-r--r--pkg/state/decode.go918
-rw-r--r--pkg/state/decode_unsafe.go27
-rw-r--r--pkg/state/encode.go1025
-rw-r--r--pkg/state/encode_unsafe.go48
-rw-r--r--pkg/state/map.go232
-rw-r--r--pkg/state/object.proto140
-rw-r--r--pkg/state/pretty/BUILD13
-rw-r--r--pkg/state/pretty/pretty.go273
-rw-r--r--pkg/state/printer.go251
-rw-r--r--pkg/state/state.go360
-rw-r--r--pkg/state/state_norace.go19
-rw-r--r--pkg/state/state_race.go19
-rw-r--r--pkg/state/state_test.go721
-rw-r--r--pkg/state/statefile/BUILD1
-rw-r--r--pkg/state/statefile/statefile.go15
-rw-r--r--pkg/state/stats.go117
-rw-r--r--pkg/state/tests/BUILD43
-rw-r--r--pkg/state/tests/array.go35
-rw-r--r--pkg/state/tests/array_test.go134
-rw-r--r--pkg/state/tests/bench.go24
-rw-r--r--pkg/state/tests/bench_test.go153
-rw-r--r--pkg/state/tests/bool_test.go31
-rw-r--r--pkg/state/tests/float_test.go118
-rw-r--r--pkg/state/tests/integer.go163
-rw-r--r--pkg/state/tests/integer_test.go94
-rw-r--r--pkg/state/tests/load.go61
-rw-r--r--pkg/state/tests/load_test.go70
-rw-r--r--pkg/state/tests/map.go28
-rw-r--r--pkg/state/tests/map_test.go90
-rw-r--r--pkg/state/tests/register.go21
-rw-r--r--pkg/state/tests/register_test.go167
-rw-r--r--pkg/state/tests/string_test.go34
-rw-r--r--pkg/state/tests/struct.go65
-rw-r--r--pkg/state/tests/struct_test.go89
-rw-r--r--pkg/state/tests/tests.go215
-rw-r--r--pkg/state/types.go361
-rw-r--r--pkg/state/wire/BUILD12
-rw-r--r--pkg/state/wire/wire.go970
-rw-r--r--pkg/sync/BUILD1
-rw-r--r--pkg/sync/memmove_unsafe.go2
-rw-r--r--pkg/sync/mutex_unsafe.go2
-rw-r--r--pkg/sync/nocopy.go28
-rw-r--r--pkg/sync/rwmutex_unsafe.go2
-rw-r--r--pkg/syncevent/waiter_unsafe.go2
-rw-r--r--pkg/syserr/netstack.go2
-rw-r--r--pkg/syserror/syserror.go2
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go5
-rw-r--r--pkg/tcpip/checker/checker.go16
-rw-r--r--pkg/tcpip/header/BUILD4
-rw-r--r--pkg/tcpip/header/arp.go77
-rw-r--r--pkg/tcpip/header/eth.go4
-rw-r--r--pkg/tcpip/header/icmpv4.go1
-rw-r--r--pkg/tcpip/header/icmpv6.go11
-rw-r--r--pkg/tcpip/header/ipv4.go5
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go7
-rw-r--r--pkg/tcpip/link/channel/BUILD1
-rw-r--r--pkg/tcpip/link/channel/channel.go20
-rw-r--r--pkg/tcpip/link/fdbased/BUILD1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go111
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go91
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go2
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go12
-rw-r--r--pkg/tcpip/link/loopback/loopback.go14
-rw-r--r--pkg/tcpip/link/muxed/BUILD1
-rw-r--r--pkg/tcpip/link/muxed/injectable.go16
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go4
-rw-r--r--pkg/tcpip/link/nested/BUILD32
-rw-r--r--pkg/tcpip/link/nested/nested.go152
-rw-r--r--pkg/tcpip/link/nested/nested_test.go109
-rw-r--r--pkg/tcpip/link/packetsocket/BUILD14
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go50
-rw-r--r--pkg/tcpip/link/qdisc/fifo/BUILD1
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go26
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go2
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go33
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go28
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go30
-rw-r--r--pkg/tcpip/link/sniffer/BUILD1
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go110
-rw-r--r--pkg/tcpip/link/tun/device.go55
-rw-r--r--pkg/tcpip/link/waitable/BUILD2
-rw-r--r--pkg/tcpip/link/waitable/waitable.go26
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go31
-rw-r--r--pkg/tcpip/network/arp/arp.go36
-rw-r--r--pkg/tcpip/network/arp/arp_test.go60
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go98
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go139
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go4
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go2
-rw-r--r--pkg/tcpip/network/ip_test.go80
-rw-r--r--pkg/tcpip/network/ipv4/BUILD2
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go18
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go153
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go231
-rw-r--r--pkg/tcpip/network/ipv6/BUILD3
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go27
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go109
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go141
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go365
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go43
-rw-r--r--pkg/tcpip/ports/ports.go308
-rw-r--r--pkg/tcpip/ports/ports_test.go58
-rw-r--r--pkg/tcpip/stack/BUILD45
-rw-r--r--pkg/tcpip/stack/conntrack.go756
-rw-r--r--pkg/tcpip/stack/fake_time_test.go209
-rw-r--r--pkg/tcpip/stack/forwarder.go4
-rw-r--r--pkg/tcpip/stack/forwarder_test.go143
-rw-r--r--pkg/tcpip/stack/iptables.go308
-rw-r--r--pkg/tcpip/stack/iptables_state.go40
-rw-r--r--pkg/tcpip/stack/iptables_targets.go28
-rw-r--r--pkg/tcpip/stack/iptables_types.go129
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go2
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go2
-rw-r--r--pkg/tcpip/stack/ndp.go252
-rw-r--r--pkg/tcpip/stack/ndp_test.go274
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go335
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go1752
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go482
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go2770
-rw-r--r--pkg/tcpip/stack/neighborstate_string.go44
-rw-r--r--pkg/tcpip/stack/nic.go231
-rw-r--r--pkg/tcpip/stack/nic_test.go269
-rw-r--r--pkg/tcpip/stack/nud.go466
-rw-r--r--pkg/tcpip/stack/nud_test.go795
-rw-r--r--pkg/tcpip/stack/packet_buffer.go27
-rw-r--r--pkg/tcpip/stack/registration.go93
-rw-r--r--pkg/tcpip/stack/route.go33
-rw-r--r--pkg/tcpip/stack/stack.go158
-rw-r--r--pkg/tcpip/stack/stack_options.go106
-rw-r--r--pkg/tcpip/stack/stack_test.go405
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go164
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go7
-rw-r--r--pkg/tcpip/stack/transport_test.go48
-rw-r--r--pkg/tcpip/tcpip.go248
-rw-r--r--pkg/tcpip/time_unsafe.go32
-rw-r--r--pkg/tcpip/timer.go168
-rw-r--r--pkg/tcpip/timer_test.go91
-rw-r--r--pkg/tcpip/transport/icmp/BUILD1
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go39
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go12
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go208
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go19
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go183
-rw-r--r--pkg/tcpip/transport/tcp/BUILD16
-rw-r--r--pkg/tcpip/transport/tcp/accept.go164
-rw-r--r--pkg/tcpip/transport/tcp/connect.go87
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go150
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go203
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go90
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go70
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go4
-rw-r--r--pkg/tcpip/transport/tcp/segment.go43
-rw-r--r--pkg/tcpip/transport/tcp/segment_unsafe.go23
-rw-r--r--pkg/tcpip/transport/tcp/snd.go65
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go32
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go14
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go751
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go22
-rw-r--r--pkg/tcpip/transport/tcp/timer.go1
-rw-r--r--pkg/tcpip/transport/tcp/timer_test.go47
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go5
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go228
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go7
-rw-r--r--pkg/tcpip/transport/udp/protocol.go72
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go702
-rw-r--r--pkg/test/criutil/criutil.go88
-rw-r--r--pkg/test/dockerutil/BUILD34
-rw-r--r--pkg/test/dockerutil/README.md86
-rw-r--r--pkg/test/dockerutil/container.go558
-rw-r--r--pkg/test/dockerutil/dockerutil.go487
-rw-r--r--pkg/test/dockerutil/exec.go193
-rw-r--r--pkg/test/dockerutil/network.go113
-rw-r--r--pkg/test/dockerutil/profile.go152
-rw-r--r--pkg/test/dockerutil/profile_test.go117
-rw-r--r--pkg/test/testutil/BUILD2
-rw-r--r--pkg/test/testutil/testutil.go20
-rw-r--r--pkg/tmutex/BUILD17
-rw-r--r--pkg/tmutex/tmutex.go81
-rw-r--r--pkg/tmutex/tmutex_test.go258
-rw-r--r--pkg/waiter/waiter.go18
527 files changed, 39736 insertions, 8706 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 114b516e2..05ca5342f 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -23,11 +23,13 @@ go_library(
"errors.go",
"eventfd.go",
"exec.go",
+ "fadvise.go",
"fcntl.go",
"file.go",
"file_amd64.go",
"file_arm64.go",
"fs.go",
+ "fuse.go",
"futex.go",
"inotify.go",
"ioctl.go",
@@ -71,6 +73,9 @@ go_library(
"//pkg/abi",
"//pkg/binary",
"//pkg/bits",
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/abi/linux/aio.go b/pkg/abi/linux/aio.go
index 3c6e0079d..86ee3f8b5 100644
--- a/pkg/abi/linux/aio.go
+++ b/pkg/abi/linux/aio.go
@@ -14,7 +14,63 @@
package linux
+import "encoding/binary"
+
+// AIORingSize is sizeof(struct aio_ring).
+const AIORingSize = 32
+
+// I/O commands.
const (
- // AIORingSize is sizeof(struct aio_ring).
- AIORingSize = 32
+ IOCB_CMD_PREAD = 0
+ IOCB_CMD_PWRITE = 1
+ IOCB_CMD_FSYNC = 2
+ IOCB_CMD_FDSYNC = 3
+ // 4 was the experimental IOCB_CMD_PREADX.
+ IOCB_CMD_POLL = 5
+ IOCB_CMD_NOOP = 6
+ IOCB_CMD_PREADV = 7
+ IOCB_CMD_PWRITEV = 8
)
+
+// I/O flags.
+const (
+ IOCB_FLAG_RESFD = 1
+ IOCB_FLAG_IOPRIO = 2
+)
+
+// IOCallback describes an I/O request.
+//
+// The priority field is currently ignored in the implementation below. Also
+// note that the IOCB_FLAG_RESFD feature is not supported.
+type IOCallback struct {
+ Data uint64
+ Key uint32
+ _ uint32
+
+ OpCode uint16
+ ReqPrio int16
+ FD int32
+
+ Buf uint64
+ Bytes uint64
+ Offset int64
+
+ Reserved2 uint64
+ Flags uint32
+
+ // eventfd to signal if IOCB_FLAG_RESFD is set in flags.
+ ResFD int32
+}
+
+// IOEvent describes an I/O result.
+//
+// +stateify savable
+type IOEvent struct {
+ Data uint64
+ Obj uint64
+ Result int64
+ Result2 int64
+}
+
+// IOEventSize is the size of an ioEvent encoded.
+var IOEventSize = binary.Size(IOEvent{})
diff --git a/pkg/abi/linux/dev.go b/pkg/abi/linux/dev.go
index fa3ae5f18..192e2093b 100644
--- a/pkg/abi/linux/dev.go
+++ b/pkg/abi/linux/dev.go
@@ -46,6 +46,10 @@ const (
// TTYAUX_MAJOR is the major device number for alternate TTY devices.
TTYAUX_MAJOR = 5
+ // MISC_MAJOR is the major device number for non-serial mice, misc feature
+ // devices.
+ MISC_MAJOR = 10
+
// UNIX98_PTY_MASTER_MAJOR is the initial major device number for
// Unix98 PTY masters.
UNIX98_PTY_MASTER_MAJOR = 128
diff --git a/pkg/abi/linux/fadvise.go b/pkg/abi/linux/fadvise.go
new file mode 100644
index 000000000..b06ff9964
--- /dev/null
+++ b/pkg/abi/linux/fadvise.go
@@ -0,0 +1,24 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+const (
+ POSIX_FADV_NORMAL = 0
+ POSIX_FADV_RANDOM = 1
+ POSIX_FADV_SEQUENTIAL = 2
+ POSIX_FADV_WILLNEED = 3
+ POSIX_FADV_DONTNEED = 4
+ POSIX_FADV_NOREUSE = 5
+)
diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go
index 6663a199c..9242e80a5 100644
--- a/pkg/abi/linux/fcntl.go
+++ b/pkg/abi/linux/fcntl.go
@@ -55,7 +55,7 @@ type Flock struct {
_ [4]byte
}
-// Flags for F_SETOWN_EX and F_GETOWN_EX.
+// Owner types for F_SETOWN_EX and F_GETOWN_EX.
const (
F_OWNER_TID = 0
F_OWNER_PID = 1
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index 055ac1d7c..e11ca2d62 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -191,8 +191,9 @@ var DirentType = abi.ValueSet{
// Values for preadv2/pwritev2.
const (
- // Note: gVisor does not implement the RWF_HIPRI feature, but the flag is
- // accepted as a valid flag argument for preadv2/pwritev2.
+ // NOTE(b/120162627): gVisor does not implement the RWF_HIPRI feature, but
+ // the flag is accepted as a valid flag argument for preadv2/pwritev2 and
+ // silently ignored.
RWF_HIPRI = 0x00000001
RWF_DSYNC = 0x00000002
RWF_SYNC = 0x00000004
diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go
new file mode 100644
index 000000000..5c6ffe4a3
--- /dev/null
+++ b/pkg/abi/linux/fuse.go
@@ -0,0 +1,248 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// +marshal
+type FUSEOpcode uint32
+
+// +marshal
+type FUSEOpID uint64
+
+// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h.
+const (
+ FUSE_LOOKUP FUSEOpcode = 1
+ FUSE_FORGET = 2 /* no reply */
+ FUSE_GETATTR = 3
+ FUSE_SETATTR = 4
+ FUSE_READLINK = 5
+ FUSE_SYMLINK = 6
+ _
+ FUSE_MKNOD = 8
+ FUSE_MKDIR = 9
+ FUSE_UNLINK = 10
+ FUSE_RMDIR = 11
+ FUSE_RENAME = 12
+ FUSE_LINK = 13
+ FUSE_OPEN = 14
+ FUSE_READ = 15
+ FUSE_WRITE = 16
+ FUSE_STATFS = 17
+ FUSE_RELEASE = 18
+ _
+ FUSE_FSYNC = 20
+ FUSE_SETXATTR = 21
+ FUSE_GETXATTR = 22
+ FUSE_LISTXATTR = 23
+ FUSE_REMOVEXATTR = 24
+ FUSE_FLUSH = 25
+ FUSE_INIT = 26
+ FUSE_OPENDIR = 27
+ FUSE_READDIR = 28
+ FUSE_RELEASEDIR = 29
+ FUSE_FSYNCDIR = 30
+ FUSE_GETLK = 31
+ FUSE_SETLK = 32
+ FUSE_SETLKW = 33
+ FUSE_ACCESS = 34
+ FUSE_CREATE = 35
+ FUSE_INTERRUPT = 36
+ FUSE_BMAP = 37
+ FUSE_DESTROY = 38
+ FUSE_IOCTL = 39
+ FUSE_POLL = 40
+ FUSE_NOTIFY_REPLY = 41
+ FUSE_BATCH_FORGET = 42
+)
+
+const (
+ // FUSE_MIN_READ_BUFFER is the minimum size the read can be for any FUSE filesystem.
+ // This is the minimum size Linux supports. See linux.fuse.h.
+ FUSE_MIN_READ_BUFFER uint32 = 8192
+)
+
+// FUSEHeaderIn is the header read by the daemon with each request.
+//
+// +marshal
+type FUSEHeaderIn struct {
+ // Len specifies the total length of the data, including this header.
+ Len uint32
+
+ // Opcode specifies the kind of operation of the request.
+ Opcode FUSEOpcode
+
+ // Unique specifies the unique identifier for this request.
+ Unique FUSEOpID
+
+ // NodeID is the ID of the filesystem object being operated on.
+ NodeID uint64
+
+ // UID is the UID of the requesting process.
+ UID uint32
+
+ // GID is the GID of the requesting process.
+ GID uint32
+
+ // PID is the PID of the requesting process.
+ PID uint32
+
+ _ uint32
+}
+
+// FUSEHeaderOut is the header written by the daemon when it processes
+// a request and wants to send a reply (almost all operations require a
+// reply; if they do not, this will be explicitly documented).
+//
+// +marshal
+type FUSEHeaderOut struct {
+ // Len specifies the total length of the data, including this header.
+ Len uint32
+
+ // Error specifies the error that occurred (0 if none).
+ Error int32
+
+ // Unique specifies the unique identifier of the corresponding request.
+ Unique FUSEOpID
+}
+
+// FUSEWriteIn is the header written by a daemon when it makes a
+// write request to the FUSE filesystem.
+//
+// +marshal
+type FUSEWriteIn struct {
+ // Fh specifies the file handle that is being written to.
+ Fh uint64
+
+ // Offset is the offset of the write.
+ Offset uint64
+
+ // Size is the size of data being written.
+ Size uint32
+
+ // WriteFlags is the flags used during the write.
+ WriteFlags uint32
+
+ // LockOwner is the ID of the lock owner.
+ LockOwner uint64
+
+ // Flags is the flags for the request.
+ Flags uint32
+
+ _ uint32
+}
+
+// FUSE_INIT flags, consistent with the ones in include/uapi/linux/fuse.h.
+const (
+ FUSE_ASYNC_READ = 1 << 0
+ FUSE_POSIX_LOCKS = 1 << 1
+ FUSE_FILE_OPS = 1 << 2
+ FUSE_ATOMIC_O_TRUNC = 1 << 3
+ FUSE_EXPORT_SUPPORT = 1 << 4
+ FUSE_BIG_WRITES = 1 << 5
+ FUSE_DONT_MASK = 1 << 6
+ FUSE_SPLICE_WRITE = 1 << 7
+ FUSE_SPLICE_MOVE = 1 << 8
+ FUSE_SPLICE_READ = 1 << 9
+ FUSE_FLOCK_LOCKS = 1 << 10
+ FUSE_HAS_IOCTL_DIR = 1 << 11
+ FUSE_AUTO_INVAL_DATA = 1 << 12
+ FUSE_DO_READDIRPLUS = 1 << 13
+ FUSE_READDIRPLUS_AUTO = 1 << 14
+ FUSE_ASYNC_DIO = 1 << 15
+ FUSE_WRITEBACK_CACHE = 1 << 16
+ FUSE_NO_OPEN_SUPPORT = 1 << 17
+ FUSE_PARALLEL_DIROPS = 1 << 18
+ FUSE_HANDLE_KILLPRIV = 1 << 19
+ FUSE_POSIX_ACL = 1 << 20
+ FUSE_ABORT_ERROR = 1 << 21
+ FUSE_MAX_PAGES = 1 << 22
+ FUSE_CACHE_SYMLINKS = 1 << 23
+ FUSE_NO_OPENDIR_SUPPORT = 1 << 24
+ FUSE_EXPLICIT_INVAL_DATA = 1 << 25
+ FUSE_MAP_ALIGNMENT = 1 << 26
+)
+
+// currently supported FUSE protocol version numbers.
+const (
+ FUSE_KERNEL_VERSION = 7
+ FUSE_KERNEL_MINOR_VERSION = 31
+)
+
+// FUSEInitIn is the request sent by the kernel to the daemon,
+// to negotiate the version and flags.
+//
+// +marshal
+type FUSEInitIn struct {
+ // Major version supported by kernel.
+ Major uint32
+
+ // Minor version supported by the kernel.
+ Minor uint32
+
+ // MaxReadahead is the maximum number of bytes to read-ahead
+ // decided by the kernel.
+ MaxReadahead uint32
+
+ // Flags of this init request.
+ Flags uint32
+}
+
+// FUSEInitOut is the reply sent by the daemon to the kernel
+// for FUSEInitIn.
+//
+// +marshal
+type FUSEInitOut struct {
+ // Major version supported by daemon.
+ Major uint32
+
+ // Minor version supported by daemon.
+ Minor uint32
+
+ // MaxReadahead is the maximum number of bytes to read-ahead.
+ // Decided by the daemon, after receiving the value from kernel.
+ MaxReadahead uint32
+
+ // Flags of this init reply.
+ Flags uint32
+
+ // MaxBackground is the maximum number of pending background requests
+ // that the daemon wants.
+ MaxBackground uint16
+
+ // CongestionThreshold is the daemon-decided threshold for
+ // the number of the pending background requests.
+ CongestionThreshold uint16
+
+ // MaxWrite is the daemon's maximum size of a write buffer.
+ // Kernel adjusts it to the minimum (fuse/init.go:fuseMinMaxWrite).
+ // if the value from daemon is too small.
+ MaxWrite uint32
+
+ // TimeGran is the daemon's time granularity for mtime and ctime metadata.
+ // The unit is nanosecond.
+ // Value should be power of 10.
+ // 1 indicates full nanosecond granularity support.
+ TimeGran uint32
+
+ // MaxPages is the daemon's maximum number of pages for one write operation.
+ // Kernel adjusts it to the maximum (fuse/init.go:FUSE_MAX_MAX_PAGES).
+ // if the value from daemon is too large.
+ MaxPages uint16
+
+ // MapAlignment is an unknown field and not used by this package at this moment.
+ // Use as a placeholder to be consistent with the FUSE protocol.
+ MapAlignment uint16
+
+ _ [8]uint32
+}
diff --git a/pkg/abi/linux/futex.go b/pkg/abi/linux/futex.go
index 08bfde3b5..8138088a6 100644
--- a/pkg/abi/linux/futex.go
+++ b/pkg/abi/linux/futex.go
@@ -60,3 +60,21 @@ const (
FUTEX_WAITERS = 0x80000000
FUTEX_OWNER_DIED = 0x40000000
)
+
+// FUTEX_BITSET_MATCH_ANY has all bits set.
+const FUTEX_BITSET_MATCH_ANY = 0xffffffff
+
+// ROBUST_LIST_LIMIT protects against a deliberately circular list.
+const ROBUST_LIST_LIMIT = 2048
+
+// RobustListHead corresponds to Linux's struct robust_list_head.
+//
+// +marshal
+type RobustListHead struct {
+ List uint64
+ FutexOffset uint64
+ ListOpPending uint64
+}
+
+// SizeOfRobustListHead is the size of a RobustListHead struct.
+var SizeOfRobustListHead = (*RobustListHead)(nil).SizeBytes()
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 2062e6a4b..2c5e56ae5 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -67,10 +67,29 @@ const (
// ioctl(2) requests provided by uapi/linux/sockios.h
const (
- SIOCGIFMEM = 0x891f
- SIOCGIFPFLAGS = 0x8935
- SIOCGMIIPHY = 0x8947
- SIOCGMIIREG = 0x8948
+ SIOCGIFNAME = 0x8910
+ SIOCGIFCONF = 0x8912
+ SIOCGIFFLAGS = 0x8913
+ SIOCGIFADDR = 0x8915
+ SIOCGIFDSTADDR = 0x8917
+ SIOCGIFBRDADDR = 0x8919
+ SIOCGIFNETMASK = 0x891b
+ SIOCGIFMETRIC = 0x891d
+ SIOCGIFMTU = 0x8921
+ SIOCGIFMEM = 0x891f
+ SIOCGIFHWADDR = 0x8927
+ SIOCGIFINDEX = 0x8933
+ SIOCGIFPFLAGS = 0x8935
+ SIOCGIFTXQLEN = 0x8942
+ SIOCETHTOOL = 0x8946
+ SIOCGMIIPHY = 0x8947
+ SIOCGMIIREG = 0x8948
+ SIOCGIFMAP = 0x8970
+)
+
+// ioctl(2) requests provided by uapi/asm-generic/sockios.h
+const (
+ SIOCGSTAMP = 0x8906
)
// ioctl(2) directions. Used to calculate requests number.
diff --git a/pkg/abi/linux/ip.go b/pkg/abi/linux/ip.go
index 31e56ffa6..ef6d1093e 100644
--- a/pkg/abi/linux/ip.go
+++ b/pkg/abi/linux/ip.go
@@ -92,6 +92,16 @@ const (
IP_UNICAST_IF = 50
)
+// IP_MTU_DISCOVER values from uapi/linux/in.h
+const (
+ IP_PMTUDISC_DONT = 0
+ IP_PMTUDISC_WANT = 1
+ IP_PMTUDISC_DO = 2
+ IP_PMTUDISC_PROBE = 3
+ IP_PMTUDISC_INTERFACE = 4
+ IP_PMTUDISC_OMIT = 5
+)
+
// Socket options from uapi/linux/in6.h
const (
IPV6_ADDRFORM = 1
diff --git a/pkg/abi/linux/netdevice.go b/pkg/abi/linux/netdevice.go
index 7866352b4..0faf015c7 100644
--- a/pkg/abi/linux/netdevice.go
+++ b/pkg/abi/linux/netdevice.go
@@ -22,6 +22,8 @@ const (
)
// IFReq is an interface request.
+//
+// +marshal
type IFReq struct {
// IFName is an encoded name, normally null-terminated. This should be
// accessed via the Name and SetName functions.
@@ -79,6 +81,8 @@ type IFMap struct {
// IFConf is used to return a list of interfaces and their addresses. See
// netdevice(7) and struct ifconf for more detail on its use.
+//
+// +marshal
type IFConf struct {
Len int32
_ [4]byte // Pad to sizeof(struct ifconf).
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 46d8b0b42..9c27f7bb2 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -14,6 +14,14 @@
package linux
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
+)
+
// This file contains structures required to support netfilter, specifically
// the iptables tool.
@@ -51,7 +59,7 @@ var VerdictStrings = map[int32]string{
NF_RETURN: "RETURN",
}
-// Socket options. These correspond to values in
+// Socket options for SOL_SOCKET. These correspond to values in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
const (
IPT_BASE_CTL = 64
@@ -66,6 +74,12 @@ const (
IPT_SO_GET_MAX = IPT_SO_GET_REVISION_TARGET
)
+// Socket option for SOL_IP. This corresponds to the value in
+// include/uapi/linux/netfilter_ipv4.h.
+const (
+ SO_ORIGINAL_DST = 80
+)
+
// Name lengths. These correspond to values in
// include/uapi/linux/netfilter/x_tables.h.
const (
@@ -76,6 +90,8 @@ const (
// IPTEntry is an iptable rule. It corresponds to struct ipt_entry in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTEntry struct {
// IP is used to filter packets based on the IP header.
IP IPTIP
@@ -112,21 +128,41 @@ type IPTEntry struct {
// SizeOfIPTEntry is the size of an IPTEntry.
const SizeOfIPTEntry = 112
-// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. This
-// struct marshaled via the binary package to write an IPTEntry to userspace.
+// KernelIPTEntry is identical to IPTEntry, but includes the Elems field.
+// KernelIPTEntry itself is not Marshallable but it implements some methods of
+// marshal.Marshallable that help in other implementations of Marshallable.
type KernelIPTEntry struct {
- IPTEntry
+ Entry IPTEntry
// Elems holds the data for all this rule's matches followed by the
// target. It is variable length -- users have to iterate over any
// matches and use TargetOffset and NextOffset to make sense of the
// data.
- Elems []byte
+ Elems primitive.ByteSlice
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (ke *KernelIPTEntry) SizeBytes() int {
+ return ke.Entry.SizeBytes() + ke.Elems.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (ke *KernelIPTEntry) MarshalBytes(dst []byte) {
+ ke.Entry.MarshalBytes(dst)
+ ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():])
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) {
+ ke.Entry.UnmarshalBytes(src)
+ ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():])
}
// IPTIP contains information for matching a packet's IP header.
// It corresponds to struct ipt_ip in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTIP struct {
// Src is the source IP address.
Src InetAddr
@@ -189,6 +225,8 @@ const SizeOfIPTIP = 84
// XTCounters holds packet and byte counts for a rule. It corresponds to struct
// xt_counters in include/uapi/linux/netfilter/x_tables.h.
+//
+// +marshal
type XTCounters struct {
// Pcnt is the packet count.
Pcnt uint64
@@ -321,6 +359,8 @@ const SizeOfXTRedirectTarget = 56
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTGetinfo struct {
Name TableName
ValidHooks uint32
@@ -336,6 +376,8 @@ const SizeOfIPTGetinfo = 84
// IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It
// corresponds to struct ipt_get_entries in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTGetEntries struct {
Name TableName
Size uint32
@@ -350,13 +392,103 @@ type IPTGetEntries struct {
const SizeOfIPTGetEntries = 40
// KernelIPTGetEntries is identical to IPTGetEntries, but includes the
-// Entrytable field. This struct marshaled via the binary package to write an
-// KernelIPTGetEntries to userspace.
+// Entrytable field. This has been manually made marshal.Marshallable since it
+// is dynamically sized.
type KernelIPTGetEntries struct {
IPTGetEntries
Entrytable []KernelIPTEntry
}
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (ke *KernelIPTGetEntries) SizeBytes() int {
+ res := ke.IPTGetEntries.SizeBytes()
+ for _, entry := range ke.Entrytable {
+ res += entry.SizeBytes()
+ }
+ return res
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) {
+ ke.IPTGetEntries.MarshalBytes(dst)
+ marshalledUntil := ke.IPTGetEntries.SizeBytes()
+ for i := 0; i < len(ke.Entrytable); i++ {
+ ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:])
+ marshalledUntil += ke.Entrytable[i].SizeBytes()
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) {
+ ke.IPTGetEntries.UnmarshalBytes(src)
+ unmarshalledUntil := ke.IPTGetEntries.SizeBytes()
+ for i := 0; i < len(ke.Entrytable); i++ {
+ ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:])
+ unmarshalledUntil += ke.Entrytable[i].SizeBytes()
+ }
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (ke *KernelIPTGetEntries) Packed() bool {
+ // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an
+ // indirection to the actual data we want to marshal (the slice data
+ // pointer), and the memory for KernelIPTGetEntries contains the slice
+ // header which we don't want to marshal.
+ return false
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) {
+ // Fall back to safe Marshal because the type in not packed.
+ ke.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) {
+ // Fall back to safe Unmarshal because the type in not packed.
+ ke.UnmarshalBytes(src)
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
+ buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
+ length, err := task.CopyInBytes(addr, buf) // escapes: okay.
+ // Unmarshal unconditionally. If we had a short copy-in, this results in a
+ // partially unmarshalled struct.
+ ke.UnmarshalBytes(buf) // escapes: fallback.
+ return length, err
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+ // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
+ // back to MarshalBytes.
+ return task.CopyOutBytes(addr, ke.marshalAll(task))
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+ // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
+ // back to MarshalBytes.
+ return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit])
+}
+
+func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte {
+ buf := task.CopyScratchBuffer(ke.SizeBytes())
+ ke.MarshalBytes(buf)
+ return buf
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) {
+ buf := make([]byte, ke.SizeBytes())
+ ke.MarshalBytes(buf)
+ length, err := w.Write(buf)
+ return int64(length), err
+}
+
+var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil)
+
// IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It
// corresponds to struct ipt_replace in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
@@ -374,12 +506,6 @@ type IPTReplace struct {
// Entries [0]IPTEntry
}
-// KernelIPTReplace is identical to IPTReplace, but includes the Entries field.
-type KernelIPTReplace struct {
- IPTReplace
- Entries [0]IPTEntry
-}
-
// SizeOfIPTReplace is the size of an IPTReplace.
const SizeOfIPTReplace = 96
@@ -392,6 +518,8 @@ func (en ExtensionName) String() string {
}
// TableName holds the name of a netfilter table.
+//
+// +marshal
type TableName [XT_TABLE_MAXNAMELEN]byte
// String implements fmt.Stringer.
diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go
index 40bec566c..ceda0a8d3 100644
--- a/pkg/abi/linux/netlink_route.go
+++ b/pkg/abi/linux/netlink_route.go
@@ -187,6 +187,8 @@ const (
// Device types, from uapi/linux/if_arp.h.
const (
+ ARPHRD_NONE = 65534
+ ARPHRD_ETHER = 1
ARPHRD_LOOPBACK = 772
)
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 4a14ef691..d6946bb82 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -134,6 +134,15 @@ const (
SHUT_RDWR = 2
)
+// Packet types from <linux/if_packet.h>
+const (
+ PACKET_HOST = 0 // To us
+ PACKET_BROADCAST = 1 // To all
+ PACKET_MULTICAST = 2 // To group
+ PACKET_OTHERHOST = 3 // To someone else
+ PACKET_OUTGOING = 4 // Outgoing of any type
+)
+
// Socket options from socket.h.
const (
SO_DEBUG = 1
@@ -225,14 +234,18 @@ const (
const SockAddrMax = 128
// InetAddr is struct in_addr, from uapi/linux/in.h.
+//
+// +marshal
type InetAddr [4]byte
// SockAddrInet is struct sockaddr_in, from uapi/linux/in.h.
+//
+// +marshal
type SockAddrInet struct {
Family uint16
Port uint16
Addr InetAddr
- Zero [8]uint8 // pad to sizeof(struct sockaddr).
+ _ [8]uint8 // pad to sizeof(struct sockaddr).
}
// InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h.
@@ -294,6 +307,8 @@ func (s *SockAddrUnix) implementsSockAddr() {}
func (s *SockAddrNetlink) implementsSockAddr() {}
// Linger is struct linger, from include/linux/socket.h.
+//
+// +marshal
type Linger struct {
OnOff int32
Linger int32
@@ -308,6 +323,8 @@ const SizeOfLinger = 8
// the end of this struct or within existing unusued space, so its size grows
// over time. The current iteration is based on linux v4.17. New versions are
// always backwards compatible.
+//
+// +marshal
type TCPInfo struct {
State uint8
CaState uint8
@@ -405,6 +422,8 @@ var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{}))
// A ControlMessageCredentials is an SCM_CREDENTIALS socket control message.
//
// ControlMessageCredentials represents struct ucred from linux/socket.h.
+//
+// +marshal
type ControlMessageCredentials struct {
PID int32
UID uint32
diff --git a/pkg/abi/linux/tcp.go b/pkg/abi/linux/tcp.go
index 174d470e2..2a8d4708b 100644
--- a/pkg/abi/linux/tcp.go
+++ b/pkg/abi/linux/tcp.go
@@ -57,4 +57,5 @@ const (
const (
MAX_TCP_KEEPIDLE = 32767
MAX_TCP_KEEPINTVL = 32767
+ MAX_TCP_KEEPCNT = 127
)
diff --git a/pkg/bpf/interpreter_test.go b/pkg/bpf/interpreter_test.go
index 547921d0a..c85d786b9 100644
--- a/pkg/bpf/interpreter_test.go
+++ b/pkg/bpf/interpreter_test.go
@@ -767,7 +767,7 @@ func TestSimpleFilter(t *testing.T) {
expectedRet: 0,
},
{
- desc: "Whitelisted syscall is allowed",
+ desc: "Allowed syscall is indeed allowed",
seccompData: seccompData{nr: 231 /* __NR_exit_group */, arch: 0xc000003e},
expectedRet: 0x7fff0000,
},
diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go
index 0e5b86344..b789e56e9 100644
--- a/pkg/buffer/safemem.go
+++ b/pkg/buffer/safemem.go
@@ -28,12 +28,11 @@ func (b *buffer) ReadBlock() safemem.Block {
return safemem.BlockFromSafeSlice(b.ReadSlice())
}
-// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
-//
-// This will advance the write index.
-func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
- need := int(srcs.NumBytes())
- if need == 0 {
+// WriteFromSafememReader writes up to count bytes from r to v and advances the
+// write index by the number of bytes written. It calls r.ReadToBlocks() at
+// most once.
+func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, error) {
+ if count == 0 {
return 0, nil
}
@@ -50,32 +49,33 @@ func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
}
// Does the last block have sufficient capacity alone?
- if l := firstBuf.WriteSize(); l >= need {
- dst = safemem.BlockSeqOf(firstBuf.WriteBlock())
+ if l := uint64(firstBuf.WriteSize()); l >= count {
+ dst = safemem.BlockSeqOf(firstBuf.WriteBlock().TakeFirst64(count))
} else {
// Append blocks until sufficient.
- need -= l
+ count -= l
blocks = append(blocks, firstBuf.WriteBlock())
- for need > 0 {
+ for count > 0 {
emptyBuf := bufferPool.Get().(*buffer)
v.data.PushBack(emptyBuf)
- need -= emptyBuf.WriteSize()
- blocks = append(blocks, emptyBuf.WriteBlock())
+ block := emptyBuf.WriteBlock().TakeFirst64(count)
+ count -= uint64(block.Len())
+ blocks = append(blocks, block)
}
dst = safemem.BlockSeqFromSlice(blocks)
}
- // Perform the copy.
- n, err := safemem.CopySeq(dst, srcs)
+ // Perform I/O.
+ n, err := r.ReadToBlocks(dst)
v.size += int64(n)
// Update all indices.
- for left := int(n); left > 0; firstBuf = firstBuf.Next() {
- if l := firstBuf.WriteSize(); left >= l {
+ for left := n; left > 0; firstBuf = firstBuf.Next() {
+ if l := firstBuf.WriteSize(); left >= uint64(l) {
firstBuf.WriteMove(l) // Whole block.
- left -= l
+ left -= uint64(l)
} else {
- firstBuf.WriteMove(left) // Partial block.
+ firstBuf.WriteMove(int(left)) // Partial block.
left = 0
}
}
@@ -83,14 +83,16 @@ func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
return n, err
}
-// ReadToBlocks implements safemem.Reader.ReadToBlocks.
-//
-// This will not advance the read index; the caller should follow
-// this call with a call to TrimFront in order to remove the read
-// data from the buffer. This is done to support pipe sematics.
-func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
- need := int(dsts.NumBytes())
- if need == 0 {
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. It advances the
+// write index by the number of bytes written.
+func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ return v.WriteFromSafememReader(&safemem.BlockSeqReader{srcs}, srcs.NumBytes())
+}
+
+// ReadToSafememWriter reads up to count bytes from v to w. It does not advance
+// the read index. It calls w.WriteFromBlocks() at most once.
+func (v *View) ReadToSafememWriter(w safemem.Writer, count uint64) (uint64, error) {
+ if count == 0 {
return 0, nil
}
@@ -105,25 +107,27 @@ func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
}
// Is all the data in a single block?
- if l := firstBuf.ReadSize(); l >= need {
- src = safemem.BlockSeqOf(firstBuf.ReadBlock())
+ if l := uint64(firstBuf.ReadSize()); l >= count {
+ src = safemem.BlockSeqOf(firstBuf.ReadBlock().TakeFirst64(count))
} else {
// Build a list of all the buffers.
- need -= l
+ count -= l
blocks = append(blocks, firstBuf.ReadBlock())
- for buf := firstBuf.Next(); buf != nil && need > 0; buf = buf.Next() {
- need -= buf.ReadSize()
- blocks = append(blocks, buf.ReadBlock())
+ for buf := firstBuf.Next(); buf != nil && count > 0; buf = buf.Next() {
+ block := buf.ReadBlock().TakeFirst64(count)
+ count -= uint64(block.Len())
+ blocks = append(blocks, block)
}
src = safemem.BlockSeqFromSlice(blocks)
}
- // Perform the copy.
- n, err := safemem.CopySeq(dsts, src)
-
- // See above: we would normally advance the read index here, but we
- // don't do that in order to support pipe semantics. We rely on a
- // separate call to TrimFront() in this case.
+ // Perform I/O. As documented, we don't advance the read index.
+ return w.WriteFromBlocks(src)
+}
- return n, err
+// ReadToBlocks implements safemem.Reader.ReadToBlocks. It does not advance the
+// read index by the number of bytes read, such that it's only safe to call if
+// the caller guarantees that ReadToBlocks will only be called once.
+func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ return v.ReadToSafememWriter(&safemem.BlockSeqWriter{dsts}, dsts.NumBytes())
}
diff --git a/pkg/cleanup/BUILD b/pkg/cleanup/BUILD
new file mode 100644
index 000000000..5c34b9872
--- /dev/null
+++ b/pkg/cleanup/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "cleanup",
+ srcs = ["cleanup.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ ],
+)
+
+go_test(
+ name = "cleanup_test",
+ srcs = ["cleanup_test.go"],
+ library = ":cleanup",
+)
diff --git a/pkg/cleanup/cleanup.go b/pkg/cleanup/cleanup.go
new file mode 100644
index 000000000..14a05f076
--- /dev/null
+++ b/pkg/cleanup/cleanup.go
@@ -0,0 +1,60 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package cleanup provides utilities to clean "stuff" on defers.
+package cleanup
+
+// Cleanup allows defers to be aborted when cleanup needs to happen
+// conditionally. Usage:
+// cu := cleanup.Make(func() { f.Close() })
+// defer cu.Clean() // failure before release is called will close the file.
+// ...
+// cu.Add(func() { f2.Close() }) // Adds another cleanup function
+// ...
+// cu.Release() // on success, aborts closing the file.
+// return f
+type Cleanup struct {
+ cleaners []func()
+}
+
+// Make creates a new Cleanup object.
+func Make(f func()) Cleanup {
+ return Cleanup{cleaners: []func(){f}}
+}
+
+// Add adds a new function to be called on Clean().
+func (c *Cleanup) Add(f func()) {
+ c.cleaners = append(c.cleaners, f)
+}
+
+// Clean calls all cleanup functions in reverse order.
+func (c *Cleanup) Clean() {
+ clean(c.cleaners)
+ c.cleaners = nil
+}
+
+// Release releases the cleanup from its duties, i.e. cleanup functions are not
+// called after this point. Returns a function that calls all registered
+// functions in case the caller has use for them.
+func (c *Cleanup) Release() func() {
+ old := c.cleaners
+ c.cleaners = nil
+ return func() { clean(old) }
+}
+
+func clean(cleaners []func()) {
+ for i := len(cleaners) - 1; i >= 0; i-- {
+ cleaners[i]()
+ }
+}
diff --git a/pkg/cleanup/cleanup_test.go b/pkg/cleanup/cleanup_test.go
new file mode 100644
index 000000000..ab3d9ed95
--- /dev/null
+++ b/pkg/cleanup/cleanup_test.go
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cleanup
+
+import "testing"
+
+func testCleanupHelper(clean, cleanAdd *bool, release bool) func() {
+ cu := Make(func() {
+ *clean = true
+ })
+ cu.Add(func() {
+ *cleanAdd = true
+ })
+ defer cu.Clean()
+ if release {
+ return cu.Release()
+ }
+ return nil
+}
+
+func TestCleanup(t *testing.T) {
+ clean := false
+ cleanAdd := false
+ testCleanupHelper(&clean, &cleanAdd, false)
+ if !clean {
+ t.Fatalf("cleanup function was not called.")
+ }
+ if !cleanAdd {
+ t.Fatalf("added cleanup function was not called.")
+ }
+}
+
+func TestRelease(t *testing.T) {
+ clean := false
+ cleanAdd := false
+ cleaner := testCleanupHelper(&clean, &cleanAdd, true)
+
+ // Check that clean was not called after release.
+ if clean {
+ t.Fatalf("cleanup function was called.")
+ }
+ if cleanAdd {
+ t.Fatalf("added cleanup function was called.")
+ }
+
+ // Call the cleaner function and check that both cleanup functions are called.
+ cleaner()
+ if !clean {
+ t.Fatalf("cleanup function was not called.")
+ }
+ if !cleanAdd {
+ t.Fatalf("added cleanup function was not called.")
+ }
+}
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go
index 5f52cbe74..b094c5662 100644
--- a/pkg/compressio/compressio.go
+++ b/pkg/compressio/compressio.go
@@ -346,20 +346,22 @@ func (p *pool) schedule(c *chunk, callback func(*chunk) error) error {
}
}
-// reader chunks reads and decompresses.
-type reader struct {
+// Reader is a compressed reader.
+type Reader struct {
pool
// in is the source.
in io.Reader
}
+var _ io.Reader = (*Reader)(nil)
+
// NewReader returns a new compressed reader. If key is non-nil, the data stream
// is assumed to contain expected hash values, which will be compared against
// hash values computed from the compressed bytes. See package comments for
// details.
-func NewReader(in io.Reader, key []byte) (io.Reader, error) {
- r := &reader{
+func NewReader(in io.Reader, key []byte) (*Reader, error) {
+ r := &Reader{
in: in,
}
@@ -394,8 +396,19 @@ var errNewBuffer = errors.New("buffer ready")
// ErrHashMismatch is returned if the hash does not match.
var ErrHashMismatch = errors.New("hash mismatch")
+// ReadByte implements wire.Reader.ReadByte.
+func (r *Reader) ReadByte() (byte, error) {
+ var p [1]byte
+ n, err := r.Read(p[:])
+ if n != 1 {
+ return p[0], err
+ }
+ // Suppress EOF.
+ return p[0], nil
+}
+
// Read implements io.Reader.Read.
-func (r *reader) Read(p []byte) (int, error) {
+func (r *Reader) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
@@ -551,8 +564,8 @@ func (r *reader) Read(p []byte) (int, error) {
return done, nil
}
-// writer chunks and schedules writes.
-type writer struct {
+// Writer is a compressed writer.
+type Writer struct {
pool
// out is the underlying writer.
@@ -562,6 +575,8 @@ type writer struct {
closed bool
}
+var _ io.Writer = (*Writer)(nil)
+
// NewWriter returns a new compressed writer. If key is non-nil, hash values are
// generated and written out for compressed bytes. See package comments for
// details.
@@ -569,8 +584,8 @@ type writer struct {
// The recommended chunkSize is on the order of 1M. Extra memory may be
// buffered (in the form of read-ahead, or buffered writes), and is limited to
// O(chunkSize * [1+GOMAXPROCS]).
-func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.WriteCloser, error) {
- w := &writer{
+func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, error) {
+ w := &Writer{
pool: pool{
chunkSize: chunkSize,
buf: bufPool.Get().(*bytes.Buffer),
@@ -597,7 +612,7 @@ func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.Write
}
// flush writes a single buffer.
-func (w *writer) flush(c *chunk) error {
+func (w *Writer) flush(c *chunk) error {
// Prefix each chunk with a length; this allows the reader to safely
// limit reads while buffering.
l := uint32(c.compressed.Len())
@@ -624,8 +639,23 @@ func (w *writer) flush(c *chunk) error {
return nil
}
+// WriteByte implements wire.Writer.WriteByte.
+//
+// Note that this implementation is necessary on the object itself, as an
+// interface-based dispatch cannot tell whether the array backing the slice
+// escapes, therefore the all bytes written will generate an escape.
+func (w *Writer) WriteByte(b byte) error {
+ var p [1]byte
+ p[0] = b
+ n, err := w.Write(p[:])
+ if n != 1 {
+ return err
+ }
+ return nil
+}
+
// Write implements io.Writer.Write.
-func (w *writer) Write(p []byte) (int, error) {
+func (w *Writer) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
@@ -710,7 +740,7 @@ func (w *writer) Write(p []byte) (int, error) {
}
// Close implements io.Closer.Close.
-func (w *writer) Close() error {
+func (w *Writer) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
diff --git a/pkg/cpuid/cpuid_arm64.go b/pkg/cpuid/cpuid_arm64.go
index 08381c1c0..ac7bb6774 100644
--- a/pkg/cpuid/cpuid_arm64.go
+++ b/pkg/cpuid/cpuid_arm64.go
@@ -312,8 +312,9 @@ func HostFeatureSet() *FeatureSet {
}
}
-// Reads bogomips from host /proc/cpuinfo. Must run before whitelisting.
-// This value is used to create the fake /proc/cpuinfo from a FeatureSet.
+// Reads bogomips from host /proc/cpuinfo. Must run before syscall filter
+// installation. This value is used to create the fake /proc/cpuinfo from a
+// FeatureSet.
func initCPUInfo() {
cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo")
if err != nil {
diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go
index 562f8f405..17a89c00d 100644
--- a/pkg/cpuid/cpuid_x86.go
+++ b/pkg/cpuid/cpuid_x86.go
@@ -1057,9 +1057,9 @@ func HostFeatureSet() *FeatureSet {
}
}
-// Reads max cpu frequency from host /proc/cpuinfo. Must run before
-// whitelisting. This value is used to create the fake /proc/cpuinfo from a
-// FeatureSet.
+// Reads max cpu frequency from host /proc/cpuinfo. Must run before syscall
+// filter installation. This value is used to create the fake /proc/cpuinfo
+// from a FeatureSet.
func initCPUFreq() {
cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo")
if err != nil {
@@ -1106,7 +1106,6 @@ func initFeaturesFromString() {
}
func init() {
- // initCpuFreq must be run before whitelists are enabled.
initCPUFreq()
initFeaturesFromString()
}
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index 9c5ad500b..aa8e4e1f3 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -11,6 +11,7 @@ go_library(
"futex_linux.go",
"io.go",
"packet_window_allocator.go",
+ "packet_window_mmap.go",
],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go
index 3cdb576e1..ec742c091 100644
--- a/pkg/flipcall/flipcall.go
+++ b/pkg/flipcall/flipcall.go
@@ -95,7 +95,7 @@ func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ...
if pwd.Length > math.MaxUint32 {
return fmt.Errorf("packet window size (%d) exceeds maximum (%d)", pwd.Length, math.MaxUint32)
}
- m, _, e := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
+ m, e := packetWindowMmap(pwd)
if e != 0 {
return fmt.Errorf("failed to mmap packet window: %v", e)
}
diff --git a/pkg/flipcall/packet_window_mmap.go b/pkg/flipcall/packet_window_mmap.go
new file mode 100644
index 000000000..869183b11
--- /dev/null
+++ b/pkg/flipcall/packet_window_mmap.go
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flipcall
+
+import (
+ "syscall"
+)
+
+// Return a memory mapping of the pwd in memory that can be shared outside the sandbox.
+func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, syscall.Errno) {
+ m, _, err := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
+ return m, err
+}
diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD
index 798a65eca..35683fe98 100644
--- a/pkg/gohacks/BUILD
+++ b/pkg/gohacks/BUILD
@@ -7,5 +7,6 @@ go_library(
srcs = [
"gohacks_unsafe.go",
],
+ stateify = False,
visibility = ["//:sandbox"],
)
diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go
index 0d07da3b1..f4a4c33d3 100644
--- a/pkg/ilist/list.go
+++ b/pkg/ilist/list.go
@@ -90,7 +90,7 @@ func (l *List) Back() Element {
//
// NOTE: This is an O(n) operation.
func (l *List) Len() (count int) {
- for e := l.Front(); e != nil; e = e.Next() {
+ for e := l.Front(); e != nil; e = (ElementMapper{}.linkerFor(e)).Next() {
count++
}
return count
@@ -182,13 +182,13 @@ func (l *List) Remove(e Element) {
if prev != nil {
ElementMapper{}.linkerFor(prev).SetNext(next)
- } else {
+ } else if l.head == e {
l.head = next
}
if next != nil {
ElementMapper{}.linkerFor(next).SetPrev(prev)
- } else {
+ } else if l.tail == e {
l.tail = prev
}
diff --git a/pkg/iovec/BUILD b/pkg/iovec/BUILD
new file mode 100644
index 000000000..eda82cfc1
--- /dev/null
+++ b/pkg/iovec/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "iovec",
+ srcs = ["iovec.go"],
+ visibility = ["//:sandbox"],
+ deps = ["//pkg/abi/linux"],
+)
+
+go_test(
+ name = "iovec_test",
+ size = "small",
+ srcs = ["iovec_test.go"],
+ library = ":iovec",
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
+)
diff --git a/pkg/iovec/iovec.go b/pkg/iovec/iovec.go
new file mode 100644
index 000000000..dd70fe80f
--- /dev/null
+++ b/pkg/iovec/iovec.go
@@ -0,0 +1,75 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// Package iovec provides helpers to interact with vectorized I/O on host
+// system.
+package iovec
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// MaxIovs is the maximum number of iovecs host platform can accept.
+var MaxIovs = linux.UIO_MAXIOV
+
+// Builder is a builder for slice of syscall.Iovec.
+type Builder struct {
+ iovec []syscall.Iovec
+ storage [8]syscall.Iovec
+
+ // overflow tracks the last buffer when iovec length is at MaxIovs.
+ overflow []byte
+}
+
+// Add adds buf to b preparing to be written. Zero-length buf won't be added.
+func (b *Builder) Add(buf []byte) {
+ if len(buf) == 0 {
+ return
+ }
+ if b.iovec == nil {
+ b.iovec = b.storage[:0]
+ }
+ if len(b.iovec) >= MaxIovs {
+ b.addByAppend(buf)
+ return
+ }
+ b.iovec = append(b.iovec, syscall.Iovec{
+ Base: &buf[0],
+ Len: uint64(len(buf)),
+ })
+ // Keep the last buf if iovec is at max capacity. We will need to append to it
+ // for later bufs.
+ if len(b.iovec) == MaxIovs {
+ n := len(buf)
+ b.overflow = buf[:n:n]
+ }
+}
+
+func (b *Builder) addByAppend(buf []byte) {
+ b.overflow = append(b.overflow, buf...)
+ b.iovec[len(b.iovec)-1] = syscall.Iovec{
+ Base: &b.overflow[0],
+ Len: uint64(len(b.overflow)),
+ }
+}
+
+// Build returns the final Iovec slice. The length of returned iovec will not
+// excceed MaxIovs.
+func (b *Builder) Build() []syscall.Iovec {
+ return b.iovec
+}
diff --git a/pkg/iovec/iovec_test.go b/pkg/iovec/iovec_test.go
new file mode 100644
index 000000000..a3900c299
--- /dev/null
+++ b/pkg/iovec/iovec_test.go
@@ -0,0 +1,121 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package iovec
+
+import (
+ "bytes"
+ "fmt"
+ "syscall"
+ "testing"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func TestBuilderEmpty(t *testing.T) {
+ var builder Builder
+ iovecs := builder.Build()
+ if got, want := len(iovecs), 0; got != want {
+ t.Errorf("len(iovecs) = %d, want %d", got, want)
+ }
+}
+
+func TestBuilderBuild(t *testing.T) {
+ a := []byte{1, 2}
+ b := []byte{3, 4, 5}
+
+ var builder Builder
+ builder.Add(a)
+ builder.Add(b)
+ builder.Add(nil) // Nil slice won't be added.
+ builder.Add([]byte{}) // Empty slice won't be added.
+ iovecs := builder.Build()
+
+ if got, want := len(iovecs), 2; got != want {
+ t.Fatalf("len(iovecs) = %d, want %d", got, want)
+ }
+ for i, data := range [][]byte{a, b} {
+ if got, want := *iovecs[i].Base, data[0]; got != want {
+ t.Fatalf("*iovecs[%d].Base = %d, want %d", i, got, want)
+ }
+ if got, want := iovecs[i].Len, uint64(len(data)); got != want {
+ t.Fatalf("iovecs[%d].Len = %d, want %d", i, got, want)
+ }
+ }
+}
+
+func TestBuilderBuildMaxIov(t *testing.T) {
+ for _, test := range []struct {
+ numIov int
+ }{
+ {
+ numIov: MaxIovs - 1,
+ },
+ {
+ numIov: MaxIovs,
+ },
+ {
+ numIov: MaxIovs + 1,
+ },
+ {
+ numIov: MaxIovs + 10,
+ },
+ } {
+ name := fmt.Sprintf("numIov=%v", test.numIov)
+ t.Run(name, func(t *testing.T) {
+ var data []byte
+ var builder Builder
+ for i := 0; i < test.numIov; i++ {
+ buf := []byte{byte(i)}
+ builder.Add(buf)
+ data = append(data, buf...)
+ }
+ iovec := builder.Build()
+
+ // Check the expected length of iovec.
+ wantNum := test.numIov
+ if wantNum > MaxIovs {
+ wantNum = MaxIovs
+ }
+ if got, want := len(iovec), wantNum; got != want {
+ t.Errorf("len(iovec) = %d, want %d", got, want)
+ }
+
+ // Test a real read-write.
+ var fds [2]int
+ if err := unix.Pipe(fds[:]); err != nil {
+ t.Fatalf("Pipe: %v", err)
+ }
+ defer syscall.Close(fds[0])
+ defer syscall.Close(fds[1])
+
+ wrote, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fds[1]), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec)))
+ if int(wrote) != len(data) || e != 0 {
+ t.Fatalf("writev: %v, %v; want %v, 0", wrote, e, len(data))
+ }
+
+ got := make([]byte, len(data))
+ if n, err := syscall.Read(fds[0], got); n != len(got) || err != nil {
+ t.Fatalf("read: %v, %v; want %v, nil", n, err, len(got))
+ }
+
+ if !bytes.Equal(got, data) {
+ t.Errorf("read: got data %v, want %v", got, data)
+ }
+ })
+ }
+}
diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD
new file mode 100644
index 000000000..5b0e4143a
--- /dev/null
+++ b/pkg/merkletree/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "merkletree",
+ srcs = ["merkletree.go"],
+ deps = ["//pkg/usermem"],
+)
+
+go_test(
+ name = "merkletree_test",
+ srcs = ["merkletree_test.go"],
+ library = ":merkletree",
+ deps = ["//pkg/usermem"],
+)
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
new file mode 100644
index 000000000..906f67943
--- /dev/null
+++ b/pkg/merkletree/merkletree.go
@@ -0,0 +1,135 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package merkletree implements Merkle tree generating and verification.
+package merkletree
+
+import (
+ "crypto/sha256"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // sha256DigestSize specifies the digest size of a SHA256 hash.
+ sha256DigestSize = 32
+)
+
+// Size defines the scale of a Merkle tree.
+type Size struct {
+ // blockSize is the size of a data block to be hashed.
+ blockSize int64
+ // digestSize is the size of a generated hash.
+ digestSize int64
+ // hashesPerBlock is the number of hashes in a block. For example, if
+ // blockSize is 4096 bytes, and digestSize is 32 bytes, there will be 128
+ // hashesPerBlock. Therefore 128 hashes in a lower level will be put into a
+ // block and generate a single hash in an upper level.
+ hashesPerBlock int64
+ // levelStart is the start block index of each level. The number of levels in
+ // the tree is the length of the slice. The leafs (level 0) are hashes of
+ // blocks in the input data. The levels above are hashes of lower level
+ // hashes. The highest level is the root hash.
+ levelStart []int64
+}
+
+// MakeSize initializes and returns a new Size object describing the structure
+// of a tree. dataSize specifies the number of the file system size in bytes.
+func MakeSize(dataSize int64) Size {
+ size := Size{
+ blockSize: usermem.PageSize,
+ // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512).
+ digestSize: sha256DigestSize,
+ hashesPerBlock: usermem.PageSize / sha256DigestSize,
+ }
+ numBlocks := (dataSize + size.blockSize - 1) / size.blockSize
+ level := int64(0)
+ offset := int64(0)
+
+ // Calcuate the number of levels in the Merkle tree and the beginning offset
+ // of each level. Level 0 is the level directly above the data blocks, while
+ // level NumLevels - 1 is the root.
+ for numBlocks > 1 {
+ size.levelStart = append(size.levelStart, offset)
+ // Round numBlocks up to fill up a block.
+ numBlocks += (size.hashesPerBlock - numBlocks%size.hashesPerBlock) % size.hashesPerBlock
+ offset += numBlocks / size.hashesPerBlock
+ numBlocks = numBlocks / size.hashesPerBlock
+ level++
+ }
+ size.levelStart = append(size.levelStart, offset)
+ return size
+}
+
+// Generate constructs a Merkle tree for the contents of data. The output is
+// written to treeWriter. The treeReader should be able to read the tree after
+// it has been written. That is, treeWriter and treeReader should point to the
+// same underlying data but have separate cursors.
+func Generate(data io.Reader, dataSize int64, treeReader io.Reader, treeWriter io.Writer) ([]byte, error) {
+ size := MakeSize(dataSize)
+
+ numBlocks := (dataSize + size.blockSize - 1) / size.blockSize
+
+ var root []byte
+ for level := 0; level < len(size.levelStart); level++ {
+ for i := int64(0); i < numBlocks; i++ {
+ buf := make([]byte, size.blockSize)
+ var (
+ n int
+ err error
+ )
+ if level == 0 {
+ // Read data block from the target file since level 0 is directly above
+ // the raw data block.
+ n, err = data.Read(buf)
+ } else {
+ // Read data block from the tree file since levels higher than 0 are
+ // hashing the lower level hashes.
+ n, err = treeReader.Read(buf)
+ }
+
+ // err is populated as long as the bytes read is smaller than the buffer
+ // size. This could be the case if we are reading the last block, and
+ // break in that case. If this is the last block, the end of the block
+ // will be zero-padded.
+ if n == 0 && err == io.EOF {
+ break
+ } else if err != nil && err != io.EOF {
+ return nil, err
+ }
+ // Hash the bytes in buf.
+ digest := sha256.Sum256(buf)
+
+ if level == len(size.levelStart)-1 {
+ root = digest[:]
+ }
+
+ // Write the generated hash to the end of the tree file.
+ if _, err = treeWriter.Write(digest[:]); err != nil {
+ return nil, err
+ }
+ }
+ // If the genereated digests do not round up to a block, zero-padding the
+ // remaining of the last block. But no need to do so for root.
+ if level != len(size.levelStart)-1 && numBlocks%size.hashesPerBlock != 0 {
+ zeroBuf := make([]byte, size.blockSize-(numBlocks%size.hashesPerBlock)*size.digestSize)
+ if _, err := treeWriter.Write(zeroBuf[:]); err != nil {
+ return nil, err
+ }
+ }
+ numBlocks = (numBlocks + size.hashesPerBlock - 1) / size.hashesPerBlock
+ }
+ return root, nil
+}
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
new file mode 100644
index 000000000..7344db0b6
--- /dev/null
+++ b/pkg/merkletree/merkletree_test.go
@@ -0,0 +1,122 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package merkletree
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func TestSize(t *testing.T) {
+ testCases := []struct {
+ dataSize int64
+ expectedLevelStart []int64
+ }{
+ {
+ dataSize: 100,
+ expectedLevelStart: []int64{0},
+ },
+ {
+ dataSize: 1000000,
+ expectedLevelStart: []int64{0, 2, 3},
+ },
+ {
+ dataSize: 4096 * int64(usermem.PageSize),
+ expectedLevelStart: []int64{0, 32, 33},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) {
+ s := MakeSize(tc.dataSize)
+ if s.blockSize != int64(usermem.PageSize) {
+ t.Errorf("got blockSize %d, want %d", s.blockSize, usermem.PageSize)
+ }
+ if s.digestSize != sha256DigestSize {
+ t.Errorf("got digestSize %d, want %d", s.digestSize, sha256DigestSize)
+ }
+ if len(s.levelStart) != len(tc.expectedLevelStart) {
+ t.Errorf("got levels %d, want %d", len(s.levelStart), len(tc.expectedLevelStart))
+ }
+ for i := 0; i < len(s.levelStart) && i < len(tc.expectedLevelStart); i++ {
+ if s.levelStart[i] != tc.expectedLevelStart[i] {
+ t.Errorf("got levelStart[%d] %d, want %d", i, s.levelStart[i], tc.expectedLevelStart[i])
+ }
+ }
+ })
+ }
+}
+
+func TestGenerate(t *testing.T) {
+ // The input data has size dataSize. It starts with the data in startWith,
+ // and all other bytes are zeroes.
+ testCases := []struct {
+ dataSize int
+ startWith []byte
+ expectedRoot []byte
+ }{
+ {
+ dataSize: usermem.PageSize,
+ startWith: nil,
+ expectedRoot: []byte{173, 127, 172, 178, 88, 111, 198, 233, 102, 192, 4, 215, 209, 209, 107, 2, 79, 88, 5, 255, 124, 180, 124, 122, 133, 218, 189, 139, 72, 137, 44, 167},
+ },
+ {
+ dataSize: 128*usermem.PageSize + 1,
+ startWith: nil,
+ expectedRoot: []byte{62, 93, 40, 92, 161, 241, 30, 223, 202, 99, 39, 2, 132, 113, 240, 139, 117, 99, 79, 243, 54, 18, 100, 184, 141, 121, 238, 46, 149, 202, 203, 132},
+ },
+ {
+ dataSize: 1,
+ startWith: []byte{'a'},
+ expectedRoot: []byte{52, 75, 204, 142, 172, 129, 37, 14, 145, 137, 103, 203, 11, 162, 209, 205, 30, 169, 213, 72, 20, 28, 243, 24, 242, 2, 92, 43, 169, 59, 110, 210},
+ },
+ {
+ dataSize: 1,
+ startWith: []byte{'1'},
+ expectedRoot: []byte{74, 35, 103, 179, 176, 149, 254, 112, 42, 65, 104, 66, 119, 56, 133, 124, 228, 15, 65, 161, 150, 0, 117, 174, 242, 34, 115, 115, 218, 37, 3, 105},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) {
+ var (
+ data bytes.Buffer
+ tree bytes.Buffer
+ )
+
+ startSize := len(tc.startWith)
+ _, err := data.Write(tc.startWith)
+ if err != nil {
+ t.Fatalf("Failed to write to data: %v", err)
+ }
+ _, err = data.Write(make([]byte, tc.dataSize-startSize))
+ if err != nil {
+ t.Fatalf("Failed to write to data: %v", err)
+ }
+
+ root, err := Generate(&data, int64(tc.dataSize), &tree, &tree)
+ if err != nil {
+ t.Fatalf("Generate failed: %v", err)
+ }
+
+ if !bytes.Equal(root, tc.expectedRoot) {
+ t.Errorf("Unexpected root")
+ }
+ })
+ }
+}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index 57b89ad7d..2cb59f934 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -2506,7 +2506,7 @@ type msgFactory struct {
var msgRegistry registry
type registry struct {
- factories [math.MaxUint8]msgFactory
+ factories [math.MaxUint8 + 1]msgFactory
// largestFixedSize is computed so that given some message size M, you can
// compute the maximum payload size (e.g. for Twrite, Rread) with
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index 28d851ff5..122c457d2 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -1091,6 +1091,19 @@ type AllocateMode struct {
Unshare bool
}
+// ToAllocateMode returns an AllocateMode from a fallocate(2) mode.
+func ToAllocateMode(mode uint64) AllocateMode {
+ return AllocateMode{
+ KeepSize: mode&unix.FALLOC_FL_KEEP_SIZE != 0,
+ PunchHole: mode&unix.FALLOC_FL_PUNCH_HOLE != 0,
+ NoHideStale: mode&unix.FALLOC_FL_NO_HIDE_STALE != 0,
+ CollapseRange: mode&unix.FALLOC_FL_COLLAPSE_RANGE != 0,
+ ZeroRange: mode&unix.FALLOC_FL_ZERO_RANGE != 0,
+ InsertRange: mode&unix.FALLOC_FL_INSERT_RANGE != 0,
+ Unshare: mode&unix.FALLOC_FL_UNSHARE_RANGE != 0,
+ }
+}
+
// ToLinux converts to a value compatible with fallocate(2)'s mode.
func (a *AllocateMode) ToLinux() uint32 {
rv := uint32(0)
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index fdfa83648..60cf94fa1 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -482,10 +482,10 @@ func (cs *connState) handle(m message) (r message) {
defer func() {
if r == nil {
// Don't allow a panic to propagate.
- recover()
+ err := recover()
// Include a useful log message.
- log.Warningf("panic in handler: %s", debug.Stack())
+ log.Warningf("panic in handler: %v\n%s", err, debug.Stack())
// Wrap in an EFAULT error; we don't really have a
// better way to describe this kind of error. It will
diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s
index 38cea9be3..7c622e5d7 100644
--- a/pkg/procid/procid_amd64.s
+++ b/pkg/procid/procid_amd64.s
@@ -14,7 +14,7 @@
// +build amd64
// +build go1.8
-// +build !go1.15
+// +build !go1.16
#include "textflag.h"
diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s
index 4f4b70fef..48ebb5fd1 100644
--- a/pkg/procid/procid_arm64.s
+++ b/pkg/procid/procid_arm64.s
@@ -14,7 +14,7 @@
// +build arm64
// +build go1.8
-// +build !go1.15
+// +build !go1.16
#include "textflag.h"
diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go
index 06308cd29..a52dc1b4e 100644
--- a/pkg/seccomp/seccomp_rules.go
+++ b/pkg/seccomp/seccomp_rules.go
@@ -56,7 +56,7 @@ func (a AllowValue) String() (s string) {
return fmt.Sprintf("%#x ", uintptr(a))
}
-// Rule stores the whitelist of syscall arguments.
+// Rule stores the allowed syscall arguments.
//
// For example:
// rule := Rule {
@@ -82,7 +82,7 @@ func (r Rule) String() (s string) {
return
}
-// SyscallRules stores a map of OR'ed whitelist rules indexed by the syscall number.
+// SyscallRules stores a map of OR'ed argument rules indexed by the syscall number.
// If the 'Rules' is empty, we treat it as any argument is allowed.
//
// For example:
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index 343f81f59..fd95eb2d2 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -17,7 +17,6 @@
package arch
import (
- "encoding/binary"
"fmt"
"io"
@@ -29,7 +28,14 @@ import (
)
// Registers represents the CPU registers for this architecture.
-type Registers = linux.PtraceRegs
+//
+// +stateify savable
+type Registers struct {
+ linux.PtraceRegs
+
+ // TPIDR_EL0 is the EL0 Read/Write Software Thread ID Register.
+ TPIDR_EL0 uint64
+}
const (
// SyscallWidth is the width of insturctions.
@@ -49,9 +55,14 @@ const ARMTrapFlag = uint64(1) << 21
type aarch64FPState []byte
// initAarch64FPState sets up initial state.
+//
+// Related code in Linux kernel: fpsimd_flush_thread().
+// FPCR = FPCR_RM_RN (0x0 << 22).
+//
+// Currently, aarch64FPState is only a space of 0x210 length for fpstate.
+// The fp head is useless in sentry/ptrace/kvm.
+//
func initAarch64FPState(data aarch64FPState) {
- binary.LittleEndian.PutUint32(data, fpsimdMagic)
- binary.LittleEndian.PutUint32(data[4:], fpsimdContextSize)
}
func newAarch64FPStateSlice() []byte {
@@ -97,9 +108,6 @@ type State struct {
// Our floating point state.
aarch64FPState `state:"wait"`
- // TLS pointer
- TPValue uint64
-
// FeatureSet is a pointer to the currently active feature set.
FeatureSet *cpuid.FeatureSet
@@ -153,7 +161,6 @@ func (s *State) Fork() State {
return State{
Regs: s.Regs,
aarch64FPState: s.aarch64FPState.fork(),
- TPValue: s.TPValue,
FeatureSet: s.FeatureSet,
OrigR0: s.OrigR0,
}
@@ -237,18 +244,18 @@ func (s *State) ptraceGetRegs() Registers {
return s.Regs
}
-var registersSize = (*Registers)(nil).SizeBytes()
+var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes()
// PtraceSetRegs implements Context.PtraceSetRegs.
func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
var regs Registers
- buf := make([]byte, registersSize)
+ buf := make([]byte, ptraceRegistersSize)
if _, err := io.ReadFull(src, buf); err != nil {
return 0, err
}
regs.UnmarshalUnsafe(buf)
s.Regs = regs
- return registersSize, nil
+ return ptraceRegistersSize, nil
}
// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
@@ -274,7 +281,7 @@ const (
func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceGetRegs(dst)
@@ -287,7 +294,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int,
func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceSetRegs(src)
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 3b3a0a272..1c3e3c14c 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -300,7 +300,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
// PTRACE_PEEKUSER and PTRACE_POKEUSER are only effective on regs and
// u_debugreg, returning 0 or silently no-oping for other fields
// respectively.
- if addr < uintptr(registersSize) {
+ if addr < uintptr(ptraceRegistersSize) {
regs := c.ptraceGetRegs()
buf := make([]byte, regs.SizeBytes())
regs.MarshalUnsafe(buf)
@@ -315,7 +315,7 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error {
if addr&7 != 0 || addr >= userStructSize {
return syscall.EIO
}
- if addr < uintptr(registersSize) {
+ if addr < uintptr(ptraceRegistersSize) {
regs := c.ptraceGetRegs()
buf := make([]byte, regs.SizeBytes())
regs.MarshalUnsafe(buf)
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
index ada7ac7b8..cabbf60e0 100644
--- a/pkg/sentry/arch/arch_arm64.go
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -142,7 +142,7 @@ func (c *context64) SetStack(value uintptr) {
// TLS returns the current TLS pointer.
func (c *context64) TLS() uintptr {
- return uintptr(c.TPValue)
+ return uintptr(c.Regs.TPIDR_EL0)
}
// SetTLS sets the current TLS pointer. Returns false if value is invalid.
@@ -151,7 +151,7 @@ func (c *context64) SetTLS(value uintptr) bool {
return false
}
- c.TPValue = uint64(value)
+ c.Regs.TPIDR_EL0 = uint64(value)
return true
}
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index dc458b37f..b9405b320 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -31,7 +31,11 @@ import (
)
// Registers represents the CPU registers for this architecture.
-type Registers = linux.PtraceRegs
+//
+// +stateify savable
+type Registers struct {
+ linux.PtraceRegs
+}
// System-related constants for x86.
const (
@@ -311,12 +315,12 @@ func (s *State) ptraceGetRegs() Registers {
return regs
}
-var registersSize = (*Registers)(nil).SizeBytes()
+var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes()
// PtraceSetRegs implements Context.PtraceSetRegs.
func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
var regs Registers
- buf := make([]byte, registersSize)
+ buf := make([]byte, ptraceRegistersSize)
if _, err := io.ReadFull(src, buf); err != nil {
return 0, err
}
@@ -374,7 +378,7 @@ func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
}
regs.Eflags = (s.Regs.Eflags &^ eflagsPtraceMutable) | (regs.Eflags & eflagsPtraceMutable)
s.Regs = regs
- return registersSize, nil
+ return ptraceRegistersSize, nil
}
// isUserSegmentSelector returns true if the given segment selector specifies a
@@ -543,7 +547,7 @@ const (
func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceGetRegs(dst)
@@ -563,7 +567,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int,
func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceSetRegs(src)
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index e74275d2d..2c5d14be5 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -16,14 +16,12 @@ go_library(
],
deps = [
"//pkg/abi/linux",
- "//pkg/context",
"//pkg/fd",
- "//pkg/fspath",
"//pkg/log",
"//pkg/sentry/fdimport",
"//pkg/sentry/fs",
"//pkg/sentry/fs/host",
- "//pkg/sentry/fsbridge",
+ "//pkg/sentry/fs/user",
"//pkg/sentry/fsimpl/host",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
@@ -35,7 +33,6 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/sentry/watchdog",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/tcpip/link/sniffer",
"//pkg/urpc",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/control/logging.go b/pkg/sentry/control/logging.go
index 811f24324..8a500a515 100644
--- a/pkg/sentry/control/logging.go
+++ b/pkg/sentry/control/logging.go
@@ -70,8 +70,8 @@ type LoggingArgs struct {
type Logging struct{}
// Change will change the log level and strace arguments. Although
-// this functions signature requires an error it never acctually
-// return san error. It's required by the URPC interface.
+// this functions signature requires an error it never actually
+// returns an error. It's required by the URPC interface.
// Additionally, it may look odd that this is the only method
// attached to an empty struct but this is also part of how
// URPC dispatches.
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index 2ed17ee09..1bae7cfaf 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -18,7 +18,6 @@ import (
"bytes"
"encoding/json"
"fmt"
- "path"
"sort"
"strings"
"text/tabwriter"
@@ -26,13 +25,10 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/fspath"
- "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fdimport"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/host"
- "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/fs/user"
hostvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -40,7 +36,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/urpc"
)
@@ -108,6 +103,9 @@ type ExecArgs struct {
// String prints the arguments as a string.
func (args ExecArgs) String() string {
+ if len(args.Argv) == 0 {
+ return args.Filename
+ }
a := make([]string, len(args.Argv))
copy(a, args.Argv)
if args.Filename != "" {
@@ -180,42 +178,30 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
}
ctx := initArgs.NewContext(proc.Kernel)
- if initArgs.Filename == "" {
- if kernel.VFS2Enabled {
- // Get the full path to the filename from the PATH env variable.
- if initArgs.MountNamespaceVFS2 == nil {
- // Set initArgs so that 'ctx' returns the namespace.
- //
- // MountNamespaceVFS2 adds a reference to the namespace, which is
- // transferred to the new process.
- initArgs.MountNamespaceVFS2 = proc.Kernel.GlobalInit().Leader().MountNamespaceVFS2()
- }
+ if kernel.VFS2Enabled {
+ // Get the full path to the filename from the PATH env variable.
+ if initArgs.MountNamespaceVFS2 == nil {
+ // Set initArgs so that 'ctx' returns the namespace.
+ //
+ // MountNamespaceVFS2 adds a reference to the namespace, which is
+ // transferred to the new process.
+ initArgs.MountNamespaceVFS2 = proc.Kernel.GlobalInit().Leader().MountNamespaceVFS2()
+ }
+ } else {
+ if initArgs.MountNamespace == nil {
+ // Set initArgs so that 'ctx' returns the namespace.
+ initArgs.MountNamespace = proc.Kernel.GlobalInit().Leader().MountNamespace()
- paths := fs.GetPath(initArgs.Envv)
- vfsObj := proc.Kernel.VFS()
- file, err := ResolveExecutablePath(ctx, vfsObj, initArgs.WorkingDirectory, initArgs.Argv[0], paths)
- if err != nil {
- return nil, 0, nil, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], paths, err)
- }
- initArgs.File = fsbridge.NewVFSFile(file)
- } else {
- // Get the full path to the filename from the PATH env variable.
- paths := fs.GetPath(initArgs.Envv)
- if initArgs.MountNamespace == nil {
- // Set initArgs so that 'ctx' returns the namespace.
- initArgs.MountNamespace = proc.Kernel.GlobalInit().Leader().MountNamespace()
-
- // initArgs must hold a reference on MountNamespace, which will
- // be donated to the new process in CreateProcess.
- initArgs.MountNamespace.IncRef()
- }
- f, err := initArgs.MountNamespace.ResolveExecutablePath(ctx, initArgs.WorkingDirectory, initArgs.Argv[0], paths)
- if err != nil {
- return nil, 0, nil, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], paths, err)
- }
- initArgs.Filename = f
+ // initArgs must hold a reference on MountNamespace, which will
+ // be donated to the new process in CreateProcess.
+ initArgs.MountNamespace.IncRef()
}
}
+ resolved, err := user.ResolveExecutablePath(ctx, &initArgs)
+ if err != nil {
+ return nil, 0, nil, nil, err
+ }
+ initArgs.Filename = resolved
fds := make([]int, len(args.FilePayload.Files))
for i, file := range args.FilePayload.Files {
@@ -428,67 +414,3 @@ func ttyName(tty *kernel.TTY) string {
}
return fmt.Sprintf("pts/%d", tty.Index)
}
-
-// ResolveExecutablePath resolves the given executable name given a set of
-// paths that might contain it.
-func ResolveExecutablePath(ctx context.Context, vfsObj *vfs.VirtualFilesystem, wd, name string, paths []string) (*vfs.FileDescription, error) {
- root := vfs.RootFromContext(ctx)
- defer root.DecRef()
- creds := auth.CredentialsFromContext(ctx)
-
- // Absolute paths can be used directly.
- if path.IsAbs(name) {
- return openExecutable(ctx, vfsObj, creds, root, name)
- }
-
- // Paths with '/' in them should be joined to the working directory, or
- // to the root if working directory is not set.
- if strings.IndexByte(name, '/') > 0 {
- if len(wd) == 0 {
- wd = "/"
- }
- if !path.IsAbs(wd) {
- return nil, fmt.Errorf("working directory %q must be absolute", wd)
- }
- return openExecutable(ctx, vfsObj, creds, root, path.Join(wd, name))
- }
-
- // Otherwise, we must lookup the name in the paths, starting from the
- // calling context's root directory.
- for _, p := range paths {
- if !path.IsAbs(p) {
- // Relative paths aren't safe, no one should be using them.
- log.Warningf("Skipping relative path %q in $PATH", p)
- continue
- }
-
- binPath := path.Join(p, name)
- f, err := openExecutable(ctx, vfsObj, creds, root, binPath)
- if err != nil {
- return nil, err
- }
- if f == nil {
- continue // Not found/no access.
- }
- return f, nil
- }
- return nil, syserror.ENOENT
-}
-
-func openExecutable(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, root vfs.VirtualDentry, path string) (*vfs.FileDescription, error) {
- pop := vfs.PathOperation{
- Root: root,
- Start: root, // binPath is absolute, Start can be anything.
- Path: fspath.Parse(path),
- FollowFinalSymlink: true,
- }
- opts := &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- FileExec: true,
- }
- f, err := vfsObj.OpenAt(ctx, creds, &pop, opts)
- if err == syserror.ENOENT || err == syserror.EACCES {
- return nil, nil
- }
- return f, err
-}
diff --git a/pkg/sentry/device/device.go b/pkg/sentry/device/device.go
index 69e71e322..f45b2bd2b 100644
--- a/pkg/sentry/device/device.go
+++ b/pkg/sentry/device/device.go
@@ -188,6 +188,9 @@ type MultiDevice struct {
// String stringifies MultiDevice.
func (m *MultiDevice) String() string {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
buf := bytes.NewBuffer(nil)
buf.WriteString("cache{")
for k, v := range m.cache {
diff --git a/pkg/sentry/devices/memdev/full.go b/pkg/sentry/devices/memdev/full.go
index c7e197691..af66fe4dc 100644
--- a/pkg/sentry/devices/memdev/full.go
+++ b/pkg/sentry/devices/memdev/full.go
@@ -42,6 +42,7 @@ type fullFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
}
// Release implements vfs.FileDescriptionImpl.Release.
diff --git a/pkg/sentry/devices/memdev/null.go b/pkg/sentry/devices/memdev/null.go
index 33d060d02..92d3d71be 100644
--- a/pkg/sentry/devices/memdev/null.go
+++ b/pkg/sentry/devices/memdev/null.go
@@ -43,6 +43,7 @@ type nullFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
}
// Release implements vfs.FileDescriptionImpl.Release.
diff --git a/pkg/sentry/devices/memdev/random.go b/pkg/sentry/devices/memdev/random.go
index acfa23149..6b81da5ef 100644
--- a/pkg/sentry/devices/memdev/random.go
+++ b/pkg/sentry/devices/memdev/random.go
@@ -48,6 +48,7 @@ type randomFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
// off is the "file offset". off is accessed using atomic memory
// operations.
diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go
index 3b1372b9e..c6f15054d 100644
--- a/pkg/sentry/devices/memdev/zero.go
+++ b/pkg/sentry/devices/memdev/zero.go
@@ -44,6 +44,7 @@ type zeroFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
}
// Release implements vfs.FileDescriptionImpl.Release.
diff --git a/pkg/sentry/devices/ttydev/BUILD b/pkg/sentry/devices/ttydev/BUILD
new file mode 100644
index 000000000..12e49b58a
--- /dev/null
+++ b/pkg/sentry/devices/ttydev/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "ttydev",
+ srcs = ["ttydev.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/vfs",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/devices/ttydev/ttydev.go b/pkg/sentry/devices/ttydev/ttydev.go
new file mode 100644
index 000000000..fbb7fd92c
--- /dev/null
+++ b/pkg/sentry/devices/ttydev/ttydev.go
@@ -0,0 +1,91 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ttydev implements devices for /dev/tty and (eventually)
+// /dev/console.
+//
+// TODO(b/159623826): Support /dev/console.
+package ttydev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // See drivers/tty/tty_io.c:tty_init().
+ ttyDevMinor = 0
+ consoleDevMinor = 1
+)
+
+// ttyDevice implements vfs.Device for /dev/tty.
+type ttyDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (ttyDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &ttyFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// ttyFD implements vfs.FileDescriptionImpl for /dev/tty.
+type ttyFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *ttyFD) Release() {}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *ttyFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, nil
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *ttyFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return 0, nil
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *ttyFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *ttyFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Register registers all devices implemented by this package in vfsObj.
+func Register(vfsObj *vfs.VirtualFilesystem) error {
+ return vfsObj.RegisterDevice(vfs.CharDevice, linux.TTYAUX_MAJOR, ttyDevMinor, ttyDevice{}, &vfs.RegisterDeviceOptions{
+ GroupName: "tty",
+ })
+}
+
+// CreateDevtmpfsFiles creates device special files in dev representing all
+// devices implemented by this package.
+func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error {
+ return dev.CreateDeviceFile(ctx, "tty", vfs.CharDevice, linux.TTYAUX_MAJOR, ttyDevMinor, 0666 /* mode */)
+}
diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD
new file mode 100644
index 000000000..71c59287c
--- /dev/null
+++ b/pkg/sentry/devices/tundev/BUILD
@@ -0,0 +1,23 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "tundev",
+ srcs = ["tundev.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/socket/netstack",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/tcpip/link/tun",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go
new file mode 100644
index 000000000..dfbd069af
--- /dev/null
+++ b/pkg/sentry/devices/tundev/tundev.go
@@ -0,0 +1,178 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tundev implements the /dev/net/tun device.
+package tundev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ netTunDevMajor = 10
+ netTunDevMinor = 200
+)
+
+// tunDevice implements vfs.Device for /dev/net/tun.
+type tunDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (tunDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &tunFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// tunFD implements vfs.FileDescriptionImpl for /dev/net/tun.
+type tunFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ device tun.Device
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ request := args[1].Uint()
+ data := args[2].Pointer()
+
+ switch request {
+ case linux.TUNSETIFF:
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("Ioctl should be called from a task context")
+ }
+ if !t.HasCapability(linux.CAP_NET_ADMIN) {
+ return 0, syserror.EPERM
+ }
+ stack, ok := t.NetworkContext().(*netstack.Stack)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+
+ var req linux.IFReq
+ if _, err := usermem.CopyObjectIn(ctx, uio, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ flags := usermem.ByteOrder.Uint16(req.Data[:])
+ return 0, fd.device.SetIff(stack.Stack, req.Name(), flags)
+
+ case linux.TUNGETIFF:
+ var req linux.IFReq
+
+ copy(req.IFName[:], fd.device.Name())
+
+ // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when
+ // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c.
+ flags := fd.device.Flags() | linux.IFF_NOFILTER
+ usermem.ByteOrder.PutUint16(req.Data[:], flags)
+
+ _, err := usermem.CopyObjectOut(ctx, uio, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *tunFD) Release() {
+ fd.device.Release()
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *tunFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return fd.Read(ctx, dst, opts)
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *tunFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ data, err := fd.device.Read()
+ if err != nil {
+ return 0, err
+ }
+ n, err := dst.CopyOut(ctx, data)
+ if n > 0 && n < len(data) {
+ // Not an error for partial copying. Packet truncated.
+ err = nil
+ }
+ return int64(n), err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *tunFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return fd.Write(ctx, src, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *tunFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ data := make([]byte, src.NumBytes())
+ if _, err := src.CopyIn(ctx, data); err != nil {
+ return 0, err
+ }
+ return fd.device.Write(data)
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (fd *tunFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fd.device.Readiness(mask)
+}
+
+// EventRegister implements watier.Waitable.EventRegister.
+func (fd *tunFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fd.device.EventRegister(e, mask)
+}
+
+// EventUnregister implements watier.Waitable.EventUnregister.
+func (fd *tunFD) EventUnregister(e *waiter.Entry) {
+ fd.device.EventUnregister(e)
+}
+
+// isNetTunSupported returns whether /dev/net/tun device is supported for s.
+func isNetTunSupported(s inet.Stack) bool {
+ _, ok := s.(*netstack.Stack)
+ return ok
+}
+
+// Register registers all devices implemented by this package in vfsObj.
+func Register(vfsObj *vfs.VirtualFilesystem) error {
+ return vfsObj.RegisterDevice(vfs.CharDevice, netTunDevMajor, netTunDevMinor, tunDevice{}, &vfs.RegisterDeviceOptions{})
+}
+
+// CreateDevtmpfsFiles creates device special files in dev representing all
+// devices implemented by this package.
+func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error {
+ return dev.CreateDeviceFile(ctx, "net/tun", vfs.CharDevice, netTunDevMajor, netTunDevMinor, 0666 /* mode */)
+}
diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go
index a4199f9e9..b8686adb4 100644
--- a/pkg/sentry/fdimport/fdimport.go
+++ b/pkg/sentry/fdimport/fdimport.go
@@ -15,6 +15,8 @@
package fdimport
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/host"
@@ -84,6 +86,9 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []
func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdioFDs []int) (*hostvfs2.TTYFileDescription, error) {
k := kernel.KernelFromContext(ctx)
+ if k == nil {
+ return nil, fmt.Errorf("cannot find kernel from context")
+ }
var ttyFile *vfs.FileDescription
for appFD, hostFD := range stdioFDs {
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index 846252c89..ca41520b4 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -146,7 +146,7 @@ func (f *File) DecRef() {
f.DecRefWithDestructor(func() {
// Drop BSD style locks.
lockRng := lock.LockRange{Start: 0, End: lock.LockEOF}
- f.Dirent.Inode.LockCtx.BSD.UnlockRegion(lock.UniqueID(f.UniqueID), lockRng)
+ f.Dirent.Inode.LockCtx.BSD.UnlockRegion(f, lockRng)
// Release resources held by the FileOperations.
f.FileOperations.Release()
@@ -310,7 +310,6 @@ func (f *File) Writev(ctx context.Context, src usermem.IOSequence) (int64, error
if !f.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
-
unlockAppendMu := f.Dirent.Inode.lockAppendMu(f.Flags().Append)
// Handle append mode.
if f.Flags().Append {
@@ -355,7 +354,6 @@ func (f *File) Pwritev(ctx context.Context, src usermem.IOSequence, offset int64
// offset."
unlockAppendMu := f.Dirent.Inode.lockAppendMu(f.Flags().Append)
defer unlockAppendMu()
-
if f.Flags().Append {
if err := f.offsetForAppend(ctx, &offset); err != nil {
return 0, err
@@ -374,9 +372,10 @@ func (f *File) Pwritev(ctx context.Context, src usermem.IOSequence, offset int64
return f.FileOperations.Write(ctx, f, src, offset)
}
-// offsetForAppend sets the given offset to the end of the file.
+// offsetForAppend atomically sets the given offset to the end of the file.
//
-// Precondition: the file.Dirent.Inode.appendMu mutex should be held for writing.
+// Precondition: the file.Dirent.Inode.appendMu mutex should be held for
+// writing.
func (f *File) offsetForAppend(ctx context.Context, offset *int64) error {
uattr, err := f.Dirent.Inode.UnstableAttr(ctx)
if err != nil {
@@ -386,7 +385,7 @@ func (f *File) offsetForAppend(ctx context.Context, offset *int64) error {
}
// Update the offset.
- *offset = uattr.Size
+ atomic.StoreInt64(offset, uattr.Size)
return nil
}
diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go
index beba0f771..f5537411e 100644
--- a/pkg/sentry/fs/file_operations.go
+++ b/pkg/sentry/fs/file_operations.go
@@ -160,6 +160,7 @@ type FileOperations interface {
// refer.
//
// Preconditions: The AddressSpace (if any) that io refers to is activated.
+ // Must only be called from a task goroutine.
Ioctl(ctx context.Context, file *File, io usermem.IO, args arch.SyscallArguments) (uintptr, error)
}
diff --git a/pkg/sentry/fs/filesystems.go b/pkg/sentry/fs/filesystems.go
index 084da2a8d..d41f30bbb 100644
--- a/pkg/sentry/fs/filesystems.go
+++ b/pkg/sentry/fs/filesystems.go
@@ -87,20 +87,6 @@ func RegisterFilesystem(f Filesystem) {
filesystems.registered[f.Name()] = f
}
-// UnregisterFilesystem removes a file system from the global set. To keep the
-// file system set compatible with save/restore, UnregisterFilesystem must be
-// called before save/restore methods.
-//
-// For instance, packages may unregister their file system after it is mounted.
-// This makes sense for pseudo file systems that should not be visible or
-// mountable. See whitelistfs in fs/host/fs.go for one example.
-func UnregisterFilesystem(name string) {
- filesystems.mu.Lock()
- defer filesystems.mu.Unlock()
-
- delete(filesystems.registered, name)
-}
-
// FindFilesystem returns a Filesystem registered at name or (nil, false) if name
// is not a file system type that can be found in /proc/filesystems.
func FindFilesystem(name string) (Filesystem, bool) {
diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go
index bdba6efe5..d2dbff268 100644
--- a/pkg/sentry/fs/fs.go
+++ b/pkg/sentry/fs/fs.go
@@ -42,9 +42,10 @@
// Dirent.dirMu
// Dirent.mu
// DirentCache.mu
-// Locks in InodeOperations implementations or overlayEntry
// Inode.Watches.mu (see `Inotify` for other lock ordering)
// MountSource.mu
+// Inode.appendMu
+// Locks in InodeOperations implementations or overlayEntry
//
// If multiple Dirent or MountSource locks must be taken, locks in the parent must be
// taken before locks in their children.
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index 789369220..5fb419bcd 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -8,7 +8,6 @@ go_template_instance(
out = "dirty_set_impl.go",
imports = {
"memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
},
package = "fsutil",
prefix = "Dirty",
@@ -25,14 +24,14 @@ go_template_instance(
name = "frame_ref_set_impl",
out = "frame_ref_set_impl.go",
imports = {
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
},
package = "fsutil",
prefix = "FrameRef",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
- "Range": "platform.FileRange",
+ "Range": "memmap.FileRange",
"Value": "uint64",
"Functions": "FrameRefSetFunctions",
},
@@ -43,7 +42,6 @@ go_template_instance(
out = "file_range_set_impl.go",
imports = {
"memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
},
package = "fsutil",
prefix = "FileRange",
@@ -86,7 +84,6 @@ go_library(
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
- "//pkg/sentry/platform",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/state",
diff --git a/pkg/sentry/fs/fsutil/dirty_set.go b/pkg/sentry/fs/fsutil/dirty_set.go
index c6cd45087..2c9446c1d 100644
--- a/pkg/sentry/fs/fsutil/dirty_set.go
+++ b/pkg/sentry/fs/fsutil/dirty_set.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -159,7 +158,7 @@ func (ds *DirtySet) AllowClean(mr memmap.MappableRange) {
// repeatedly until all bytes have been written. max is the true size of the
// cached object; offsets beyond max will not be passed to writeAt, even if
// they are marked dirty.
-func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
var changedDirty bool
defer func() {
if changedDirty {
@@ -194,7 +193,7 @@ func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet
// successful partial write, SyncDirtyAll will call it repeatedly until all
// bytes have been written. max is the true size of the cached object; offsets
// beyond max will not be passed to writeAt, even if they are marked dirty.
-func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
dseg := dirty.FirstSegment()
for dseg.Ok() {
if err := syncDirtyRange(ctx, dseg.Range(), cache, max, mem, writeAt); err != nil {
@@ -210,7 +209,7 @@ func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max
}
// Preconditions: mr must be page-aligned.
-func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
for cseg := cache.LowerBoundSegment(mr.Start); cseg.Ok() && cseg.Start() < mr.End; cseg = cseg.NextSegment() {
wbr := cseg.Range().Intersect(mr)
if max < wbr.Start {
diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go
index 5643cdac9..bbafebf03 100644
--- a/pkg/sentry/fs/fsutil/file_range_set.go
+++ b/pkg/sentry/fs/fsutil/file_range_set.go
@@ -23,13 +23,12 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/usermem"
)
// FileRangeSet maps offsets into a memmap.Mappable to offsets into a
-// platform.File. It is used to implement Mappables that store data in
+// memmap.File. It is used to implement Mappables that store data in
// sparsely-allocated memory.
//
// type FileRangeSet <generated by go_generics>
@@ -65,20 +64,20 @@ func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, spli
}
// FileRange returns the FileRange mapped by seg.
-func (seg FileRangeIterator) FileRange() platform.FileRange {
+func (seg FileRangeIterator) FileRange() memmap.FileRange {
return seg.FileRangeOf(seg.Range())
}
// FileRangeOf returns the FileRange mapped by mr.
//
// Preconditions: seg.Range().IsSupersetOf(mr). mr.Length() != 0.
-func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) platform.FileRange {
+func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRange {
frstart := seg.Value() + (mr.Start - seg.Start())
- return platform.FileRange{frstart, frstart + mr.Length()}
+ return memmap.FileRange{frstart, frstart + mr.Length()}
}
// Fill attempts to ensure that all memmap.Mappable offsets in required are
-// mapped to a platform.File offset, by allocating from mf with the given
+// mapped to a memmap.File offset, by allocating from mf with the given
// memory usage kind and invoking readAt to store data into memory. (If readAt
// returns a successful partial read, Fill will call it repeatedly until all
// bytes have been read.) EOF is handled consistently with the requirements of
@@ -141,7 +140,7 @@ func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.Map
}
// Drop removes segments for memmap.Mappable offsets in mr, freeing the
-// corresponding platform.FileRanges.
+// corresponding memmap.FileRanges.
//
// Preconditions: mr must be page-aligned.
func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) {
@@ -154,7 +153,7 @@ func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) {
}
// DropAll removes all segments in mr, freeing the corresponding
-// platform.FileRanges.
+// memmap.FileRanges.
func (frs *FileRangeSet) DropAll(mf *pgalloc.MemoryFile) {
for seg := frs.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
mf.DecRef(seg.FileRange())
diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go
index dd6f5aba6..a808894df 100644
--- a/pkg/sentry/fs/fsutil/frame_ref_set.go
+++ b/pkg/sentry/fs/fsutil/frame_ref_set.go
@@ -17,7 +17,7 @@ package fsutil
import (
"math"
- "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usage"
)
@@ -39,7 +39,7 @@ func (FrameRefSetFunctions) ClearValue(val *uint64) {
}
// Merge implements segment.Functions.Merge.
-func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.FileRange, val2 uint64) (uint64, bool) {
+func (FrameRefSetFunctions) Merge(_ memmap.FileRange, val1 uint64, _ memmap.FileRange, val2 uint64) (uint64, bool) {
if val1 != val2 {
return 0, false
}
@@ -47,13 +47,13 @@ func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.
}
// Split implements segment.Functions.Split.
-func (FrameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) {
+func (FrameRefSetFunctions) Split(_ memmap.FileRange, val uint64, _ uint64) (uint64, uint64) {
return val, val
}
// IncRefAndAccount adds a reference on the range fr. All newly inserted segments
// are accounted as host page cache memory mappings.
-func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) {
+func (refs *FrameRefSet) IncRefAndAccount(fr memmap.FileRange) {
seg, gap := refs.Find(fr.Start)
for {
switch {
@@ -74,7 +74,7 @@ func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) {
// DecRefAndAccount removes a reference on the range fr and untracks segments
// that are removed from memory accounting.
-func (refs *FrameRefSet) DecRefAndAccount(fr platform.FileRange) {
+func (refs *FrameRefSet) DecRefAndAccount(fr memmap.FileRange) {
seg := refs.FindSegment(fr.Start)
for seg.Ok() && seg.Start() < fr.End {
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index e82afd112..ef0113b52 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -126,7 +125,7 @@ func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) {
// offsets in fr or until the next call to UnmapAll.
//
// Preconditions: The caller must hold a reference on all offsets in fr.
-func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) (safemem.BlockSeq, error) {
+func (f *HostFileMapper) MapInternal(fr memmap.FileRange, fd int, write bool) (safemem.BlockSeq, error) {
chunks := ((fr.End + chunkMask) >> chunkShift) - (fr.Start >> chunkShift)
f.mapsMu.Lock()
defer f.mapsMu.Unlock()
@@ -146,7 +145,7 @@ func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool)
}
// Preconditions: f.mapsMu must be locked.
-func (f *HostFileMapper) forEachMappingBlockLocked(fr platform.FileRange, fd int, write bool, fn func(safemem.Block)) error {
+func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int, write bool, fn func(safemem.Block)) error {
prot := syscall.PROT_READ
if write {
prot |= syscall.PROT_WRITE
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index 78fec553e..c15d8a946 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -21,18 +21,17 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
)
-// HostMappable implements memmap.Mappable and platform.File over a
+// HostMappable implements memmap.Mappable and memmap.File over a
// CachedFileObject.
//
// Lock order (compare the lock order model in mm/mm.go):
// truncateMu ("fs locks")
// mu ("memmap.Mappable locks not taken by Translate")
-// ("platform.File locks")
+// ("memmap.File locks")
// backingFile ("CachedFileObject locks")
//
// +stateify savable
@@ -124,24 +123,24 @@ func (h *HostMappable) NotifyChangeFD() error {
return nil
}
-// MapInternal implements platform.File.MapInternal.
-func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// MapInternal implements memmap.File.MapInternal.
+func (h *HostMappable) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write)
}
-// FD implements platform.File.FD.
+// FD implements memmap.File.FD.
func (h *HostMappable) FD() int {
return h.backingFile.FD()
}
-// IncRef implements platform.File.IncRef.
-func (h *HostMappable) IncRef(fr platform.FileRange) {
+// IncRef implements memmap.File.IncRef.
+func (h *HostMappable) IncRef(fr memmap.FileRange) {
mr := memmap.MappableRange{Start: fr.Start, End: fr.End}
h.hostFileMapper.IncRefOn(mr)
}
-// DecRef implements platform.File.DecRef.
-func (h *HostMappable) DecRef(fr platform.FileRange) {
+// DecRef implements memmap.File.DecRef.
+func (h *HostMappable) DecRef(fr memmap.FileRange) {
mr := memmap.MappableRange{Start: fr.Start, End: fr.End}
h.hostFileMapper.DecRefOn(mr)
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 800c8b4e1..fe8b0b6ac 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -26,7 +26,6 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
@@ -934,7 +933,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange
// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error {
- // Whether we have a host fd (and consequently what platform.File is
+ // Whether we have a host fd (and consequently what memmap.File is
// mapped) can change across save/restore, so invalidate all translations
// unconditionally.
c.mapsMu.Lock()
@@ -999,10 +998,10 @@ func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.Evictable
}
}
-// IncRef implements platform.File.IncRef. This is used when we directly map an
-// underlying host fd and CachingInodeOperations is used as the platform.File
+// IncRef implements memmap.File.IncRef. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the memmap.File
// during translation.
-func (c *CachingInodeOperations) IncRef(fr platform.FileRange) {
+func (c *CachingInodeOperations) IncRef(fr memmap.FileRange) {
// Hot path. Avoid defers.
c.dataMu.Lock()
seg, gap := c.refs.Find(fr.Start)
@@ -1024,10 +1023,10 @@ func (c *CachingInodeOperations) IncRef(fr platform.FileRange) {
}
}
-// DecRef implements platform.File.DecRef. This is used when we directly map an
-// underlying host fd and CachingInodeOperations is used as the platform.File
+// DecRef implements memmap.File.DecRef. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the memmap.File
// during translation.
-func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
+func (c *CachingInodeOperations) DecRef(fr memmap.FileRange) {
// Hot path. Avoid defers.
c.dataMu.Lock()
seg := c.refs.FindSegment(fr.Start)
@@ -1046,15 +1045,15 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
c.dataMu.Unlock()
}
-// MapInternal implements platform.File.MapInternal. This is used when we
+// MapInternal implements memmap.File.MapInternal. This is used when we
// directly map an underlying host fd and CachingInodeOperations is used as the
-// platform.File during translation.
-func (c *CachingInodeOperations) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// memmap.File during translation.
+func (c *CachingInodeOperations) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return c.hostFileMapper.MapInternal(fr, c.backingFile.FD(), at.Write)
}
-// FD implements platform.File.FD. This is used when we directly map an
-// underlying host fd and CachingInodeOperations is used as the platform.File
+// FD implements memmap.File.FD. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the memmap.File
// during translation.
func (c *CachingInodeOperations) FD() int {
return c.backingFile.FD()
diff --git a/pkg/sentry/fs/g3doc/.gitignore b/pkg/sentry/fs/g3doc/.gitignore
new file mode 100644
index 000000000..2d19fc766
--- /dev/null
+++ b/pkg/sentry/fs/g3doc/.gitignore
@@ -0,0 +1 @@
+*.html
diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md
new file mode 100644
index 000000000..2ca84dd74
--- /dev/null
+++ b/pkg/sentry/fs/g3doc/fuse.md
@@ -0,0 +1,263 @@
+# Foreword
+
+This document describes an on-going project to support FUSE filesystems within
+the sentry. This is intended to become the final documentation for this
+subsystem, and is therefore written in the past tense. However FUSE support is
+currently incomplete and the document will be updated as things progress.
+
+# FUSE: Filesystem in Userspace
+
+The sentry supports dispatching filesystem operations to a FUSE server, allowing
+FUSE filesystem to be used with a sandbox.
+
+## Overview
+
+FUSE has two main components:
+
+1. A client kernel driver (canonically `fuse.ko` in Linux), which forwards
+ filesystem operations (usually initiated by syscalls) to the server.
+
+2. A server, which is a userspace daemon that implements the actual filesystem.
+
+The sentry implements the client component, which allows a server daemon running
+within the sandbox to implement a filesystem within the sandbox.
+
+A FUSE filesystem is initialized with `mount(2)`, typically with the help of a
+utility like `fusermount(1)`. Various mount options exist for establishing
+ownership and access permissions on the filesystem, but the most important mount
+option is a file descriptor used to establish communication between the client
+and server.
+
+The FUSE device FD is obtained by opening `/dev/fuse`. During regular operation,
+the client and server use the FUSE protocol described in `fuse(4)` to service
+filesystem operations. See the "Protocol" section below for more information
+about this protocol. The core of the sentry support for FUSE is the client-side
+implementation of this protocol.
+
+## FUSE in the Sentry
+
+The sentry's FUSE client targets VFS2 and has the following components:
+
+- An implementation of `/dev/fuse`.
+
+- A VFS2 filesystem for mapping syscalls to FUSE ops. Since we're targeting
+ VFS2, one point of contention may be the lack of inodes in VFS2. We can
+ tentatively implement a kernfs-based filesystem to bridge the gap in APIs.
+ The kernfs base functionality can serve the role of the Linux inode cache
+ and, the filesystem can map VFS2 syscalls to kernfs inode operations; see
+ the `kernfs.Inode` interface.
+
+The FUSE protocol lends itself well to marshaling with `go_marshal`. The various
+request and response packets can be defined in the ABI package and converted to
+and from the wire format using `go_marshal`.
+
+### Design Goals
+
+- While filesystem performance is always important, the sentry's FUSE support
+ is primarily concerned with compatibility, with performance as a secondary
+ concern.
+
+- Avoiding deadlocks from a hung server daemon.
+
+- Consider the potential for denial of service from a malicious server daemon.
+ Protecting itself from userspace is already a design goal for the sentry,
+ but needs additional consideration for FUSE. Normally, an operating system
+ doesn't rely on userspace to make progress with filesystem operations. Since
+ this changes with FUSE, it opens up the possibility of creating a chain of
+ dependencies controlled by userspace, which could affect an entire sandbox.
+ For example: a FUSE op can block a syscall, which could be holding a
+ subsystem lock, which can then block another task goroutine.
+
+### Milestones
+
+Below are some broad goals to aim for while implementing FUSE in the sentry.
+Many FUSE ops can be grouped into broad categories of functionality, and most
+ops can be implemented in parallel.
+
+#### Minimal client that can mount a trivial FUSE filesystem.
+
+- Implement `/dev/fuse` - a character device used to establish an FD for
+ communication between the sentry and the server daemon.
+
+- Implement basic FUSE ops like `FUSE_INIT`, `FUSE_DESTROY`.
+
+#### Read-only mount with basic file operations
+
+- Implement the majority of file, directory and file descriptor FUSE ops. For
+ this milestone, we can skip uncommon or complex operations like mmap, mknod,
+ file locking, poll, and extended attributes. We can stub these out along
+ with any ops that modify the filesystem. The exact list of required ops are
+ to be determined, but the goal is to mount a real filesystem as read-only,
+ and be able to read contents from the filesystem in the sentry.
+
+#### Full read-write support
+
+- Implement the remaining FUSE ops and decide if we can omit rarely used
+ operations like ioctl.
+
+# Appendix
+
+## FUSE Protocol
+
+The FUSE protocol is a request-response protocol. All requests are initiated by
+the client. The wire-format for the protocol is raw C structs serialized to
+memory.
+
+All FUSE requests begin with the following request header:
+
+```c
+struct fuse_in_header {
+ uint32_t len; // Length of the request, including this header.
+ uint32_t opcode; // Requested operation.
+ uint64_t unique; // A unique identifier for this request.
+ uint64_t nodeid; // ID of the filesystem object being operated on.
+ uint32_t uid; // UID of the requesting process.
+ uint32_t gid; // GID of the requesting process.
+ uint32_t pid; // PID of the requesting process.
+ uint32_t padding;
+};
+```
+
+The request is then followed by a payload specific to the `opcode`.
+
+All responses begin with this response header:
+
+```c
+struct fuse_out_header {
+ uint32_t len; // Length of the response, including this header.
+ int32_t error; // Status of the request, 0 if success.
+ uint64_t unique; // The unique identifier from the corresponding request.
+};
+```
+
+The response payload also depends on the request `opcode`. If `error != 0`, the
+response payload must be empty.
+
+### Operations
+
+The following is a list of all FUSE operations used in `fuse_in_header.opcode`
+as of Linux v4.4, and a brief description of their purpose. These are defined in
+`uapi/linux/fuse.h`. Many of these have a corresponding request and response
+payload struct; `fuse(4)` has details for some of these. We also note how these
+operations map to the sentry virtual filesystem.
+
+#### FUSE meta-operations
+
+These operations are specific to FUSE and don't have a corresponding action in a
+generic filesystem.
+
+- `FUSE_INIT`: This operation initializes a new FUSE filesystem, and is the
+ first message sent by the client after mount. This is used for version and
+ feature negotiation. This is related to `mount(2)`.
+- `FUSE_DESTROY`: Teardown a FUSE filesystem, related to `unmount(2)`.
+- `FUSE_INTERRUPT`: Interrupts an in-flight operation, specified by the
+ `fuse_in_header.unique` value provided in the corresponding request header.
+ The client can send at most one of these per request, and will enter an
+ uninterruptible wait for a reply. The server is expected to reply promptly.
+- `FUSE_FORGET`: A hint to the server that server should evict the indicate
+ node from any caches. This is wired up to `(struct
+ super_operations).evict_inode` in Linux, which is in turned hooked as the
+ inode cache shrinker which is typically triggered by system memory pressure.
+- `FUSE_BATCH_FORGET`: Batch version of `FUSE_FORGET`.
+
+#### Filesystem Syscalls
+
+These FUSE ops map directly to an equivalent filesystem syscall, or family of
+syscalls. The relevant syscalls have a similar name to the operation, unless
+otherwise noted.
+
+Node creation:
+
+- `FUSE_MKNOD`
+- `FUSE_MKDIR`
+- `FUSE_CREATE`: This is equivalent to `open(2)` and `creat(2)`, which
+ atomically creates and opens a node.
+
+Node attributes and extended attributes:
+
+- `FUSE_GETATTR`
+- `FUSE_SETATTR`
+- `FUSE_SETXATTR`
+- `FUSE_GETXATTR`
+- `FUSE_LISTXATTR`
+- `FUSE_REMOVEXATTR`
+
+Node link manipulation:
+
+- `FUSE_READLINK`
+- `FUSE_LINK`
+- `FUSE_SYMLINK`
+- `FUSE_UNLINK`
+
+Directory operations:
+
+- `FUSE_RMDIR`
+- `FUSE_RENAME`
+- `FUSE_RENAME2`
+- `FUSE_OPENDIR`: `open(2)` for directories.
+- `FUSE_RELEASEDIR`: `close(2)` for directories.
+- `FUSE_READDIR`
+- `FUSE_READDIRPLUS`
+- `FUSE_FSYNCDIR`: `fsync(2)` for directories.
+- `FUSE_LOOKUP`: Establishes a unique identifier for a FS node. This is
+ reminiscent of `VirtualFilesystem.GetDentryAt` in that it resolves a path
+ component to a node. However the returned identifier is opaque to the
+ client. The server must remember this mapping, as this is how the client
+ will reference the node in the future.
+
+File operations:
+
+- `FUSE_OPEN`: `open(2)` for files.
+- `FUSE_RELEASE`: `close(2)` for files.
+- `FUSE_FSYNC`
+- `FUSE_FALLOCATE`
+- `FUSE_SETUPMAPPING`: Creates a memory map on a file for `mmap(2)`.
+- `FUSE_REMOVEMAPPING`: Removes a memory map for `munmap(2)`.
+
+File locking:
+
+- `FUSE_GETLK`
+- `FUSE_SETLK`
+- `FUSE_SETLKW`
+- `FUSE_COPY_FILE_RANGE`
+
+File descriptor operations:
+
+- `FUSE_IOCTL`
+- `FUSE_POLL`
+- `FUSE_LSEEK`
+
+Filesystem operations:
+
+- `FUSE_STATFS`
+
+#### Permissions
+
+- `FUSE_ACCESS` is used to check if a node is accessible, as part of many
+ syscall implementations. Maps to `vfs.FilesystemImpl.AccessAt` in the
+ sentry.
+
+#### I/O Operations
+
+These ops are used to read and write file pages. They're used to implement both
+I/O syscalls like `read(2)`, `write(2)` and `mmap(2)`.
+
+- `FUSE_READ`
+- `FUSE_WRITE`
+
+#### Miscellaneous
+
+- `FUSE_FLUSH`: Used by the client to indicate when a file descriptor is
+ closed. Distinct from `FUSE_FSYNC`, which corresponds to an `fsync(2)`
+ syscall from the user. Maps to `vfs.FileDescriptorImpl.Release` in the
+ sentry.
+- `FUSE_BMAP`: Old address space API for block defrag. Probably not needed.
+- `FUSE_NOTIFY_REPLY`: [TODO: what does this do?]
+
+# References
+
+- [fuse(4) Linux manual page](https://www.man7.org/linux/man-pages/man4/fuse.4.html)
+- [Linux kernel FUSE documentation](https://www.kernel.org/doc/html/latest/filesystems/fuse.html)
+- [The reference implementation of the Linux FUSE (Filesystem in Userspace)
+ interface](https://github.com/libfuse/libfuse)
+- [The kernel interface of FUSE](https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/include/uapi/linux/fuse.h)
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index a016c896e..51d7368a1 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -640,7 +640,7 @@ func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset,
// WriteOut implements fs.InodeOperations.WriteOut.
func (i *inodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
- if !i.session().cachePolicy.cacheUAttrs(inode) {
+ if inode.MountSource.Flags.ReadOnly || !i.session().cachePolicy.cacheUAttrs(inode) {
return nil
}
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index aabce6cc9..d41d23a43 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/context",
"//pkg/fd",
"//pkg/fdnotifier",
+ "//pkg/iovec",
"//pkg/log",
"//pkg/refs",
"//pkg/safemem",
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 62f1246aa..fbfba1b58 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -368,6 +368,9 @@ func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset,
// WriteOut implements fs.InodeOperations.WriteOut.
func (i *inodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
+ if inode.MountSource.Flags.ReadOnly {
+ return nil
+ }
// Have we been using host kernel metadata caches?
if !inode.MountSource.Flags.ForcePageCache || !canMap(inode) {
// Then the metadata is already up to date on the host.
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index b6e94583e..cfb089e43 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/fdnotifier"
- "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/socket/control"
@@ -39,11 +38,6 @@ import (
// LINT.IfChange
-// maxSendBufferSize is the maximum host send buffer size allowed for endpoint.
-//
-// N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max).
-const maxSendBufferSize = 8 << 20
-
// ConnectedEndpoint is a host FD backed implementation of
// transport.ConnectedEndpoint and transport.Receiver.
//
@@ -103,10 +97,6 @@ func (c *ConnectedEndpoint) init() *syserr.Error {
if err != nil {
return syserr.FromError(err)
}
- if sndbuf > maxSendBufferSize {
- log.Warningf("Socket send buffer too large: %d", sndbuf)
- return syserr.ErrInvalidEndpointState
- }
c.stype = linux.SockType(stype)
c.sndbuf = int64(sndbuf)
diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go
index 5c18dbd5e..905afb50d 100644
--- a/pkg/sentry/fs/host/socket_iovec.go
+++ b/pkg/sentry/fs/host/socket_iovec.go
@@ -17,15 +17,12 @@ package host
import (
"syscall"
- "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/syserror"
)
// LINT.IfChange
-// maxIovs is the maximum number of iovecs to pass to the host.
-var maxIovs = linux.UIO_MAXIOV
-
// copyToMulti copies as many bytes from src to dst as possible.
func copyToMulti(dst [][]byte, src []byte) {
for _, d := range dst {
@@ -76,7 +73,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec
}
}
- if iovsRequired > maxIovs {
+ if iovsRequired > iovec.MaxIovs {
// The kernel will reject our call if we pass this many iovs.
// Use a single intermediate buffer instead.
b := make([]byte, stopLen)
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index cb91355ab..82a02fcb2 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -308,9 +308,9 @@ func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) e
task := kernel.TaskFromContext(ctx)
if task == nil {
// No task? Linux does not have an analog for this case, but
- // tty_check_change is more of a blacklist of cases than a
- // whitelist, and is surprisingly permissive. Allowing the
- // change seems most appropriate.
+ // tty_check_change only blocks specific cases and is
+ // surprisingly permissive. Allowing the change seems
+ // appropriate.
return nil
}
diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go
index 926538d90..8a5d9c7eb 100644
--- a/pkg/sentry/fs/lock/lock.go
+++ b/pkg/sentry/fs/lock/lock.go
@@ -62,7 +62,7 @@ import (
type LockType int
// UniqueID is a unique identifier of the holder of a regional file lock.
-type UniqueID uint64
+type UniqueID interface{}
const (
// ReadLock describes a POSIX regional file lock to be taken
@@ -98,12 +98,7 @@ type Lock struct {
// If len(Readers) > 0 then HasWriter must be false.
Readers map[UniqueID]bool
- // HasWriter indicates that this is a write lock held by a single
- // UniqueID.
- HasWriter bool
-
- // Writer is only valid if HasWriter is true. It identifies a
- // single write lock holder.
+ // Writer holds the writer unique ID. It's nil if there are no writers.
Writer UniqueID
}
@@ -186,7 +181,6 @@ func makeLock(uid UniqueID, t LockType) Lock {
case ReadLock:
value.Readers[uid] = true
case WriteLock:
- value.HasWriter = true
value.Writer = uid
default:
panic(fmt.Sprintf("makeLock: invalid lock type %d", t))
@@ -196,10 +190,7 @@ func makeLock(uid UniqueID, t LockType) Lock {
// isHeld returns true if uid is a holder of Lock.
func (l Lock) isHeld(uid UniqueID) bool {
- if l.HasWriter && l.Writer == uid {
- return true
- }
- return l.Readers[uid]
+ return l.Writer == uid || l.Readers[uid]
}
// lock sets uid as a holder of a typed lock on Lock.
@@ -214,20 +205,20 @@ func (l *Lock) lock(uid UniqueID, t LockType) {
}
// We cannot downgrade a write lock to a read lock unless the
// uid is the same.
- if l.HasWriter {
+ if l.Writer != nil {
if l.Writer != uid {
panic(fmt.Sprintf("lock: cannot downgrade write lock to read lock for uid %d, writer is %d", uid, l.Writer))
}
// Ensure that there is only one reader if upgrading.
l.Readers = make(map[UniqueID]bool)
// Ensure that there is no longer a writer.
- l.HasWriter = false
+ l.Writer = nil
}
l.Readers[uid] = true
return
case WriteLock:
// If we are already the writer, then this is a no-op.
- if l.HasWriter && l.Writer == uid {
+ if l.Writer == uid {
return
}
// We can only upgrade a read lock to a write lock if there
@@ -243,7 +234,6 @@ func (l *Lock) lock(uid UniqueID, t LockType) {
}
// Ensure that there is only a writer.
l.Readers = make(map[UniqueID]bool)
- l.HasWriter = true
l.Writer = uid
default:
panic(fmt.Sprintf("lock: invalid lock type %d", t))
@@ -277,9 +267,8 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
switch t {
case ReadLock:
return l.lockable(r, func(value Lock) bool {
- // If there is no writer, there's no problem adding
- // another reader.
- if !value.HasWriter {
+ // If there is no writer, there's no problem adding another reader.
+ if value.Writer == nil {
return true
}
// If there is a writer, then it must be the same uid
@@ -289,10 +278,9 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
case WriteLock:
return l.lockable(r, func(value Lock) bool {
// If there are only readers.
- if !value.HasWriter {
- // Then this uid can only take a write lock if
- // this is a private upgrade, meaning that the
- // only reader is uid.
+ if value.Writer == nil {
+ // Then this uid can only take a write lock if this is a private
+ // upgrade, meaning that the only reader is uid.
return len(value.Readers) == 1 && value.Readers[uid]
}
// If the uid is already a writer on this region, then
@@ -304,7 +292,8 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
}
}
-// lock returns true if uid took a lock of type t on the entire range of LockRange.
+// lock returns true if uid took a lock of type t on the entire range of
+// LockRange.
//
// Preconditions: r.Start <= r.End (will panic otherwise).
func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool {
@@ -339,7 +328,7 @@ func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool {
seg, _ = l.SplitUnchecked(seg, r.End)
}
- // Set the lock on the segment. This is guaranteed to
+ // Set the lock on the segment. This is guaranteed to
// always be safe, given canLock above.
value := seg.ValuePtr()
value.lock(uid, t)
@@ -386,7 +375,7 @@ func (l *LockSet) unlock(uid UniqueID, r LockRange) {
value := seg.Value()
var remove bool
- if value.HasWriter && value.Writer == uid {
+ if value.Writer == uid {
// If we are unlocking a writer, then since there can
// only ever be one writer and no readers, then this
// lock should always be removed from the set.
diff --git a/pkg/sentry/fs/lock/lock_set_functions.go b/pkg/sentry/fs/lock/lock_set_functions.go
index 8a3ace0c1..50a16e662 100644
--- a/pkg/sentry/fs/lock/lock_set_functions.go
+++ b/pkg/sentry/fs/lock/lock_set_functions.go
@@ -44,14 +44,9 @@ func (lockSetFunctions) Merge(r1 LockRange, val1 Lock, r2 LockRange, val2 Lock)
return Lock{}, false
}
}
- if val1.HasWriter != val2.HasWriter {
+ if val1.Writer != val2.Writer {
return Lock{}, false
}
- if val1.HasWriter {
- if val1.Writer != val2.Writer {
- return Lock{}, false
- }
- }
return val1, true
}
@@ -62,7 +57,6 @@ func (lockSetFunctions) Split(r LockRange, val Lock, split uint64) (Lock, Lock)
for k, v := range val.Readers {
val0.Readers[k] = v
}
- val0.HasWriter = val.HasWriter
val0.Writer = val.Writer
return val, val0
diff --git a/pkg/sentry/fs/lock/lock_test.go b/pkg/sentry/fs/lock/lock_test.go
index ba002aeb7..fad90984b 100644
--- a/pkg/sentry/fs/lock/lock_test.go
+++ b/pkg/sentry/fs/lock/lock_test.go
@@ -42,9 +42,6 @@ func equals(e0, e1 []entry) bool {
if !reflect.DeepEqual(e0[i].LockRange, e1[i].LockRange) {
return false
}
- if e0[i].Lock.HasWriter != e1[i].Lock.HasWriter {
- return false
- }
if e0[i].Lock.Writer != e1[i].Lock.Writer {
return false
}
@@ -105,7 +102,7 @@ func TestCanLock(t *testing.T) {
LockRange: LockRange{2048, 3072},
},
{
- Lock: Lock{HasWriter: true, Writer: 1},
+ Lock: Lock{Writer: 1},
LockRange: LockRange{3072, 4096},
},
})
@@ -241,7 +238,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -254,7 +251,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -273,7 +270,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{0, 4096},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -301,7 +298,7 @@ func TestSetLock(t *testing.T) {
// 0 4096 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 4096},
},
{
@@ -318,7 +315,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -550,7 +547,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, 4096},
},
{
@@ -594,7 +591,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, 3072},
},
{
@@ -633,7 +630,7 @@ func TestSetLock(t *testing.T) {
// 0 1024 2048 4096 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -663,11 +660,11 @@ func TestSetLock(t *testing.T) {
// 0 1024 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, LockEOF},
},
},
@@ -675,28 +672,30 @@ func TestSetLock(t *testing.T) {
}
for _, test := range tests {
- l := fill(test.before)
+ t.Run(test.name, func(t *testing.T) {
+ l := fill(test.before)
- r := LockRange{Start: test.start, End: test.end}
- success := l.lock(test.uid, test.lockType, r)
- var got []entry
- for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
- got = append(got, entry{
- Lock: seg.Value(),
- LockRange: seg.Range(),
- })
- }
+ r := LockRange{Start: test.start, End: test.end}
+ success := l.lock(test.uid, test.lockType, r)
+ var got []entry
+ for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ got = append(got, entry{
+ Lock: seg.Value(),
+ LockRange: seg.Range(),
+ })
+ }
- if success != test.success {
- t.Errorf("%s: setlock(%v, %+v, %d, %d) got success %v, want %v", test.name, test.before, r, test.uid, test.lockType, success, test.success)
- continue
- }
+ if success != test.success {
+ t.Errorf("setlock(%v, %+v, %d, %d) got success %v, want %v", test.before, r, test.uid, test.lockType, success, test.success)
+ return
+ }
- if success {
- if !equals(got, test.after) {
- t.Errorf("%s: got set %+v, want %+v", test.name, got, test.after)
+ if success {
+ if !equals(got, test.after) {
+ t.Errorf("got set %+v, want %+v", got, test.after)
+ }
}
- }
+ })
}
}
@@ -782,7 +781,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -824,7 +823,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -837,7 +836,7 @@ func TestUnlock(t *testing.T) {
// 0 4096 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -876,7 +875,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -889,7 +888,7 @@ func TestUnlock(t *testing.T) {
// 0 4096
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 4096},
},
},
@@ -906,7 +905,7 @@ func TestUnlock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, 4096},
},
{
@@ -974,7 +973,7 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -991,7 +990,7 @@ func TestUnlock(t *testing.T) {
// 0 8 4096 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 8},
},
{
@@ -1008,7 +1007,7 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -1025,7 +1024,7 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 8192 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -1041,19 +1040,21 @@ func TestUnlock(t *testing.T) {
}
for _, test := range tests {
- l := fill(test.before)
+ t.Run(test.name, func(t *testing.T) {
+ l := fill(test.before)
- r := LockRange{Start: test.start, End: test.end}
- l.unlock(test.uid, r)
- var got []entry
- for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
- got = append(got, entry{
- Lock: seg.Value(),
- LockRange: seg.Range(),
- })
- }
- if !equals(got, test.after) {
- t.Errorf("%s: got set %+v, want %+v", test.name, got, test.after)
- }
+ r := LockRange{Start: test.start, End: test.end}
+ l.unlock(test.uid, r)
+ var got []entry
+ for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ got = append(got, entry{
+ Lock: seg.Value(),
+ LockRange: seg.Range(),
+ })
+ }
+ if !equals(got, test.after) {
+ t.Errorf("got set %+v, want %+v", got, test.after)
+ }
+ })
}
}
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index b414ddaee..3f2bd0e87 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -17,13 +17,9 @@ package fs
import (
"fmt"
"math"
- "path"
- "strings"
"syscall"
- "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sync"
@@ -625,71 +621,3 @@ func (mns *MountNamespace) SyncAll(ctx context.Context) {
defer mns.mu.Unlock()
mns.root.SyncAll(ctx)
}
-
-// ResolveExecutablePath resolves the given executable name given a set of
-// paths that might contain it.
-func (mns *MountNamespace) ResolveExecutablePath(ctx context.Context, wd, name string, paths []string) (string, error) {
- // Absolute paths can be used directly.
- if path.IsAbs(name) {
- return name, nil
- }
-
- // Paths with '/' in them should be joined to the working directory, or
- // to the root if working directory is not set.
- if strings.IndexByte(name, '/') > 0 {
- if wd == "" {
- wd = "/"
- }
- if !path.IsAbs(wd) {
- return "", fmt.Errorf("working directory %q must be absolute", wd)
- }
- return path.Join(wd, name), nil
- }
-
- // Otherwise, We must lookup the name in the paths, starting from the
- // calling context's root directory.
- root := RootFromContext(ctx)
- if root == nil {
- // Caller has no root. Don't bother traversing anything.
- return "", syserror.ENOENT
- }
- defer root.DecRef()
- for _, p := range paths {
- binPath := path.Join(p, name)
- traversals := uint(linux.MaxSymlinkTraversals)
- d, err := mns.FindInode(ctx, root, nil, binPath, &traversals)
- if err == syserror.ENOENT || err == syserror.EACCES {
- // Didn't find it here.
- continue
- }
- if err != nil {
- return "", err
- }
- defer d.DecRef()
-
- // Check that it is a regular file.
- if !IsRegular(d.Inode.StableAttr) {
- continue
- }
-
- // Check whether we can read and execute the found file.
- if err := d.Inode.CheckPermission(ctx, PermMask{Read: true, Execute: true}); err != nil {
- log.Infof("Found executable at %q, but user cannot execute it: %v", binPath, err)
- continue
- }
- return path.Join("/", p, name), nil
- }
- return "", syserror.ENOENT
-}
-
-// GetPath returns the PATH as a slice of strings given the environment
-// variables.
-func GetPath(env []string) []string {
- const prefix = "PATH="
- for _, e := range env {
- if strings.HasPrefix(e, prefix) {
- return strings.Split(strings.TrimPrefix(e, prefix), ":")
- }
- }
- return nil
-}
diff --git a/pkg/sentry/fs/user/BUILD b/pkg/sentry/fs/user/BUILD
index f37f979f1..66e949c95 100644
--- a/pkg/sentry/fs/user/BUILD
+++ b/pkg/sentry/fs/user/BUILD
@@ -4,15 +4,21 @@ package(licenses = ["notice"])
go_library(
name = "user",
- srcs = ["user.go"],
+ srcs = [
+ "path.go",
+ "user.go",
+ ],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/log",
"//pkg/sentry/fs",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
+ "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go
new file mode 100644
index 000000000..397e96045
--- /dev/null
+++ b/pkg/sentry/fs/user/path.go
@@ -0,0 +1,170 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package user
+
+import (
+ "fmt"
+ "path"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// ResolveExecutablePath resolves the given executable name given the working
+// dir and environment.
+func ResolveExecutablePath(ctx context.Context, args *kernel.CreateProcessArgs) (string, error) {
+ name := args.Filename
+ if len(name) == 0 {
+ if len(args.Argv) == 0 {
+ return "", fmt.Errorf("no filename or command provided")
+ }
+ name = args.Argv[0]
+ }
+
+ // Absolute paths can be used directly.
+ if path.IsAbs(name) {
+ return name, nil
+ }
+
+ // Paths with '/' in them should be joined to the working directory, or
+ // to the root if working directory is not set.
+ if strings.IndexByte(name, '/') > 0 {
+ wd := args.WorkingDirectory
+ if wd == "" {
+ wd = "/"
+ }
+ if !path.IsAbs(wd) {
+ return "", fmt.Errorf("working directory %q must be absolute", wd)
+ }
+ return path.Join(wd, name), nil
+ }
+
+ // Otherwise, We must lookup the name in the paths.
+ paths := getPath(args.Envv)
+ if kernel.VFS2Enabled {
+ f, err := resolveVFS2(ctx, args.Credentials, args.MountNamespaceVFS2, paths, name)
+ if err != nil {
+ return "", fmt.Errorf("error finding executable %q in PATH %v: %v", name, paths, err)
+ }
+ return f, nil
+ }
+
+ f, err := resolve(ctx, args.MountNamespace, paths, name)
+ if err != nil {
+ return "", fmt.Errorf("error finding executable %q in PATH %v: %v", name, paths, err)
+ }
+ return f, nil
+}
+
+func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name string) (string, error) {
+ root := fs.RootFromContext(ctx)
+ if root == nil {
+ // Caller has no root. Don't bother traversing anything.
+ return "", syserror.ENOENT
+ }
+ defer root.DecRef()
+ for _, p := range paths {
+ if !path.IsAbs(p) {
+ // Relative paths aren't safe, no one should be using them.
+ log.Warningf("Skipping relative path %q in $PATH", p)
+ continue
+ }
+
+ binPath := path.Join(p, name)
+ traversals := uint(linux.MaxSymlinkTraversals)
+ d, err := mns.FindInode(ctx, root, nil, binPath, &traversals)
+ if err == syserror.ENOENT || err == syserror.EACCES {
+ // Didn't find it here.
+ continue
+ }
+ if err != nil {
+ return "", err
+ }
+ defer d.DecRef()
+
+ // Check that it is a regular file.
+ if !fs.IsRegular(d.Inode.StableAttr) {
+ continue
+ }
+
+ // Check whether we can read and execute the found file.
+ if err := d.Inode.CheckPermission(ctx, fs.PermMask{Read: true, Execute: true}); err != nil {
+ log.Infof("Found executable at %q, but user cannot execute it: %v", binPath, err)
+ continue
+ }
+ return path.Join("/", p, name), nil
+ }
+
+ // Couldn't find it.
+ return "", syserror.ENOENT
+}
+
+func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, paths []string, name string) (string, error) {
+ root := mns.Root()
+ defer root.DecRef()
+ for _, p := range paths {
+ if !path.IsAbs(p) {
+ // Relative paths aren't safe, no one should be using them.
+ log.Warningf("Skipping relative path %q in $PATH", p)
+ continue
+ }
+
+ binPath := path.Join(p, name)
+ pop := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(binPath),
+ FollowFinalSymlink: true,
+ }
+ opts := &vfs.OpenOptions{
+ FileExec: true,
+ Flags: linux.O_RDONLY,
+ }
+ dentry, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, pop, opts)
+ if err == syserror.ENOENT || err == syserror.EACCES {
+ // Didn't find it here.
+ continue
+ }
+ if err != nil {
+ return "", err
+ }
+ dentry.DecRef()
+
+ return binPath, nil
+ }
+
+ // Couldn't find it.
+ return "", syserror.ENOENT
+}
+
+// getPath returns the PATH as a slice of strings given the environment
+// variables.
+func getPath(env []string) []string {
+ const prefix = "PATH="
+ for _, e := range env {
+ if strings.HasPrefix(e, prefix) {
+ return strings.Split(strings.TrimPrefix(e, prefix), ":")
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/fs/user/user.go b/pkg/sentry/fs/user/user.go
index fe7f67c00..f4d525523 100644
--- a/pkg/sentry/fs/user/user.go
+++ b/pkg/sentry/fs/user/user.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package user contains methods for resolving filesystem paths based on the
+// user and their environment.
package user
import (
diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD
index 585764223..93512c9b6 100644
--- a/pkg/sentry/fsimpl/devpts/BUILD
+++ b/pkg/sentry/fsimpl/devpts/BUILD
@@ -18,6 +18,7 @@ go_library(
"//pkg/context",
"//pkg/safemem",
"//pkg/sentry/arch",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
index c03c65445..e6fda2b4f 100644
--- a/pkg/sentry/fsimpl/devpts/devpts.go
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -116,6 +116,8 @@ type rootInode struct {
kernfs.InodeNotSymlink
kernfs.OrderedChildren
+ locks vfs.FileLocks
+
// Keep a reference to this inode's dentry.
dentry kernfs.Dentry
@@ -183,7 +185,7 @@ func (i *rootInode) masterClose(t *Terminal) {
// Open implements kernfs.Inode.Open.
func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
index 7a7ce5d81..1081fff52 100644
--- a/pkg/sentry/fsimpl/devpts/master.go
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
@@ -34,6 +35,8 @@ type masterInode struct {
kernfs.InodeNotDirectory
kernfs.InodeNotSymlink
+ locks vfs.FileLocks
+
// Keep a reference to this inode's dentry.
dentry kernfs.Dentry
@@ -55,6 +58,7 @@ func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vf
inode: mi,
t: t,
}
+ fd.LockFD.Init(&mi.locks)
if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
mi.DecRef()
return nil, err
@@ -63,8 +67,8 @@ func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vf
}
// Stat implements kernfs.Inode.Stat.
-func (mi *masterInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
- statx, err := mi.InodeAttrs.Stat(vfsfs, opts)
+func (mi *masterInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ statx, err := mi.InodeAttrs.Stat(ctx, vfsfs, opts)
if err != nil {
return linux.Statx{}, err
}
@@ -85,6 +89,7 @@ func (mi *masterInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds
type masterFileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
inode *masterInode
t *Terminal
@@ -181,7 +186,17 @@ func (mfd *masterFileDescription) SetStat(ctx context.Context, opts vfs.SetStatO
// Stat implements vfs.FileDescriptionImpl.Stat.
func (mfd *masterFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
fs := mfd.vfsfd.VirtualDentry().Mount().Filesystem()
- return mfd.inode.Stat(fs, opts)
+ return mfd.inode.Stat(ctx, fs, opts)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (mfd *masterFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return mfd.Locks().LockPOSIX(ctx, &mfd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (mfd *masterFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return mfd.Locks().UnlockPOSIX(ctx, &mfd.vfsfd, uid, start, length, whence)
}
// maybeEmitUnimplementedEvent emits unimplemented event if cmd is valid.
diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go
index 526cd406c..a91cae3ef 100644
--- a/pkg/sentry/fsimpl/devpts/slave.go
+++ b/pkg/sentry/fsimpl/devpts/slave.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -33,6 +34,8 @@ type slaveInode struct {
kernfs.InodeNotDirectory
kernfs.InodeNotSymlink
+ locks vfs.FileLocks
+
// Keep a reference to this inode's dentry.
dentry kernfs.Dentry
@@ -51,6 +54,7 @@ func (si *slaveInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs
fd := &slaveFileDescription{
inode: si,
}
+ fd.LockFD.Init(&si.locks)
if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
si.DecRef()
return nil, err
@@ -69,8 +73,8 @@ func (si *slaveInode) Valid(context.Context) bool {
}
// Stat implements kernfs.Inode.Stat.
-func (si *slaveInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
- statx, err := si.InodeAttrs.Stat(vfsfs, opts)
+func (si *slaveInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ statx, err := si.InodeAttrs.Stat(ctx, vfsfs, opts)
if err != nil {
return linux.Statx{}, err
}
@@ -91,6 +95,7 @@ func (si *slaveInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds
type slaveFileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
inode *slaveInode
}
@@ -127,7 +132,7 @@ func (sfd *slaveFileDescription) Write(ctx context.Context, src usermem.IOSequen
return sfd.inode.t.ld.outputQueueWrite(ctx, src)
}
-// Ioctl implements vfs.FileDescripionImpl.Ioctl.
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
switch cmd := args[1].Uint(); cmd {
case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
@@ -178,5 +183,15 @@ func (sfd *slaveFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOp
// Stat implements vfs.FileDescriptionImpl.Stat.
func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem()
- return sfd.inode.Stat(fs, opts)
+ return sfd.inode.Stat(ctx, fs, opts)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (sfd *slaveFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return sfd.Locks().LockPOSIX(ctx, &sfd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (sfd *slaveFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return sfd.Locks().UnlockPOSIX(ctx, &sfd.vfsfd, uid, start, length, whence)
}
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
index 142ee53b0..d0e06cdc0 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
@@ -136,6 +136,8 @@ func (a *Accessor) pathOperationAt(pathname string) *vfs.PathOperation {
// CreateDeviceFile creates a device special file at the given pathname in the
// devtmpfs instance accessed by the Accessor.
func (a *Accessor) CreateDeviceFile(ctx context.Context, pathname string, kind vfs.DeviceKind, major, minor uint32, perms uint16) error {
+ actx := a.wrapContext(ctx)
+
mode := (linux.FileMode)(perms)
switch kind {
case vfs.BlockDevice:
@@ -145,12 +147,24 @@ func (a *Accessor) CreateDeviceFile(ctx context.Context, pathname string, kind v
default:
panic(fmt.Sprintf("invalid vfs.DeviceKind: %v", kind))
}
+
+ // Create any parent directories. See
+ // devtmpfs.c:handle_create()=>path_create().
+ for it := fspath.Parse(pathname).Begin; it.NextOk(); it = it.Next() {
+ pop := a.pathOperationAt(it.String())
+ if err := a.vfsObj.MkdirAt(actx, a.creds, pop, &vfs.MkdirOptions{
+ Mode: 0755,
+ }); err != nil {
+ return fmt.Errorf("failed to create directory %q: %v", it.String(), err)
+ }
+ }
+
// NOTE: Linux's devtmpfs refuses to automatically delete files it didn't
// create, which it recognizes by storing a pointer to the kdevtmpfs struct
// thread in struct inode::i_private. Accessor doesn't yet support deletion
// of files at all, and probably won't as long as we don't need to support
// kernel modules, so this is moot for now.
- return a.vfsObj.MknodAt(a.wrapContext(ctx), a.creds, a.pathOperationAt(pathname), &vfs.MknodOptions{
+ return a.vfsObj.MknodAt(actx, a.creds, a.pathOperationAt(pathname), &vfs.MknodOptions{
Mode: mode,
DevMajor: major,
DevMinor: minor,
diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go
index c573d7935..d12d78b84 100644
--- a/pkg/sentry/fsimpl/eventfd/eventfd.go
+++ b/pkg/sentry/fsimpl/eventfd/eventfd.go
@@ -37,6 +37,7 @@ type EventFileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
// queue is used to notify interested parties when the event object
// becomes readable or writable.
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index ff861d0fe..abc610ef3 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -54,6 +54,7 @@ go_library(
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/ext/disklayout",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
@@ -95,7 +96,7 @@ go_test(
"//pkg/syserror",
"//pkg/test/testutil",
"//pkg/usermem",
- "@com_github_google_go-cmp//cmp:go_default_library",
- "@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go
index a2d8c3ad6..8bb104ff0 100644
--- a/pkg/sentry/fsimpl/ext/block_map_file.go
+++ b/pkg/sentry/fsimpl/ext/block_map_file.go
@@ -58,15 +58,16 @@ var _ io.ReaderAt = (*blockMapFile)(nil)
// newBlockMapFile is the blockMapFile constructor. It initializes the file to
// physical blocks map with (at most) the first 12 (direct) blocks.
-func newBlockMapFile(regFile regularFile) (*blockMapFile, error) {
- file := &blockMapFile{regFile: regFile}
+func newBlockMapFile(args inodeArgs) (*blockMapFile, error) {
+ file := &blockMapFile{}
file.regFile.impl = file
+ file.regFile.inode.init(args, &file.regFile)
for i := uint(0); i < 4; i++ {
- file.coverage[i] = getCoverage(regFile.inode.blkSize, i)
+ file.coverage[i] = getCoverage(file.regFile.inode.blkSize, i)
}
- blkMap := regFile.inode.diskInode.Data()
+ blkMap := file.regFile.inode.diskInode.Data()
binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks)
binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk)
binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk)
diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go
index 181727ef7..6fa84e7aa 100644
--- a/pkg/sentry/fsimpl/ext/block_map_test.go
+++ b/pkg/sentry/fsimpl/ext/block_map_test.go
@@ -85,20 +85,6 @@ func (n *blkNumGen) next() uint32 {
// the inode covers and that is written to disk.
func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) {
mockDisk := make([]byte, mockBMDiskSize)
- regFile := regularFile{
- inode: inode{
- fs: &filesystem{
- dev: bytes.NewReader(mockDisk),
- },
- diskInode: &disklayout.InodeNew{
- InodeOld: disklayout.InodeOld{
- SizeLo: getMockBMFileFize(),
- },
- },
- blkSize: uint64(mockBMBlkSize),
- },
- }
-
var fileData []byte
blkNums := newBlkNumGen()
var data []byte
@@ -125,9 +111,20 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) {
data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk)
fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...)
- copy(regFile.inode.diskInode.Data(), data)
+ args := inodeArgs{
+ fs: &filesystem{
+ dev: bytes.NewReader(mockDisk),
+ },
+ diskInode: &disklayout.InodeNew{
+ InodeOld: disklayout.InodeOld{
+ SizeLo: getMockBMFileFize(),
+ },
+ },
+ blkSize: uint64(mockBMBlkSize),
+ }
+ copy(args.diskInode.Data(), data)
- mockFile, err := newBlockMapFile(regFile)
+ mockFile, err := newBlockMapFile(args)
if err != nil {
t.Fatalf("newBlockMapFile failed: %v", err)
}
diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go
index bfbd7c3d4..55902322a 100644
--- a/pkg/sentry/fsimpl/ext/dentry.go
+++ b/pkg/sentry/fsimpl/ext/dentry.go
@@ -60,3 +60,20 @@ func (d *dentry) DecRef() {
// inode.decRef().
d.inode.decRef()
}
+
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+//
+// TODO(b/134676337): Implement inotify.
+func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {}
+
+// Watches implements vfs.DentryImpl.Watches.
+//
+// TODO(b/134676337): Implement inotify.
+func (d *dentry) Watches() *vfs.Watches {
+ return nil
+}
+
+// OnZeroWatches implements vfs.Dentry.OnZeroWatches.
+//
+// TODO(b/134676337): Implement inotify.
+func (d *dentry) OnZeroWatches() {}
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index 12b875c8f..357512c7e 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
@@ -54,16 +55,15 @@ type directory struct {
}
// newDirectory is the directory constructor.
-func newDirectory(inode inode, newDirent bool) (*directory, error) {
+func newDirectory(args inodeArgs, newDirent bool) (*directory, error) {
file := &directory{
- inode: inode,
childCache: make(map[string]*dentry),
childMap: make(map[string]*dirent),
}
- file.inode.impl = file
+ file.inode.init(args, file)
// Initialize childList by reading dirents from the underlying file.
- if inode.diskInode.Flags().Index {
+ if args.diskInode.Flags().Index {
// TODO(b/134676337): Support hash tree directories. Currently only the '.'
// and '..' entries are read in.
@@ -74,7 +74,7 @@ func newDirectory(inode inode, newDirent bool) (*directory, error) {
// The dirents are organized in a linear array in the file data.
// Extract the file data and decode the dirents.
- regFile, err := newRegularFile(inode)
+ regFile, err := newRegularFile(args)
if err != nil {
return nil, err
}
@@ -82,7 +82,7 @@ func newDirectory(inode inode, newDirent bool) (*directory, error) {
// buf is used as scratch space for reading in dirents from disk and
// unmarshalling them into dirent structs.
buf := make([]byte, disklayout.DirentSize)
- size := inode.diskInode.Size()
+ size := args.diskInode.Size()
for off, inc := uint64(0), uint64(0); off < size; off += inc {
toRead := size - off
if toRead > disklayout.DirentSize {
@@ -306,3 +306,13 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in
fd.off = offset
return offset, nil
}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *directoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *directoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go
index 11dcc0346..c36225a7c 100644
--- a/pkg/sentry/fsimpl/ext/extent_file.go
+++ b/pkg/sentry/fsimpl/ext/extent_file.go
@@ -38,9 +38,10 @@ var _ io.ReaderAt = (*extentFile)(nil)
// newExtentFile is the extent file constructor. It reads the entire extent
// tree into memory.
// TODO(b/134676337): Build extent tree on demand to reduce memory usage.
-func newExtentFile(regFile regularFile) (*extentFile, error) {
- file := &extentFile{regFile: regFile}
+func newExtentFile(args inodeArgs) (*extentFile, error) {
+ file := &extentFile{}
file.regFile.impl = file
+ file.regFile.inode.init(args, &file.regFile)
err := file.buildExtTree()
if err != nil {
return nil, err
diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go
index a2382daa3..cd10d46ee 100644
--- a/pkg/sentry/fsimpl/ext/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/extent_test.go
@@ -177,21 +177,19 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, []
t.Helper()
mockDisk := make([]byte, mockExtentBlkSize*10)
- mockExtentFile := &extentFile{
- regFile: regularFile{
- inode: inode{
- fs: &filesystem{
- dev: bytes.NewReader(mockDisk),
- },
- diskInode: &disklayout.InodeNew{
- InodeOld: disklayout.InodeOld{
- SizeLo: uint32(mockExtentBlkSize) * getNumPhyBlks(root),
- },
- },
- blkSize: mockExtentBlkSize,
+ mockExtentFile := &extentFile{}
+ args := inodeArgs{
+ fs: &filesystem{
+ dev: bytes.NewReader(mockDisk),
+ },
+ diskInode: &disklayout.InodeNew{
+ InodeOld: disklayout.InodeOld{
+ SizeLo: uint32(mockExtentBlkSize) * getNumPhyBlks(root),
},
},
+ blkSize: mockExtentBlkSize,
}
+ mockExtentFile.regFile.inode.init(args, &mockExtentFile.regFile)
fileData := writeTree(&mockExtentFile.regFile.inode, mockDisk, node0, mockExtentBlkSize)
diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go
index 92f7da40d..90b086468 100644
--- a/pkg/sentry/fsimpl/ext/file_description.go
+++ b/pkg/sentry/fsimpl/ext/file_description.go
@@ -26,6 +26,7 @@ import (
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
}
func (fd *fileDescription) filesystem() *filesystem {
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index 485f86f4b..30636cf66 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -54,6 +54,8 @@ type inode struct {
// diskInode gives us access to the inode struct on disk. Immutable.
diskInode disklayout.Inode
+ locks vfs.FileLocks
+
// This is immutable. The first field of the implementations must have inode
// as the first field to ensure temporality.
impl interface{}
@@ -115,7 +117,7 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) {
}
// Build the inode based on its type.
- inode := inode{
+ args := inodeArgs{
fs: fs,
inodeNum: inodeNum,
blkSize: blkSize,
@@ -124,19 +126,19 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) {
switch diskInode.Mode().FileType() {
case linux.ModeSymlink:
- f, err := newSymlink(inode)
+ f, err := newSymlink(args)
if err != nil {
return nil, err
}
return &f.inode, nil
case linux.ModeRegular:
- f, err := newRegularFile(inode)
+ f, err := newRegularFile(args)
if err != nil {
return nil, err
}
return &f.inode, nil
case linux.ModeDirectory:
- f, err := newDirectory(inode, fs.sb.IncompatibleFeatures().DirentFileType)
+ f, err := newDirectory(args, fs.sb.IncompatibleFeatures().DirentFileType)
if err != nil {
return nil, err
}
@@ -147,6 +149,21 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) {
}
}
+type inodeArgs struct {
+ fs *filesystem
+ inodeNum uint32
+ blkSize uint64
+ diskInode disklayout.Inode
+}
+
+func (in *inode) init(args inodeArgs, impl interface{}) {
+ in.fs = args.fs
+ in.inodeNum = args.inodeNum
+ in.blkSize = args.blkSize
+ in.diskInode = args.diskInode
+ in.impl = impl
+}
+
// open creates and returns a file description for the dentry passed in.
func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
ats := vfs.AccessTypesForOpenFlags(opts)
@@ -157,6 +174,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt
switch in.impl.(type) {
case *regularFile:
var fd regularFileFD
+ fd.LockFD.Init(&in.locks)
if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
@@ -168,6 +186,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt
return nil, syserror.EISDIR
}
var fd directoryFD
+ fd.LockFD.Init(&in.locks)
if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
@@ -178,6 +197,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt
return nil, syserror.ELOOP
}
var fd symlinkFD
+ fd.LockFD.Init(&in.locks)
fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
default:
diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go
index 30135ddb0..66d14bb95 100644
--- a/pkg/sentry/fsimpl/ext/regular_file.go
+++ b/pkg/sentry/fsimpl/ext/regular_file.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
@@ -43,28 +44,19 @@ type regularFile struct {
// newRegularFile is the regularFile constructor. It figures out what kind of
// file this is and initializes the fileReader.
-func newRegularFile(inode inode) (*regularFile, error) {
- regFile := regularFile{
- inode: inode,
- }
-
- inodeFlags := inode.diskInode.Flags()
-
- if inodeFlags.Extents {
- file, err := newExtentFile(regFile)
+func newRegularFile(args inodeArgs) (*regularFile, error) {
+ if args.diskInode.Flags().Extents {
+ file, err := newExtentFile(args)
if err != nil {
return nil, err
}
-
- file.regFile.inode.impl = &file.regFile
return &file.regFile, nil
}
- file, err := newBlockMapFile(regFile)
+ file, err := newBlockMapFile(args)
if err != nil {
return nil, err
}
- file.regFile.inode.impl = &file.regFile
return &file.regFile, nil
}
@@ -77,6 +69,7 @@ func (in *inode) isRegular() bool {
// vfs.FileDescriptionImpl.
type regularFileFD struct {
fileDescription
+ vfs.LockFD
// off is the file offset. off is accessed using atomic memory operations.
off int64
@@ -157,3 +150,13 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt
// TODO(b/134676337): Implement mmap(2).
return syserror.ENODEV
}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *regularFileFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go
index 1447a4dc1..62efd4095 100644
--- a/pkg/sentry/fsimpl/ext/symlink.go
+++ b/pkg/sentry/fsimpl/ext/symlink.go
@@ -30,18 +30,17 @@ type symlink struct {
// newSymlink is the symlink constructor. It reads out the symlink target from
// the inode (however it might have been stored).
-func newSymlink(inode inode) (*symlink, error) {
- var file *symlink
+func newSymlink(args inodeArgs) (*symlink, error) {
var link []byte
// If the symlink target is lesser than 60 bytes, its stores in inode.Data().
// Otherwise either extents or block maps will be used to store the link.
- size := inode.diskInode.Size()
+ size := args.diskInode.Size()
if size < 60 {
- link = inode.diskInode.Data()[:size]
+ link = args.diskInode.Data()[:size]
} else {
// Create a regular file out of this inode and read out the target.
- regFile, err := newRegularFile(inode)
+ regFile, err := newRegularFile(args)
if err != nil {
return nil, err
}
@@ -52,8 +51,8 @@ func newSymlink(inode inode) (*symlink, error) {
}
}
- file = &symlink{inode: inode, target: string(link)}
- file.inode.impl = file
+ file := &symlink{target: string(link)}
+ file.inode.init(args, file)
return file, nil
}
@@ -67,6 +66,7 @@ func (in *inode) isSymlink() bool {
// O_PATH. For this reason most of the functions return EBADF.
type symlinkFD struct {
fileDescription
+ vfs.NoLockFD
}
// Compiles only if symlinkFD implements vfs.FileDescriptionImpl.
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
new file mode 100644
index 000000000..999111deb
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -0,0 +1,63 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "request_list",
+ out = "request_list.go",
+ package = "fuse",
+ prefix = "request",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Request",
+ "Linker": "*Request",
+ },
+)
+
+go_library(
+ name = "fuse",
+ srcs = [
+ "connection.go",
+ "dev.go",
+ "fusefs.go",
+ "init.go",
+ "register.go",
+ "request_list.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "fuse_test",
+ size = "small",
+ srcs = ["dev_test.go"],
+ library = ":fuse",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/fsimpl/testutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/fuse/connection.go b/pkg/sentry/fsimpl/fuse/connection.go
new file mode 100644
index 000000000..6df2728ab
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/connection.go
@@ -0,0 +1,437 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+)
+
+// maxActiveRequestsDefault is the default setting controlling the upper bound
+// on the number of active requests at any given time.
+const maxActiveRequestsDefault = 10000
+
+// Ordinary requests have even IDs, while interrupts IDs are odd.
+// Used to increment the unique ID for each FUSE request.
+var reqIDStep uint64 = 2
+
+const (
+ // fuseDefaultMaxBackground is the default value for MaxBackground.
+ fuseDefaultMaxBackground = 12
+
+ // fuseDefaultCongestionThreshold is the default value for CongestionThreshold,
+ // and is 75% of the default maximum of MaxGround.
+ fuseDefaultCongestionThreshold = (fuseDefaultMaxBackground * 3 / 4)
+
+ // fuseDefaultMaxPagesPerReq is the default value for MaxPagesPerReq.
+ fuseDefaultMaxPagesPerReq = 32
+)
+
+// Request represents a FUSE operation request that hasn't been sent to the
+// server yet.
+//
+// +stateify savable
+type Request struct {
+ requestEntry
+
+ id linux.FUSEOpID
+ hdr *linux.FUSEHeaderIn
+ data []byte
+}
+
+// Response represents an actual response from the server, including the
+// response payload.
+//
+// +stateify savable
+type Response struct {
+ opcode linux.FUSEOpcode
+ hdr linux.FUSEHeaderOut
+ data []byte
+}
+
+// connection is the struct by which the sentry communicates with the FUSE server daemon.
+type connection struct {
+ fd *DeviceFD
+
+ // The following FUSE_INIT flags are currently unsupported by this implementation:
+ // - FUSE_ATOMIC_O_TRUNC: requires open(..., O_TRUNC)
+ // - FUSE_EXPORT_SUPPORT
+ // - FUSE_HANDLE_KILLPRIV
+ // - FUSE_POSIX_LOCKS: requires POSIX locks
+ // - FUSE_FLOCK_LOCKS: requires POSIX locks
+ // - FUSE_AUTO_INVAL_DATA: requires page caching eviction
+ // - FUSE_EXPLICIT_INVAL_DATA: requires page caching eviction
+ // - FUSE_DO_READDIRPLUS/FUSE_READDIRPLUS_AUTO: requires FUSE_READDIRPLUS implementation
+ // - FUSE_ASYNC_DIO
+ // - FUSE_POSIX_ACL: affects defaultPermissions, posixACL, xattr handler
+
+ // initialized after receiving FUSE_INIT reply.
+ // Until it's set, suspend sending FUSE requests.
+ // Use SetInitialized() and IsInitialized() for atomic access.
+ initialized int32
+
+ // initializedChan is used to block requests before initialization.
+ initializedChan chan struct{}
+
+ // blocked when there are too many outstading backgrounds requests (NumBackground == MaxBackground).
+ // TODO(gvisor.dev/issue/3185): update the numBackground accordingly; use a channel to block.
+ blocked bool
+
+ // connected (connection established) when a new FUSE file system is created.
+ // Set to false when:
+ // umount,
+ // connection abort,
+ // device release.
+ connected bool
+
+ // aborted via sysfs.
+ // TODO(gvisor.dev/issue/3185): abort all queued requests.
+ aborted bool
+
+ // connInitError if FUSE_INIT encountered error (major version mismatch).
+ // Only set in INIT.
+ connInitError bool
+
+ // connInitSuccess if FUSE_INIT is successful.
+ // Only set in INIT.
+ // Used for destory.
+ connInitSuccess bool
+
+ // TODO(gvisor.dev/issue/3185): All the queue logic are working in progress.
+
+ // NumberBackground is the number of requests in the background.
+ numBackground uint16
+
+ // congestionThreshold for NumBackground.
+ // Negotiated in FUSE_INIT.
+ congestionThreshold uint16
+
+ // maxBackground is the maximum number of NumBackground.
+ // Block connection when it is reached.
+ // Negotiated in FUSE_INIT.
+ maxBackground uint16
+
+ // numActiveBackground is the number of requests in background and has being marked as active.
+ numActiveBackground uint16
+
+ // numWating is the number of requests waiting for completion.
+ numWaiting uint32
+
+ // TODO(gvisor.dev/issue/3185): BgQueue
+ // some queue for background queued requests.
+
+ // bgLock protects:
+ // MaxBackground, CongestionThreshold, NumBackground,
+ // NumActiveBackground, BgQueue, Blocked.
+ bgLock sync.Mutex
+
+ // maxRead is the maximum size of a read buffer in in bytes.
+ maxRead uint32
+
+ // maxWrite is the maximum size of a write buffer in bytes.
+ // Negotiated in FUSE_INIT.
+ maxWrite uint32
+
+ // maxPages is the maximum number of pages for a single request to use.
+ // Negotiated in FUSE_INIT.
+ maxPages uint16
+
+ // minor version of the FUSE protocol.
+ // Negotiated and only set in INIT.
+ minor uint32
+
+ // asyncRead if read pages asynchronously.
+ // Negotiated and only set in INIT.
+ asyncRead bool
+
+ // abortErr is true if kernel need to return an unique read error after abort.
+ // Negotiated and only set in INIT.
+ abortErr bool
+
+ // writebackCache is true for write-back cache policy,
+ // false for write-through policy.
+ // Negotiated and only set in INIT.
+ writebackCache bool
+
+ // cacheSymlinks if filesystem needs to cache READLINK responses in page cache.
+ // Negotiated and only set in INIT.
+ cacheSymlinks bool
+
+ // bigWrites if doing multi-page cached writes.
+ // Negotiated and only set in INIT.
+ bigWrites bool
+
+ // dontMask if filestestem does not apply umask to creation modes.
+ // Negotiated in INIT.
+ dontMask bool
+}
+
+// newFUSEConnection creates a FUSE connection to fd.
+func newFUSEConnection(_ context.Context, fd *vfs.FileDescription, maxInFlightRequests uint64) (*connection, error) {
+ // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to
+ // mount a FUSE filesystem.
+ fuseFD := fd.Impl().(*DeviceFD)
+ fuseFD.mounted = true
+
+ // Create the writeBuf for the header to be stored in.
+ hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+ fuseFD.writeBuf = make([]byte, hdrLen)
+ fuseFD.completions = make(map[linux.FUSEOpID]*futureResponse)
+ fuseFD.fullQueueCh = make(chan struct{}, maxInFlightRequests)
+ fuseFD.writeCursor = 0
+
+ return &connection{
+ fd: fuseFD,
+ maxBackground: fuseDefaultMaxBackground,
+ congestionThreshold: fuseDefaultCongestionThreshold,
+ maxPages: fuseDefaultMaxPagesPerReq,
+ initializedChan: make(chan struct{}),
+ connected: true,
+ }, nil
+}
+
+// SetInitialized atomically sets the connection as initialized.
+func (conn *connection) SetInitialized() {
+ // Unblock the requests sent before INIT.
+ close(conn.initializedChan)
+
+ // Close the channel first to avoid the non-atomic situation
+ // where conn.initialized is true but there are
+ // tasks being blocked on the channel.
+ // And it prevents the newer tasks from gaining
+ // unnecessary higher chance to be issued before the blocked one.
+
+ atomic.StoreInt32(&(conn.initialized), int32(1))
+}
+
+// IsInitialized atomically check if the connection is initialized.
+// pairs with SetInitialized().
+func (conn *connection) Initialized() bool {
+ return atomic.LoadInt32(&(conn.initialized)) != 0
+}
+
+// NewRequest creates a new request that can be sent to the FUSE server.
+func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) {
+ conn.fd.mu.Lock()
+ defer conn.fd.mu.Unlock()
+ conn.fd.nextOpID += linux.FUSEOpID(reqIDStep)
+
+ hdrLen := (*linux.FUSEHeaderIn)(nil).SizeBytes()
+ hdr := linux.FUSEHeaderIn{
+ Len: uint32(hdrLen + payload.SizeBytes()),
+ Opcode: opcode,
+ Unique: conn.fd.nextOpID,
+ NodeID: ino,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ PID: pid,
+ }
+
+ buf := make([]byte, hdr.Len)
+ hdr.MarshalUnsafe(buf[:hdrLen])
+ payload.MarshalUnsafe(buf[hdrLen:])
+
+ return &Request{
+ id: hdr.Unique,
+ hdr: &hdr,
+ data: buf,
+ }, nil
+}
+
+// Call makes a request to the server and blocks the invoking task until a
+// server responds with a response. Task should never be nil.
+// Requests will not be sent before the connection is initialized.
+// For async tasks, use CallAsync().
+func (conn *connection) Call(t *kernel.Task, r *Request) (*Response, error) {
+ // Block requests sent before connection is initalized.
+ if !conn.Initialized() {
+ if err := t.Block(conn.initializedChan); err != nil {
+ return nil, err
+ }
+ }
+
+ return conn.call(t, r)
+}
+
+// CallAsync makes an async (aka background) request.
+// Those requests either do not expect a response (e.g. release) or
+// the response should be handled by others (e.g. init).
+// Return immediately unless the connection is blocked (before initialization).
+// Async call example: init, release, forget, aio, interrupt.
+// When the Request is FUSE_INIT, it will not be blocked before initialization.
+func (conn *connection) CallAsync(t *kernel.Task, r *Request) error {
+ // Block requests sent before connection is initalized.
+ if !conn.Initialized() && r.hdr.Opcode != linux.FUSE_INIT {
+ if err := t.Block(conn.initializedChan); err != nil {
+ return err
+ }
+ }
+
+ // This should be the only place that invokes call() with a nil task.
+ _, err := conn.call(nil, r)
+ return err
+}
+
+// call makes a call without blocking checks.
+func (conn *connection) call(t *kernel.Task, r *Request) (*Response, error) {
+ if !conn.connected {
+ return nil, syserror.ENOTCONN
+ }
+
+ if conn.connInitError {
+ return nil, syserror.ECONNREFUSED
+ }
+
+ fut, err := conn.callFuture(t, r)
+ if err != nil {
+ return nil, err
+ }
+
+ return fut.resolve(t)
+}
+
+// Error returns the error of the FUSE call.
+func (r *Response) Error() error {
+ errno := r.hdr.Error
+ if errno >= 0 {
+ return nil
+ }
+
+ sysErrNo := syscall.Errno(-errno)
+ return error(sysErrNo)
+}
+
+// UnmarshalPayload unmarshals the response data into m.
+func (r *Response) UnmarshalPayload(m marshal.Marshallable) error {
+ hdrLen := r.hdr.SizeBytes()
+ haveDataLen := r.hdr.Len - uint32(hdrLen)
+ wantDataLen := uint32(m.SizeBytes())
+
+ if haveDataLen < wantDataLen {
+ return fmt.Errorf("payload too small. Minimum data lenth required: %d, but got data length %d", wantDataLen, haveDataLen)
+ }
+
+ m.UnmarshalUnsafe(r.data[hdrLen:])
+ return nil
+}
+
+// callFuture makes a request to the server and returns a future response.
+// Call resolve() when the response needs to be fulfilled.
+func (conn *connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, error) {
+ conn.fd.mu.Lock()
+ defer conn.fd.mu.Unlock()
+
+ // Is the queue full?
+ //
+ // We must busy wait here until the request can be queued. We don't
+ // block on the fd.fullQueueCh with a lock - so after being signalled,
+ // before we acquire the lock, it is possible that a barging task enters
+ // and queues a request. As a result, upon acquiring the lock we must
+ // again check if the room is available.
+ //
+ // This can potentially starve a request forever but this can only happen
+ // if there are always too many ongoing requests all the time. The
+ // supported maxActiveRequests setting should be really high to avoid this.
+ for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests {
+ if t == nil {
+ // Since there is no task that is waiting. We must error out.
+ return nil, errors.New("FUSE request queue full")
+ }
+
+ log.Infof("Blocking request %v from being queued. Too many active requests: %v",
+ r.id, conn.fd.numActiveRequests)
+ conn.fd.mu.Unlock()
+ err := t.Block(conn.fd.fullQueueCh)
+ conn.fd.mu.Lock()
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return conn.callFutureLocked(t, r)
+}
+
+// callFutureLocked makes a request to the server and returns a future response.
+func (conn *connection) callFutureLocked(t *kernel.Task, r *Request) (*futureResponse, error) {
+ conn.fd.queue.PushBack(r)
+ conn.fd.numActiveRequests += 1
+ fut := newFutureResponse(r.hdr.Opcode)
+ conn.fd.completions[r.id] = fut
+
+ // Signal the readers that there is something to read.
+ conn.fd.waitQueue.Notify(waiter.EventIn)
+
+ return fut, nil
+}
+
+// futureResponse represents an in-flight request, that may or may not have
+// completed yet. Convert it to a resolved Response by calling Resolve, but note
+// that this may block.
+//
+// +stateify savable
+type futureResponse struct {
+ opcode linux.FUSEOpcode
+ ch chan struct{}
+ hdr *linux.FUSEHeaderOut
+ data []byte
+}
+
+// newFutureResponse creates a future response to a FUSE request.
+func newFutureResponse(opcode linux.FUSEOpcode) *futureResponse {
+ return &futureResponse{
+ opcode: opcode,
+ ch: make(chan struct{}),
+ }
+}
+
+// resolve blocks the task until the server responds to its corresponding request,
+// then returns a resolved response.
+func (f *futureResponse) resolve(t *kernel.Task) (*Response, error) {
+ // If there is no Task associated with this request - then we don't try to resolve
+ // the response. Instead, the task writing the response (proxy to the server) will
+ // process the response on our behalf.
+ if t == nil {
+ log.Infof("fuse.Response.resolve: Not waiting on a response from server.")
+ return nil, nil
+ }
+
+ if err := t.Block(f.ch); err != nil {
+ return nil, err
+ }
+
+ return f.getResponse(), nil
+}
+
+// getResponse creates a Response from the data the futureResponse has.
+func (f *futureResponse) getResponse() *Response {
+ return &Response{
+ opcode: f.opcode,
+ hdr: *f.hdr,
+ data: f.data,
+ }
+}
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go
new file mode 100644
index 000000000..2225076bc
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/dev.go
@@ -0,0 +1,397 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const fuseDevMinor = 229
+
+// fuseDevice implements vfs.Device for /dev/fuse.
+type fuseDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (fuseDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ if !kernel.FUSEEnabled {
+ return nil, syserror.ENOENT
+ }
+
+ var fd DeviceFD
+ if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// DeviceFD implements vfs.FileDescriptionImpl for /dev/fuse.
+type DeviceFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ // mounted specifies whether a FUSE filesystem was mounted using the DeviceFD.
+ mounted bool
+
+ // nextOpID is used to create new requests.
+ nextOpID linux.FUSEOpID
+
+ // queue is the list of requests that need to be processed by the FUSE server.
+ queue requestList
+
+ // numActiveRequests is the number of requests made by the Sentry that has
+ // yet to be responded to.
+ numActiveRequests uint64
+
+ // completions is used to map a request to its response. A Writer will use this
+ // to notify the caller of a completed response.
+ completions map[linux.FUSEOpID]*futureResponse
+
+ writeCursor uint32
+
+ // writeBuf is the memory buffer used to copy in the FUSE out header from
+ // userspace.
+ writeBuf []byte
+
+ // writeCursorFR current FR being copied from server.
+ writeCursorFR *futureResponse
+
+ // mu protects all the queues, maps, buffers and cursors and nextOpID.
+ mu sync.Mutex
+
+ // waitQueue is used to notify interested parties when the device becomes
+ // readable or writable.
+ waitQueue waiter.Queue
+
+ // fullQueueCh is a channel used to synchronize the readers with the writers.
+ // Writers (inbound requests to the filesystem) block if there are too many
+ // unprocessed in-flight requests.
+ fullQueueCh chan struct{}
+
+ // fs is the FUSE filesystem that this FD is being used for.
+ fs *filesystem
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *DeviceFD) Release() {
+ fd.fs.conn.connected = false
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
+ return 0, syserror.ENOSYS
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
+ // We require that any Read done on this filesystem have a sane minimum
+ // read buffer. It must have the capacity for the fixed parts of any request
+ // header (Linux uses the request header and the FUSEWriteIn header for this
+ // calculation) + the negotiated MaxWrite room for the data.
+ minBuffSize := linux.FUSE_MIN_READ_BUFFER
+ inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes())
+ writeHdrLen := uint32((*linux.FUSEWriteIn)(nil).SizeBytes())
+ negotiatedMinBuffSize := inHdrLen + writeHdrLen + fd.fs.conn.maxWrite
+ if minBuffSize < negotiatedMinBuffSize {
+ minBuffSize = negotiatedMinBuffSize
+ }
+
+ // If the read buffer is too small, error out.
+ if dst.NumBytes() < int64(minBuffSize) {
+ return 0, syserror.EINVAL
+ }
+
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ return fd.readLocked(ctx, dst, opts)
+}
+
+// readLocked implements the reading of the fuse device while locked with DeviceFD.mu.
+func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ if fd.queue.Empty() {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ var readCursor uint32
+ var bytesRead int64
+ for {
+ req := fd.queue.Front()
+ if dst.NumBytes() < int64(req.hdr.Len) {
+ // The request is too large. Cannot process it. All requests must be smaller than the
+ // negotiated size as specified by Connection.MaxWrite set as part of the FUSE_INIT
+ // handshake.
+ errno := -int32(syscall.EIO)
+ if req.hdr.Opcode == linux.FUSE_SETXATTR {
+ errno = -int32(syscall.E2BIG)
+ }
+
+ // Return the error to the calling task.
+ if err := fd.sendError(ctx, errno, req); err != nil {
+ return 0, err
+ }
+
+ // We're done with this request.
+ fd.queue.Remove(req)
+
+ // Restart the read as this request was invalid.
+ log.Warningf("fuse.DeviceFD.Read: request found was too large. Restarting read.")
+ return fd.readLocked(ctx, dst, opts)
+ }
+
+ n, err := dst.CopyOut(ctx, req.data[readCursor:])
+ if err != nil {
+ return 0, err
+ }
+ readCursor += uint32(n)
+ bytesRead += int64(n)
+
+ if readCursor >= req.hdr.Len {
+ // Fully done with this req, remove it from the queue.
+ fd.queue.Remove(req)
+ break
+ }
+ }
+
+ return bytesRead, nil
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
+ return 0, syserror.ENOSYS
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ return fd.writeLocked(ctx, src, opts)
+}
+
+// writeLocked implements writing to the fuse device while locked with DeviceFD.mu.
+func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
+ var cn, n int64
+ hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+
+ for src.NumBytes() > 0 {
+ if fd.writeCursorFR != nil {
+ // Already have common header, and we're now copying the payload.
+ wantBytes := fd.writeCursorFR.hdr.Len
+
+ // Note that the FR data doesn't have the header. Copy it over if its necessary.
+ if fd.writeCursorFR.data == nil {
+ fd.writeCursorFR.data = make([]byte, wantBytes)
+ }
+
+ bytesCopied, err := src.CopyIn(ctx, fd.writeCursorFR.data[fd.writeCursor:wantBytes])
+ if err != nil {
+ return 0, err
+ }
+ src = src.DropFirst(bytesCopied)
+
+ cn = int64(bytesCopied)
+ n += cn
+ fd.writeCursor += uint32(cn)
+ if fd.writeCursor == wantBytes {
+ // Done reading this full response. Clean up and unblock the
+ // initiator.
+ break
+ }
+
+ // Check if we have more data in src.
+ continue
+ }
+
+ // Assert that the header isn't read into the writeBuf yet.
+ if fd.writeCursor >= hdrLen {
+ return 0, syserror.EINVAL
+ }
+
+ // We don't have the full common response header yet.
+ wantBytes := hdrLen - fd.writeCursor
+ bytesCopied, err := src.CopyIn(ctx, fd.writeBuf[fd.writeCursor:wantBytes])
+ if err != nil {
+ return 0, err
+ }
+ src = src.DropFirst(bytesCopied)
+
+ cn = int64(bytesCopied)
+ n += cn
+ fd.writeCursor += uint32(cn)
+ if fd.writeCursor == hdrLen {
+ // Have full header in the writeBuf. Use it to fetch the actual futureResponse
+ // from the device's completions map.
+ var hdr linux.FUSEHeaderOut
+ hdr.UnmarshalBytes(fd.writeBuf)
+
+ // We have the header now and so the writeBuf has served its purpose.
+ // We could reset it manually here but instead of doing that, at the
+ // end of the write, the writeCursor will be set to 0 thereby allowing
+ // the next request to overwrite whats in the buffer,
+
+ fut, ok := fd.completions[hdr.Unique]
+ if !ok {
+ // Server sent us a response for a request we never sent?
+ return 0, syserror.EINVAL
+ }
+
+ delete(fd.completions, hdr.Unique)
+
+ // Copy over the header into the future response. The rest of the payload
+ // will be copied over to the FR's data in the next iteration.
+ fut.hdr = &hdr
+ fd.writeCursorFR = fut
+
+ // Next iteration will now try read the complete request, if src has
+ // any data remaining. Otherwise we're done.
+ }
+ }
+
+ if fd.writeCursorFR != nil {
+ if err := fd.sendResponse(ctx, fd.writeCursorFR); err != nil {
+ return 0, err
+ }
+
+ // Ready the device for the next request.
+ fd.writeCursorFR = nil
+ fd.writeCursor = 0
+ }
+
+ return n, nil
+}
+
+// Readiness implements vfs.FileDescriptionImpl.Readiness.
+func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ var ready waiter.EventMask
+ ready |= waiter.EventOut // FD is always writable
+ if !fd.queue.Empty() {
+ // Have reqs available, FD is readable.
+ ready |= waiter.EventIn
+ }
+
+ return ready & mask
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (fd *DeviceFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fd.waitQueue.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (fd *DeviceFD) EventUnregister(e *waiter.Entry) {
+ fd.waitQueue.EventUnregister(e)
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
+ return 0, syserror.ENOSYS
+}
+
+// sendResponse sends a response to the waiting task (if any).
+func (fd *DeviceFD) sendResponse(ctx context.Context, fut *futureResponse) error {
+ // See if the running task need to perform some action before returning.
+ // Since we just finished writing the future, we can be sure that
+ // getResponse generates a populated response.
+ if err := fd.noReceiverAction(ctx, fut.getResponse()); err != nil {
+ return err
+ }
+
+ // Signal that the queue is no longer full.
+ select {
+ case fd.fullQueueCh <- struct{}{}:
+ default:
+ }
+ fd.numActiveRequests -= 1
+
+ // Signal the task waiting on a response.
+ close(fut.ch)
+ return nil
+}
+
+// sendError sends an error response to the waiting task (if any).
+func (fd *DeviceFD) sendError(ctx context.Context, errno int32, req *Request) error {
+ // Return the error to the calling task.
+ outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+ respHdr := linux.FUSEHeaderOut{
+ Len: outHdrLen,
+ Error: errno,
+ Unique: req.hdr.Unique,
+ }
+
+ fut, ok := fd.completions[respHdr.Unique]
+ if !ok {
+ // Server sent us a response for a request we never sent?
+ return syserror.EINVAL
+ }
+ delete(fd.completions, respHdr.Unique)
+
+ fut.hdr = &respHdr
+ if err := fd.sendResponse(ctx, fut); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// noReceiverAction has the calling kernel.Task do some action if its known that no
+// receiver is going to be waiting on the future channel. This is to be used by:
+// FUSE_INIT.
+func (fd *DeviceFD) noReceiverAction(ctx context.Context, r *Response) error {
+ if r.opcode == linux.FUSE_INIT {
+ creds := auth.CredentialsFromContext(ctx)
+ rootUserNs := kernel.KernelFromContext(ctx).RootUserNamespace()
+ return fd.fs.conn.InitRecv(r, creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, rootUserNs))
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
new file mode 100644
index 000000000..84c222ad6
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -0,0 +1,428 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "fmt"
+ "io"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+)
+
+// echoTestOpcode is the Opcode used during testing. The server used in tests
+// will simply echo the payload back with the appropriate headers.
+const echoTestOpcode linux.FUSEOpcode = 1000
+
+type testPayload struct {
+ data uint32
+}
+
+// TestFUSECommunication tests that the communication layer between the Sentry and the
+// FUSE server daemon works as expected.
+func TestFUSECommunication(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ k := kernel.KernelFromContext(s.Ctx)
+ creds := auth.CredentialsFromContext(s.Ctx)
+
+ // Create test cases with different number of concurrent clients and servers.
+ testCases := []struct {
+ Name string
+ NumClients int
+ NumServers int
+ MaxActiveRequests uint64
+ }{
+ {
+ Name: "SingleClientSingleServer",
+ NumClients: 1,
+ NumServers: 1,
+ MaxActiveRequests: maxActiveRequestsDefault,
+ },
+ {
+ Name: "SingleClientMultipleServers",
+ NumClients: 1,
+ NumServers: 10,
+ MaxActiveRequests: maxActiveRequestsDefault,
+ },
+ {
+ Name: "MultipleClientsSingleServer",
+ NumClients: 10,
+ NumServers: 1,
+ MaxActiveRequests: maxActiveRequestsDefault,
+ },
+ {
+ Name: "MultipleClientsMultipleServers",
+ NumClients: 10,
+ NumServers: 10,
+ MaxActiveRequests: maxActiveRequestsDefault,
+ },
+ {
+ Name: "RequestCapacityFull",
+ NumClients: 10,
+ NumServers: 1,
+ MaxActiveRequests: 1,
+ },
+ {
+ Name: "RequestCapacityContinuouslyFull",
+ NumClients: 100,
+ NumServers: 2,
+ MaxActiveRequests: 2,
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.Name, func(t *testing.T) {
+ conn, fd, err := newTestConnection(s, k, testCase.MaxActiveRequests)
+ if err != nil {
+ t.Fatalf("newTestConnection: %v", err)
+ }
+
+ clientsDone := make([]chan struct{}, testCase.NumClients)
+ serversDone := make([]chan struct{}, testCase.NumServers)
+ serversKill := make([]chan struct{}, testCase.NumServers)
+
+ // FUSE clients.
+ for i := 0; i < testCase.NumClients; i++ {
+ clientsDone[i] = make(chan struct{})
+ go func(i int) {
+ fuseClientRun(t, s, k, conn, creds, uint32(i), uint64(i), clientsDone[i])
+ }(i)
+ }
+
+ // FUSE servers.
+ for j := 0; j < testCase.NumServers; j++ {
+ serversDone[j] = make(chan struct{})
+ serversKill[j] = make(chan struct{}, 1) // The kill command shouldn't block.
+ go func(j int) {
+ fuseServerRun(t, s, k, fd, serversDone[j], serversKill[j])
+ }(j)
+ }
+
+ // Tear down.
+ //
+ // Make sure all the clients are done.
+ for i := 0; i < testCase.NumClients; i++ {
+ <-clientsDone[i]
+ }
+
+ // Kill any server that is potentially waiting.
+ for j := 0; j < testCase.NumServers; j++ {
+ serversKill[j] <- struct{}{}
+ }
+
+ // Make sure all the servers are done.
+ for j := 0; j < testCase.NumServers; j++ {
+ <-serversDone[j]
+ }
+ })
+ }
+}
+
+// CallTest makes a request to the server and blocks the invoking
+// goroutine until a server responds with a response. Doesn't block
+// a kernel.Task. Analogous to Connection.Call but used for testing.
+func CallTest(conn *connection, t *kernel.Task, r *Request, i uint32) (*Response, error) {
+ conn.fd.mu.Lock()
+
+ // Wait until we're certain that a new request can be processed.
+ for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests {
+ conn.fd.mu.Unlock()
+ select {
+ case <-conn.fd.fullQueueCh:
+ }
+ conn.fd.mu.Lock()
+ }
+
+ fut, err := conn.callFutureLocked(t, r) // No task given.
+ conn.fd.mu.Unlock()
+
+ if err != nil {
+ return nil, err
+ }
+
+ // Resolve the response.
+ //
+ // Block without a task.
+ select {
+ case <-fut.ch:
+ }
+
+ // A response is ready. Resolve and return it.
+ return fut.getResponse(), nil
+}
+
+// ReadTest is analogous to vfs.FileDescription.Read and reads from the FUSE
+// device. However, it does so by - not blocking the task that is calling - and
+// instead just waits on a channel. The behaviour is essentially the same as
+// DeviceFD.Read except it guarantees that the task is not blocked.
+func ReadTest(serverTask *kernel.Task, fd *vfs.FileDescription, inIOseq usermem.IOSequence, killServer chan struct{}) (int64, bool, error) {
+ var err error
+ var n, total int64
+
+ dev := fd.Impl().(*DeviceFD)
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ dev.EventRegister(&w, waiter.EventIn)
+ for {
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err = dev.Read(serverTask, inIOseq, vfs.ReadOptions{})
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ // Emulate the blocking for when no requests are available
+ select {
+ case <-ch:
+ case <-killServer:
+ // Server killed by the main program.
+ return 0, true, nil
+ }
+ }
+
+ dev.EventUnregister(&w)
+ return total, false, err
+}
+
+// fuseClientRun emulates all the actions of a normal FUSE request. It creates
+// a header, a payload, calls the server, waits for the response, and processes
+// the response.
+func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *connection, creds *auth.Credentials, pid uint32, inode uint64, clientDone chan struct{}) {
+ defer func() { clientDone <- struct{}{} }()
+
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ clientTask, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("fuse-client-%v", pid), tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatal(err)
+ }
+ testObj := &testPayload{
+ data: rand.Uint32(),
+ }
+
+ req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
+ if err != nil {
+ t.Fatalf("NewRequest creation failed: %v", err)
+ }
+
+ // Queue up a request.
+ // Analogous to Call except it doesn't block on the task.
+ resp, err := CallTest(conn, clientTask, req, pid)
+ if err != nil {
+ t.Fatalf("CallTaskNonBlock failed: %v", err)
+ }
+
+ if err = resp.Error(); err != nil {
+ t.Fatalf("Server responded with an error: %v", err)
+ }
+
+ var respTestPayload testPayload
+ if err := resp.UnmarshalPayload(&respTestPayload); err != nil {
+ t.Fatalf("Unmarshalling payload error: %v", err)
+ }
+
+ if resp.hdr.Unique != req.hdr.Unique {
+ t.Fatalf("got response for another request. Expected response for req %v but got response for req %v",
+ req.hdr.Unique, resp.hdr.Unique)
+ }
+
+ if respTestPayload.data != testObj.data {
+ t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data)
+ }
+
+}
+
+// fuseServerRun creates a task and emulates all the actions of a simple FUSE server
+// that simply reads a request and echos the same struct back as a response using the
+// appropriate headers.
+func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.FileDescription, serverDone, killServer chan struct{}) {
+ defer func() { serverDone <- struct{}{} }()
+
+ // Create the tasks that the server will be using.
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ var readPayload testPayload
+
+ serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Read the request.
+ for {
+ inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes())
+ payloadLen := uint32(readPayload.SizeBytes())
+
+ // The raed buffer must meet some certain size criteria.
+ buffSize := inHdrLen + payloadLen
+ if buffSize < linux.FUSE_MIN_READ_BUFFER {
+ buffSize = linux.FUSE_MIN_READ_BUFFER
+ }
+ inBuf := make([]byte, buffSize)
+ inIOseq := usermem.BytesIOSequence(inBuf)
+
+ n, serverKilled, err := ReadTest(serverTask, fd, inIOseq, killServer)
+ if err != nil {
+ t.Fatalf("Read failed :%v", err)
+ }
+
+ // Server should shut down. No new requests are going to be made.
+ if serverKilled {
+ break
+ }
+
+ if n <= 0 {
+ t.Fatalf("Read read no bytes")
+ }
+
+ var readFUSEHeaderIn linux.FUSEHeaderIn
+ readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen])
+ readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen])
+
+ if readFUSEHeaderIn.Opcode != echoTestOpcode {
+ t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload)
+ }
+
+ // Write the response.
+ outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+ outBuf := make([]byte, outHdrLen+payloadLen)
+ outHeader := linux.FUSEHeaderOut{
+ Len: outHdrLen + payloadLen,
+ Error: 0,
+ Unique: readFUSEHeaderIn.Unique,
+ }
+
+ // Echo the payload back.
+ outHeader.MarshalUnsafe(outBuf[:outHdrLen])
+ readPayload.MarshalUnsafe(outBuf[outHdrLen:])
+ outIOseq := usermem.BytesIOSequence(outBuf)
+
+ n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("Write failed :%v", err)
+ }
+ }
+}
+
+func setup(t *testing.T) *testutil.System {
+ k, err := testutil.Boot()
+ if err != nil {
+ t.Fatalf("Error creating kernel: %v", err)
+ }
+
+ ctx := k.SupervisorContext()
+ creds := auth.CredentialsFromContext(ctx)
+
+ k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserList: true,
+ AllowUserMount: true,
+ })
+
+ mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("NewMountNamespace(): %v", err)
+ }
+
+ return testutil.NewSystem(ctx, t, k.VFS(), mntns)
+}
+
+// newTestConnection creates a fuse connection that the sentry can communicate with
+// and the FD for the server to communicate with.
+func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*connection, *vfs.FileDescription, error) {
+ vfsObj := &vfs.VirtualFilesystem{}
+ fuseDev := &DeviceFD{}
+
+ if err := vfsObj.Init(); err != nil {
+ return nil, nil, err
+ }
+
+ vd := vfsObj.NewAnonVirtualDentry("genCountFD")
+ defer vd.DecRef()
+ if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, nil, err
+ }
+
+ fsopts := filesystemOptions{
+ maxActiveRequests: maxActiveRequests,
+ }
+ fs, err := NewFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return fs.conn, &fuseDev.vfsfd, nil
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (t *testPayload) SizeBytes() int {
+ return 4
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (t *testPayload) MarshalBytes(dst []byte) {
+ usermem.ByteOrder.PutUint32(dst[:4], t.data)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (t *testPayload) UnmarshalBytes(src []byte) {
+ *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])}
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (t *testPayload) Packed() bool {
+ return true
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (t *testPayload) MarshalUnsafe(dst []byte) {
+ t.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (t *testPayload) UnmarshalUnsafe(src []byte) {
+ t.UnmarshalBytes(src)
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (t *testPayload) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+ panic("not implemented")
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (t *testPayload) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+ panic("not implemented")
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (t *testPayload) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
+ panic("not implemented")
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (t *testPayload) WriteTo(w io.Writer) (int64, error) {
+ panic("not implemented")
+}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
new file mode 100644
index 000000000..200a93bbf
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -0,0 +1,228 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fuse implements fusefs.
+package fuse
+
+import (
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the default filesystem name.
+const Name = "fuse"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+type filesystemOptions struct {
+ // userID specifies the numeric uid of the mount owner.
+ // This option should not be specified by the filesystem owner.
+ // It is set by libfuse (or, if libfuse is not used, must be set
+ // by the filesystem itself). For more information, see man page
+ // for fuse(8)
+ userID uint32
+
+ // groupID specifies the numeric gid of the mount owner.
+ // This option should not be specified by the filesystem owner.
+ // It is set by libfuse (or, if libfuse is not used, must be set
+ // by the filesystem itself). For more information, see man page
+ // for fuse(8)
+ groupID uint32
+
+ // rootMode specifies the the file mode of the filesystem's root.
+ rootMode linux.FileMode
+
+ // maxActiveRequests specifies the maximum number of active requests that can
+ // exist at any time. Any further requests will block when trying to
+ // Call the server.
+ maxActiveRequests uint64
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+ devMinor uint32
+
+ // conn is used for communication between the FUSE server
+ // daemon and the sentry fusefs.
+ conn *connection
+
+ // opts is the options the fusefs is initialized with.
+ opts *filesystemOptions
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ var fsopts filesystemOptions
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ deviceDescriptorStr, ok := mopts["fd"]
+ if !ok {
+ log.Warningf("%s.GetFilesystem: communication file descriptor N (obtained by opening /dev/fuse) must be specified as 'fd=N'", fsType.Name())
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "fd")
+
+ deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ kernelTask := kernel.TaskFromContext(ctx)
+ if kernelTask == nil {
+ log.Warningf("%s.GetFilesystem: couldn't get kernel task from context", fsType.Name())
+ return nil, nil, syserror.EINVAL
+ }
+ fuseFd := kernelTask.GetFileVFS2(int32(deviceDescriptor))
+
+ // Parse and set all the other supported FUSE mount options.
+ // TODO(gVisor.dev/issue/3229): Expand the supported mount options.
+ if userIDStr, ok := mopts["user_id"]; ok {
+ delete(mopts, "user_id")
+ userID, err := strconv.ParseUint(userIDStr, 10, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), userIDStr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.userID = uint32(userID)
+ }
+
+ if groupIDStr, ok := mopts["group_id"]; ok {
+ delete(mopts, "group_id")
+ groupID, err := strconv.ParseUint(groupIDStr, 10, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), groupIDStr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.groupID = uint32(groupID)
+ }
+
+ rootMode := linux.FileMode(0777)
+ modeStr, ok := mopts["rootmode"]
+ if ok {
+ delete(mopts, "rootmode")
+ mode, err := strconv.ParseUint(modeStr, 8, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr)
+ return nil, nil, syserror.EINVAL
+ }
+ rootMode = linux.FileMode(mode)
+ }
+ fsopts.rootMode = rootMode
+
+ // Set the maxInFlightRequests option.
+ fsopts.maxActiveRequests = maxActiveRequestsDefault
+
+ // Check for unparsed options.
+ if len(mopts) != 0 {
+ log.Warningf("%s.GetFilesystem: unknown options: %v", fsType.Name(), mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Create a new FUSE filesystem.
+ fs, err := NewFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd)
+ if err != nil {
+ log.Warningf("%s.NewFUSEFilesystem: failed with error: %v", fsType.Name(), err)
+ return nil, nil, err
+ }
+
+ fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
+
+ // Send a FUSE_INIT request to the FUSE daemon server before returning.
+ // This call is not blocking.
+ if err := fs.conn.InitSend(creds, uint32(kernelTask.ThreadID())); err != nil {
+ log.Warningf("%s.InitSend: failed with error: %v", fsType.Name(), err)
+ return nil, nil, err
+ }
+
+ // root is the fusefs root directory.
+ root := fs.newInode(creds, fsopts.rootMode)
+
+ return fs.VFSFilesystem(), root.VFSDentry(), nil
+}
+
+// NewFUSEFilesystem creates a new FUSE filesystem.
+func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) {
+ fs := &filesystem{
+ devMinor: devMinor,
+ opts: opts,
+ }
+
+ conn, err := newFUSEConnection(ctx, device, opts.maxActiveRequests)
+ if err != nil {
+ log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err)
+ return nil, syserror.EINVAL
+ }
+
+ fs.conn = conn
+ fuseFD := device.Impl().(*DeviceFD)
+ fuseFD.fs = fs
+
+ return fs, nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release()
+}
+
+// inode implements kernfs.Inode.
+type inode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoDynamicLookup
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
+ dentry kernfs.Dentry
+}
+
+func (fs *filesystem) newInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
+ i := &inode{}
+ i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
+ i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ i.dentry.Init(i)
+
+ return &i.dentry
+}
+
+// Open implements kernfs.Inode.Open.
+func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
diff --git a/pkg/sentry/fsimpl/fuse/init.go b/pkg/sentry/fsimpl/fuse/init.go
new file mode 100644
index 000000000..779c2bd3f
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/init.go
@@ -0,0 +1,166 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// consts used by FUSE_INIT negotiation.
+const (
+ // fuseMaxMaxPages is the maximum value for MaxPages received in InitOut.
+ // Follow the same behavior as unix fuse implementation.
+ fuseMaxMaxPages = 256
+
+ // Maximum value for the time granularity for file time stamps, 1s.
+ // Follow the same behavior as unix fuse implementation.
+ fuseMaxTimeGranNs = 1000000000
+
+ // Minimum value for MaxWrite.
+ // Follow the same behavior as unix fuse implementation.
+ fuseMinMaxWrite = 4096
+
+ // Temporary default value for max readahead, 128kb.
+ fuseDefaultMaxReadahead = 131072
+
+ // The FUSE_INIT_IN flags sent to the daemon.
+ // TODO(gvisor.dev/issue/3199): complete the flags.
+ fuseDefaultInitFlags = linux.FUSE_MAX_PAGES
+)
+
+// Adjustable maximums for Connection's cogestion control parameters.
+// Used as the upperbound of the config values.
+// Currently we do not support adjustment to them.
+var (
+ MaxUserBackgroundRequest uint16 = fuseDefaultMaxBackground
+ MaxUserCongestionThreshold uint16 = fuseDefaultCongestionThreshold
+)
+
+// InitSend sends a FUSE_INIT request.
+func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error {
+ in := linux.FUSEInitIn{
+ Major: linux.FUSE_KERNEL_VERSION,
+ Minor: linux.FUSE_KERNEL_MINOR_VERSION,
+ // TODO(gvisor.dev/issue/3196): find appropriate way to calculate this
+ MaxReadahead: fuseDefaultMaxReadahead,
+ Flags: fuseDefaultInitFlags,
+ }
+
+ req, err := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in)
+ if err != nil {
+ return err
+ }
+
+ // Since there is no task to block on and FUSE_INIT is the request
+ // to unblock other requests, use nil.
+ return conn.CallAsync(nil, req)
+}
+
+// InitRecv receives a FUSE_INIT reply and process it.
+func (conn *connection) InitRecv(res *Response, hasSysAdminCap bool) error {
+ if err := res.Error(); err != nil {
+ return err
+ }
+
+ var out linux.FUSEInitOut
+ if err := res.UnmarshalPayload(&out); err != nil {
+ return err
+ }
+
+ return conn.initProcessReply(&out, hasSysAdminCap)
+}
+
+// Process the FUSE_INIT reply from the FUSE server.
+func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap bool) error {
+ // No support for old major fuse versions.
+ if out.Major != linux.FUSE_KERNEL_VERSION {
+ conn.connInitError = true
+
+ // Set the connection as initialized and unblock the blocked requests
+ // (i.e. return error for them).
+ conn.SetInitialized()
+
+ return nil
+ }
+
+ // Start processing the reply.
+ conn.connInitSuccess = true
+ conn.minor = out.Minor
+
+ // No support for limits before minor version 13.
+ if out.Minor >= 13 {
+ conn.bgLock.Lock()
+
+ if out.MaxBackground > 0 {
+ conn.maxBackground = out.MaxBackground
+
+ if !hasSysAdminCap &&
+ conn.maxBackground > MaxUserBackgroundRequest {
+ conn.maxBackground = MaxUserBackgroundRequest
+ }
+ }
+
+ if out.CongestionThreshold > 0 {
+ conn.congestionThreshold = out.CongestionThreshold
+
+ if !hasSysAdminCap &&
+ conn.congestionThreshold > MaxUserCongestionThreshold {
+ conn.congestionThreshold = MaxUserCongestionThreshold
+ }
+ }
+
+ conn.bgLock.Unlock()
+ }
+
+ // No support for the following flags before minor version 6.
+ if out.Minor >= 6 {
+ conn.asyncRead = out.Flags&linux.FUSE_ASYNC_READ != 0
+ conn.bigWrites = out.Flags&linux.FUSE_BIG_WRITES != 0
+ conn.dontMask = out.Flags&linux.FUSE_DONT_MASK != 0
+ conn.writebackCache = out.Flags&linux.FUSE_WRITEBACK_CACHE != 0
+ conn.cacheSymlinks = out.Flags&linux.FUSE_CACHE_SYMLINKS != 0
+ conn.abortErr = out.Flags&linux.FUSE_ABORT_ERROR != 0
+
+ // TODO(gvisor.dev/issue/3195): figure out how to use TimeGran (0 < TimeGran <= fuseMaxTimeGranNs).
+
+ if out.Flags&linux.FUSE_MAX_PAGES != 0 {
+ maxPages := out.MaxPages
+ if maxPages < 1 {
+ maxPages = 1
+ }
+ if maxPages > fuseMaxMaxPages {
+ maxPages = fuseMaxMaxPages
+ }
+ conn.maxPages = maxPages
+ }
+ }
+
+ // No support for negotiating MaxWrite before minor version 5.
+ if out.Minor >= 5 {
+ conn.maxWrite = out.MaxWrite
+ } else {
+ conn.maxWrite = fuseMinMaxWrite
+ }
+ if conn.maxWrite < fuseMinMaxWrite {
+ conn.maxWrite = fuseMinMaxWrite
+ }
+
+ // Set connection as initialized and unblock the requests
+ // issued before init.
+ conn.SetInitialized()
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/fuse/register.go b/pkg/sentry/fsimpl/fuse/register.go
new file mode 100644
index 000000000..b5b581152
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/register.go
@@ -0,0 +1,42 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// Register registers the FUSE device with vfsObj.
+func Register(vfsObj *vfs.VirtualFilesystem) error {
+ if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, fuseDevice{}, &vfs.RegisterDeviceOptions{
+ GroupName: "misc",
+ }); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// CreateDevtmpfsFile creates a device special file in devtmpfs.
+func CreateDevtmpfsFile(ctx context.Context, dev *devtmpfs.Accessor) error {
+ if err := dev.CreateDeviceFile(ctx, "fuse", vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, 0666 /* mode */); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 67e916525..4a800dcf9 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -35,6 +35,7 @@ go_library(
"fstree.go",
"gofer.go",
"handle.go",
+ "host_named_pipe.go",
"p9file.go",
"regular_file.go",
"socket.go",
@@ -47,11 +48,13 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fd",
+ "//pkg/fdnotifier",
"//pkg/fspath",
"//pkg/log",
"//pkg/p9",
"//pkg/safemem",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/host",
"//pkg/sentry/hostfd",
"//pkg/sentry/kernel",
@@ -71,6 +74,7 @@ go_library(
"//pkg/unet",
"//pkg/usermem",
"//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index b98218753..8c7c8e1b3 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -85,6 +85,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
d2 := &dentry{
refs: 1, // held by d
fs: d.fs,
+ ino: d.fs.nextSyntheticIno(),
mode: uint32(opts.mode),
uid: uint32(opts.kuid),
gid: uint32(opts.kgid),
@@ -138,6 +139,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
fd.dirents = ds
}
+ d.InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
if d.cachedMetadataAuthoritative() {
d.touchAtime(fd.vfsfd.Mount())
}
@@ -183,13 +185,13 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
{
Name: ".",
Type: linux.DT_DIR,
- Ino: d.ino,
+ Ino: uint64(d.ino),
NextOff: 1,
},
{
Name: "..",
Type: uint8(atomic.LoadUint32(&parent.mode) >> 12),
- Ino: parent.ino,
+ Ino: uint64(parent.ino),
NextOff: 2,
},
}
@@ -225,7 +227,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
}
dirent := vfs.Dirent{
Name: p9d.Name,
- Ino: p9d.QID.Path,
+ Ino: uint64(inoFromPath(p9d.QID.Path)),
NextOff: int64(len(dirents) + 1),
}
// p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
@@ -258,7 +260,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
dirents = append(dirents, vfs.Dirent{
Name: child.name,
Type: uint8(atomic.LoadUint32(&child.mode) >> 12),
- Ino: child.ino,
+ Ino: uint64(child.ino),
NextOff: int64(len(dirents) + 1),
})
}
@@ -299,3 +301,8 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in
return 0, syserror.EINVAL
}
}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *directoryFD) Sync(ctx context.Context) error {
+ return fd.dentry().handle.sync(ctx)
+}
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 7f2181216..00e3c99cd 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -16,6 +16,7 @@ package gofer
import (
"sync"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -118,7 +119,7 @@ func putDentrySlice(ds *[]*dentry) {
// must be up to date.
//
// Postconditions: The returned dentry's cached metadata is up to date.
-func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
+func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) {
if !d.isDir() {
return nil, syserror.ENOTDIR
}
@@ -149,11 +150,9 @@ afterSymlink:
return nil, err
}
if d != d.parent && !d.cachedMetadataAuthoritative() {
- _, attrMask, attr, err := d.parent.file.getAttr(ctx, dentryAttrMask())
- if err != nil {
+ if err := d.parent.updateFromGetattr(ctx); err != nil {
return nil, err
}
- d.parent.updateFromP9Attrs(attrMask, &attr)
}
rp.Advance()
return d.parent, nil
@@ -168,7 +167,7 @@ afterSymlink:
if err := rp.CheckMount(&child.vfsd); err != nil {
return nil, err
}
- if child.isSymlink() && rp.ShouldFollowSymlink() {
+ if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() {
target, err := child.readlink(ctx, rp.Mount())
if err != nil {
return nil, err
@@ -208,18 +207,28 @@ func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFil
// Preconditions: As for getChildLocked. !parent.isSynthetic().
func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) {
+ if child != nil {
+ // Need to lock child.metadataMu because we might be updating child
+ // metadata. We need to hold the lock *before* getting metadata from the
+ // server and release it after updating local metadata.
+ child.metadataMu.Lock()
+ }
qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
if err != nil && err != syserror.ENOENT {
+ if child != nil {
+ child.metadataMu.Unlock()
+ }
return nil, err
}
if child != nil {
- if !file.isNil() && qid.Path == child.ino {
- // The file at this path hasn't changed. Just update cached
- // metadata.
+ if !file.isNil() && inoFromPath(qid.Path) == child.ino {
+ // The file at this path hasn't changed. Just update cached metadata.
file.close(ctx)
- child.updateFromP9Attrs(attrMask, &attr)
+ child.updateFromP9AttrsLocked(attrMask, &attr)
+ child.metadataMu.Unlock()
return child, nil
}
+ child.metadataMu.Unlock()
if file.isNil() && child.isSynthetic() {
// We have a synthetic file, and no remote file has arisen to
// replace it.
@@ -275,7 +284,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
for !rp.Final() {
d.dirMu.Lock()
- next, err := fs.stepLocked(ctx, rp, d, ds)
+ next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
d.dirMu.Unlock()
if err != nil {
return nil, err
@@ -301,7 +310,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath,
}
for !rp.Done() {
d.dirMu.Lock()
- next, err := fs.stepLocked(ctx, rp, d, ds)
+ next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
d.dirMu.Unlock()
if err != nil {
return nil, err
@@ -371,17 +380,33 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
}
parent.touchCMtime()
parent.dirents = nil
+ ev := linux.IN_CREATE
+ if dir {
+ ev |= linux.IN_ISDIR
+ }
+ parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
return nil
}
if fs.opts.interop == InteropModeShared {
- // The existence of a dentry at name would be inconclusive because the
- // file it represents may have been deleted from the remote filesystem,
- // so we would need to make an RPC to revalidate the dentry. Just
- // attempt the file creation RPC instead. If a file does exist, the RPC
- // will fail with EEXIST like we would have. If the RPC succeeds, and a
- // stale dentry exists, the dentry will fail revalidation next time
- // it's used.
- return createInRemoteDir(parent, name)
+ if child := parent.children[name]; child != nil && child.isSynthetic() {
+ return syserror.EEXIST
+ }
+ // The existence of a non-synthetic dentry at name would be inconclusive
+ // because the file it represents may have been deleted from the remote
+ // filesystem, so we would need to make an RPC to revalidate the dentry.
+ // Just attempt the file creation RPC instead. If a file does exist, the
+ // RPC will fail with EEXIST like we would have. If the RPC succeeds, and a
+ // stale dentry exists, the dentry will fail revalidation next time it's
+ // used.
+ if err := createInRemoteDir(parent, name); err != nil {
+ return err
+ }
+ ev := linux.IN_CREATE
+ if dir {
+ ev |= linux.IN_ISDIR
+ }
+ parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
+ return nil
}
if child := parent.children[name]; child != nil {
return syserror.EEXIST
@@ -397,6 +422,11 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
}
parent.touchCMtime()
parent.dirents = nil
+ ev := linux.IN_CREATE
+ if dir {
+ ev |= linux.IN_ISDIR
+ }
+ parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
return nil
}
@@ -443,21 +473,61 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
defer mntns.DecRef()
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
+
child, ok := parent.children[name]
if ok && child == nil {
return syserror.ENOENT
}
- // We only need a dentry representing the file at name if it can be a mount
- // point. If child is nil, then it can't be a mount point. If child is
- // non-nil but stale, the actual file can't be a mount point either; we
- // detect this case by just speculatively calling PrepareDeleteDentry and
- // only revalidating the dentry if that fails (indicating that the existing
- // dentry is a mount point).
+
+ sticky := atomic.LoadUint32(&parent.mode)&linux.ModeSticky != 0
+ if sticky {
+ if !ok {
+ // If the sticky bit is set, we need to retrieve the child to determine
+ // whether removing it is allowed.
+ child, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ if err != nil {
+ return err
+ }
+ } else if child != nil && !child.cachedMetadataAuthoritative() {
+ // Make sure the dentry representing the file at name is up to date
+ // before examining its metadata.
+ child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds)
+ if err != nil {
+ return err
+ }
+ }
+ if err := parent.mayDelete(rp.Credentials(), child); err != nil {
+ return err
+ }
+ }
+
+ // If a child dentry exists, prepare to delete it. This should fail if it is
+ // a mount point. We detect mount points by speculatively calling
+ // PrepareDeleteDentry, which fails if child is a mount point. However, we
+ // may need to revalidate the file in this case to make sure that it has not
+ // been deleted or replaced on the remote fs, in which case the mount point
+ // will have disappeared. If calling PrepareDeleteDentry fails again on the
+ // up-to-date dentry, we can be sure that it is a mount point.
+ //
+ // Also note that if child is nil, then it can't be a mount point.
if child != nil {
+ // Hold child.dirMu so we can check child.children and
+ // child.syntheticChildren. We don't access these fields until a bit later,
+ // but locking child.dirMu after calling vfs.PrepareDeleteDentry() would
+ // create an inconsistent lock ordering between dentry.dirMu and
+ // vfs.Dentry.mu (in the VFS lock order, it would make dentry.dirMu both "a
+ // FilesystemImpl lock" and "a lock acquired by a FilesystemImpl between
+ // PrepareDeleteDentry and CommitDeleteDentry). To avoid this, lock
+ // child.dirMu before calling PrepareDeleteDentry.
child.dirMu.Lock()
defer child.dirMu.Unlock()
if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
- if parent.cachedMetadataAuthoritative() {
+ // We can skip revalidation in several cases:
+ // - We are not in InteropModeShared
+ // - The parent directory is synthetic, in which case the child must also
+ // be synthetic
+ // - We already updated the child during the sticky bit check above
+ if parent.cachedMetadataAuthoritative() || sticky {
return err
}
child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds)
@@ -518,7 +588,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
if child == nil {
return syserror.ENOENT
}
- } else {
+ } else if child == nil || !child.isSynthetic() {
err = parent.file.unlinkAt(ctx, name, flags)
if err != nil {
if child != nil {
@@ -527,6 +597,18 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
return err
}
}
+
+ // Generate inotify events for rmdir or unlink.
+ if dir {
+ parent.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */)
+ } else {
+ var cw *vfs.Watches
+ if child != nil {
+ cw = &child.watches
+ }
+ vfs.InotifyRemoveChild(cw, &parent.watches, name)
+ }
+
if child != nil {
vfsObj.CommitDeleteDentry(&child.vfsd)
child.setDeleted()
@@ -754,25 +836,27 @@ afterTrailingSymlink:
}
// Determine whether or not we need to create a file.
parent.dirMu.Lock()
- child, err := fs.stepLocked(ctx, rp, parent, &ds)
+ child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
if err == syserror.ENOENT && mayCreate {
if parent.isSynthetic() {
parent.dirMu.Unlock()
return nil, syserror.EPERM
}
- fd, err := parent.createAndOpenChildLocked(ctx, rp, &opts)
+ fd, err := parent.createAndOpenChildLocked(ctx, rp, &opts, &ds)
parent.dirMu.Unlock()
return fd, err
}
+ parent.dirMu.Unlock()
if err != nil {
- parent.dirMu.Unlock()
return nil, err
}
- // Open existing child or follow symlink.
- parent.dirMu.Unlock()
if mustCreate {
return nil, syserror.EEXIST
}
+ if !child.isDir() && rp.MustBeDir() {
+ return nil, syserror.ENOTDIR
+ }
+ // Open existing child or follow symlink.
if child.isSymlink() && rp.ShouldFollowSymlink() {
target, err := child.readlink(ctx, rp.Mount())
if err != nil {
@@ -793,20 +877,32 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
if err := d.checkPermissions(rp.Credentials(), ats); err != nil {
return nil, err
}
+
+ trunc := opts.Flags&linux.O_TRUNC != 0 && d.fileType() == linux.S_IFREG
+ if trunc {
+ // Lock metadataMu *while* we open a regular file with O_TRUNC because
+ // open(2) will change the file size on server.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ }
+
+ var vfd *vfs.FileDescription
+ var err error
mnt := rp.Mount()
switch d.fileType() {
case linux.S_IFREG:
if !d.fs.opts.regularFilesUseSpecialFileFD {
- if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0); err != nil {
+ if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, trunc); err != nil {
return nil, err
}
fd := &regularFileFD{}
+ fd.LockFD.Init(&d.locks)
if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
AllowDirectIO: true,
}); err != nil {
return nil, err
}
- return &fd.vfsfd, nil
+ vfd = &fd.vfsfd
}
case linux.S_IFDIR:
// Can't open directories with O_CREAT.
@@ -826,6 +922,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
}
}
fd := &directoryFD{}
+ fd.LockFD.Init(&d.locks)
if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
@@ -842,10 +939,28 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
}
case linux.S_IFIFO:
if d.isSynthetic() {
- return d.pipe.Open(ctx, mnt, &d.vfsd, opts.Flags)
+ return d.pipe.Open(ctx, mnt, &d.vfsd, opts.Flags, &d.locks)
+ }
+ }
+
+ if vfd == nil {
+ if vfd, err = d.openSpecialFileLocked(ctx, mnt, opts); err != nil {
+ return nil, err
+ }
+ }
+
+ if trunc {
+ // If no errors occured so far then update file size in memory. This
+ // step is required even if !d.cachedMetadataAuthoritative() because
+ // d.mappings has to be updated.
+ // d.metadataMu has already been acquired if trunc == true.
+ d.updateFileSizeLocked(0)
+
+ if d.cachedMetadataAuthoritative() {
+ d.touchCMtimeLocked()
}
}
- return d.openSpecialFileLocked(ctx, mnt, opts)
+ return vfd, err
}
func (d *dentry) connectSocketLocked(ctx context.Context, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
@@ -873,19 +988,37 @@ func (d *dentry) openSpecialFileLocked(ctx context.Context, mnt *vfs.Mount, opts
if opts.Flags&linux.O_DIRECT != 0 {
return nil, syserror.EINVAL
}
- h, err := openHandle(ctx, d.file, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0)
+ // We assume that the server silently inserts O_NONBLOCK in the open flags
+ // for all named pipes (because all existing gofers do this).
+ //
+ // NOTE(b/133875563): This makes named pipe opens racy, because the
+ // mechanisms for translating nonblocking to blocking opens can only detect
+ // the instantaneous presence of a peer holding the other end of the pipe
+ // open, not whether the pipe was *previously* opened by a peer that has
+ // since closed its end.
+ isBlockingOpenOfNamedPipe := d.fileType() == linux.S_IFIFO && opts.Flags&linux.O_NONBLOCK == 0
+retry:
+ h, err := openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0)
if err != nil {
+ if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && err == syserror.ENXIO {
+ // An attempt to open a named pipe with O_WRONLY|O_NONBLOCK fails
+ // with ENXIO if opening the same named pipe with O_WRONLY would
+ // block because there are no readers of the pipe.
+ if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil {
+ return nil, err
+ }
+ goto retry
+ }
return nil, err
}
- seekable := d.fileType() == linux.S_IFREG
- fd := &specialFileFD{
- handle: h,
- seekable: seekable,
+ if isBlockingOpenOfNamedPipe && ats == vfs.MayRead && h.fd >= 0 {
+ if err := blockUntilNonblockingPipeHasWriter(ctx, h.fd); err != nil {
+ h.close(ctx)
+ return nil, err
+ }
}
- if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
- DenyPRead: !seekable,
- DenyPWrite: !seekable,
- }); err != nil {
+ fd, err := newSpecialFileFD(h, mnt, d, &d.locks, opts.Flags)
+ if err != nil {
h.close(ctx)
return nil, err
}
@@ -894,7 +1027,7 @@ func (d *dentry) openSpecialFileLocked(ctx context.Context, mnt *vfs.Mount, opts
// Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked.
// !d.isSynthetic().
-func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) {
if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
return nil, err
}
@@ -919,7 +1052,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
// Filter file creation flags and O_LARGEFILE out; the create RPC already
// has the semantics of O_CREAT|O_EXCL, while some servers will choke on
// O_LARGEFILE.
- createFlags := p9.OpenFlags(opts.Flags &^ (linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_LARGEFILE))
+ createFlags := p9.OpenFlags(opts.Flags &^ (vfs.FileCreationFlags | linux.O_LARGEFILE))
fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
if err != nil {
dirfile.close(ctx)
@@ -947,6 +1080,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
}
return nil, err
}
+ *ds = appendDentry(*ds, child)
// Incorporate the fid that was opened by lcreate.
useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD
if useRegularFileFD {
@@ -959,10 +1093,6 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
child.handleWritable = vfs.MayWriteFileWithOpenFlags(opts.Flags)
child.handleMu.Unlock()
}
- // Take a reference on the new dentry to be held by the new file
- // description. (This reference also means that the new dentry is not
- // eligible for caching yet, so we don't need to append to a dentry slice.)
- child.refs = 1
// Insert the dentry into the tree.
d.cacheNewChildLocked(child, name)
if d.cachedMetadataAuthoritative() {
@@ -974,6 +1104,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
var childVFSFD *vfs.FileDescription
if useRegularFileFD {
fd := &regularFileFD{}
+ fd.LockFD.Init(&child.locks)
if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{
AllowDirectIO: true,
}); err != nil {
@@ -981,26 +1112,21 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
}
childVFSFD = &fd.vfsfd
} else {
- seekable := child.fileType() == linux.S_IFREG
- fd := &specialFileFD{
- handle: handle{
- file: openFile,
- fd: -1,
- },
- seekable: seekable,
+ h := handle{
+ file: openFile,
+ fd: -1,
}
if fdobj != nil {
- fd.handle.fd = int32(fdobj.Release())
+ h.fd = int32(fdobj.Release())
}
- if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{
- DenyPRead: !seekable,
- DenyPWrite: !seekable,
- }); err != nil {
- fd.handle.close(ctx)
+ fd, err := newSpecialFileFD(h, mnt, child, &d.locks, opts.Flags)
+ if err != nil {
+ h.close(ctx)
return nil, err
}
childVFSFD = &fd.vfsfd
}
+ d.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */)
return childVFSFD, nil
}
@@ -1052,7 +1178,8 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
return err
}
}
- if err := oldParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ creds := rp.Credentials()
+ if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
vfsObj := rp.VirtualFilesystem()
@@ -1067,12 +1194,15 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if renamed == nil {
return syserror.ENOENT
}
+ if err := oldParent.mayDelete(creds, renamed); err != nil {
+ return err
+ }
if renamed.isDir() {
if renamed == newParent || genericIsAncestorDentry(renamed, newParent) {
return syserror.EINVAL
}
if oldParent != newParent {
- if err := renamed.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ if err := renamed.checkPermissions(creds, vfs.MayWrite); err != nil {
return err
}
}
@@ -1083,7 +1213,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
if oldParent != newParent {
- if err := newParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ if err := newParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
newParent.dirMu.Lock()
@@ -1181,10 +1311,12 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if newParent.cachedMetadataAuthoritative() {
newParent.dirents = nil
newParent.touchCMtime()
- if renamed.isDir() {
+ if renamed.isDir() && (replaced == nil || !replaced.isDir()) {
+ // Increase the link count if we did not replace another directory.
newParent.incLinks()
}
}
+ vfs.InotifyRename(ctx, &renamed.watches, &oldParent.watches, &newParent.watches, oldName, newName, renamed.isDir())
return nil
}
@@ -1197,12 +1329,21 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
+ fs.renameMuRUnlockAndCheckCaching(&ds)
return err
}
- return d.setStat(ctx, rp.Credentials(), &opts.Stat, rp.Mount())
+ if err := d.setStat(ctx, rp.Credentials(), &opts, rp.Mount()); err != nil {
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+ return err
+ }
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+
+ if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ d.InotifyWithParent(ev, 0, vfs.InodeEvent)
+ }
+ return nil
}
// StatAt implements vfs.FilesystemImpl.StatAt.
@@ -1326,24 +1467,38 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
+ fs.renameMuRUnlockAndCheckCaching(&ds)
return err
}
- return d.setxattr(ctx, rp.Credentials(), &opts)
+ if err := d.setxattr(ctx, rp.Credentials(), &opts); err != nil {
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+ return err
+ }
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
}
// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckCaching(&ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
+ fs.renameMuRUnlockAndCheckCaching(&ds)
return err
}
- return d.removexattr(ctx, rp.Credentials(), name)
+ if err := d.removexattr(ctx, rp.Credentials(), name); err != nil {
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+ return err
+ }
+ fs.renameMuRUnlockAndCheckCaching(&ds)
+
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
}
// PrependPath implements vfs.FilesystemImpl.PrependPath.
@@ -1352,3 +1507,7 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
defer fs.renameMu.RUnlock()
return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
}
+
+func (fs *filesystem) nextSyntheticIno() inodeNumber {
+ return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask)
+}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 6295f6b54..e20de84b5 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -45,6 +45,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -84,12 +85,6 @@ type filesystem struct {
// devMinor is the filesystem's minor device number. devMinor is immutable.
devMinor uint32
- // uid and gid are the effective KUID and KGID of the filesystem's creator,
- // and are used as the owner and group for files that don't specify one.
- // uid and gid are immutable.
- uid auth.KUID
- gid auth.KGID
-
// renameMu serves two purposes:
//
// - It synchronizes path resolution with renaming initiated by this
@@ -115,6 +110,26 @@ type filesystem struct {
syncMu sync.Mutex
syncableDentries map[*dentry]struct{}
specialFileFDs map[*specialFileFD]struct{}
+
+ // syntheticSeq stores a counter to used to generate unique inodeNumber for
+ // synthetic dentries.
+ syntheticSeq uint64
+}
+
+// inodeNumber represents inode number reported in Dirent.Ino. For regular
+// dentries, it comes from QID.Path from the 9P server. Synthetic dentries
+// have have their inodeNumber generated sequentially, with the MSB reserved to
+// prevent conflicts with regular dentries.
+type inodeNumber uint64
+
+// Reserve MSB for synthetic mounts.
+const syntheticInoMask = uint64(1) << 63
+
+func inoFromPath(path uint64) inodeNumber {
+ if path&syntheticInoMask != 0 {
+ log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask)
+ }
+ return inodeNumber(path &^ syntheticInoMask)
}
type filesystemOptions struct {
@@ -122,6 +137,8 @@ type filesystemOptions struct {
fd int
aname string
interop InteropMode // derived from the "cache" mount option
+ dfltuid auth.KUID
+ dfltgid auth.KGID
msize uint32
version string
@@ -230,6 +247,15 @@ type InternalFilesystemOptions struct {
OpenSocketsByConnecting bool
}
+// _V9FS_DEFUID and _V9FS_DEFGID (from Linux's fs/9p/v9fs.h) are the default
+// UIDs and GIDs used for files that do not provide a specific owner or group
+// respectively.
+const (
+ // uint32(-2) doesn't work in Go.
+ _V9FS_DEFUID = auth.KUID(4294967294)
+ _V9FS_DEFGID = auth.KGID(4294967294)
+)
+
// Name implements vfs.FilesystemType.Name.
func (FilesystemType) Name() string {
return Name
@@ -315,6 +341,31 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
}
+ // Parse the default UID and GID.
+ fsopts.dfltuid = _V9FS_DEFUID
+ if dfltuidstr, ok := mopts["dfltuid"]; ok {
+ delete(mopts, "dfltuid")
+ dfltuid, err := strconv.ParseUint(dfltuidstr, 10, 32)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltuid=%s", dfltuidstr)
+ return nil, nil, syserror.EINVAL
+ }
+ // In Linux, dfltuid is interpreted as a UID and is converted to a KUID
+ // in the caller's user namespace, but goferfs isn't
+ // application-mountable.
+ fsopts.dfltuid = auth.KUID(dfltuid)
+ }
+ fsopts.dfltgid = _V9FS_DEFGID
+ if dfltgidstr, ok := mopts["dfltgid"]; ok {
+ delete(mopts, "dfltgid")
+ dfltgid, err := strconv.ParseUint(dfltgidstr, 10, 32)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltgid=%s", dfltgidstr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.dfltgid = auth.KGID(dfltgid)
+ }
+
// Parse the 9P message size.
fsopts.msize = 1024 * 1024 // 1M, tested to give good enough performance up to 64M
if msizestr, ok := mopts["msize"]; ok {
@@ -422,8 +473,6 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
client: client,
clock: ktime.RealtimeClockFromContext(ctx),
devMinor: devMinor,
- uid: creds.EffectiveKUID,
- gid: creds.EffectiveKGID,
syncableDentries: make(map[*dentry]struct{}),
specialFileFDs: make(map[*specialFileFD]struct{}),
}
@@ -553,21 +602,27 @@ type dentry struct {
// returned by the server. dirents is protected by dirMu.
dirents []vfs.Dirent
- // Cached metadata; protected by metadataMu and accessed using atomic
- // memory operations unless otherwise specified.
+ // Cached metadata; protected by metadataMu.
+ // To access:
+ // - In situations where consistency is not required (like stat), these
+ // can be accessed using atomic operations only (without locking).
+ // - Lock metadataMu and can access without atomic operations.
+ // To mutate:
+ // - Lock metadataMu and use atomic operations to update because we might
+ // have atomic readers that don't hold the lock.
metadataMu sync.Mutex
- ino uint64 // immutable
- mode uint32 // type is immutable, perms are mutable
- uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
- gid uint32 // auth.KGID, but ...
- blockSize uint32 // 0 if unknown
+ ino inodeNumber // immutable
+ mode uint32 // type is immutable, perms are mutable
+ uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
+ gid uint32 // auth.KGID, but ...
+ blockSize uint32 // 0 if unknown
// Timestamps, all nsecs from the Unix epoch.
atime int64
mtime int64
ctime int64
btime int64
// File size, protected by both metadataMu and dataMu (i.e. both must be
- // locked to mutate it).
+ // locked to mutate it; locking either is sufficient to access it).
size uint64
// nlink counts the number of hard links to this dentry. It's updated and
@@ -634,6 +689,11 @@ type dentry struct {
// If this dentry represents a synthetic named pipe, pipe is the pipe
// endpoint bound to this file.
pipe *pipe.VFSPipe
+
+ locks vfs.FileLocks
+
+ // Inotify watches for this dentry.
+ watches vfs.Watches
}
// dentryAttrMask returns a p9.AttrMask enabling all attributes used by the
@@ -670,10 +730,10 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
d := &dentry{
fs: fs,
file: file,
- ino: qid.Path,
+ ino: inoFromPath(qid.Path),
mode: uint32(attr.Mode),
- uid: uint32(fs.uid),
- gid: uint32(fs.gid),
+ uid: uint32(fs.opts.dfltuid),
+ gid: uint32(fs.opts.dfltgid),
blockSize: usermem.PageSize,
handle: handle{
fd: -1,
@@ -725,8 +785,8 @@ func (d *dentry) cachedMetadataAuthoritative() bool {
// updateFromP9Attrs is called to update d's metadata after an update from the
// remote filesystem.
-func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) {
- d.metadataMu.Lock()
+// Precondition: d.metadataMu must be locked.
+func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
if mask.Mode {
if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want {
d.metadataMu.Unlock()
@@ -760,11 +820,8 @@ func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) {
atomic.StoreUint32(&d.nlink, uint32(attr.NLink))
}
if mask.Size {
- d.dataMu.Lock()
- atomic.StoreUint64(&d.size, attr.Size)
- d.dataMu.Unlock()
+ d.updateFileSizeLocked(attr.Size)
}
- d.metadataMu.Unlock()
}
// Preconditions: !d.isSynthetic()
@@ -776,6 +833,10 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
file p9file
handleMuRLocked bool
)
+ // d.metadataMu must be locked *before* we getAttr so that we do not end up
+ // updating stale attributes in d.updateFromP9AttrsLocked().
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
d.handleMu.RLock()
if !d.handle.file.isNil() {
file = d.handle.file
@@ -791,7 +852,7 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
if err != nil {
return err
}
- d.updateFromP9Attrs(attrMask, &attr)
+ d.updateFromP9AttrsLocked(attrMask, &attr)
return nil
}
@@ -803,10 +864,18 @@ func (d *dentry) statTo(stat *linux.Statx) {
stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME
stat.Blksize = atomic.LoadUint32(&d.blockSize)
stat.Nlink = atomic.LoadUint32(&d.nlink)
+ if stat.Nlink == 0 {
+ // The remote filesystem doesn't support link count; just make
+ // something up. This is consistent with Linux, where
+ // fs/inode.c:inode_init_always() initializes link count to 1, and
+ // fs/9p/vfs_inode_dotl.c:v9fs_stat2inode_dotl() doesn't touch it if
+ // it's not provided by the remote filesystem.
+ stat.Nlink = 1
+ }
stat.UID = atomic.LoadUint32(&d.uid)
stat.GID = atomic.LoadUint32(&d.gid)
stat.Mode = uint16(atomic.LoadUint32(&d.mode))
- stat.Ino = d.ino
+ stat.Ino = uint64(d.ino)
stat.Size = atomic.LoadUint64(&d.size)
// This is consistent with regularFileFD.Seek(), which treats regular files
// as having no holes.
@@ -819,7 +888,8 @@ func (d *dentry) statTo(stat *linux.Statx) {
stat.DevMinor = d.fs.devMinor
}
-func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mnt *vfs.Mount) error {
+func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions, mnt *vfs.Mount) error {
+ stat := &opts.Stat
if stat.Mask == 0 {
return nil
}
@@ -827,7 +897,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
return syserror.EPERM
}
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
- if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
return err
}
if err := mnt.CheckBeginWrite(); err != nil {
@@ -844,14 +914,14 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
// Prepare for truncate.
if stat.Mask&linux.STATX_SIZE != 0 {
- switch d.mode & linux.S_IFMT {
- case linux.S_IFREG:
+ switch mode.FileType() {
+ case linux.ModeRegular:
if !setLocalMtime {
// Truncate updates mtime.
setLocalMtime = true
stat.Mtime.Nsec = linux.UTIME_NOW
}
- case linux.S_IFDIR:
+ case linux.ModeDirectory:
return syserror.EISDIR
default:
return syserror.EINVAL
@@ -860,8 +930,25 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
}
d.metadataMu.Lock()
defer d.metadataMu.Unlock()
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ // The size needs to be changed even when
+ // !d.cachedMetadataAuthoritative() because d.mappings has to be
+ // updated.
+ d.updateFileSizeLocked(stat.Size)
+ }
if !d.isSynthetic() {
if stat.Mask != 0 {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ // Check whether to allow a truncate request to be made.
+ switch d.mode & linux.S_IFMT {
+ case linux.S_IFREG:
+ // Allow.
+ case linux.S_IFDIR:
+ return syserror.EISDIR
+ default:
+ return syserror.EINVAL
+ }
+ }
if err := d.file.setAttr(ctx, p9.SetAttrMask{
Permissions: stat.Mask&linux.STATX_MODE != 0,
UID: stat.Mask&linux.STATX_UID != 0,
@@ -908,6 +995,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
} else {
atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime))
}
+ // Restore mask bits that we cleared earlier.
+ stat.Mask |= linux.STATX_ATIME
}
if setLocalMtime {
if stat.Mtime.Nsec == linux.UTIME_NOW {
@@ -915,48 +1004,56 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
} else {
atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime))
}
+ // Restore mask bits that we cleared earlier.
+ stat.Mask |= linux.STATX_MTIME
}
atomic.StoreInt64(&d.ctime, now)
- if stat.Mask&linux.STATX_SIZE != 0 {
+ return nil
+}
+
+// Preconditions: d.metadataMu must be locked.
+func (d *dentry) updateFileSizeLocked(newSize uint64) {
+ d.dataMu.Lock()
+ oldSize := d.size
+ atomic.StoreUint64(&d.size, newSize)
+ // d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings
+ // below. This allows concurrent calls to Read/Translate/etc. These
+ // functions synchronize with truncation by refusing to use cache
+ // contents beyond the new d.size. (We are still holding d.metadataMu,
+ // so we can't race with Write or another truncate.)
+ d.dataMu.Unlock()
+ if d.size < oldSize {
+ oldpgend, _ := usermem.PageRoundUp(oldSize)
+ newpgend, _ := usermem.PageRoundUp(d.size)
+ if oldpgend != newpgend {
+ d.mapsMu.Lock()
+ d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/truncate.c:truncate_setsize() =>
+ // truncate_pagecache() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ d.mapsMu.Unlock()
+ }
+ // We are now guaranteed that there are no translations of
+ // truncated pages, and can remove them from the cache. Since
+ // truncated pages have been removed from the remote file, they
+ // should be dropped without being written back.
d.dataMu.Lock()
- oldSize := d.size
- d.size = stat.Size
- // d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings
- // below. This allows concurrent calls to Read/Translate/etc. These
- // functions synchronize with truncation by refusing to use cache
- // contents beyond the new d.size. (We are still holding d.metadataMu,
- // so we can't race with Write or another truncate.)
+ d.cache.Truncate(d.size, d.fs.mfp.MemoryFile())
+ d.dirty.KeepClean(memmap.MappableRange{d.size, oldpgend})
d.dataMu.Unlock()
- if d.size < oldSize {
- oldpgend, _ := usermem.PageRoundUp(oldSize)
- newpgend, _ := usermem.PageRoundUp(d.size)
- if oldpgend != newpgend {
- d.mapsMu.Lock()
- d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
- // Compare Linux's mm/truncate.c:truncate_setsize() =>
- // truncate_pagecache() =>
- // mm/memory.c:unmap_mapping_range(evencows=1).
- InvalidatePrivate: true,
- })
- d.mapsMu.Unlock()
- }
- // We are now guaranteed that there are no translations of
- // truncated pages, and can remove them from the cache. Since
- // truncated pages have been removed from the remote file, they
- // should be dropped without being written back.
- d.dataMu.Lock()
- d.cache.Truncate(d.size, d.fs.mfp.MemoryFile())
- d.dirty.KeepClean(memmap.MappableRange{d.size, oldpgend})
- d.dataMu.Unlock()
- }
}
- return nil
}
func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
}
+func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error {
+ return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&child.uid)))
+}
+
func dentryUIDFromP9UID(uid p9.UID) uint32 {
if !uid.Ok() {
return uint32(auth.OverflowUID)
@@ -1011,6 +1108,37 @@ func (d *dentry) decRefLocked() {
}
}
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {
+ if d.isDir() {
+ events |= linux.IN_ISDIR
+ }
+
+ d.fs.renameMu.RLock()
+ // The ordering below is important, Linux always notifies the parent first.
+ if d.parent != nil {
+ d.parent.watches.Notify(d.name, events, cookie, et, d.isDeleted())
+ }
+ d.watches.Notify("", events, cookie, et, d.isDeleted())
+ d.fs.renameMu.RUnlock()
+}
+
+// Watches implements vfs.DentryImpl.Watches.
+func (d *dentry) Watches() *vfs.Watches {
+ return &d.watches
+}
+
+// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches.
+//
+// If no watches are left on this dentry and it has no references, cache it.
+func (d *dentry) OnZeroWatches() {
+ if atomic.LoadInt64(&d.refs) == 0 {
+ d.fs.renameMu.Lock()
+ d.checkCachingLocked()
+ d.fs.renameMu.Unlock()
+ }
+}
+
// checkCachingLocked should be called after d's reference count becomes 0 or it
// becomes disowned.
//
@@ -1042,6 +1170,9 @@ func (d *dentry) checkCachingLocked() {
// Deleted and invalidated dentries with zero references are no longer
// reachable by path resolution and should be dropped immediately.
if d.vfsd.IsDead() {
+ if d.isDeleted() {
+ d.watches.HandleDeletion()
+ }
if d.cached {
d.fs.cachedDentries.Remove(d)
d.fs.cachedDentriesLen--
@@ -1050,6 +1181,14 @@ func (d *dentry) checkCachingLocked() {
d.destroyLocked()
return
}
+ // If d still has inotify watches and it is not deleted or invalidated, we
+ // cannot cache it and allow it to be evicted. Otherwise, we will lose its
+ // watches, even if a new dentry is created for the same file in the future.
+ // Note that the size of d.watches cannot concurrently transition from zero
+ // to non-zero, because adding a watch requires holding a reference on d.
+ if d.watches.Size() > 0 {
+ return
+ }
// If d is already cached, just move it to the front of the LRU.
if d.cached {
d.fs.cachedDentries.Remove(d)
@@ -1155,7 +1294,7 @@ func (d *dentry) setDeleted() {
// We only support xattrs prefixed with "user." (see b/148380782). Currently,
// there is no need to expose any other xattrs through a gofer.
func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) {
- if d.file.isNil() {
+ if d.file.isNil() || !d.userXattrSupported() {
return nil, nil
}
xattrMap, err := d.file.listXattr(ctx, size)
@@ -1181,6 +1320,9 @@ func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vf
if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
return "", syserror.EOPNOTSUPP
}
+ if !d.userXattrSupported() {
+ return "", syserror.ENODATA
+ }
return d.file.getXattr(ctx, opts.Name, opts.Size)
}
@@ -1194,6 +1336,9 @@ func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vf
if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
return syserror.EOPNOTSUPP
}
+ if !d.userXattrSupported() {
+ return syserror.EPERM
+ }
return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags)
}
@@ -1207,10 +1352,20 @@ func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name
if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
return syserror.EOPNOTSUPP
}
+ if !d.userXattrSupported() {
+ return syserror.EPERM
+ }
return d.file.removeXattr(ctx, name)
}
-// Preconditions: !d.isSynthetic(). d.isRegularFile() || d.isDirectory().
+// Extended attributes in the user.* namespace are only supported for regular
+// files and directories.
+func (d *dentry) userXattrSupported() bool {
+ filetype := linux.FileMode(atomic.LoadUint32(&d.mode)).FileType()
+ return filetype == linux.ModeRegular || filetype == linux.ModeDirectory
+}
+
+// Preconditions: !d.isSynthetic(). d.isRegularFile() || d.isDir().
func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool) error {
// O_TRUNC unconditionally requires us to obtain a new handle (opened with
// O_TRUNC).
@@ -1302,23 +1457,21 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
}
// incLinks increments link count.
-//
-// Preconditions: d.nlink != 0 && d.nlink < math.MaxUint32.
func (d *dentry) incLinks() {
- v := atomic.AddUint32(&d.nlink, 1)
- if v < 2 {
- panic(fmt.Sprintf("dentry.nlink is invalid (was 0 or overflowed): %d", v))
+ if atomic.LoadUint32(&d.nlink) == 0 {
+ // The remote filesystem doesn't support link count.
+ return
}
+ atomic.AddUint32(&d.nlink, 1)
}
// decLinks decrements link count.
-//
-// Preconditions: d.nlink > 1.
func (d *dentry) decLinks() {
- v := atomic.AddUint32(&d.nlink, ^uint32(0))
- if v == 0 {
- panic(fmt.Sprintf("dentry.nlink must be greater than 0: %d", v))
+ if atomic.LoadUint32(&d.nlink) == 0 {
+ // The remote filesystem doesn't support link count.
+ return
}
+ atomic.AddUint32(&d.nlink, ^uint32(0))
}
// fileDescription is embedded by gofer implementations of
@@ -1326,6 +1479,9 @@ func (d *dentry) decLinks() {
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ lockLogging sync.Once
}
func (fd *fileDescription) filesystem() *filesystem {
@@ -1354,7 +1510,13 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
- return fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, fd.vfsfd.Mount())
+ if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts, fd.vfsfd.Mount()); err != nil {
+ return err
+ }
+ if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ fd.dentry().InotifyWithParent(ev, 0, vfs.InodeEvent)
+ }
+ return nil
}
// Listxattr implements vfs.FileDescriptionImpl.Listxattr.
@@ -1369,10 +1531,41 @@ func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOption
// Setxattr implements vfs.FileDescriptionImpl.Setxattr.
func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error {
- return fd.dentry().setxattr(ctx, auth.CredentialsFromContext(ctx), &opts)
+ d := fd.dentry()
+ if err := d.setxattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil {
+ return err
+ }
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
}
// Removexattr implements vfs.FileDescriptionImpl.Removexattr.
func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
- return fd.dentry().removexattr(ctx, auth.CredentialsFromContext(ctx), name)
+ d := fd.dentry()
+ if err := d.removexattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil {
+ return err
+ }
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ fd.lockLogging.Do(func() {
+ log.Infof("File lock using gofer file handled internally.")
+ })
+ return fd.LockFD.LockBSD(ctx, uid, t, block)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ fd.lockLogging.Do(func() {
+ log.Infof("Range lock using gofer file handled internally.")
+ })
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
}
diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go
index 724a3f1f7..8792ca4f2 100644
--- a/pkg/sentry/fsimpl/gofer/handle.go
+++ b/pkg/sentry/fsimpl/gofer/handle.go
@@ -126,11 +126,16 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o
}
func (h *handle) sync(ctx context.Context) error {
+ // Handle most common case first.
if h.fd >= 0 {
ctx.UninterruptibleSleepStart(false)
err := syscall.Fsync(int(h.fd))
ctx.UninterruptibleSleepFinish(false)
return err
}
+ if h.file.isNil() {
+ // File hasn't been touched, there is nothing to sync.
+ return nil
+ }
return h.file.fsync(ctx)
}
diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
new file mode 100644
index 000000000..7294de7d6
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
@@ -0,0 +1,97 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Global pipe used by blockUntilNonblockingPipeHasWriter since we can't create
+// pipes after sentry initialization due to syscall filters.
+var (
+ tempPipeMu sync.Mutex
+ tempPipeReadFD int
+ tempPipeWriteFD int
+ tempPipeBuf [1]byte
+)
+
+func init() {
+ var pipeFDs [2]int
+ if err := unix.Pipe(pipeFDs[:]); err != nil {
+ panic(fmt.Sprintf("failed to create pipe for gofer.blockUntilNonblockingPipeHasWriter: %v", err))
+ }
+ tempPipeReadFD = pipeFDs[0]
+ tempPipeWriteFD = pipeFDs[1]
+}
+
+func blockUntilNonblockingPipeHasWriter(ctx context.Context, fd int32) error {
+ for {
+ ok, err := nonblockingPipeHasWriter(fd)
+ if err != nil {
+ return err
+ }
+ if ok {
+ return nil
+ }
+ if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil {
+ return err
+ }
+ }
+}
+
+func nonblockingPipeHasWriter(fd int32) (bool, error) {
+ tempPipeMu.Lock()
+ defer tempPipeMu.Unlock()
+ // Copy 1 byte from fd into the temporary pipe.
+ n, err := unix.Tee(int(fd), tempPipeWriteFD, 1, unix.SPLICE_F_NONBLOCK)
+ if err == syserror.EAGAIN {
+ // The pipe represented by fd is empty, but has a writer.
+ return true, nil
+ }
+ if err != nil {
+ return false, err
+ }
+ if n == 0 {
+ // The pipe represented by fd is empty and has no writer.
+ return false, nil
+ }
+ // The pipe represented by fd is non-empty, so it either has, or has
+ // previously had, a writer. Remove the byte copied to the temporary pipe
+ // before returning.
+ if n, err := unix.Read(tempPipeReadFD, tempPipeBuf[:]); err != nil || n != 1 {
+ panic(fmt.Sprintf("failed to drain pipe for gofer.blockUntilNonblockingPipeHasWriter: got (%d, %v), wanted (1, nil)", n, err))
+ }
+ return true, nil
+}
+
+func sleepBetweenNamedPipeOpenChecks(ctx context.Context) error {
+ t := time.NewTimer(100 * time.Millisecond)
+ defer t.Stop()
+ cancel := ctx.SleepStart()
+ select {
+ case <-t.C:
+ ctx.SleepFinish(true)
+ return nil
+ case <-cancel:
+ ctx.SleepFinish(false)
+ return syserror.ErrInterrupted
+ }
+}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 0d10cf7ac..09f142cfc 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -24,11 +24,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -67,12 +67,46 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error {
return d.handle.file.flush(ctx)
}
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+
+ d := fd.dentry()
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+
+ size := offset + length
+
+ // Allocating a smaller size is a noop.
+ if size <= d.size {
+ return nil
+ }
+
+ d.handleMu.Lock()
+ defer d.handleMu.Unlock()
+
+ err := d.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length)
+ if err != nil {
+ return err
+ }
+ d.dataMu.Lock()
+ atomic.StoreUint64(&d.size, size)
+ d.dataMu.Unlock()
+ if !d.cachedMetadataAuthoritative() {
+ d.touchCMtimeLocked()
+ }
+ return nil
+}
+
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
if offset < 0 {
return 0, syserror.EINVAL
}
- if opts.Flags != 0 {
+
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
return 0, syserror.EOPNOTSUPP
}
@@ -120,21 +154,53 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset, error. The final
+// offset should be ignored by PWrite.
+func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
- if opts.Flags != 0 {
- return 0, syserror.EOPNOTSUPP
+
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, offset, syserror.EOPNOTSUPP
+ }
+
+ d := fd.dentry()
+ // If the fd was opened with O_APPEND, make sure the file size is updated.
+ // There is a possible race here if size is modified externally after
+ // metadata cache is updated.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, offset, err
+ }
+ }
+
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ // Set offset to file size if the fd was opened with O_APPEND.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Holding d.metadataMu is sufficient for reading d.size.
+ offset = int64(d.size)
}
limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
if err != nil {
- return 0, err
+ return 0, offset, err
}
src = src.TakeFirst64(limit)
+ n, err := fd.pwriteLocked(ctx, src, offset, opts)
+ return n, offset + n, err
+}
+// Preconditions: fd.dentry().metatdataMu must be locked.
+func (fd *regularFileFD) pwriteLocked(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
d := fd.dentry()
- d.metadataMu.Lock()
- defer d.metadataMu.Unlock()
if d.fs.opts.interop != InteropModeShared {
// Compare Linux's mm/filemap.c:__generic_file_write_iter() =>
// file_update_time(). This is d.touchCMtime(), but without locking
@@ -154,12 +220,12 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
return 0, syserror.EINVAL
}
mr := memmap.MappableRange{pgstart, pgend}
- var freed []platform.FileRange
+ var freed []memmap.FileRange
d.dataMu.Lock()
cseg := d.cache.LowerBoundSegment(mr.Start)
for cseg.Ok() && cseg.Start() < mr.End {
cseg = d.cache.Isolate(cseg, mr)
- freed = append(freed, platform.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()})
+ freed = append(freed, memmap.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()})
cseg = d.cache.Remove(cseg).NextSegment()
}
d.dataMu.Unlock()
@@ -197,8 +263,8 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
fd.mu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
fd.mu.Unlock()
return n, err
}
@@ -489,15 +555,24 @@ func (d *dentry) writeback(ctx context.Context, offset, size int64) error {
func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
fd.mu.Lock()
defer fd.mu.Unlock()
+ newOffset, err := regularFileSeekLocked(ctx, fd.dentry(), fd.off, offset, whence)
+ if err != nil {
+ return 0, err
+ }
+ fd.off = newOffset
+ return newOffset, nil
+}
+
+// Calculate the new offset for a seek operation on a regular file.
+func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int64, whence int32) (int64, error) {
switch whence {
case linux.SEEK_SET:
// Use offset as specified.
case linux.SEEK_CUR:
- offset += fd.off
+ offset += fdOffset
case linux.SEEK_END, linux.SEEK_DATA, linux.SEEK_HOLE:
// Ensure file size is up to date.
- d := fd.dentry()
- if fd.filesystem().opts.interop == InteropModeShared {
+ if !d.cachedMetadataAuthoritative() {
if err := d.updateFromGetattr(ctx); err != nil {
return 0, err
}
@@ -525,7 +600,6 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (
if offset < 0 {
return 0, syserror.EINVAL
}
- fd.off = offset
return offset, nil
}
@@ -536,20 +610,19 @@ func (fd *regularFileFD) Sync(ctx context.Context) error {
func (d *dentry) syncSharedHandle(ctx context.Context) error {
d.handleMu.RLock()
- if !d.handleWritable {
- d.handleMu.RUnlock()
- return nil
- }
- d.dataMu.Lock()
- // Write dirty cached data to the remote file.
- err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt)
- d.dataMu.Unlock()
- if err == nil {
- // Sync the remote file.
- err = d.handle.sync(ctx)
+ defer d.handleMu.RUnlock()
+
+ if d.handleWritable {
+ d.dataMu.Lock()
+ // Write dirty cached data to the remote file.
+ err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt)
+ d.dataMu.Unlock()
+ if err != nil {
+ return err
+ }
}
- d.handleMu.RUnlock()
- return err
+ // Sync the remote file.
+ return d.handle.sync(ctx)
}
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
@@ -747,7 +820,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange
// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
func (d *dentry) InvalidateUnsavable(ctx context.Context) error {
- // Whether we have a host fd (and consequently what platform.File is
+ // Whether we have a host fd (and consequently what memmap.File is
// mapped) can change across save/restore, so invalidate all translations
// unconditionally.
d.mapsMu.Lock()
@@ -795,8 +868,8 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) {
}
}
-// dentryPlatformFile implements platform.File. It exists solely because dentry
-// cannot implement both vfs.DentryImpl.IncRef and platform.File.IncRef.
+// dentryPlatformFile implements memmap.File. It exists solely because dentry
+// cannot implement both vfs.DentryImpl.IncRef and memmap.File.IncRef.
//
// dentryPlatformFile is only used when a host FD representing the remote file
// is available (i.e. dentry.handle.fd >= 0), and that FD is used for
@@ -804,7 +877,7 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) {
type dentryPlatformFile struct {
*dentry
- // fdRefs counts references on platform.File offsets. fdRefs is protected
+ // fdRefs counts references on memmap.File offsets. fdRefs is protected
// by dentry.dataMu.
fdRefs fsutil.FrameRefSet
@@ -816,29 +889,29 @@ type dentryPlatformFile struct {
hostFileMapperInitOnce sync.Once
}
-// IncRef implements platform.File.IncRef.
-func (d *dentryPlatformFile) IncRef(fr platform.FileRange) {
+// IncRef implements memmap.File.IncRef.
+func (d *dentryPlatformFile) IncRef(fr memmap.FileRange) {
d.dataMu.Lock()
d.fdRefs.IncRefAndAccount(fr)
d.dataMu.Unlock()
}
-// DecRef implements platform.File.DecRef.
-func (d *dentryPlatformFile) DecRef(fr platform.FileRange) {
+// DecRef implements memmap.File.DecRef.
+func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) {
d.dataMu.Lock()
d.fdRefs.DecRefAndAccount(fr)
d.dataMu.Unlock()
}
-// MapInternal implements platform.File.MapInternal.
-func (d *dentryPlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// MapInternal implements memmap.File.MapInternal.
+func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
d.handleMu.RLock()
bs, err := d.hostFileMapper.MapInternal(fr, int(d.handle.fd), at.Write)
d.handleMu.RUnlock()
return bs, err
}
-// FD implements platform.File.FD.
+// FD implements memmap.File.FD.
func (d *dentryPlatformFile) FD() int {
d.handleMu.RLock()
fd := d.handle.fd
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index a464e6a94..811528982 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -16,20 +16,22 @@ package gofer
import (
"sync"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
)
-// specialFileFD implements vfs.FileDescriptionImpl for files other than
-// regular files, directories, and symlinks: pipes, sockets, etc. It is also
-// used for regular files when filesystemOptions.specialRegularFiles is in
-// effect. specialFileFD differs from regularFileFD by using per-FD handles
-// instead of shared per-dentry handles, and never buffering I/O.
+// specialFileFD implements vfs.FileDescriptionImpl for pipes, sockets, device
+// special files, and (when filesystemOptions.regularFilesUseSpecialFileFD is
+// in effect) regular files. specialFileFD differs from regularFileFD by using
+// per-FD handles instead of shared per-dentry handles, and never buffering I/O.
type specialFileFD struct {
fileDescription
@@ -40,13 +42,48 @@ type specialFileFD struct {
// file offset is significant, i.e. a regular file. seekable is immutable.
seekable bool
+ // haveQueue is true if this file description represents a file for which
+ // queue may send I/O readiness events. haveQueue is immutable.
+ haveQueue bool
+ queue waiter.Queue
+
// If seekable is true, off is the file offset. off is protected by mu.
mu sync.Mutex
off int64
}
+func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) {
+ ftype := d.fileType()
+ seekable := ftype == linux.S_IFREG
+ haveQueue := (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && h.fd >= 0
+ fd := &specialFileFD{
+ handle: h,
+ seekable: seekable,
+ haveQueue: haveQueue,
+ }
+ fd.LockFD.Init(locks)
+ if haveQueue {
+ if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil {
+ return nil, err
+ }
+ }
+ if err := fd.vfsfd.Init(fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
+ DenyPRead: !seekable,
+ DenyPWrite: !seekable,
+ }); err != nil {
+ if haveQueue {
+ fdnotifier.RemoveFD(h.fd)
+ }
+ return nil, err
+ }
+ return fd, nil
+}
+
// Release implements vfs.FileDescriptionImpl.Release.
func (fd *specialFileFD) Release() {
+ if fd.haveQueue {
+ fdnotifier.RemoveFD(fd.handle.fd)
+ }
fd.handle.close(context.Background())
fs := fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
fs.syncMu.Lock()
@@ -62,12 +99,44 @@ func (fd *specialFileFD) OnClose(ctx context.Context) error {
return fd.handle.file.flush(ctx)
}
+// Readiness implements waiter.Waitable.Readiness.
+func (fd *specialFileFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ if fd.haveQueue {
+ return fdnotifier.NonBlockingPoll(fd.handle.fd, mask)
+ }
+ return fd.fileDescription.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (fd *specialFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ if fd.haveQueue {
+ fd.queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(fd.handle.fd)
+ return
+ }
+ fd.fileDescription.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (fd *specialFileFD) EventUnregister(e *waiter.Entry) {
+ if fd.haveQueue {
+ fd.queue.EventUnregister(e)
+ fdnotifier.UpdateFD(fd.handle.fd)
+ return
+ }
+ fd.fileDescription.EventUnregister(e)
+}
+
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
if fd.seekable && offset < 0 {
return 0, syserror.EINVAL
}
- if opts.Flags != 0 {
+
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
return 0, syserror.EOPNOTSUPP
}
@@ -76,11 +145,14 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
// mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't
// hold here since specialFileFD doesn't client-cache data. Just buffer the
// read instead.
- if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
+ if d := fd.dentry(); d.cachedMetadataAuthoritative() {
d.touchAtime(fd.vfsfd.Mount())
}
buf := make([]byte, dst.NumBytes())
n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
+ if err == syserror.EAGAIN {
+ err = syserror.ErrWouldBlock
+ }
if n == 0 {
return 0, err
}
@@ -105,32 +177,76 @@ func (fd *specialFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset, error. The final
+// offset should be ignored by PWrite.
+func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if fd.seekable && offset < 0 {
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
- if opts.Flags != 0 {
- return 0, syserror.EOPNOTSUPP
+
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, offset, syserror.EOPNOTSUPP
+ }
+
+ d := fd.dentry()
+ // If the regular file fd was opened with O_APPEND, make sure the file size
+ // is updated. There is a possible race here if size is modified externally
+ // after metadata cache is updated.
+ if fd.seekable && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, offset, err
+ }
}
if fd.seekable {
+ // We need to hold the metadataMu *while* writing to a regular file.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+
+ // Set offset to file size if the regular file was opened with O_APPEND.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Holding d.metadataMu is sufficient for reading d.size.
+ offset = int64(d.size)
+ }
limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
if err != nil {
- return 0, err
+ return 0, offset, err
}
src = src.TakeFirst64(limit)
}
// Do a buffered write. See rationale in PRead.
- if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
+ if d.cachedMetadataAuthoritative() {
d.touchCMtime()
}
buf := make([]byte, src.NumBytes())
// Don't do partial writes if we get a partial read from src.
if _, err := src.CopyIn(ctx, buf); err != nil {
- return 0, err
+ return 0, offset, err
}
n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
- return int64(n), err
+ if err == syserror.EAGAIN {
+ err = syserror.ErrWouldBlock
+ }
+ finalOff = offset
+ // Update file size for regular files.
+ if fd.seekable {
+ finalOff += int64(n)
+ // d.metadataMu is already locked at this point.
+ if uint64(finalOff) > d.size {
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+ atomic.StoreUint64(&d.size, uint64(finalOff))
+ }
+ }
+ return int64(n), finalOff, err
}
// Write implements vfs.FileDescriptionImpl.Write.
@@ -140,8 +256,8 @@ func (fd *specialFileFD) Write(ctx context.Context, src usermem.IOSequence, opts
}
fd.mu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
fd.mu.Unlock()
return n, err
}
@@ -153,27 +269,15 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) (
}
fd.mu.Lock()
defer fd.mu.Unlock()
- switch whence {
- case linux.SEEK_SET:
- // Use offset as given.
- case linux.SEEK_CUR:
- offset += fd.off
- default:
- // SEEK_END, SEEK_DATA, and SEEK_HOLE aren't supported since it's not
- // clear that file size is even meaningful for these files.
- return 0, syserror.EINVAL
- }
- if offset < 0 {
- return 0, syserror.EINVAL
+ newOffset, err := regularFileSeekLocked(ctx, fd.dentry(), fd.off, offset, whence)
+ if err != nil {
+ return 0, err
}
- fd.off = offset
- return offset, nil
+ fd.off = newOffset
+ return newOffset, nil
}
// Sync implements vfs.FileDescriptionImpl.Sync.
func (fd *specialFileFD) Sync(ctx context.Context) error {
- if !fd.vfsfd.IsWritable() {
- return nil
- }
- return fd.handle.sync(ctx)
+ return fd.dentry().syncSharedHandle(ctx)
}
diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go
index 2608e7e1d..0eef4e16e 100644
--- a/pkg/sentry/fsimpl/gofer/time.go
+++ b/pkg/sentry/fsimpl/gofer/time.go
@@ -36,8 +36,11 @@ func statxTimestampFromDentry(ns int64) linux.StatxTimestamp {
}
}
-// Preconditions: fs.interop != InteropModeShared.
+// Preconditions: d.cachedMetadataAuthoritative() == true.
func (d *dentry) touchAtime(mnt *vfs.Mount) {
+ if mnt.Flags.NoATime {
+ return
+ }
if err := mnt.CheckBeginWrite(); err != nil {
return
}
@@ -48,8 +51,8 @@ func (d *dentry) touchAtime(mnt *vfs.Mount) {
mnt.EndWrite()
}
-// Preconditions: fs.interop != InteropModeShared. The caller has successfully
-// called vfs.Mount.CheckBeginWrite().
+// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has
+// successfully called vfs.Mount.CheckBeginWrite().
func (d *dentry) touchCtime() {
now := d.fs.clock.Now().Nanoseconds()
d.metadataMu.Lock()
@@ -57,8 +60,8 @@ func (d *dentry) touchCtime() {
d.metadataMu.Unlock()
}
-// Preconditions: fs.interop != InteropModeShared. The caller has successfully
-// called vfs.Mount.CheckBeginWrite().
+// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has
+// successfully called vfs.Mount.CheckBeginWrite().
func (d *dentry) touchCMtime() {
now := d.fs.clock.Now().Nanoseconds()
d.metadataMu.Lock()
@@ -67,6 +70,8 @@ func (d *dentry) touchCMtime() {
d.metadataMu.Unlock()
}
+// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has
+// locked d.metadataMu.
func (d *dentry) touchCMtimeLocked() {
now := d.fs.clock.Now().Nanoseconds()
atomic.StoreInt64(&d.mtime, now)
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index ca0fe6d2b..bd701bbc7 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -22,17 +22,18 @@ go_library(
"//pkg/context",
"//pkg/fdnotifier",
"//pkg/fspath",
+ "//pkg/iovec",
"//pkg/log",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/hostfd",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
- "//pkg/sentry/platform",
"//pkg/sentry/socket/control",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 18b127521..c894f2ca0 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/refs"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/hostfd"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -90,7 +91,9 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions)
isTTY: opts.IsTTY,
wouldBlock: wouldBlock(uint32(fileType)),
seekable: seekable,
- canMap: canMap(uint32(fileType)),
+ // NOTE(b/38213152): Technically, some obscure char devices can be memory
+ // mapped, but we only allow regular files.
+ canMap: fileType == linux.S_IFREG,
}
i.pf.inode = i
@@ -182,6 +185,8 @@ type inode struct {
kernfs.InodeNotDirectory
kernfs.InodeNotSymlink
+ locks vfs.FileLocks
+
// When the reference count reaches zero, the host fd is closed.
refs.AtomicRefCount
@@ -254,7 +259,7 @@ func (i *inode) Mode() linux.FileMode {
}
// Stat implements kernfs.Inode.
-func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
if opts.Mask&linux.STATX__RESERVED != 0 {
return linux.Statx{}, syserror.EINVAL
}
@@ -368,7 +373,7 @@ func (i *inode) fstat(fs *filesystem) (linux.Statx, error) {
// SetStat implements kernfs.Inode.
func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
- s := opts.Stat
+ s := &opts.Stat
m := s.Mask
if m == 0 {
@@ -381,7 +386,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
if err := syscall.Fstat(i.hostFD, &hostStat); err != nil {
return err
}
- if err := vfs.CheckSetStat(ctx, creds, &s, linux.FileMode(hostStat.Mode&linux.PermissionsMask), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, &opts, linux.FileMode(hostStat.Mode), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil {
return err
}
@@ -391,6 +396,9 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
}
}
if m&linux.STATX_SIZE != 0 {
+ if hostStat.Mode&linux.S_IFMT != linux.S_IFREG {
+ return syserror.EINVAL
+ }
if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil {
return err
}
@@ -454,10 +462,12 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u
fileType := s.Mode & linux.FileTypeMask
// Constrain flags to a subset we can handle.
- // TODO(gvisor.dev/issue/1672): implement behavior corresponding to these allowed flags.
- flags &= syscall.O_ACCMODE | syscall.O_DIRECT | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND
+ //
+ // TODO(gvisor.dev/issue/2601): Support O_NONBLOCK by adding RWF_NOWAIT to pread/pwrite calls.
+ flags &= syscall.O_ACCMODE | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND
- if fileType == syscall.S_IFSOCK {
+ switch fileType {
+ case syscall.S_IFSOCK:
if i.isTTY {
log.Warningf("cannot use host socket fd %d as TTY", i.hostFD)
return nil, syserror.ENOTTY
@@ -468,35 +478,41 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u
return nil, err
}
// Currently, we only allow Unix sockets to be imported.
- return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d)
- }
+ return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d, &i.locks)
- // TODO(gvisor.dev/issue/1672): Whitelist specific file types here, so that
- // we don't allow importing arbitrary file types without proper support.
- if i.isTTY {
- fd := &TTYFileDescription{
- fileDescription: fileDescription{inode: i},
- termios: linux.DefaultSlaveTermios,
+ case syscall.S_IFREG, syscall.S_IFIFO, syscall.S_IFCHR:
+ if i.isTTY {
+ fd := &TTYFileDescription{
+ fileDescription: fileDescription{inode: i},
+ termios: linux.DefaultSlaveTermios,
+ }
+ fd.LockFD.Init(&i.locks)
+ vfsfd := &fd.vfsfd
+ if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return vfsfd, nil
}
+
+ fd := &fileDescription{inode: i}
+ fd.LockFD.Init(&i.locks)
vfsfd := &fd.vfsfd
if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return vfsfd, nil
- }
- fd := &fileDescription{inode: i}
- vfsfd := &fd.vfsfd
- if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
- return nil, err
+ default:
+ log.Warningf("cannot import host fd %d with file type %o", i.hostFD, fileType)
+ return nil, syserror.EPERM
}
- return vfsfd, nil
}
// fileDescription is embedded by host fd implementations of FileDescriptionImpl.
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
// inode is vfsfd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode), but
// cached to reduce indirections and casting. fileDescription does not hold
@@ -521,8 +537,8 @@ func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
}
// Stat implements vfs.FileDescriptionImpl.
-func (f *fileDescription) Stat(_ context.Context, opts vfs.StatOptions) (linux.Statx, error) {
- return f.inode.Stat(f.vfsfd.Mount().Filesystem(), opts)
+func (f *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ return f.inode.Stat(ctx, f.vfsfd.Mount().Filesystem(), opts)
}
// Release implements vfs.FileDescriptionImpl.
@@ -530,6 +546,16 @@ func (f *fileDescription) Release() {
// noop
}
+// Allocate implements vfs.FileDescriptionImpl.
+func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ if !f.inode.seekable {
+ return syserror.ESPIPE
+ }
+
+ // TODO(gvisor.dev/issue/2923): Implement Allocate for non-pipe hostfds.
+ return syserror.EOPNOTSUPP
+}
+
// PRead implements FileDescriptionImpl.
func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
i := f.inode
@@ -556,7 +582,7 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts
}
return n, err
}
- // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so.
+
f.offsetMu.Lock()
n, err := readFromHostFD(ctx, i.hostFD, dst, f.offset, opts.Flags)
f.offset += n
@@ -565,8 +591,10 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts
}
func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) {
- // TODO(gvisor.dev/issue/1672): Support select preadv2 flags.
- if flags != 0 {
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if flags&^linux.RWF_HIPRI != 0 {
return 0, syserror.EOPNOTSUPP
}
reader := hostfd.GetReadWriterAt(int32(hostFD), offset, flags)
@@ -577,41 +605,58 @@ func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, off
// PWrite implements FileDescriptionImpl.
func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
- i := f.inode
- if !i.seekable {
+ if !f.inode.seekable {
return 0, syserror.ESPIPE
}
- return writeToHostFD(ctx, i.hostFD, src, offset, opts.Flags)
+ return f.writeToHostFD(ctx, src, offset, opts.Flags)
}
// Write implements FileDescriptionImpl.
func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
i := f.inode
if !i.seekable {
- n, err := writeToHostFD(ctx, i.hostFD, src, -1, opts.Flags)
+ n, err := f.writeToHostFD(ctx, src, -1, opts.Flags)
if isBlockError(err) {
err = syserror.ErrWouldBlock
}
return n, err
}
- // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so.
- // TODO(gvisor.dev/issue/1672): Write to end of file and update offset if O_APPEND is set on this file.
+
f.offsetMu.Lock()
- n, err := writeToHostFD(ctx, i.hostFD, src, f.offset, opts.Flags)
+ // NOTE(gvisor.dev/issue/2983): O_APPEND may cause memory corruption if
+ // another process modifies the host file between retrieving the file size
+ // and writing to the host fd. This is an unavoidable race condition because
+ // we cannot enforce synchronization on the host.
+ if f.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ f.offsetMu.Unlock()
+ return 0, err
+ }
+ f.offset = s.Size
+ }
+ n, err := f.writeToHostFD(ctx, src, f.offset, opts.Flags)
f.offset += n
f.offsetMu.Unlock()
return n, err
}
-func writeToHostFD(ctx context.Context, hostFD int, src usermem.IOSequence, offset int64, flags uint32) (int64, error) {
- // TODO(gvisor.dev/issue/1672): Support select pwritev2 flags.
+func (f *fileDescription) writeToHostFD(ctx context.Context, src usermem.IOSequence, offset int64, flags uint32) (int64, error) {
+ hostFD := f.inode.hostFD
+ // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
if flags != 0 {
return 0, syserror.EOPNOTSUPP
}
writer := hostfd.GetReadWriterAt(int32(hostFD), offset, flags)
n, err := src.CopyInTo(ctx, writer)
hostfd.PutReadWriterAt(writer)
+ // NOTE(gvisor.dev/issue/2979): We always sync everything, even for O_DSYNC.
+ if n > 0 && f.vfsfd.StatusFlags()&(linux.O_DSYNC|linux.O_SYNC) != 0 {
+ if syncErr := unix.Fsync(hostFD); syncErr != nil {
+ return int64(n), syncErr
+ }
+ }
return int64(n), err
}
@@ -682,7 +727,7 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i
// Sync implements FileDescriptionImpl.
func (f *fileDescription) Sync(context.Context) error {
- // TODO(gvisor.dev/issue/1672): Currently we do not support the SyncData optimization, so we always sync everything.
+ // TODO(gvisor.dev/issue/1897): Currently, we always sync everything.
return unix.Fsync(f.inode.hostFD)
}
@@ -712,3 +757,13 @@ func (f *fileDescription) EventUnregister(e *waiter.Entry) {
func (f *fileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
return fdnotifier.NonBlockingPoll(int32(f.inode.hostFD), mask)
}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (f *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return f.Locks().LockPOSIX(ctx, &f.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (f *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return f.Locks().UnlockPOSIX(ctx, &f.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go
index 8545a82f0..65d3af38c 100644
--- a/pkg/sentry/fsimpl/host/mmap.go
+++ b/pkg/sentry/fsimpl/host/mmap.go
@@ -19,13 +19,12 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
)
-// inodePlatformFile implements platform.File. It exists solely because inode
-// cannot implement both kernfs.Inode.IncRef and platform.File.IncRef.
+// inodePlatformFile implements memmap.File. It exists solely because inode
+// cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef.
//
// inodePlatformFile should only be used if inode.canMap is true.
type inodePlatformFile struct {
@@ -34,7 +33,7 @@ type inodePlatformFile struct {
// fdRefsMu protects fdRefs.
fdRefsMu sync.Mutex
- // fdRefs counts references on platform.File offsets. It is used solely for
+ // fdRefs counts references on memmap.File offsets. It is used solely for
// memory accounting.
fdRefs fsutil.FrameRefSet
@@ -45,32 +44,32 @@ type inodePlatformFile struct {
fileMapperInitOnce sync.Once
}
-// IncRef implements platform.File.IncRef.
+// IncRef implements memmap.File.IncRef.
//
// Precondition: i.inode.canMap must be true.
-func (i *inodePlatformFile) IncRef(fr platform.FileRange) {
+func (i *inodePlatformFile) IncRef(fr memmap.FileRange) {
i.fdRefsMu.Lock()
i.fdRefs.IncRefAndAccount(fr)
i.fdRefsMu.Unlock()
}
-// DecRef implements platform.File.DecRef.
+// DecRef implements memmap.File.DecRef.
//
// Precondition: i.inode.canMap must be true.
-func (i *inodePlatformFile) DecRef(fr platform.FileRange) {
+func (i *inodePlatformFile) DecRef(fr memmap.FileRange) {
i.fdRefsMu.Lock()
i.fdRefs.DecRefAndAccount(fr)
i.fdRefsMu.Unlock()
}
-// MapInternal implements platform.File.MapInternal.
+// MapInternal implements memmap.File.MapInternal.
//
// Precondition: i.inode.canMap must be true.
-func (i *inodePlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return i.fileMapper.MapInternal(fr, i.hostFD, at.Write)
}
-// FD implements platform.File.FD.
+// FD implements memmap.File.FD.
func (i *inodePlatformFile) FD() int {
return i.hostFD
}
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
index 38f1fbfba..fd16bd92d 100644
--- a/pkg/sentry/fsimpl/host/socket.go
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -47,11 +47,6 @@ func newEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue) (transpor
return ep, nil
}
-// maxSendBufferSize is the maximum host send buffer size allowed for endpoint.
-//
-// N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max).
-const maxSendBufferSize = 8 << 20
-
// ConnectedEndpoint is an implementation of transport.ConnectedEndpoint and
// transport.Receiver. It is backed by a host fd that was imported at sentry
// startup. This fd is shared with a hostfs inode, which retains ownership of
@@ -114,10 +109,6 @@ func (c *ConnectedEndpoint) init() *syserr.Error {
if err != nil {
return syserr.FromError(err)
}
- if sndbuf > maxSendBufferSize {
- log.Warningf("Socket send buffer too large: %d", sndbuf)
- return syserr.ErrInvalidEndpointState
- }
c.stype = linux.SockType(stype)
c.sndbuf = int64(sndbuf)
diff --git a/pkg/sentry/fsimpl/host/socket_iovec.go b/pkg/sentry/fsimpl/host/socket_iovec.go
index 584c247d2..fc0d5fd38 100644
--- a/pkg/sentry/fsimpl/host/socket_iovec.go
+++ b/pkg/sentry/fsimpl/host/socket_iovec.go
@@ -17,13 +17,10 @@ package host
import (
"syscall"
- "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/syserror"
)
-// maxIovs is the maximum number of iovecs to pass to the host.
-var maxIovs = linux.UIO_MAXIOV
-
// copyToMulti copies as many bytes from src to dst as possible.
func copyToMulti(dst [][]byte, src []byte) {
for _, d := range dst {
@@ -74,7 +71,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec
}
}
- if iovsRequired > maxIovs {
+ if iovsRequired > iovec.MaxIovs {
// The kernel will reject our call if we pass this many iovs.
// Use a single intermediate buffer instead.
b := make([]byte, stopLen)
diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go
index 68af6e5af..4ee9270cc 100644
--- a/pkg/sentry/fsimpl/host/tty.go
+++ b/pkg/sentry/fsimpl/host/tty.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -325,9 +326,9 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal)
task := kernel.TaskFromContext(ctx)
if task == nil {
// No task? Linux does not have an analog for this case, but
- // tty_check_change is more of a blacklist of cases than a
- // whitelist, and is surprisingly permissive. Allowing the
- // change seems most appropriate.
+ // tty_check_change only blocks specific cases and is
+ // surprisingly permissive. Allowing the change seems
+ // appropriate.
return nil
}
@@ -377,3 +378,13 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal)
_ = pg.SendSignal(kernel.SignalInfoPriv(sig))
return kernel.ERESTARTSYS
}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (t *TTYFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, typ fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return t.Locks().LockPOSIX(ctx, &t.vfsfd, uid, typ, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (t *TTYFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return t.Locks().UnlockPOSIX(ctx, &t.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go
index 2bc757b1a..412bdb2eb 100644
--- a/pkg/sentry/fsimpl/host/util.go
+++ b/pkg/sentry/fsimpl/host/util.go
@@ -49,16 +49,6 @@ func wouldBlock(fileType uint32) bool {
return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK
}
-// canMap returns true if a file with fileType is allowed to be memory mapped.
-// This is ported over from VFS1, but it's probably not the best way for us
-// to check if a file can be memory mapped.
-func canMap(fileType uint32) bool {
- // TODO(gvisor.dev/issue/1672): Also allow "special files" to be mapped (see fs/host:canMap()).
- //
- // TODO(b/38213152): Some obscure character devices can be mapped.
- return fileType == syscall.S_IFREG
-}
-
// isBlockError checks if an error is EAGAIN or EWOULDBLOCK.
// If so, they can be transformed into syserror.ErrWouldBlock.
func isBlockError(err error) bool {
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index ef34cb28a..3835557fe 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -45,6 +45,7 @@ go_library(
"//pkg/fspath",
"//pkg/log",
"//pkg/refs",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
"//pkg/sentry/socket/unix/transport",
@@ -69,6 +70,6 @@ go_test(
"//pkg/sentry/vfs",
"//pkg/syserror",
"//pkg/usermem",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index 1568a9d49..c6c4472e7 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -38,7 +39,8 @@ type DynamicBytesFile struct {
InodeNotDirectory
InodeNotSymlink
- data vfs.DynamicBytesSource
+ locks vfs.FileLocks
+ data vfs.DynamicBytesSource
}
var _ Inode = (*DynamicBytesFile)(nil)
@@ -55,7 +57,7 @@ func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint
// Open implements Inode.Open.
func (f *DynamicBytesFile) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &DynamicBytesFD{}
- if err := fd.Init(rp.Mount(), vfsd, f.data, opts.Flags); err != nil {
+ if err := fd.Init(rp.Mount(), vfsd, f.data, &f.locks, opts.Flags); err != nil {
return nil, err
}
return &fd.vfsfd, nil
@@ -77,13 +79,15 @@ func (*DynamicBytesFile) SetStat(context.Context, *vfs.Filesystem, *auth.Credent
type DynamicBytesFD struct {
vfs.FileDescriptionDefaultImpl
vfs.DynamicBytesFileDescriptionImpl
+ vfs.LockFD
vfsfd vfs.FileDescription
inode Inode
}
// Init initializes a DynamicBytesFD.
-func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, flags uint32) error {
+func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, locks *vfs.FileLocks, flags uint32) error {
+ fd.LockFD.Init(locks)
if err := fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}); err != nil {
return err
}
@@ -97,12 +101,12 @@ func (fd *DynamicBytesFD) Seek(ctx context.Context, offset int64, whence int32)
return fd.DynamicBytesFileDescriptionImpl.Seek(ctx, offset, whence)
}
-// Read implmenets vfs.FileDescriptionImpl.Read.
+// Read implements vfs.FileDescriptionImpl.Read.
func (fd *DynamicBytesFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
return fd.DynamicBytesFileDescriptionImpl.Read(ctx, dst, opts)
}
-// PRead implmenets vfs.FileDescriptionImpl.PRead.
+// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *DynamicBytesFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
return fd.DynamicBytesFileDescriptionImpl.PRead(ctx, dst, offset, opts)
}
@@ -123,7 +127,7 @@ func (fd *DynamicBytesFD) Release() {}
// Stat implements vfs.FileDescriptionImpl.Stat.
func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
fs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
- return fd.inode.Stat(fs, opts)
+ return fd.inode.Stat(ctx, fs, opts)
}
// SetStat implements vfs.FileDescriptionImpl.SetStat.
@@ -131,3 +135,13 @@ func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error {
// DynamicBytesFiles are immutable.
return syserror.EPERM
}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *DynamicBytesFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *DynamicBytesFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index 8284e76a7..1d37ccb98 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -42,6 +43,7 @@ import (
type GenericDirectoryFD struct {
vfs.FileDescriptionDefaultImpl
vfs.DirectoryFileDescriptionDefaultImpl
+ vfs.LockFD
vfsfd vfs.FileDescription
children *OrderedChildren
@@ -55,9 +57,9 @@ type GenericDirectoryFD struct {
// NewGenericDirectoryFD creates a new GenericDirectoryFD and returns its
// dentry.
-func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, opts *vfs.OpenOptions) (*GenericDirectoryFD, error) {
+func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions) (*GenericDirectoryFD, error) {
fd := &GenericDirectoryFD{}
- if err := fd.Init(children, opts); err != nil {
+ if err := fd.Init(children, locks, opts); err != nil {
return nil, err
}
if err := fd.vfsfd.Init(fd, opts.Flags, m, d, &vfs.FileDescriptionOptions{}); err != nil {
@@ -69,11 +71,12 @@ func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildre
// Init initializes a GenericDirectoryFD. Use it when overriding
// GenericDirectoryFD. Caller must call fd.VFSFileDescription.Init() with the
// correct implementation.
-func (fd *GenericDirectoryFD) Init(children *OrderedChildren, opts *vfs.OpenOptions) error {
+func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions) error {
if vfs.AccessTypesForOpenFlags(opts)&vfs.MayWrite != 0 {
// Can't open directories for writing.
return syserror.EISDIR
}
+ fd.LockFD.Init(locks)
fd.children = children
return nil
}
@@ -109,7 +112,7 @@ func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence
return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts)
}
-// Release implements vfs.FileDecriptionImpl.Release.
+// Release implements vfs.FileDescriptionImpl.Release.
func (fd *GenericDirectoryFD) Release() {}
func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem {
@@ -120,7 +123,7 @@ func (fd *GenericDirectoryFD) inode() Inode {
return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
}
-// IterDirents implements vfs.FileDecriptionImpl.IterDirents. IterDirents holds
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds
// o.mu when calling cb.
func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
fd.mu.Lock()
@@ -129,7 +132,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
opts := vfs.StatOptions{Mask: linux.STATX_INO}
// Handle ".".
if fd.off == 0 {
- stat, err := fd.inode().Stat(fd.filesystem(), opts)
+ stat, err := fd.inode().Stat(ctx, fd.filesystem(), opts)
if err != nil {
return err
}
@@ -149,7 +152,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
if fd.off == 1 {
vfsd := fd.vfsfd.VirtualDentry().Dentry()
parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode
- stat, err := parentInode.Stat(fd.filesystem(), opts)
+ stat, err := parentInode.Stat(ctx, fd.filesystem(), opts)
if err != nil {
return err
}
@@ -173,7 +176,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
childIdx := fd.off - 2
for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() {
inode := it.Dentry.Impl().(*Dentry).inode
- stat, err := inode.Stat(fd.filesystem(), opts)
+ stat, err := inode.Stat(ctx, fd.filesystem(), opts)
if err != nil {
return err
}
@@ -195,7 +198,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
return err
}
-// Seek implements vfs.FileDecriptionImpl.Seek.
+// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
fd.mu.Lock()
defer fd.mu.Unlock()
@@ -223,7 +226,7 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int
func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
fs := fd.filesystem()
inode := fd.inode()
- return inode.Stat(fs, opts)
+ return inode.Stat(ctx, fs, opts)
}
// SetStat implements vfs.FileDescriptionImpl.SetStat.
@@ -232,3 +235,18 @@ func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptio
inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
return inode.SetStat(ctx, fd.filesystem(), creds, opts)
}
+
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (fd *GenericDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return fd.DirectoryFileDescriptionDefaultImpl.Allocate(ctx, mode, offset, length)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *GenericDirectoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *GenericDirectoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 4a12ae245..61a36cff9 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -35,7 +35,7 @@ import (
// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done().
//
// Postcondition: Caller must call fs.processDeferredDecRefs*.
-func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry) (*vfs.Dentry, error) {
+func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, mayFollowSymlinks bool) (*vfs.Dentry, error) {
d := vfsd.Impl().(*Dentry)
if !d.isDir() {
return nil, syserror.ENOTDIR
@@ -81,7 +81,7 @@ afterSymlink:
return nil, err
}
// Resolve any symlink at current path component.
- if rp.ShouldFollowSymlink() && next.isSymlink() {
+ if mayFollowSymlinks && rp.ShouldFollowSymlink() && next.isSymlink() {
targetVD, targetPathname, err := next.inode.Getlink(ctx, rp.Mount())
if err != nil {
return nil, err
@@ -152,7 +152,7 @@ func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingP
vfsd := rp.Start()
for !rp.Done() {
var err error
- vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd)
+ vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd, true /* mayFollowSymlinks */)
if err != nil {
return nil, nil, err
}
@@ -178,7 +178,7 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving
vfsd := rp.Start()
for !rp.Final() {
var err error
- vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd)
+ vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd, true /* mayFollowSymlinks */)
if err != nil {
return nil, nil, err
}
@@ -449,7 +449,7 @@ afterTrailingSymlink:
return nil, syserror.ENAMETOOLONG
}
// Determine whether or not we need to create a file.
- childVFSD, err := fs.stepExistingLocked(ctx, rp, parentVFSD)
+ childVFSD, err := fs.stepExistingLocked(ctx, rp, parentVFSD, false /* mayFollowSymlinks */)
if err == syserror.ENOENT {
// Already checked for searchability above; now check for writability.
if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil {
@@ -684,7 +684,7 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return linux.Statx{}, err
}
- return inode.Stat(fs.VFSFilesystem(), opts)
+ return inode.Stat(ctx, fs.VFSFilesystem(), opts)
}
// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 982daa2e6..579e627f0 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -243,7 +243,7 @@ func (a *InodeAttrs) Mode() linux.FileMode {
// Stat partially implements Inode.Stat. Note that this function doesn't provide
// all the stat fields, and the embedder should consider extending the result
// with filesystem-specific fields.
-func (a *InodeAttrs) Stat(*vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) {
+func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) {
var stat linux.Statx
stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK
stat.DevMajor = a.devMajor
@@ -267,7 +267,7 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut
if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 {
return syserror.EPERM
}
- if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
return err
}
@@ -293,6 +293,8 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut
// inode numbers are immutable after node creation.
// TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
+ // Also, STATX_SIZE will need some special handling, because read-only static
+ // files should return EIO for truncate operations.
return nil
}
@@ -469,6 +471,8 @@ func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.De
if err := o.checkExistingLocked(name, child); err != nil {
return err
}
+
+ // TODO(gvisor.dev/issue/3027): Check sticky bit before removing.
o.removeLocked(name)
return nil
}
@@ -516,6 +520,8 @@ func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, c
if err := o.checkExistingLocked(oldname, child); err != nil {
return nil, err
}
+
+ // TODO(gvisor.dev/issue/3027): Check sticky bit before removing.
replaced := dst.replaceChildLocked(newname, child)
return replaced, nil
}
@@ -555,6 +561,8 @@ type StaticDirectory struct {
InodeAttrs
InodeNoDynamicLookup
OrderedChildren
+
+ locks vfs.FileLocks
}
var _ Inode = (*StaticDirectory)(nil)
@@ -584,7 +592,7 @@ func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint3
// Open implements kernfs.Inode.
func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := NewGenericDirectoryFD(rp.Mount(), vfsd, &s.OrderedChildren, &opts)
+ fd, err := NewGenericDirectoryFD(rp.Mount(), vfsd, &s.OrderedChildren, &s.locks, &opts)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index a83151ad3..46f207664 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -225,9 +225,24 @@ func (d *Dentry) destroy() {
}
}
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+//
+// Although Linux technically supports inotify on pseudo filesystems (inotify
+// is implemented at the vfs layer), it is not particularly useful. It is left
+// unimplemented until someone actually needs it.
+func (d *Dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {}
+
+// Watches implements vfs.DentryImpl.Watches.
+func (d *Dentry) Watches() *vfs.Watches {
+ return nil
+}
+
+// OnZeroWatches implements vfs.Dentry.OnZeroWatches.
+func (d *Dentry) OnZeroWatches() {}
+
// InsertChild inserts child into the vfs dentry cache with the given name under
// this dentry. This does not update the directory inode, so calling this on
-// it's own isn't sufficient to insert a child into a directory. InsertChild
+// its own isn't sufficient to insert a child into a directory. InsertChild
// updates the link count on d if required.
//
// Precondition: d must represent a directory inode.
@@ -331,7 +346,7 @@ type inodeMetadata interface {
// Stat returns the metadata for this inode. This corresponds to
// vfs.FilesystemImpl.StatAt.
- Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error)
+ Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error)
// SetStat updates the metadata for this inode. This corresponds to
// vfs.FilesystemImpl.SetStatAt. Implementations are responsible for checking
@@ -413,10 +428,10 @@ type inodeDynamicLookup interface {
// IterDirents is used to iterate over dynamically created entries. It invokes
// cb on each entry in the directory represented by the FileDescription.
// 'offset' is the offset for the entire IterDirents call, which may include
- // results from the caller. 'relOffset' is the offset inside the entries
- // returned by this IterDirents invocation. In other words,
- // 'offset+relOffset+1' is the value that should be set in vfs.Dirent.NextOff,
- // while 'relOffset' is the place where iteration should start from.
+ // results from the caller (e.g. "." and ".."). 'relOffset' is the offset
+ // inside the entries returned by this IterDirents invocation. In other words,
+ // 'offset' should be used to calculate each vfs.Dirent.NextOff as well as
+ // the return value, while 'relOffset' is the place to start iteration.
IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error)
}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index 412cf6ac9..dc407eb1d 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -100,8 +100,10 @@ type readonlyDir struct {
kernfs.InodeNotSymlink
kernfs.InodeNoDynamicLookup
kernfs.InodeDirectoryNoNewChildren
-
kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
dentry kernfs.Dentry
}
@@ -117,7 +119,7 @@ func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMod
}
func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts)
if err != nil {
return nil, err
}
@@ -128,10 +130,12 @@ type dir struct {
attrs
kernfs.InodeNotSymlink
kernfs.InodeNoDynamicLookup
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
fs *filesystem
dentry kernfs.Dentry
- kernfs.OrderedChildren
}
func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
@@ -147,7 +151,7 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte
}
func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD
new file mode 100644
index 000000000..8cf5b35d3
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/BUILD
@@ -0,0 +1,41 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "fstree",
+ out = "fstree.go",
+ package = "overlay",
+ prefix = "generic",
+ template = "//pkg/sentry/vfs/genericfstree:generic_fstree",
+ types = {
+ "Dentry": "dentry",
+ },
+)
+
+go_library(
+ name = "overlay",
+ srcs = [
+ "copy_up.go",
+ "directory.go",
+ "filesystem.go",
+ "fstree.go",
+ "non_directory.go",
+ "overlay.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
new file mode 100644
index 000000000..8f8dcfafe
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -0,0 +1,262 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "fmt"
+ "io"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func (d *dentry) isCopiedUp() bool {
+ return atomic.LoadUint32(&d.copiedUp) != 0
+}
+
+// copyUpLocked ensures that d exists on the upper layer, i.e. d.upperVD.Ok().
+//
+// Preconditions: filesystem.renameMu must be locked.
+func (d *dentry) copyUpLocked(ctx context.Context) error {
+ // Fast path.
+ if d.isCopiedUp() {
+ return nil
+ }
+
+ ftype := atomic.LoadUint32(&d.mode) & linux.S_IFMT
+ switch ftype {
+ case linux.S_IFREG, linux.S_IFDIR, linux.S_IFLNK, linux.S_IFBLK, linux.S_IFCHR:
+ // Can be copied-up.
+ default:
+ // Can't be copied-up.
+ return syserror.EPERM
+ }
+
+ // Ensure that our parent directory is copied-up.
+ if d.parent == nil {
+ // d is a filesystem root with no upper layer.
+ return syserror.EROFS
+ }
+ if err := d.parent.copyUpLocked(ctx); err != nil {
+ return err
+ }
+
+ d.copyMu.Lock()
+ defer d.copyMu.Unlock()
+ if d.upperVD.Ok() {
+ // Raced with another call to d.copyUpLocked().
+ return nil
+ }
+ if d.vfsd.IsDead() {
+ // Raced with deletion of d.
+ return syserror.ENOENT
+ }
+
+ // Perform copy-up.
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ newpop := vfs.PathOperation{
+ Root: d.parent.upperVD,
+ Start: d.parent.upperVD,
+ Path: fspath.Parse(d.name),
+ }
+ cleanupUndoCopyUp := func() {
+ var err error
+ if ftype == linux.S_IFDIR {
+ err = vfsObj.RmdirAt(ctx, d.fs.creds, &newpop)
+ } else {
+ err = vfsObj.UnlinkAt(ctx, d.fs.creds, &newpop)
+ }
+ if err != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after copy-up error: %v", err)
+ }
+ }
+ switch ftype {
+ case linux.S_IFREG:
+ oldFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVDs[0],
+ Start: d.lowerVDs[0],
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ return err
+ }
+ defer oldFD.DecRef()
+ newFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &newpop, &vfs.OpenOptions{
+ Flags: linux.O_WRONLY | linux.O_CREAT | linux.O_EXCL,
+ Mode: linux.FileMode(d.mode &^ linux.S_IFMT),
+ })
+ if err != nil {
+ return err
+ }
+ defer newFD.DecRef()
+ bufIOSeq := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size
+ for {
+ readN, readErr := oldFD.Read(ctx, bufIOSeq, vfs.ReadOptions{})
+ if readErr != nil && readErr != io.EOF {
+ cleanupUndoCopyUp()
+ return readErr
+ }
+ total := int64(0)
+ for total < readN {
+ writeN, writeErr := newFD.Write(ctx, bufIOSeq.DropFirst64(total), vfs.WriteOptions{})
+ total += writeN
+ if writeErr != nil {
+ cleanupUndoCopyUp()
+ return writeErr
+ }
+ }
+ if readErr == io.EOF {
+ break
+ }
+ }
+ if err := newFD.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: d.uid,
+ GID: d.gid,
+ },
+ }); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ d.upperVD = newFD.VirtualDentry()
+ d.upperVD.IncRef()
+
+ case linux.S_IFDIR:
+ if err := vfsObj.MkdirAt(ctx, d.fs.creds, &newpop, &vfs.MkdirOptions{
+ Mode: linux.FileMode(d.mode &^ linux.S_IFMT),
+ }); err != nil {
+ return err
+ }
+ if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: d.uid,
+ GID: d.gid,
+ },
+ }); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ upperVD, err := vfsObj.GetDentryAt(ctx, d.fs.creds, &newpop, &vfs.GetDentryOptions{})
+ if err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ d.upperVD = upperVD
+
+ case linux.S_IFLNK:
+ target, err := vfsObj.ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVDs[0],
+ Start: d.lowerVDs[0],
+ })
+ if err != nil {
+ return err
+ }
+ if err := vfsObj.SymlinkAt(ctx, d.fs.creds, &newpop, target); err != nil {
+ return err
+ }
+ if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID,
+ Mode: uint16(d.mode),
+ UID: d.uid,
+ GID: d.gid,
+ },
+ }); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ upperVD, err := vfsObj.GetDentryAt(ctx, d.fs.creds, &newpop, &vfs.GetDentryOptions{})
+ if err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ d.upperVD = upperVD
+
+ case linux.S_IFBLK, linux.S_IFCHR:
+ lowerStat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVDs[0],
+ Start: d.lowerVDs[0],
+ }, &vfs.StatOptions{})
+ if err != nil {
+ return err
+ }
+ if err := vfsObj.MknodAt(ctx, d.fs.creds, &newpop, &vfs.MknodOptions{
+ Mode: linux.FileMode(d.mode),
+ DevMajor: lowerStat.RdevMajor,
+ DevMinor: lowerStat.RdevMinor,
+ }); err != nil {
+ return err
+ }
+ if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: d.uid,
+ GID: d.gid,
+ },
+ }); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ upperVD, err := vfsObj.GetDentryAt(ctx, d.fs.creds, &newpop, &vfs.GetDentryOptions{})
+ if err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ d.upperVD = upperVD
+
+ default:
+ // Should have rejected this at the beginning of this function?
+ panic(fmt.Sprintf("unexpected file type %o", ftype))
+ }
+
+ // TODO(gvisor.dev/issue/1199): copy up xattrs
+
+ // Update the dentry's device and inode numbers (except for directories,
+ // for which these remain overlay-assigned).
+ if ftype != linux.S_IFDIR {
+ upperStat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.upperVD,
+ Start: d.upperVD,
+ }, &vfs.StatOptions{
+ Mask: linux.STATX_INO,
+ })
+ if err != nil {
+ d.upperVD.DecRef()
+ d.upperVD = vfs.VirtualDentry{}
+ cleanupUndoCopyUp()
+ return err
+ }
+ if upperStat.Mask&linux.STATX_INO == 0 {
+ d.upperVD.DecRef()
+ d.upperVD = vfs.VirtualDentry{}
+ cleanupUndoCopyUp()
+ return syserror.EREMOTE
+ }
+ atomic.StoreUint32(&d.devMajor, upperStat.DevMajor)
+ atomic.StoreUint32(&d.devMinor, upperStat.DevMinor)
+ atomic.StoreUint64(&d.ino, upperStat.Ino)
+ }
+
+ atomic.StoreUint32(&d.copiedUp, 1)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/overlay/directory.go b/pkg/sentry/fsimpl/overlay/directory.go
new file mode 100644
index 000000000..f5c2462a5
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/directory.go
@@ -0,0 +1,287 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func (d *dentry) isDir() bool {
+ return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR
+}
+
+// Preconditions: d.dirMu must be locked. d.isDir().
+func (d *dentry) collectWhiteoutsForRmdirLocked(ctx context.Context) (map[string]bool, error) {
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ var readdirErr error
+ whiteouts := make(map[string]bool)
+ var maybeWhiteouts []string
+ d.iterLayers(func(layerVD vfs.VirtualDentry, isUpper bool) bool {
+ layerFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_DIRECTORY,
+ })
+ if err != nil {
+ readdirErr = err
+ return false
+ }
+ defer layerFD.DecRef()
+
+ // Reuse slice allocated for maybeWhiteouts from a previous layer to
+ // reduce allocations.
+ maybeWhiteouts = maybeWhiteouts[:0]
+ if err := layerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error {
+ if dirent.Name == "." || dirent.Name == ".." {
+ return nil
+ }
+ if _, ok := whiteouts[dirent.Name]; ok {
+ // This file has been whited-out in a previous layer.
+ return nil
+ }
+ if dirent.Type == linux.DT_CHR {
+ // We have to determine if this is a whiteout, which doesn't
+ // count against the directory's emptiness. However, we can't
+ // do so while holding locks held by layerFD.IterDirents().
+ maybeWhiteouts = append(maybeWhiteouts, dirent.Name)
+ return nil
+ }
+ // Non-whiteout file in the directory prevents rmdir.
+ return syserror.ENOTEMPTY
+ })); err != nil {
+ readdirErr = err
+ return false
+ }
+
+ for _, maybeWhiteoutName := range maybeWhiteouts {
+ stat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ Path: fspath.Parse(maybeWhiteoutName),
+ }, &vfs.StatOptions{})
+ if err != nil {
+ readdirErr = err
+ return false
+ }
+ if stat.RdevMajor != 0 || stat.RdevMinor != 0 {
+ // This file is a real character device, not a whiteout.
+ readdirErr = syserror.ENOTEMPTY
+ return false
+ }
+ whiteouts[maybeWhiteoutName] = isUpper
+ }
+ // Continue iteration since we haven't found any non-whiteout files in
+ // this directory yet.
+ return true
+ })
+ return whiteouts, readdirErr
+}
+
+type directoryFD struct {
+ fileDescription
+ vfs.DirectoryFileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+
+ mu sync.Mutex
+ off int64
+ dirents []vfs.Dirent
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *directoryFD) Release() {
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ d := fd.dentry()
+ if fd.dirents == nil {
+ ds, err := d.getDirents(ctx)
+ if err != nil {
+ return err
+ }
+ fd.dirents = ds
+ }
+
+ for fd.off < int64(len(fd.dirents)) {
+ if err := cb.Handle(fd.dirents[fd.off]); err != nil {
+ return err
+ }
+ fd.off++
+ }
+ return nil
+}
+
+// Preconditions: d.isDir().
+func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
+ d.fs.renameMu.RLock()
+ defer d.fs.renameMu.RUnlock()
+ d.dirMu.Lock()
+ defer d.dirMu.Unlock()
+
+ if d.dirents != nil {
+ return d.dirents, nil
+ }
+
+ parent := genericParentOrSelf(d)
+ dirents := []vfs.Dirent{
+ {
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: d.ino,
+ NextOff: 1,
+ },
+ {
+ Name: "..",
+ Type: uint8(atomic.LoadUint32(&parent.mode) >> 12),
+ Ino: parent.ino,
+ NextOff: 2,
+ },
+ }
+
+ // Merge dirents from all layers comprising this directory.
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ var readdirErr error
+ prevDirents := make(map[string]struct{})
+ var maybeWhiteouts []vfs.Dirent
+ d.iterLayers(func(layerVD vfs.VirtualDentry, isUpper bool) bool {
+ layerFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_DIRECTORY,
+ })
+ if err != nil {
+ readdirErr = err
+ return false
+ }
+ defer layerFD.DecRef()
+
+ // Reuse slice allocated for maybeWhiteouts from a previous layer to
+ // reduce allocations.
+ maybeWhiteouts = maybeWhiteouts[:0]
+ if err := layerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error {
+ if dirent.Name == "." || dirent.Name == ".." {
+ return nil
+ }
+ if _, ok := prevDirents[dirent.Name]; ok {
+ // This file is hidden by, or merged with, another file with
+ // the same name in a previous layer.
+ return nil
+ }
+ prevDirents[dirent.Name] = struct{}{}
+ if dirent.Type == linux.DT_CHR {
+ // We can't determine if this file is a whiteout while holding
+ // locks held by layerFD.IterDirents().
+ maybeWhiteouts = append(maybeWhiteouts, dirent)
+ return nil
+ }
+ dirent.NextOff = int64(len(dirents) + 1)
+ dirents = append(dirents, dirent)
+ return nil
+ })); err != nil {
+ readdirErr = err
+ return false
+ }
+
+ for _, dirent := range maybeWhiteouts {
+ stat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ Path: fspath.Parse(dirent.Name),
+ }, &vfs.StatOptions{})
+ if err != nil {
+ readdirErr = err
+ return false
+ }
+ if stat.RdevMajor == 0 && stat.RdevMinor == 0 {
+ // This file is a whiteout; don't emit a dirent for it.
+ continue
+ }
+ dirent.NextOff = int64(len(dirents) + 1)
+ dirents = append(dirents, dirent)
+ }
+ return true
+ })
+ if readdirErr != nil {
+ return nil, readdirErr
+ }
+
+ // Cache dirents for future directoryFDs.
+ d.dirents = dirents
+ return dirents, nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if offset == 0 {
+ // Ensure that the next call to fd.IterDirents() calls
+ // fd.dentry().getDirents().
+ fd.dirents = nil
+ }
+ fd.off = offset
+ return fd.off, nil
+ case linux.SEEK_CUR:
+ offset += fd.off
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ // Don't clear fd.dirents in this case, even if offset == 0.
+ fd.off = offset
+ return fd.off, nil
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync. Forwards sync to the upper
+// layer, if there is one. The lower layer doesn't need to sync because it
+// never changes.
+func (fd *directoryFD) Sync(ctx context.Context) error {
+ d := fd.dentry()
+ if !d.isCopiedUp() {
+ return nil
+ }
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ pop := vfs.PathOperation{
+ Root: d.upperVD,
+ Start: d.upperVD,
+ }
+ upperFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &pop, &vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_DIRECTORY})
+ if err != nil {
+ return err
+ }
+ err = upperFD.Sync(ctx)
+ upperFD.DecRef()
+ return err
+}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
new file mode 100644
index 000000000..6b705e955
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -0,0 +1,1364 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// _OVL_XATTR_OPAQUE is an extended attribute key whose value is set to "y" for
+// opaque directories.
+// Linux: fs/overlayfs/overlayfs.h:OVL_XATTR_OPAQUE
+const _OVL_XATTR_OPAQUE = "trusted.overlay.opaque"
+
+func isWhiteout(stat *linux.Statx) bool {
+ return stat.Mode&linux.S_IFMT == linux.S_IFCHR && stat.RdevMajor == 0 && stat.RdevMinor == 0
+}
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *filesystem) Sync(ctx context.Context) error {
+ if fs.opts.UpperRoot.Ok() {
+ return fs.opts.UpperRoot.Mount().Filesystem().Impl().Sync(ctx)
+ }
+ return nil
+}
+
+var dentrySlicePool = sync.Pool{
+ New: func() interface{} {
+ ds := make([]*dentry, 0, 4) // arbitrary non-zero initial capacity
+ return &ds
+ },
+}
+
+func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry {
+ if ds == nil {
+ ds = dentrySlicePool.Get().(*[]*dentry)
+ }
+ *ds = append(*ds, d)
+ return ds
+}
+
+// Preconditions: ds != nil.
+func putDentrySlice(ds *[]*dentry) {
+ // Allow dentries to be GC'd.
+ for i := range *ds {
+ (*ds)[i] = nil
+ }
+ *ds = (*ds)[:0]
+ dentrySlicePool.Put(ds)
+}
+
+// renameMuRUnlockAndCheckDrop calls fs.renameMu.RUnlock(), then calls
+// dentry.checkDropLocked on all dentries in *ds with fs.renameMu locked for
+// writing.
+//
+// ds is a pointer-to-pointer since defer evaluates its arguments immediately,
+// but dentry slices are allocated lazily, and it's much easier to say "defer
+// fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() {
+// fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this.
+func (fs *filesystem) renameMuRUnlockAndCheckDrop(ds **[]*dentry) {
+ fs.renameMu.RUnlock()
+ if *ds == nil {
+ return
+ }
+ if len(**ds) != 0 {
+ fs.renameMu.Lock()
+ for _, d := range **ds {
+ d.checkDropLocked()
+ }
+ fs.renameMu.Unlock()
+ }
+ putDentrySlice(*ds)
+}
+
+func (fs *filesystem) renameMuUnlockAndCheckDrop(ds **[]*dentry) {
+ if *ds == nil {
+ fs.renameMu.Unlock()
+ return
+ }
+ for _, d := range **ds {
+ d.checkDropLocked()
+ }
+ fs.renameMu.Unlock()
+ putDentrySlice(*ds)
+}
+
+// stepLocked resolves rp.Component() to an existing file, starting from the
+// given directory.
+//
+// Dentries which may have a reference count of zero, and which therefore
+// should be dropped once traversal is complete, are appended to ds.
+//
+// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+// !rp.Done().
+func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) {
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+afterSymlink:
+ name := rp.Component()
+ if name == "." {
+ rp.Advance()
+ return d, nil
+ }
+ if name == ".." {
+ if isRoot, err := rp.CheckRoot(&d.vfsd); err != nil {
+ return nil, err
+ } else if isRoot || d.parent == nil {
+ rp.Advance()
+ return d, nil
+ }
+ if err := rp.CheckMount(&d.parent.vfsd); err != nil {
+ return nil, err
+ }
+ rp.Advance()
+ return d.parent, nil
+ }
+ child, err := fs.getChildLocked(ctx, d, name, ds)
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.CheckMount(&child.vfsd); err != nil {
+ return nil, err
+ }
+ if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() {
+ target, err := child.readlink(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ goto afterSymlink // don't check the current directory again
+ }
+ rp.Advance()
+ return child, nil
+}
+
+// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
+ if child, ok := parent.children[name]; ok {
+ return child, nil
+ }
+ child, err := fs.lookupLocked(ctx, parent, name)
+ if err != nil {
+ return nil, err
+ }
+ if parent.children == nil {
+ parent.children = make(map[string]*dentry)
+ }
+ parent.children[name] = child
+ // child's refcount is initially 0, so it may be dropped after traversal.
+ *ds = appendDentry(*ds, child)
+ return child, nil
+}
+
+// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked.
+func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) {
+ childPath := fspath.Parse(name)
+ child := fs.newDentry()
+ existsOnAnyLayer := false
+ var lookupErr error
+
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ parent.iterLayers(func(parentVD vfs.VirtualDentry, isUpper bool) bool {
+ childVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: parentVD,
+ Start: parentVD,
+ Path: childPath,
+ }, &vfs.GetDentryOptions{})
+ if err == syserror.ENOENT || err == syserror.ENAMETOOLONG {
+ // The file doesn't exist on this layer. Proceed to the next one.
+ return true
+ }
+ if err != nil {
+ lookupErr = err
+ return false
+ }
+
+ mask := uint32(linux.STATX_TYPE)
+ if !existsOnAnyLayer {
+ // Mode, UID, GID, and (for non-directories) inode number come from
+ // the topmost layer on which the file exists.
+ mask |= linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
+ }
+ stat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: childVD,
+ Start: childVD,
+ }, &vfs.StatOptions{
+ Mask: mask,
+ })
+ if err != nil {
+ lookupErr = err
+ return false
+ }
+ if stat.Mask&mask != mask {
+ lookupErr = syserror.EREMOTE
+ return false
+ }
+
+ if isWhiteout(&stat) {
+ // This is a whiteout, so it "doesn't exist" on this layer, and
+ // layers below this one are ignored.
+ return false
+ }
+ isDir := stat.Mode&linux.S_IFMT == linux.S_IFDIR
+ if existsOnAnyLayer && !isDir {
+ // Directories are not merged with non-directory files from lower
+ // layers; instead, layers including and below the first
+ // non-directory file are ignored. (This file must be a directory
+ // on previous layers, since lower layers aren't searched for
+ // non-directory files.)
+ return false
+ }
+
+ // Update child to include this layer.
+ if isUpper {
+ child.upperVD = childVD
+ child.copiedUp = 1
+ } else {
+ child.lowerVDs = append(child.lowerVDs, childVD)
+ }
+ if !existsOnAnyLayer {
+ existsOnAnyLayer = true
+ child.mode = uint32(stat.Mode)
+ child.uid = stat.UID
+ child.gid = stat.GID
+ child.devMajor = stat.DevMajor
+ child.devMinor = stat.DevMinor
+ child.ino = stat.Ino
+ }
+
+ // For non-directory files, only the topmost layer that contains a file
+ // matters.
+ if !isDir {
+ return false
+ }
+
+ // Directories are merged with directories from lower layers if they
+ // are not explicitly opaque.
+ opaqueVal, err := vfsObj.GetxattrAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: childVD,
+ Start: childVD,
+ }, &vfs.GetxattrOptions{
+ Name: _OVL_XATTR_OPAQUE,
+ Size: 1,
+ })
+ return !(err == nil && opaqueVal == "y")
+ })
+
+ if lookupErr != nil {
+ child.destroyLocked()
+ return nil, lookupErr
+ }
+ if !existsOnAnyLayer {
+ child.destroyLocked()
+ return nil, syserror.ENOENT
+ }
+
+ // Device and inode numbers were copied from the topmost layer above;
+ // override them if necessary.
+ if child.isDir() {
+ child.devMajor = linux.UNNAMED_MAJOR
+ child.devMinor = fs.dirDevMinor
+ child.ino = fs.newDirIno()
+ } else if !child.upperVD.Ok() {
+ child.devMajor = linux.UNNAMED_MAJOR
+ child.devMinor = fs.lowerDevMinors[child.lowerVDs[0].Mount().Filesystem()]
+ }
+
+ parent.IncRef()
+ child.parent = parent
+ child.name = name
+ return child, nil
+}
+
+// lookupLayerLocked is similar to lookupLocked, but only returns information
+// about the file rather than a dentry.
+//
+// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked.
+func (fs *filesystem) lookupLayerLocked(ctx context.Context, parent *dentry, name string) (lookupLayer, error) {
+ childPath := fspath.Parse(name)
+ lookupLayer := lookupLayerNone
+ var lookupErr error
+
+ parent.iterLayers(func(parentVD vfs.VirtualDentry, isUpper bool) bool {
+ stat, err := fs.vfsfs.VirtualFilesystem().StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: parentVD,
+ Start: parentVD,
+ Path: childPath,
+ }, &vfs.StatOptions{
+ Mask: linux.STATX_TYPE,
+ })
+ if err == syserror.ENOENT || err == syserror.ENAMETOOLONG {
+ // The file doesn't exist on this layer. Proceed to the next
+ // one.
+ return true
+ }
+ if err != nil {
+ lookupErr = err
+ return false
+ }
+ if stat.Mask&linux.STATX_TYPE == 0 {
+ // Linux's overlayfs tends to return EREMOTE in cases where a file
+ // is unusable for reasons that are not better captured by another
+ // errno.
+ lookupErr = syserror.EREMOTE
+ return false
+ }
+ if isWhiteout(&stat) {
+ // This is a whiteout, so it "doesn't exist" on this layer, and
+ // layers below this one are ignored.
+ if isUpper {
+ lookupLayer = lookupLayerUpperWhiteout
+ }
+ return false
+ }
+ // The file exists; we can stop searching.
+ if isUpper {
+ lookupLayer = lookupLayerUpper
+ } else {
+ lookupLayer = lookupLayerLower
+ }
+ return false
+ })
+
+ return lookupLayer, lookupErr
+}
+
+type lookupLayer int
+
+const (
+ // lookupLayerNone indicates that no file exists at the given path on the
+ // upper layer, and is either whited out or does not exist on lower layers.
+ // Therefore, the file does not exist in the overlay filesystem, and file
+ // creation may proceed normally (if an upper layer exists).
+ lookupLayerNone lookupLayer = iota
+
+ // lookupLayerLower indicates that no file exists at the given path on the
+ // upper layer, but exists on a lower layer. Therefore, the file exists in
+ // the overlay filesystem, but must be copied-up before mutation.
+ lookupLayerLower
+
+ // lookupLayerUpper indicates that a non-whiteout file exists at the given
+ // path on the upper layer. Therefore, the file exists in the overlay
+ // filesystem, and is already copied-up.
+ lookupLayerUpper
+
+ // lookupLayerUpperWhiteout indicates that a whiteout exists at the given
+ // path on the upper layer. Therefore, the file does not exist in the
+ // overlay filesystem, and file creation must remove the whiteout before
+ // proceeding.
+ lookupLayerUpperWhiteout
+)
+
+func (ll lookupLayer) existsInOverlay() bool {
+ return ll == lookupLayerLower || ll == lookupLayerUpper
+}
+
+// walkParentDirLocked resolves all but the last path component of rp to an
+// existing directory, starting from the given directory (which is usually
+// rp.Start().Impl().(*dentry)). It does not check that the returned directory
+// is searchable by the provider of rp.
+//
+// Preconditions: fs.renameMu must be locked. !rp.Done().
+func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
+ for !rp.Final() {
+ d.dirMu.Lock()
+ next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// resolveLocked resolves rp to an existing file.
+//
+// Preconditions: fs.renameMu must be locked.
+func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) {
+ d := rp.Start().Impl().(*dentry)
+ for !rp.Done() {
+ d.dirMu.Lock()
+ next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if rp.MustBeDir() && !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// doCreateAt checks that creating a file at rp is permitted, then invokes
+// create to do so.
+//
+// Preconditions: !rp.Done(). For the final path component in rp,
+// !rp.ShouldFollowSymlink().
+func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ start := rp.Start().Impl().(*dentry)
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EEXIST
+ }
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
+ if parent.vfsd.IsDead() {
+ return syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+
+ // Determine if a file already exists at name.
+ if _, ok := parent.children[name]; ok {
+ return syserror.EEXIST
+ }
+ childLayer, err := fs.lookupLayerLocked(ctx, parent, name)
+ if err != nil {
+ return err
+ }
+ if childLayer.existsInOverlay() {
+ return syserror.EEXIST
+ }
+
+ // Ensure that the parent directory is copied-up so that we can create the
+ // new file in the upper layer.
+ if err := parent.copyUpLocked(ctx); err != nil {
+ return err
+ }
+
+ // Finally create the new file.
+ if err := create(parent, name, childLayer == lookupLayerUpperWhiteout); err != nil {
+ return err
+ }
+ parent.dirents = nil
+ return nil
+}
+
+// Preconditions: pop's parent directory has been copied up.
+func (fs *filesystem) createWhiteout(ctx context.Context, vfsObj *vfs.VirtualFilesystem, pop *vfs.PathOperation) error {
+ return vfsObj.MknodAt(ctx, fs.creds, pop, &vfs.MknodOptions{
+ Mode: linux.S_IFCHR, // permissions == include/linux/fs.h:WHITEOUT_MODE == 0
+ // DevMajor == DevMinor == 0, from include/linux/fs.h:WHITEOUT_DEV
+ })
+}
+
+func (fs *filesystem) cleanupRecreateWhiteout(ctx context.Context, vfsObj *vfs.VirtualFilesystem, pop *vfs.PathOperation) {
+ if err := fs.createWhiteout(ctx, vfsObj, pop); err != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to recreate whiteout after failed file creation: %v", err)
+ }
+}
+
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return d.checkPermissions(creds, ats)
+}
+
+// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ layerVD := d.topLayer()
+ return fs.vfsfs.VirtualFilesystem().BoundEndpointAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, &opts)
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ if opts.CheckSearchable {
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ start := rp.Start().Impl().(*dentry)
+ d, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ if rp.Mount() != vd.Mount() {
+ return syserror.EXDEV
+ }
+ old := vd.Dentry().Impl().(*dentry)
+ if old.isDir() {
+ return syserror.EPERM
+ }
+ if err := old.copyUpLocked(ctx); err != nil {
+ return err
+ }
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ newpop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &newpop); err != nil {
+ return err
+ }
+ }
+ if err := vfsObj.LinkAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: old.upperVD,
+ Start: old.upperVD,
+ }, &newpop); err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &newpop)
+ }
+ return err
+ }
+ creds := rp.Credentials()
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &newpop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &newpop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after LinkAt metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &newpop)
+ }
+ return err
+ }
+ return nil
+ })
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ return err
+ }
+ }
+ if err := vfsObj.MkdirAt(ctx, fs.creds, &pop, &opts); err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ creds := rp.Credentials()
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ if haveUpperWhiteout {
+ // There may be directories on lower layers (previously hidden by
+ // the whiteout) that the new directory should not be merged with.
+ // Mark it opaque to prevent merging.
+ if err := vfsObj.SetxattrAt(ctx, fs.creds, &pop, &vfs.SetxattrOptions{
+ Name: _OVL_XATTR_OPAQUE,
+ Value: "y",
+ }); err != nil {
+ if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt set-opaque failure: %v", cleanupErr)
+ } else {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ }
+ return nil
+ })
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ // Disallow attempts to create whiteouts.
+ if opts.Mode&linux.S_IFMT == linux.S_IFCHR && opts.DevMajor == 0 && opts.DevMinor == 0 {
+ return syserror.EPERM
+ }
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ return err
+ }
+ }
+ if err := vfsObj.MknodAt(ctx, fs.creds, &pop, &opts); err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ creds := rp.Credentials()
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ return nil
+ })
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ mayCreate := opts.Flags&linux.O_CREAT != 0
+ mustCreate := opts.Flags&(linux.O_CREAT|linux.O_EXCL) == (linux.O_CREAT | linux.O_EXCL)
+
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+
+ start := rp.Start().Impl().(*dentry)
+ if rp.Done() {
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ return start.openLocked(ctx, rp, &opts)
+ }
+
+afterTrailingSymlink:
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+ // Check for search permission in the parent directory.
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ // Determine whether or not we need to create a file.
+ parent.dirMu.Lock()
+ child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ if err == syserror.ENOENT && mayCreate {
+ fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds)
+ parent.dirMu.Unlock()
+ return fd, err
+ }
+ if err != nil {
+ parent.dirMu.Unlock()
+ return nil, err
+ }
+ // Open existing child or follow symlink.
+ parent.dirMu.Unlock()
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ if child.isSymlink() && rp.ShouldFollowSymlink() {
+ target, err := child.readlink(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ start = parent
+ goto afterTrailingSymlink
+ }
+ return child.openLocked(ctx, rp, &opts)
+}
+
+// Preconditions: fs.renameMu must be locked.
+func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
+ if err := d.checkPermissions(rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ if ats.MayWrite() {
+ if err := d.copyUpLocked(ctx); err != nil {
+ return nil, err
+ }
+ }
+ mnt := rp.Mount()
+
+ // Directory FDs open FDs from each layer when directory entries are read,
+ // so they don't require opening an FD from d.topLayer() up front.
+ ftype := atomic.LoadUint32(&d.mode) & linux.S_IFMT
+ if ftype == linux.S_IFDIR {
+ // Can't open directories with O_CREAT.
+ if opts.Flags&linux.O_CREAT != 0 {
+ return nil, syserror.EISDIR
+ }
+ // Can't open directories writably.
+ if ats&vfs.MayWrite != 0 {
+ return nil, syserror.EISDIR
+ }
+ if opts.Flags&linux.O_DIRECT != 0 {
+ return nil, syserror.EINVAL
+ }
+ fd := &directoryFD{}
+ fd.LockFD.Init(&d.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+ }
+
+ layerVD, isUpper := d.topLayerInfo()
+ layerFD, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, opts)
+ if err != nil {
+ return nil, err
+ }
+ layerFlags := layerFD.StatusFlags()
+ fd := &nonDirectoryFD{
+ copiedUp: isUpper,
+ cachedFD: layerFD,
+ cachedFlags: layerFlags,
+ }
+ fd.LockFD.Init(&d.locks)
+ layerFDOpts := layerFD.Options()
+ if err := fd.vfsfd.Init(fd, layerFlags, mnt, &d.vfsd, &layerFDOpts); err != nil {
+ layerFD.DecRef()
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// Preconditions: parent.dirMu must be locked. parent does not already contain
+// a child named rp.Component().
+func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) {
+ creds := rp.Credentials()
+ if err := parent.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ if parent.vfsd.IsDead() {
+ return nil, syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ defer mnt.EndWrite()
+
+ if err := parent.copyUpLocked(ctx); err != nil {
+ return nil, err
+ }
+
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ childName := rp.Component()
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ // We don't know if a whiteout exists on the upper layer; speculatively
+ // unlink it.
+ //
+ // TODO(gvisor.dev/issue/1199): Modify OpenAt => stepLocked so that we do
+ // know whether a whiteout exists.
+ var haveUpperWhiteout bool
+ switch err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err {
+ case nil:
+ haveUpperWhiteout = true
+ case syserror.ENOENT:
+ haveUpperWhiteout = false
+ default:
+ return nil, err
+ }
+ // Create the file on the upper layer, and get an FD representing it.
+ upperFD, err := vfsObj.OpenAt(ctx, fs.creds, &pop, &vfs.OpenOptions{
+ Flags: opts.Flags&^vfs.FileCreationFlags | linux.O_CREAT | linux.O_EXCL,
+ Mode: opts.Mode,
+ })
+ if err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return nil, err
+ }
+ // Change the file's owner to the caller. We can't use upperFD.SetStat()
+ // because it will pick up creds from ctx.
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return nil, err
+ }
+ // Re-lookup to get a dentry representing the new file, which is needed for
+ // the returned FD.
+ child, err := fs.getChildLocked(ctx, parent, childName, ds)
+ if err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return nil, err
+ }
+ // Finally construct the overlay FD.
+ upperFlags := upperFD.StatusFlags()
+ fd := &nonDirectoryFD{
+ copiedUp: true,
+ cachedFD: upperFD,
+ cachedFlags: upperFlags,
+ }
+ fd.LockFD.Init(&child.locks)
+ upperFDOpts := upperFD.Options()
+ if err := fd.vfsfd.Init(fd, upperFlags, mnt, &child.vfsd, &upperFDOpts); err != nil {
+ upperFD.DecRef()
+ // Don't bother with cleanup; the file was created successfully, we
+ // just can't open it anymore for some reason.
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return "", err
+ }
+ layerVD := d.topLayer()
+ return fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ })
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
+ if opts.Flags != 0 {
+ return syserror.EINVAL
+ }
+
+ var ds *[]*dentry
+ fs.renameMu.Lock()
+ defer fs.renameMuUnlockAndCheckDrop(&ds)
+ newParent, err := fs.walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry), &ds)
+ if err != nil {
+ return err
+ }
+ newName := rp.Component()
+ if newName == "." || newName == ".." {
+ return syserror.EBUSY
+ }
+ mnt := rp.Mount()
+ if mnt != oldParentVD.Mount() {
+ return syserror.EXDEV
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+
+ // FIXME(gvisor.dev/issue/1199): Actually implement rename.
+ _ = newParent
+ return syserror.EXDEV
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ start := rp.Start().Impl().(*dentry)
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ name := rp.Component()
+ if name == "." {
+ return syserror.EINVAL
+ }
+ if name == ".." {
+ return syserror.ENOTEMPTY
+ }
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+
+ // Ensure that parent is copied-up before potentially holding child.copyMu
+ // below.
+ if err := parent.copyUpLocked(ctx); err != nil {
+ return err
+ }
+
+ // Unlike UnlinkAt, we need a dentry representing the child directory being
+ // removed in order to verify that it's empty.
+ child, err := fs.getChildLocked(ctx, parent, name, &ds)
+ if err != nil {
+ return err
+ }
+ if !child.isDir() {
+ return syserror.ENOTDIR
+ }
+ child.dirMu.Lock()
+ defer child.dirMu.Unlock()
+ whiteouts, err := child.collectWhiteoutsForRmdirLocked(ctx)
+ if err != nil {
+ return err
+ }
+ child.copyMu.RLock()
+ defer child.copyMu.RUnlock()
+ if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
+ return err
+ }
+
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(name),
+ }
+ if child.upperVD.Ok() {
+ cleanupRecreateWhiteouts := func() {
+ if !child.upperVD.Ok() {
+ return
+ }
+ for whiteoutName, whiteoutUpper := range whiteouts {
+ if !whiteoutUpper {
+ continue
+ }
+ if err := fs.createWhiteout(ctx, vfsObj, &vfs.PathOperation{
+ Root: child.upperVD,
+ Start: child.upperVD,
+ Path: fspath.Parse(whiteoutName),
+ }); err != nil && err != syserror.EEXIST {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RmdirAt failure: %v", err)
+ }
+ }
+ }
+ // Remove existing whiteouts on the upper layer.
+ for whiteoutName, whiteoutUpper := range whiteouts {
+ if !whiteoutUpper {
+ continue
+ }
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: child.upperVD,
+ Start: child.upperVD,
+ Path: fspath.Parse(whiteoutName),
+ }); err != nil {
+ cleanupRecreateWhiteouts()
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return err
+ }
+ }
+ // Remove the existing directory on the upper layer.
+ if err := vfsObj.RmdirAt(ctx, fs.creds, &pop); err != nil {
+ cleanupRecreateWhiteouts()
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return err
+ }
+ }
+ if err := fs.createWhiteout(ctx, vfsObj, &pop); err != nil {
+ // Don't attempt to recover from this: the original directory is
+ // already gone, so any dentries representing it are invalid, and
+ // creating a new directory won't undo that.
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to create whiteout during RmdirAt: %v", err)
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return err
+ }
+
+ vfsObj.CommitDeleteDentry(&child.vfsd)
+ delete(parent.children, name)
+ ds = appendDentry(ds, child)
+ parent.dirents = nil
+ return nil
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+
+ mode := linux.FileMode(atomic.LoadUint32(&d.mode))
+ if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ return err
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ if err := d.copyUpLocked(ctx); err != nil {
+ return err
+ }
+ // Changes to d's attributes are serialized by d.copyMu.
+ d.copyMu.Lock()
+ defer d.copyMu.Unlock()
+ if err := d.fs.vfsfs.VirtualFilesystem().SetStatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.upperVD,
+ Start: d.upperVD,
+ }, &opts); err != nil {
+ return err
+ }
+ d.updateAfterSetStatLocked(&opts)
+ return nil
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+
+ var stat linux.Statx
+ if layerMask := opts.Mask &^ statInternalMask; layerMask != 0 {
+ layerVD := d.topLayer()
+ stat, err = fs.vfsfs.VirtualFilesystem().StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, &vfs.StatOptions{
+ Mask: layerMask,
+ Sync: opts.Sync,
+ })
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ }
+ d.statInternalTo(ctx, &opts, &stat)
+ return stat, nil
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ return fs.statFS(ctx)
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ return err
+ }
+ }
+ if err := vfsObj.SymlinkAt(ctx, fs.creds, &pop, target); err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ creds := rp.Credentials()
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after SymlinkAt metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ return nil
+ })
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ start := rp.Start().Impl().(*dentry)
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EISDIR
+ }
+ if rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+
+ // Ensure that parent is copied-up before potentially holding child.copyMu
+ // below.
+ if err := parent.copyUpLocked(ctx); err != nil {
+ return err
+ }
+
+ child := parent.children[name]
+ var childLayer lookupLayer
+ if child != nil {
+ if child.isDir() {
+ return syserror.EISDIR
+ }
+ if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
+ return err
+ }
+ // Hold child.copyMu to prevent it from being copied-up during
+ // deletion.
+ child.copyMu.RLock()
+ defer child.copyMu.RUnlock()
+ if child.upperVD.Ok() {
+ childLayer = lookupLayerUpper
+ } else {
+ childLayer = lookupLayerLower
+ }
+ } else {
+ // Determine if the file being unlinked actually exists. Holding
+ // parent.dirMu prevents a dentry from being instantiated for the file,
+ // which in turn prevents it from being copied-up, so this result is
+ // stable.
+ childLayer, err = fs.lookupLayerLocked(ctx, parent, name)
+ if err != nil {
+ return err
+ }
+ if !childLayer.existsInOverlay() {
+ return syserror.ENOENT
+ }
+ }
+
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(name),
+ }
+ if childLayer == lookupLayerUpper {
+ // Remove the existing file on the upper layer.
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ if child != nil {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ }
+ return err
+ }
+ }
+ if err := fs.createWhiteout(ctx, vfsObj, &pop); err != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to create whiteout during UnlinkAt: %v", err)
+ if child != nil {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ }
+ return err
+ }
+
+ if child != nil {
+ vfsObj.CommitDeleteDentry(&child.vfsd)
+ delete(parent.children, name)
+ ds = appendDentry(ds, child)
+ }
+ parent.dirents = nil
+ return nil
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ // TODO(gvisor.dev/issue/1199): Linux overlayfs actually allows listxattr,
+ // but not any other xattr syscalls. For now we just reject all of them.
+ return nil, syserror.ENOTSUP
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return "", err
+ }
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return syserror.ENOTSUP
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(&ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return syserror.ENOTSUP
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.renameMu.RLock()
+ defer fs.renameMu.RUnlock()
+ return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
+}
diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go
new file mode 100644
index 000000000..c0749e711
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/non_directory.go
@@ -0,0 +1,266 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func (d *dentry) isSymlink() bool {
+ return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFLNK
+}
+
+func (d *dentry) readlink(ctx context.Context) (string, error) {
+ layerVD := d.topLayer()
+ return d.fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ })
+}
+
+type nonDirectoryFD struct {
+ fileDescription
+
+ // If copiedUp is false, cachedFD represents
+ // fileDescription.dentry().lowerVDs[0]; otherwise, cachedFD represents
+ // fileDescription.dentry().upperVD. cachedFlags is the last known value of
+ // cachedFD.StatusFlags(). copiedUp, cachedFD, and cachedFlags are
+ // protected by mu.
+ mu sync.Mutex
+ copiedUp bool
+ cachedFD *vfs.FileDescription
+ cachedFlags uint32
+}
+
+func (fd *nonDirectoryFD) getCurrentFD(ctx context.Context) (*vfs.FileDescription, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return nil, err
+ }
+ wrappedFD.IncRef()
+ return wrappedFD, nil
+}
+
+func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescription, error) {
+ d := fd.dentry()
+ statusFlags := fd.vfsfd.StatusFlags()
+ if !fd.copiedUp && d.isCopiedUp() {
+ // Switch to the copied-up file.
+ upperVD := d.topLayer()
+ upperFD, err := fd.filesystem().vfsfs.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: upperVD,
+ Start: upperVD,
+ }, &vfs.OpenOptions{
+ Flags: statusFlags,
+ })
+ if err != nil {
+ return nil, err
+ }
+ oldOff, oldOffErr := fd.cachedFD.Seek(ctx, 0, linux.SEEK_CUR)
+ if oldOffErr == nil {
+ if _, err := upperFD.Seek(ctx, oldOff, linux.SEEK_SET); err != nil {
+ upperFD.DecRef()
+ return nil, err
+ }
+ }
+ fd.cachedFD.DecRef()
+ fd.copiedUp = true
+ fd.cachedFD = upperFD
+ fd.cachedFlags = statusFlags
+ } else if fd.cachedFlags != statusFlags {
+ if err := fd.cachedFD.SetStatusFlags(ctx, d.fs.creds, statusFlags); err != nil {
+ return nil, err
+ }
+ fd.cachedFlags = statusFlags
+ }
+ return fd.cachedFD, nil
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *nonDirectoryFD) Release() {
+ fd.cachedFD.DecRef()
+ fd.cachedFD = nil
+}
+
+// OnClose implements vfs.FileDescriptionImpl.OnClose.
+func (fd *nonDirectoryFD) OnClose(ctx context.Context) error {
+ // Linux doesn't define ovl_file_operations.flush at all (i.e. its
+ // equivalent to OnClose is a no-op). We pass through to
+ // fd.cachedFD.OnClose() without upgrading if fd.dentry() has been
+ // copied-up, since OnClose is mostly used to define post-close writeback,
+ // and if fd.cachedFD hasn't been updated then it can't have been used to
+ // mutate fd.dentry() anyway.
+ fd.mu.Lock()
+ if statusFlags := fd.vfsfd.StatusFlags(); fd.cachedFlags != statusFlags {
+ if err := fd.cachedFD.SetStatusFlags(ctx, fd.filesystem().creds, statusFlags); err != nil {
+ fd.mu.Unlock()
+ return err
+ }
+ fd.cachedFlags = statusFlags
+ }
+ wrappedFD := fd.cachedFD
+ defer wrappedFD.IncRef()
+ fd.mu.Unlock()
+ return wrappedFD.OnClose(ctx)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ var stat linux.Statx
+ if layerMask := opts.Mask &^ statInternalMask; layerMask != 0 {
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ stat, err = wrappedFD.Stat(ctx, vfs.StatOptions{
+ Mask: layerMask,
+ Sync: opts.Sync,
+ })
+ wrappedFD.DecRef()
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ }
+ fd.dentry().statInternalTo(ctx, &opts, &stat)
+ return stat, nil
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ d := fd.dentry()
+ mode := linux.FileMode(atomic.LoadUint32(&d.mode))
+ if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ return err
+ }
+ mnt := fd.vfsfd.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ if err := d.copyUpLocked(ctx); err != nil {
+ return err
+ }
+ // Changes to d's attributes are serialized by d.copyMu.
+ d.copyMu.Lock()
+ defer d.copyMu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return err
+ }
+ if err := wrappedFD.SetStat(ctx, opts); err != nil {
+ return err
+ }
+ d.updateAfterSetStatLocked(&opts)
+ return nil
+}
+
+// StatFS implements vfs.FileDescriptionImpl.StatFS.
+func (fd *nonDirectoryFD) StatFS(ctx context.Context) (linux.Statfs, error) {
+ return fd.filesystem().statFS(ctx)
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *nonDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ return 0, err
+ }
+ defer wrappedFD.DecRef()
+ return wrappedFD.PRead(ctx, dst, offset, opts)
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *nonDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // Hold fd.mu during the read to serialize the file offset.
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return 0, err
+ }
+ return wrappedFD.Read(ctx, dst, opts)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *nonDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ return 0, err
+ }
+ defer wrappedFD.DecRef()
+ return wrappedFD.PWrite(ctx, src, offset, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *nonDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // Hold fd.mu during the write to serialize the file offset.
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return 0, err
+ }
+ return wrappedFD.Write(ctx, src, opts)
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *nonDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ // Hold fd.mu during the seek to serialize the file offset.
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return 0, err
+ }
+ return wrappedFD.Seek(ctx, offset, whence)
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *nonDirectoryFD) Sync(ctx context.Context) error {
+ fd.mu.Lock()
+ if !fd.dentry().isCopiedUp() {
+ fd.mu.Unlock()
+ return nil
+ }
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ fd.mu.Unlock()
+ return err
+ }
+ wrappedFD.IncRef()
+ defer wrappedFD.DecRef()
+ fd.mu.Unlock()
+ return wrappedFD.Sync(ctx)
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ return err
+ }
+ defer wrappedFD.DecRef()
+ return wrappedFD.ConfigureMMap(ctx, opts)
+}
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
new file mode 100644
index 000000000..e720d4825
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -0,0 +1,627 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package overlay provides an overlay filesystem implementation, which
+// synthesizes a filesystem by composing one or more immutable filesystems
+// ("lower layers") with an optional mutable filesystem ("upper layer").
+//
+// Lock order:
+//
+// directoryFD.mu / nonDirectoryFD.mu
+// filesystem.renameMu
+// dentry.dirMu
+// dentry.copyMu
+//
+// Locking dentry.dirMu in multiple dentries requires that parent dentries are
+// locked before child dentries, and that filesystem.renameMu is locked to
+// stabilize this relationship.
+package overlay
+
+import (
+ "strings"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the default filesystem name.
+const Name = "overlay"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// FilesystemOptions may be passed as vfs.GetFilesystemOptions.InternalData to
+// FilesystemType.GetFilesystem.
+type FilesystemOptions struct {
+ // Callers passing FilesystemOptions to
+ // overlay.FilesystemType.GetFilesystem() are responsible for ensuring that
+ // the vfs.Mounts comprising the layers of the overlay filesystem do not
+ // contain submounts.
+
+ // If UpperRoot.Ok(), it is the root of the writable upper layer of the
+ // overlay.
+ UpperRoot vfs.VirtualDentry
+
+ // LowerRoots contains the roots of the immutable lower layers of the
+ // overlay. LowerRoots is immutable.
+ LowerRoots []vfs.VirtualDentry
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ vfsfs vfs.Filesystem
+
+ // Immutable options.
+ opts FilesystemOptions
+
+ // creds is a copy of the filesystem's creator's credentials, which are
+ // used for accesses to the filesystem's layers. creds is immutable.
+ creds *auth.Credentials
+
+ // dirDevMinor is the device minor number used for directories. dirDevMinor
+ // is immutable.
+ dirDevMinor uint32
+
+ // lowerDevMinors maps lower layer filesystems to device minor numbers
+ // assigned to non-directory files originating from that filesystem.
+ // lowerDevMinors is immutable.
+ lowerDevMinors map[*vfs.Filesystem]uint32
+
+ // renameMu synchronizes renaming with non-renaming operations in order to
+ // ensure consistent lock ordering between dentry.dirMu in different
+ // dentries.
+ renameMu sync.RWMutex
+
+ // lastDirIno is the last inode number assigned to a directory. lastDirIno
+ // is accessed using atomic memory operations.
+ lastDirIno uint64
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ fsoptsRaw := opts.InternalData
+ fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions)
+ if fsoptsRaw != nil && !haveFSOpts {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw)
+ return nil, nil, syserror.EINVAL
+ }
+ if haveFSOpts {
+ if len(fsopts.LowerRoots) == 0 {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty")
+ return nil, nil, syserror.EINVAL
+ }
+ if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified")
+ return nil, nil, syserror.EINVAL
+ }
+ // We don't enforce a maximum number of lower layers when not
+ // configured by applications; the sandbox owner can have an overlay
+ // filesystem with any number of lower layers.
+ } else {
+ vfsroot := vfs.RootFromContext(ctx)
+ defer vfsroot.DecRef()
+ upperPathname, ok := mopts["upperdir"]
+ if ok {
+ delete(mopts, "upperdir")
+ // Linux overlayfs also requires a workdir when upperdir is
+ // specified; we don't, so silently ignore this option.
+ delete(mopts, "workdir")
+ upperPath := fspath.Parse(upperPathname)
+ if !upperPath.Absolute {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname)
+ return nil, nil, syserror.EINVAL
+ }
+ upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
+ Root: vfsroot,
+ Start: vfsroot,
+ Path: upperPath,
+ FollowFinalSymlink: true,
+ }, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err)
+ return nil, nil, err
+ }
+ defer upperRoot.DecRef()
+ privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */)
+ if err != nil {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err)
+ return nil, nil, err
+ }
+ defer privateUpperRoot.DecRef()
+ fsopts.UpperRoot = privateUpperRoot
+ }
+ lowerPathnamesStr, ok := mopts["lowerdir"]
+ if !ok {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: missing required option lowerdir")
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "lowerdir")
+ lowerPathnames := strings.Split(lowerPathnamesStr, ":")
+ const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK
+ if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified")
+ return nil, nil, syserror.EINVAL
+ }
+ if len(lowerPathnames) > maxLowerLayers {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers)
+ return nil, nil, syserror.EINVAL
+ }
+ for _, lowerPathname := range lowerPathnames {
+ lowerPath := fspath.Parse(lowerPathname)
+ if !lowerPath.Absolute {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: lowerdir %q must be absolute", lowerPathname)
+ return nil, nil, syserror.EINVAL
+ }
+ lowerRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
+ Root: vfsroot,
+ Start: vfsroot,
+ Path: lowerPath,
+ FollowFinalSymlink: true,
+ }, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err)
+ return nil, nil, err
+ }
+ defer lowerRoot.DecRef()
+ privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */)
+ if err != nil {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err)
+ return nil, nil, err
+ }
+ defer privateLowerRoot.DecRef()
+ fsopts.LowerRoots = append(fsopts.LowerRoots, privateLowerRoot)
+ }
+ }
+ if len(mopts) != 0 {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Allocate device numbers.
+ dirDevMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+ lowerDevMinors := make(map[*vfs.Filesystem]uint32)
+ for _, lowerRoot := range fsopts.LowerRoots {
+ lowerFS := lowerRoot.Mount().Filesystem()
+ if _, ok := lowerDevMinors[lowerFS]; !ok {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ vfsObj.PutAnonBlockDevMinor(dirDevMinor)
+ for _, lowerDevMinor := range lowerDevMinors {
+ vfsObj.PutAnonBlockDevMinor(lowerDevMinor)
+ }
+ return nil, nil, err
+ }
+ lowerDevMinors[lowerFS] = devMinor
+ }
+ }
+
+ // Take extra references held by the filesystem.
+ if fsopts.UpperRoot.Ok() {
+ fsopts.UpperRoot.IncRef()
+ }
+ for _, lowerRoot := range fsopts.LowerRoots {
+ lowerRoot.IncRef()
+ }
+
+ fs := &filesystem{
+ opts: fsopts,
+ creds: creds.Fork(),
+ dirDevMinor: dirDevMinor,
+ lowerDevMinors: lowerDevMinors,
+ }
+ fs.vfsfs.Init(vfsObj, &fstype, fs)
+
+ // Construct the root dentry.
+ root := fs.newDentry()
+ root.refs = 1
+ if fs.opts.UpperRoot.Ok() {
+ fs.opts.UpperRoot.IncRef()
+ root.copiedUp = 1
+ root.upperVD = fs.opts.UpperRoot
+ }
+ for _, lowerRoot := range fs.opts.LowerRoots {
+ lowerRoot.IncRef()
+ root.lowerVDs = append(root.lowerVDs, lowerRoot)
+ }
+ rootTopVD := root.topLayer()
+ // Get metadata from the topmost layer. See fs.lookupLocked().
+ const rootStatMask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
+ rootStat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: rootTopVD,
+ Start: rootTopVD,
+ }, &vfs.StatOptions{
+ Mask: rootStatMask,
+ })
+ if err != nil {
+ root.destroyLocked()
+ fs.vfsfs.DecRef()
+ return nil, nil, err
+ }
+ if rootStat.Mask&rootStatMask != rootStatMask {
+ root.destroyLocked()
+ fs.vfsfs.DecRef()
+ return nil, nil, syserror.EREMOTE
+ }
+ if isWhiteout(&rootStat) {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout")
+ root.destroyLocked()
+ fs.vfsfs.DecRef()
+ return nil, nil, syserror.EINVAL
+ }
+ root.mode = uint32(rootStat.Mode)
+ root.uid = rootStat.UID
+ root.gid = rootStat.GID
+ if rootStat.Mode&linux.S_IFMT == linux.S_IFDIR {
+ root.devMajor = linux.UNNAMED_MAJOR
+ root.devMinor = fs.dirDevMinor
+ root.ino = fs.newDirIno()
+ } else if !root.upperVD.Ok() {
+ root.devMajor = linux.UNNAMED_MAJOR
+ root.devMinor = fs.lowerDevMinors[root.lowerVDs[0].Mount().Filesystem()]
+ root.ino = rootStat.Ino
+ } else {
+ root.devMajor = rootStat.DevMajor
+ root.devMinor = rootStat.DevMinor
+ root.ino = rootStat.Ino
+ }
+
+ return &fs.vfsfs, &root.vfsd, nil
+}
+
+// clonePrivateMount creates a non-recursive bind mount rooted at vd, not
+// associated with any MountNamespace, and returns the root of the new mount.
+// (This is required to ensure that each layer of an overlay comprises only a
+// single mount, and therefore can't cross into e.g. the overlay filesystem
+// itself, risking lock recursion.) A reference is held on the returned
+// VirtualDentry.
+func clonePrivateMount(vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, forceReadOnly bool) (vfs.VirtualDentry, error) {
+ oldmnt := vd.Mount()
+ opts := oldmnt.Options()
+ if forceReadOnly {
+ opts.ReadOnly = true
+ }
+ newmnt, err := vfsObj.NewDisconnectedMount(oldmnt.Filesystem(), vd.Dentry(), &opts)
+ if err != nil {
+ return vfs.VirtualDentry{}, err
+ }
+ return vfs.MakeVirtualDentry(newmnt, vd.Dentry()), nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ vfsObj.PutAnonBlockDevMinor(fs.dirDevMinor)
+ for _, lowerDevMinor := range fs.lowerDevMinors {
+ vfsObj.PutAnonBlockDevMinor(lowerDevMinor)
+ }
+ if fs.opts.UpperRoot.Ok() {
+ fs.opts.UpperRoot.DecRef()
+ }
+ for _, lowerRoot := range fs.opts.LowerRoots {
+ lowerRoot.DecRef()
+ }
+}
+
+func (fs *filesystem) statFS(ctx context.Context) (linux.Statfs, error) {
+ // Always statfs the root of the topmost layer. Compare Linux's
+ // fs/overlayfs/super.c:ovl_statfs().
+ var rootVD vfs.VirtualDentry
+ if fs.opts.UpperRoot.Ok() {
+ rootVD = fs.opts.UpperRoot
+ } else {
+ rootVD = fs.opts.LowerRoots[0]
+ }
+ fsstat, err := fs.vfsfs.VirtualFilesystem().StatFSAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: rootVD,
+ Start: rootVD,
+ })
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ fsstat.Type = linux.OVERLAYFS_SUPER_MAGIC
+ return fsstat, nil
+}
+
+func (fs *filesystem) newDirIno() uint64 {
+ return atomic.AddUint64(&fs.lastDirIno, 1)
+}
+
+// dentry implements vfs.DentryImpl.
+type dentry struct {
+ vfsd vfs.Dentry
+
+ refs int64
+
+ // fs is the owning filesystem. fs is immutable.
+ fs *filesystem
+
+ // mode, uid, and gid are the file mode, owner, and group of the file in
+ // the topmost layer (and therefore the overlay file as well), and are used
+ // for permission checks on this dentry. These fields are protected by
+ // copyMu and accessed using atomic memory operations.
+ mode uint32
+ uid uint32
+ gid uint32
+
+ // copiedUp is 1 if this dentry has been copied-up (i.e. upperVD.Ok()) and
+ // 0 otherwise. copiedUp is accessed using atomic memory operations.
+ copiedUp uint32
+
+ // parent is the dentry corresponding to this dentry's parent directory.
+ // name is this dentry's name in parent. If this dentry is a filesystem
+ // root, parent is nil and name is the empty string. parent and name are
+ // protected by fs.renameMu.
+ parent *dentry
+ name string
+
+ // If this dentry represents a directory, children maps the names of
+ // children for which dentries have been instantiated to those dentries,
+ // and dirents (if not nil) is a cache of dirents as returned by
+ // directoryFDs representing this directory. children is protected by
+ // dirMu.
+ dirMu sync.Mutex
+ children map[string]*dentry
+ dirents []vfs.Dirent
+
+ // upperVD and lowerVDs are the files from the overlay filesystem's layers
+ // that comprise the file on the overlay filesystem.
+ //
+ // If !upperVD.Ok(), it can transition to a valid vfs.VirtualDentry (i.e.
+ // be copied up) with copyMu locked for writing; otherwise, it is
+ // immutable. lowerVDs is always immutable.
+ copyMu sync.RWMutex
+ upperVD vfs.VirtualDentry
+ lowerVDs []vfs.VirtualDentry
+
+ // inlineLowerVDs backs lowerVDs in the common case where len(lowerVDs) <=
+ // len(inlineLowerVDs).
+ inlineLowerVDs [1]vfs.VirtualDentry
+
+ // devMajor, devMinor, and ino are the device major/minor and inode numbers
+ // used by this dentry. These fields are protected by copyMu and accessed
+ // using atomic memory operations.
+ devMajor uint32
+ devMinor uint32
+ ino uint64
+
+ locks vfs.FileLocks
+}
+
+// newDentry creates a new dentry. The dentry initially has no references; it
+// is the caller's responsibility to set the dentry's reference count and/or
+// call dentry.destroy() as appropriate. The dentry is initially invalid in
+// that it contains no layers; the caller is responsible for setting them.
+func (fs *filesystem) newDentry() *dentry {
+ d := &dentry{
+ fs: fs,
+ }
+ d.lowerVDs = d.inlineLowerVDs[:0]
+ d.vfsd.Init(d)
+ return d
+}
+
+// IncRef implements vfs.DentryImpl.IncRef.
+func (d *dentry) IncRef() {
+ // d.refs may be 0 if d.fs.renameMu is locked, which serializes against
+ // d.checkDropLocked().
+ atomic.AddInt64(&d.refs, 1)
+}
+
+// TryIncRef implements vfs.DentryImpl.TryIncRef.
+func (d *dentry) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&d.refs)
+ if refs <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *dentry) DecRef() {
+ if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
+ d.fs.renameMu.Lock()
+ d.checkDropLocked()
+ d.fs.renameMu.Unlock()
+ } else if refs < 0 {
+ panic("overlay.dentry.DecRef() called without holding a reference")
+ }
+}
+
+// checkDropLocked should be called after d's reference count becomes 0 or it
+// becomes deleted.
+//
+// Preconditions: d.fs.renameMu must be locked for writing.
+func (d *dentry) checkDropLocked() {
+ // Dentries with a positive reference count must be retained. (The only way
+ // to obtain a reference on a dentry with zero references is via path
+ // resolution, which requires renameMu, so if d.refs is zero then it will
+ // remain zero while we hold renameMu for writing.) Dentries with a
+ // negative reference count have already been destroyed.
+ if atomic.LoadInt64(&d.refs) != 0 {
+ return
+ }
+ // Refs is still zero; destroy it.
+ d.destroyLocked()
+ return
+}
+
+// destroyLocked destroys the dentry.
+//
+// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0.
+func (d *dentry) destroyLocked() {
+ switch atomic.LoadInt64(&d.refs) {
+ case 0:
+ // Mark the dentry destroyed.
+ atomic.StoreInt64(&d.refs, -1)
+ case -1:
+ panic("overlay.dentry.destroyLocked() called on already destroyed dentry")
+ default:
+ panic("overlay.dentry.destroyLocked() called with references on the dentry")
+ }
+
+ if d.upperVD.Ok() {
+ d.upperVD.DecRef()
+ }
+ for _, lowerVD := range d.lowerVDs {
+ lowerVD.DecRef()
+ }
+
+ if d.parent != nil {
+ d.parent.dirMu.Lock()
+ if !d.vfsd.IsDead() {
+ delete(d.parent.children, d.name)
+ }
+ d.parent.dirMu.Unlock()
+ // Drop the reference held by d on its parent without recursively
+ // locking d.fs.renameMu.
+ if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
+ d.parent.checkDropLocked()
+ } else if refs < 0 {
+ panic("overlay.dentry.DecRef() called without holding a reference")
+ }
+ }
+}
+
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {
+ // TODO(gvisor.dev/issue/1479): Implement inotify.
+}
+
+// Watches implements vfs.DentryImpl.Watches.
+func (d *dentry) Watches() *vfs.Watches {
+ // TODO(gvisor.dev/issue/1479): Implement inotify.
+ return nil
+}
+
+// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches.
+//
+// TODO(gvisor.dev/issue/1479): Implement inotify.
+func (d *dentry) OnZeroWatches() {}
+
+// iterLayers invokes yield on each layer comprising d, from top to bottom. If
+// any call to yield returns false, iterLayer stops iteration.
+func (d *dentry) iterLayers(yield func(vd vfs.VirtualDentry, isUpper bool) bool) {
+ if d.isCopiedUp() {
+ if !yield(d.upperVD, true) {
+ return
+ }
+ }
+ for _, lowerVD := range d.lowerVDs {
+ if !yield(lowerVD, false) {
+ return
+ }
+ }
+}
+
+func (d *dentry) topLayerInfo() (vd vfs.VirtualDentry, isUpper bool) {
+ if d.isCopiedUp() {
+ return d.upperVD, true
+ }
+ return d.lowerVDs[0], false
+}
+
+func (d *dentry) topLayer() vfs.VirtualDentry {
+ vd, _ := d.topLayerInfo()
+ return vd
+}
+
+func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
+}
+
+// statInternalMask is the set of stat fields that is set by
+// dentry.statInternalTo().
+const statInternalMask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
+
+// statInternalTo writes fields to stat that are stored in d, and therefore do
+// not requiring invoking StatAt on the overlay's layers.
+func (d *dentry) statInternalTo(ctx context.Context, opts *vfs.StatOptions, stat *linux.Statx) {
+ stat.Mask |= statInternalMask
+ if d.isDir() {
+ // Linux sets nlink to 1 for merged directories
+ // (fs/overlayfs/inode.c:ovl_getattr()); we set it to 2 because this is
+ // correct more often ("." and the directory's entry in its parent),
+ // and some of our tests expect this.
+ stat.Nlink = 2
+ }
+ stat.UID = atomic.LoadUint32(&d.uid)
+ stat.GID = atomic.LoadUint32(&d.gid)
+ stat.Mode = uint16(atomic.LoadUint32(&d.mode))
+ stat.Ino = atomic.LoadUint64(&d.ino)
+ stat.DevMajor = atomic.LoadUint32(&d.devMajor)
+ stat.DevMinor = atomic.LoadUint32(&d.devMinor)
+}
+
+// Preconditions: d.copyMu must be locked for writing.
+func (d *dentry) updateAfterSetStatLocked(opts *vfs.SetStatOptions) {
+ if opts.Stat.Mask&linux.STATX_MODE != 0 {
+ atomic.StoreUint32(&d.mode, (d.mode&linux.S_IFMT)|uint32(opts.Stat.Mode&^linux.S_IFMT))
+ }
+ if opts.Stat.Mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&d.uid, opts.Stat.UID)
+ }
+ if opts.Stat.Mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&d.gid, opts.Stat.GID)
+ }
+}
+
+// fileDescription is embedded by overlay implementations of
+// vfs.FileDescriptionImpl.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+}
+
+func (fd *fileDescription) filesystem() *filesystem {
+ return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
+}
+
+func (fd *fileDescription) dentry() *dentry {
+ return fd.vfsfd.Dentry().Impl().(*dentry)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index cab771211..811f80a5f 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -81,7 +81,8 @@ type inode struct {
kernfs.InodeNotSymlink
kernfs.InodeNoopRefCount
- pipe *pipe.VFSPipe
+ locks vfs.FileLocks
+ pipe *pipe.VFSPipe
ino uint64
uid auth.KUID
@@ -114,7 +115,7 @@ func (i *inode) Mode() linux.FileMode {
}
// Stat implements kernfs.Inode.Stat.
-func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+func (i *inode) Stat(_ context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
ts := linux.NsecToStatxTimestamp(i.ctime.Nanoseconds())
return linux.Statx{
Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS,
@@ -147,7 +148,7 @@ func (i *inode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.
// Open implements kernfs.Inode.Open.
func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- return i.pipe.Open(ctx, rp.Mount(), vfsd, opts.Flags)
+ return i.pipe.Open(ctx, rp.Mount(), vfsd, opts.Flags, &i.locks)
}
// NewConnectedPipeFDs returns a pair of FileDescriptions representing the read
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 17c1342b5..6014138ff 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -22,6 +22,7 @@ go_library(
"//pkg/log",
"//pkg/refs",
"//pkg/safemem",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/fsbridge",
"//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/inet",
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index 36a911db4..79c2725f3 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -37,6 +37,8 @@ type subtasksInode struct {
kernfs.OrderedChildren
kernfs.AlwaysValid
+ locks vfs.FileLocks
+
fs *filesystem
task *kernel.Task
pidns *kernel.PIDNamespace
@@ -126,7 +128,7 @@ func (fd *subtasksFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallbac
return fd.GenericDirectoryFD.IterDirents(ctx, cb)
}
-// Seek implements vfs.FileDecriptionImpl.Seek.
+// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *subtasksFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
if fd.task.ExitState() >= kernel.TaskExitZombie {
return 0, syserror.ENOENT
@@ -153,7 +155,7 @@ func (fd *subtasksFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) erro
// Open implements kernfs.Inode.
func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &subtasksFD{task: i.task}
- if err := fd.Init(&i.OrderedChildren, &opts); err != nil {
+ if err := fd.Init(&i.OrderedChildren, &i.locks, &opts); err != nil {
return nil, err
}
if err := fd.VFSFileDescription().Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
@@ -163,8 +165,8 @@ func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *v
}
// Stat implements kernfs.Inode.
-func (i *subtasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
- stat, err := i.InodeAttrs.Stat(vsfs, opts)
+func (i *subtasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts)
if err != nil {
return linux.Statx{}, err
}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index 482055db1..a5c7aa470 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -38,6 +38,8 @@ type taskInode struct {
kernfs.InodeAttrs
kernfs.OrderedChildren
+ locks vfs.FileLocks
+
task *kernel.Task
}
@@ -103,7 +105,7 @@ func (i *taskInode) Valid(ctx context.Context) bool {
// Open implements kernfs.Inode.
func (i *taskInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
if err != nil {
return nil, err
}
@@ -154,8 +156,8 @@ func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.
}
// Stat implements kernfs.Inode.
-func (i *taskOwnedInode) Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
- stat, err := i.Inode.Stat(fs, opts)
+func (i *taskOwnedInode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.Inode.Stat(ctx, fs, opts)
if err != nil {
return linux.Statx{}, err
}
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index 44ccc9e4a..fea29e5f0 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -53,6 +53,8 @@ func taskFDExists(t *kernel.Task, fd int32) bool {
}
type fdDir struct {
+ locks vfs.FileLocks
+
fs *filesystem
task *kernel.Task
@@ -62,7 +64,7 @@ type fdDir struct {
}
// IterDirents implements kernfs.inodeDynamicLookup.
-func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, absOffset, relOffset int64) (int64, error) {
+func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
var fds []int32
i.task.WithMuLocked(func(t *kernel.Task) {
if fdTable := t.FDTable(); fdTable != nil {
@@ -70,7 +72,6 @@ func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, abs
}
})
- offset := absOffset + relOffset
typ := uint8(linux.DT_REG)
if i.produceSymlink {
typ = linux.DT_LNK
@@ -143,7 +144,7 @@ func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro
// Open implements kernfs.Inode.
func (i *fdDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
if err != nil {
return nil, err
}
@@ -270,7 +271,7 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry,
// Open implements kernfs.Inode.
func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 2f297e48a..859b7d727 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -34,6 +35,10 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// "There is an (arbitrary) limit on the number of lines in the file. As at
+// Linux 3.18, the limit is five lines." - user_namespaces(7)
+const maxIDMapLines = 5
+
// mm gets the kernel task's MemoryManager. No additional reference is taken on
// mm here. This is safe because MemoryManager.destroy is required to leave the
// MemoryManager in a state where it's still usable as a DynamicBytesSource.
@@ -226,8 +231,9 @@ func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// Linux will return envp up to and including the first NULL character,
// so find it.
- if end := bytes.IndexByte(buf.Bytes()[ar.Length():], 0); end != -1 {
- buf.Truncate(end)
+ envStart := int(ar.Length())
+ if nullIdx := bytes.IndexByte(buf.Bytes()[envStart:], 0); nullIdx != -1 {
+ buf.Truncate(envStart + nullIdx)
}
}
@@ -282,7 +288,8 @@ func (d *commData) Generate(ctx context.Context, buf *bytes.Buffer) error {
return nil
}
-// idMapData implements vfs.DynamicBytesSource for /proc/[pid]/{gid_map|uid_map}.
+// idMapData implements vfs.WritableDynamicBytesSource for
+// /proc/[pid]/{gid_map|uid_map}.
//
// +stateify savable
type idMapData struct {
@@ -294,7 +301,7 @@ type idMapData struct {
var _ dynamicInode = (*idMapData)(nil)
-// Generate implements vfs.DynamicBytesSource.Generate.
+// Generate implements vfs.WritableDynamicBytesSource.Generate.
func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error {
var entries []auth.IDMapEntry
if d.gids {
@@ -308,6 +315,60 @@ func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error {
return nil
}
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // "In addition, the number of bytes written to the file must be less than
+ // the system page size, and the write must be performed at the start of
+ // the file ..." - user_namespaces(7)
+ srclen := src.NumBytes()
+ if srclen >= usermem.PageSize || offset != 0 {
+ return 0, syserror.EINVAL
+ }
+ b := make([]byte, srclen)
+ if _, err := src.CopyIn(ctx, b); err != nil {
+ return 0, err
+ }
+
+ // Truncate from the first NULL byte.
+ var nul int64
+ nul = int64(bytes.IndexByte(b, 0))
+ if nul == -1 {
+ nul = srclen
+ }
+ b = b[:nul]
+ // Remove the last \n.
+ if nul >= 1 && b[nul-1] == '\n' {
+ b = b[:nul-1]
+ }
+ lines := bytes.SplitN(b, []byte("\n"), maxIDMapLines+1)
+ if len(lines) > maxIDMapLines {
+ return 0, syserror.EINVAL
+ }
+
+ entries := make([]auth.IDMapEntry, len(lines))
+ for i, l := range lines {
+ var e auth.IDMapEntry
+ _, err := fmt.Sscan(string(l), &e.FirstID, &e.FirstParentID, &e.Length)
+ if err != nil {
+ return 0, syserror.EINVAL
+ }
+ entries[i] = e
+ }
+ var err error
+ if d.gids {
+ err = d.task.UserNamespace().SetGIDMap(ctx, entries)
+ } else {
+ err = d.task.UserNamespace().SetUIDMap(ctx, entries)
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ // On success, Linux's kernel/user_namespace.c:map_write() always returns
+ // count, even if fewer bytes were used.
+ return int64(srclen), nil
+}
+
// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps.
//
// +stateify savable
@@ -775,6 +836,8 @@ type namespaceInode struct {
kernfs.InodeNoopRefCount
kernfs.InodeNotDirectory
kernfs.InodeNotSymlink
+
+ locks vfs.FileLocks
}
var _ kernfs.Inode = (*namespaceInode)(nil)
@@ -791,6 +854,7 @@ func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32
func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &namespaceFD{inode: i}
i.IncRef()
+ fd.LockFD.Init(&i.locks)
if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
@@ -801,6 +865,7 @@ func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *
// /proc/[pid]/ns/*.
type namespaceFD struct {
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
vfsfd vfs.FileDescription
inode *namespaceInode
@@ -811,7 +876,7 @@ var _ vfs.FileDescriptionImpl = (*namespaceFD)(nil)
// Stat implements FileDescriptionImpl.
func (fd *namespaceFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
- return fd.inode.Stat(vfs, opts)
+ return fd.inode.Stat(ctx, vfs, opts)
}
// SetStat implements FileDescriptionImpl.
@@ -826,7 +891,12 @@ func (fd *namespaceFD) Release() {
fd.inode.DecRef()
}
-// OnClose implements FileDescriptionImpl.
-func (*namespaceFD) OnClose(context.Context) error {
- return nil
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *namespaceFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *namespaceFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index b51d43954..6d2b90a8b 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -43,6 +43,8 @@ type tasksInode struct {
kernfs.OrderedChildren
kernfs.AlwaysValid
+ locks vfs.FileLocks
+
fs *filesystem
pidns *kernel.PIDNamespace
@@ -197,15 +199,15 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback
// Open implements kernfs.Inode.
func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
if err != nil {
return nil, err
}
return fd.VFSFileDescription(), nil
}
-func (i *tasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
- stat, err := i.InodeAttrs.Stat(vsfs, opts)
+func (i *tasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts)
if err != nil {
return linux.Statx{}, err
}
diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go
index d29ef3f83..242ba9b5d 100644
--- a/pkg/sentry/fsimpl/signalfd/signalfd.go
+++ b/pkg/sentry/fsimpl/signalfd/signalfd.go
@@ -31,6 +31,7 @@ type SignalFileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
// target is the original signal target task.
//
diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD
index a741e2bb6..1b548ccd4 100644
--- a/pkg/sentry/fsimpl/sys/BUILD
+++ b/pkg/sentry/fsimpl/sys/BUILD
@@ -29,6 +29,6 @@ go_test(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index 0af373604..01ce30a4d 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -98,8 +98,10 @@ type dir struct {
kernfs.InodeNoDynamicLookup
kernfs.InodeNotSymlink
kernfs.InodeDirectoryNoNewChildren
-
kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
dentry kernfs.Dentry
}
@@ -121,7 +123,7 @@ func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.Set
// Open implements kernfs.Inode.Open.
func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &opts)
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts)
if err != nil {
return nil, err
}
@@ -136,7 +138,7 @@ type cpuFile struct {
// Generate implements vfs.DynamicBytesSource.Generate.
func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "0-%d", c.maxCores-1)
+ fmt.Fprintf(buf, "0-%d\n", c.maxCores-1)
return nil
}
diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go
index 4b3602d47..242d5fd12 100644
--- a/pkg/sentry/fsimpl/sys/sys_test.go
+++ b/pkg/sentry/fsimpl/sys/sys_test.go
@@ -51,7 +51,7 @@ func TestReadCPUFile(t *testing.T) {
k := kernel.KernelFromContext(s.Ctx)
maxCPUCores := k.ApplicationCores()
- expected := fmt.Sprintf("0-%d", maxCPUCores-1)
+ expected := fmt.Sprintf("0-%d\n", maxCPUCores-1)
for _, fname := range []string{"online", "possible", "present"} {
pop := s.PathOpAtRoot(fmt.Sprintf("devices/system/cpu/%s", fname))
diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD
index 0e4053a46..400a97996 100644
--- a/pkg/sentry/fsimpl/testutil/BUILD
+++ b/pkg/sentry/fsimpl/testutil/BUILD
@@ -32,6 +32,6 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/usermem",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
index c16a36cdb..e743e8114 100644
--- a/pkg/sentry/fsimpl/testutil/kernel.go
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -62,6 +62,7 @@ func Boot() (*kernel.Kernel, error) {
return nil, fmt.Errorf("creating platform: %v", err)
}
+ kernel.VFS2Enabled = true
k := &kernel.Kernel{
Platform: plat,
}
@@ -73,7 +74,7 @@ func Boot() (*kernel.Kernel, error) {
k.SetMemoryFile(mf)
// Pass k as the platform since it is savable, unlike the actual platform.
- vdso, err := loader.PrepareVDSO(nil, k)
+ vdso, err := loader.PrepareVDSO(k)
if err != nil {
return nil, fmt.Errorf("creating vdso: %v", err)
}
@@ -103,11 +104,6 @@ func Boot() (*kernel.Kernel, error) {
return nil, fmt.Errorf("initializing kernel: %v", err)
}
- kernel.VFS2Enabled = true
-
- if err := k.VFS().Init(); err != nil {
- return nil, fmt.Errorf("VFS init: %v", err)
- }
k.VFS().MustRegisterFilesystemType(tmpfs.Name, &tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
AllowUserList: true,
diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go
index 60c92d626..2dc90d484 100644
--- a/pkg/sentry/fsimpl/timerfd/timerfd.go
+++ b/pkg/sentry/fsimpl/timerfd/timerfd.go
@@ -32,6 +32,7 @@ type TimerFileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
events waiter.Queue
timer *ktime.Timer
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index 007be1572..e73732a6b 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -59,9 +59,9 @@ go_library(
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
"//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/uniqueid",
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
- "//pkg/sentry/vfs/lock",
"//pkg/sentry/vfs/memxattr",
"//pkg/sync",
"//pkg/syserror",
diff --git a/pkg/sentry/fsimpl/tmpfs/device_file.go b/pkg/sentry/fsimpl/tmpfs/device_file.go
index 83bf885ee..ac54d420d 100644
--- a/pkg/sentry/fsimpl/tmpfs/device_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/device_file.go
@@ -29,7 +29,7 @@ type deviceFile struct {
minor uint32
}
-func (fs *filesystem) newDeviceFile(creds *auth.Credentials, mode linux.FileMode, kind vfs.DeviceKind, major, minor uint32) *inode {
+func (fs *filesystem) newDeviceFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, kind vfs.DeviceKind, major, minor uint32) *inode {
file := &deviceFile{
kind: kind,
major: major,
@@ -43,7 +43,7 @@ func (fs *filesystem) newDeviceFile(creds *auth.Credentials, mode linux.FileMode
default:
panic(fmt.Sprintf("invalid DeviceKind: %v", kind))
}
- file.inode.init(file, fs, creds, mode)
+ file.inode.init(file, fs, kuid, kgid, mode)
file.inode.nlink = 1 // from parent directory
return &file.inode
}
diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index f2399981b..0a1ad4765 100644
--- a/pkg/sentry/fsimpl/tmpfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -48,9 +48,9 @@ type directory struct {
childList dentryList
}
-func (fs *filesystem) newDirectory(creds *auth.Credentials, mode linux.FileMode) *directory {
+func (fs *filesystem) newDirectory(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *directory {
dir := &directory{}
- dir.inode.init(dir, fs, creds, linux.S_IFDIR|mode)
+ dir.inode.init(dir, fs, kuid, kgid, linux.S_IFDIR|mode)
dir.inode.nlink = 2 // from "." and parent directory or ".." for root
dir.dentry.inode = &dir.inode
dir.dentry.vfsd.Init(&dir.dentry)
@@ -81,6 +81,10 @@ func (dir *directory) removeChildLocked(child *dentry) {
dir.iterMu.Unlock()
}
+func (dir *directory) mayDelete(creds *auth.Credentials, child *dentry) error {
+ return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&dir.inode.mode)), auth.KUID(atomic.LoadUint32(&child.inode.uid)))
+}
+
type directoryFD struct {
fileDescription
vfs.DirectoryFileDescriptionDefaultImpl
@@ -106,6 +110,8 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
fs := fd.filesystem()
dir := fd.inode().impl.(*directory)
+ defer fd.dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+
// fs.mu is required to read d.parent and dentry.name.
fs.mu.RLock()
defer fs.mu.RUnlock()
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 80fa7b29d..ef210a69b 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -79,7 +79,7 @@ afterSymlink:
}
if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
// Symlink traversal updates access time.
- atomic.StoreInt64(&d.inode.atime, d.inode.fs.clock.Now().Nanoseconds())
+ child.inode.touchAtime(rp.Mount())
if err := rp.HandleSymlink(symlink.target); err != nil {
return nil, err
}
@@ -177,6 +177,12 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa
if err := create(parentDir, name); err != nil {
return err
}
+
+ ev := linux.IN_CREATE
+ if dir {
+ ev |= linux.IN_ISDIR
+ }
+ parentDir.inode.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
parentDir.inode.touchCMtime()
return nil
}
@@ -231,17 +237,22 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
return syserror.EXDEV
}
d := vd.Dentry().Impl().(*dentry)
- if d.inode.isDir() {
+ i := d.inode
+ if i.isDir() {
return syserror.EPERM
}
- if d.inode.nlink == 0 {
+ if err := vfs.MayLink(auth.CredentialsFromContext(ctx), linux.FileMode(atomic.LoadUint32(&i.mode)), auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
+ return err
+ }
+ if i.nlink == 0 {
return syserror.ENOENT
}
- if d.inode.nlink == maxLinks {
+ if i.nlink == maxLinks {
return syserror.EMLINK
}
- d.inode.incLinksLocked()
- parentDir.insertChildLocked(fs.newDentry(d.inode), name)
+ i.incLinksLocked()
+ i.watches.Notify("", linux.IN_ATTRIB, 0, vfs.InodeEvent, false /* unlinked */)
+ parentDir.insertChildLocked(fs.newDentry(i), name)
return nil
})
}
@@ -249,11 +260,12 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
return fs.doCreateAt(rp, true /* dir */, func(parentDir *directory, name string) error {
+ creds := rp.Credentials()
if parentDir.inode.nlink == maxLinks {
return syserror.EMLINK
}
parentDir.inode.incLinksLocked() // from child's ".."
- childDir := fs.newDirectory(rp.Credentials(), opts.Mode)
+ childDir := fs.newDirectory(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)
parentDir.insertChildLocked(&childDir.dentry, name)
return nil
})
@@ -262,18 +274,19 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
// MknodAt implements vfs.FilesystemImpl.MknodAt.
func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error {
+ creds := rp.Credentials()
var childInode *inode
switch opts.Mode.FileType() {
- case 0, linux.S_IFREG:
- childInode = fs.newRegularFile(rp.Credentials(), opts.Mode)
+ case linux.S_IFREG:
+ childInode = fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)
case linux.S_IFIFO:
- childInode = fs.newNamedPipe(rp.Credentials(), opts.Mode)
+ childInode = fs.newNamedPipe(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)
case linux.S_IFBLK:
- childInode = fs.newDeviceFile(rp.Credentials(), opts.Mode, vfs.BlockDevice, opts.DevMajor, opts.DevMinor)
+ childInode = fs.newDeviceFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, vfs.BlockDevice, opts.DevMajor, opts.DevMinor)
case linux.S_IFCHR:
- childInode = fs.newDeviceFile(rp.Credentials(), opts.Mode, vfs.CharDevice, opts.DevMajor, opts.DevMinor)
+ childInode = fs.newDeviceFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, vfs.CharDevice, opts.DevMajor, opts.DevMinor)
case linux.S_IFSOCK:
- childInode = fs.newSocketFile(rp.Credentials(), opts.Mode, opts.Endpoint)
+ childInode = fs.newSocketFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, opts.Endpoint)
default:
return syserror.EINVAL
}
@@ -348,15 +361,20 @@ afterTrailingSymlink:
}
defer rp.Mount().EndWrite()
// Create and open the child.
- child := fs.newDentry(fs.newRegularFile(rp.Credentials(), opts.Mode))
+ creds := rp.Credentials()
+ child := fs.newDentry(fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode))
parentDir.insertChildLocked(child, name)
fd, err := child.open(ctx, rp, &opts, true)
if err != nil {
return nil, err
}
+ parentDir.inode.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */)
parentDir.inode.touchCMtime()
return fd, nil
}
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
// Is the file mounted over?
if err := rp.CheckMount(&child.vfsd); err != nil {
return nil, err
@@ -364,7 +382,7 @@ afterTrailingSymlink:
// Do we need to resolve a trailing symlink?
if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
// Symlink traversal updates access time.
- atomic.StoreInt64(&child.inode.atime, child.inode.fs.clock.Now().Nanoseconds())
+ child.inode.touchAtime(rp.Mount())
if err := rp.HandleSymlink(symlink.target); err != nil {
return nil, err
}
@@ -388,10 +406,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
switch impl := d.inode.impl.(type) {
case *regularFile:
var fd regularFileFD
- if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ fd.LockFD.Init(&d.inode.locks)
+ if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{AllowDirectIO: true}); err != nil {
return nil, err
}
- if opts.Flags&linux.O_TRUNC != 0 {
+ if !afterCreate && opts.Flags&linux.O_TRUNC != 0 {
if _, err := impl.truncate(0); err != nil {
return nil, err
}
@@ -403,15 +422,16 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
return nil, syserror.EISDIR
}
var fd directoryFD
- if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ fd.LockFD.Init(&d.inode.locks)
+ if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{AllowDirectIO: true}); err != nil {
return nil, err
}
return &fd.vfsfd, nil
case *symlink:
- // Can't open symlinks without O_PATH (which is unimplemented).
+ // TODO(gvisor.dev/issue/2782): Can't open symlinks without O_PATH.
return nil, syserror.ELOOP
case *namedPipe:
- return impl.pipe.Open(ctx, rp.Mount(), &d.vfsd, opts.Flags)
+ return impl.pipe.Open(ctx, rp.Mount(), &d.vfsd, opts.Flags, &d.inode.locks)
case *deviceFile:
return rp.VirtualFilesystem().OpenDeviceSpecialFile(ctx, rp.Mount(), &d.vfsd, impl.kind, impl.major, impl.minor, opts)
case *socketFile:
@@ -472,6 +492,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if !ok {
return syserror.ENOENT
}
+ if err := oldParentDir.mayDelete(rp.Credentials(), renamed); err != nil {
+ return err
+ }
// Note that we don't need to call rp.CheckMount(), since if renamed is a
// mount point then we want to rename the mount point, not anything in the
// mounted filesystem.
@@ -559,6 +582,8 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
newParentDir.inode.touchCMtime()
}
renamed.inode.touchCtime()
+
+ vfs.InotifyRename(ctx, &renamed.inode.watches, &oldParentDir.inode.watches, &newParentDir.inode.watches, oldName, newName, renamed.inode.isDir())
return nil
}
@@ -584,6 +609,9 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
if !ok {
return syserror.ENOENT
}
+ if err := parentDir.mayDelete(rp.Credentials(), child); err != nil {
+ return err
+ }
childDir, ok := child.inode.impl.(*directory)
if !ok {
return syserror.ENOTDIR
@@ -603,8 +631,11 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
return err
}
parentDir.removeChildLocked(child)
- parentDir.inode.decLinksLocked() // from child's ".."
+ parentDir.inode.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */)
+ // Remove links for child, child/., and child/..
+ child.inode.decLinksLocked()
child.inode.decLinksLocked()
+ parentDir.inode.decLinksLocked()
vfsObj.CommitDeleteDentry(&child.vfsd)
parentDir.inode.touchCMtime()
return nil
@@ -613,12 +644,21 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
fs.mu.RLock()
- defer fs.mu.RUnlock()
d, err := resolveLocked(rp)
if err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ if err := d.inode.setStat(ctx, rp.Credentials(), &opts); err != nil {
+ fs.mu.RUnlock()
return err
}
- return d.inode.setStat(ctx, rp.Credentials(), &opts.Stat)
+ fs.mu.RUnlock()
+
+ if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ d.InotifyWithParent(ev, 0, vfs.InodeEvent)
+ }
+ return nil
}
// StatAt implements vfs.FilesystemImpl.StatAt.
@@ -656,7 +696,8 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error {
- child := fs.newDentry(fs.newSymlink(rp.Credentials(), target))
+ creds := rp.Credentials()
+ child := fs.newDentry(fs.newSymlink(creds.EffectiveKUID, creds.EffectiveKGID, 0777, target))
parentDir.insertChildLocked(child, name)
return nil
})
@@ -681,6 +722,9 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
if !ok {
return syserror.ENOENT
}
+ if err := parentDir.mayDelete(rp.Credentials(), child); err != nil {
+ return err
+ }
if child.inode.isDir() {
return syserror.EISDIR
}
@@ -698,6 +742,12 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
return err
}
+
+ // Generate inotify events. Note that this must take place before the link
+ // count of the child is decremented, or else the watches may be dropped
+ // before these events are added.
+ vfs.InotifyRemoveChild(&child.inode.watches, &parentDir.inode.watches, name)
+
parentDir.removeChildLocked(child)
child.inode.decLinksLocked()
vfsObj.CommitDeleteDentry(&child.vfsd)
@@ -749,23 +799,37 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
fs.mu.RLock()
- defer fs.mu.RUnlock()
d, err := resolveLocked(rp)
if err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ if err := d.inode.setxattr(rp.Credentials(), &opts); err != nil {
+ fs.mu.RUnlock()
return err
}
- return d.inode.setxattr(rp.Credentials(), &opts)
+ fs.mu.RUnlock()
+
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
}
// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
fs.mu.RLock()
- defer fs.mu.RUnlock()
d, err := resolveLocked(rp)
if err != nil {
+ fs.mu.RUnlock()
return err
}
- return d.inode.removexattr(rp.Credentials(), name)
+ if err := d.inode.removexattr(rp.Credentials(), name); err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ fs.mu.RUnlock()
+
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
}
// PrependPath implements vfs.FilesystemImpl.PrependPath.
diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
index 8d77b3fa8..739350cf0 100644
--- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go
+++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
@@ -30,9 +30,9 @@ type namedPipe struct {
// Preconditions:
// * fs.mu must be locked.
// * rp.Mount().CheckBeginWrite() has been called successfully.
-func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode {
+func (fs *filesystem) newNamedPipe(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)}
- file.inode.init(file, fs, creds, linux.S_IFIFO|mode)
+ file.inode.init(file, fs, kuid, kgid, linux.S_IFIFO|mode)
file.inode.nlink = 1 // Only the parent has a link.
return &file.inode
}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index 3f433d666..abbaa5d60 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -85,12 +84,12 @@ type regularFile struct {
size uint64
}
-func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
+func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
file := &regularFile{
memFile: fs.memFile,
seals: linux.F_SEAL_SEAL,
}
- file.inode.init(file, fs, creds, linux.S_IFREG|mode)
+ file.inode.init(file, fs, kuid, kgid, linux.S_IFREG|mode)
file.inode.nlink = 1 // from parent directory
return &file.inode
}
@@ -275,11 +274,35 @@ func (fd *regularFileFD) Release() {
// noop
}
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ f := fd.inode().impl.(*regularFile)
+
+ f.inode.mu.Lock()
+ defer f.inode.mu.Unlock()
+ oldSize := f.size
+ size := offset + length
+ if oldSize >= size {
+ return nil
+ }
+ _, err := f.truncateLocked(size)
+ return err
+}
+
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
if offset < 0 {
return 0, syserror.EINVAL
}
+
+ // Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since
+ // all state is in-memory.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
if dst.NumBytes() == 0 {
return 0, nil
}
@@ -302,40 +325,60 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset and error. The
+// final offset should be ignored by PWrite.
+func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
+
+ // Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since
+ // all state is in-memory.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 {
+ return 0, offset, syserror.EOPNOTSUPP
+ }
+
srclen := src.NumBytes()
if srclen == 0 {
- return 0, nil
+ return 0, offset, nil
}
f := fd.inode().impl.(*regularFile)
+ f.inode.mu.Lock()
+ defer f.inode.mu.Unlock()
+ // If the file is opened with O_APPEND, update offset to file size.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Locking f.inode.mu is sufficient for reading f.size.
+ offset = int64(f.size)
+ }
if end := offset + srclen; end < offset {
// Overflow.
- return 0, syserror.EFBIG
+ return 0, offset, syserror.EINVAL
}
- var err error
srclen, err = vfs.CheckLimit(ctx, offset, srclen)
if err != nil {
- return 0, err
+ return 0, offset, err
}
src = src.TakeFirst64(srclen)
- f.inode.mu.Lock()
rw := getRegularFileReadWriter(f, offset)
n, err := src.CopyInTo(ctx, rw)
- fd.inode().touchCMtimeLocked()
- f.inode.mu.Unlock()
+ f.inode.touchCMtimeLocked()
putRegularFileReadWriter(rw)
- return n, err
+ return n, n + offset, err
}
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
fd.offMu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
fd.offMu.Unlock()
return n, err
}
@@ -361,33 +404,6 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (
return offset, nil
}
-// Sync implements vfs.FileDescriptionImpl.Sync.
-func (fd *regularFileFD) Sync(ctx context.Context) error {
- return nil
-}
-
-// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
-func (fd *regularFileFD) LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error {
- return fd.inode().lockBSD(uid, t, block)
-}
-
-// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
-func (fd *regularFileFD) UnlockBSD(ctx context.Context, uid lock.UniqueID) error {
- fd.inode().unlockBSD(uid)
- return nil
-}
-
-// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
-func (fd *regularFileFD) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, rng lock.LockRange, block lock.Blocker) error {
- return fd.inode().lockPOSIX(uid, t, rng, block)
-}
-
-// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
-func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng lock.LockRange) error {
- fd.inode().unlockPOSIX(uid, rng)
- return nil
-}
-
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
file := fd.inode().impl.(*regularFile)
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
index 64e1c40ad..146c7fdfe 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
@@ -138,48 +138,37 @@ func TestLocks(t *testing.T) {
}
defer cleanup()
- var (
- uid1 lock.UniqueID
- uid2 lock.UniqueID
- // Non-blocking.
- block lock.Blocker
- )
-
- uid1 = 123
- uid2 = 456
-
- if err := fd.Impl().LockBSD(ctx, uid1, lock.ReadLock, block); err != nil {
+ uid1 := 123
+ uid2 := 456
+ if err := fd.Impl().LockBSD(ctx, uid1, lock.ReadLock, nil); err != nil {
t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
}
- if err := fd.Impl().LockBSD(ctx, uid2, lock.ReadLock, block); err != nil {
+ if err := fd.Impl().LockBSD(ctx, uid2, lock.ReadLock, nil); err != nil {
t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
}
- if got, want := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, block), syserror.ErrWouldBlock; got != want {
+ if got, want := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil), syserror.ErrWouldBlock; got != want {
t.Fatalf("fd.Impl().LockBSD failed: got = %v, want = %v", got, want)
}
if err := fd.Impl().UnlockBSD(ctx, uid1); err != nil {
t.Fatalf("fd.Impl().UnlockBSD failed: err = %v", err)
}
- if err := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, block); err != nil {
+ if err := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil); err != nil {
t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
}
- rng1 := lock.LockRange{0, 1}
- rng2 := lock.LockRange{1, 2}
-
- if err := fd.Impl().LockPOSIX(ctx, uid1, lock.ReadLock, rng1, block); err != nil {
+ if err := fd.Impl().LockPOSIX(ctx, uid1, lock.ReadLock, 0, 1, linux.SEEK_SET, nil); err != nil {
t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
}
- if err := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, rng2, block); err != nil {
+ if err := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 1, 2, linux.SEEK_SET, nil); err != nil {
t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
}
- if err := fd.Impl().LockPOSIX(ctx, uid1, lock.WriteLock, rng1, block); err != nil {
+ if err := fd.Impl().LockPOSIX(ctx, uid1, lock.WriteLock, 0, 1, linux.SEEK_SET, nil); err != nil {
t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
}
- if got, want := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, rng1, block), syserror.ErrWouldBlock; got != want {
+ if got, want := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 0, 1, linux.SEEK_SET, nil), syserror.ErrWouldBlock; got != want {
t.Fatalf("fd.Impl().LockPOSIX failed: got = %v, want = %v", got, want)
}
- if err := fd.Impl().UnlockPOSIX(ctx, uid1, rng1); err != nil {
+ if err := fd.Impl().UnlockPOSIX(ctx, uid1, 0, 1, linux.SEEK_SET); err != nil {
t.Fatalf("fd.Impl().UnlockPOSIX failed: err = %v", err)
}
}
diff --git a/pkg/sentry/fsimpl/tmpfs/socket_file.go b/pkg/sentry/fsimpl/tmpfs/socket_file.go
index 25c2321af..3ed650474 100644
--- a/pkg/sentry/fsimpl/tmpfs/socket_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/socket_file.go
@@ -26,9 +26,9 @@ type socketFile struct {
ep transport.BoundEndpoint
}
-func (fs *filesystem) newSocketFile(creds *auth.Credentials, mode linux.FileMode, ep transport.BoundEndpoint) *inode {
+func (fs *filesystem) newSocketFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, ep transport.BoundEndpoint) *inode {
file := &socketFile{ep: ep}
- file.inode.init(file, fs, creds, mode)
+ file.inode.init(file, fs, kuid, kgid, mode)
file.inode.nlink = 1 // from parent directory
return &file.inode
}
diff --git a/pkg/sentry/fsimpl/tmpfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go
index 47e075ed4..b0de5fabe 100644
--- a/pkg/sentry/fsimpl/tmpfs/symlink.go
+++ b/pkg/sentry/fsimpl/tmpfs/symlink.go
@@ -24,11 +24,11 @@ type symlink struct {
target string // immutable
}
-func (fs *filesystem) newSymlink(creds *auth.Credentials, target string) *inode {
+func (fs *filesystem) newSymlink(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, target string) *inode {
link := &symlink{
target: target,
}
- link.inode.init(link, fs, creds, linux.S_IFLNK|0777)
+ link.inode.init(link, fs, kuid, kgid, linux.S_IFLNK|mode)
link.inode.nlink = 1 // from parent directory
return &link.inode
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 1e781aecd..2545d88e9 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -30,6 +30,7 @@ package tmpfs
import (
"fmt"
"math"
+ "strconv"
"strings"
"sync/atomic"
@@ -40,7 +41,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/sentry/vfs/memxattr"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -112,6 +112,58 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
}
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ rootMode := linux.FileMode(0777)
+ if rootFileType == linux.S_IFDIR {
+ rootMode = 01777
+ }
+ modeStr, ok := mopts["mode"]
+ if ok {
+ delete(mopts, "mode")
+ mode, err := strconv.ParseUint(modeStr, 8, 32)
+ if err != nil {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid mode: %q", modeStr)
+ return nil, nil, syserror.EINVAL
+ }
+ rootMode = linux.FileMode(mode & 07777)
+ }
+ rootKUID := creds.EffectiveKUID
+ uidStr, ok := mopts["uid"]
+ if ok {
+ delete(mopts, "uid")
+ uid, err := strconv.ParseUint(uidStr, 10, 32)
+ if err != nil {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid uid: %q", uidStr)
+ return nil, nil, syserror.EINVAL
+ }
+ kuid := creds.UserNamespace.MapToKUID(auth.UID(uid))
+ if !kuid.Ok() {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unmapped uid: %d", uid)
+ return nil, nil, syserror.EINVAL
+ }
+ rootKUID = kuid
+ }
+ rootKGID := creds.EffectiveKGID
+ gidStr, ok := mopts["gid"]
+ if ok {
+ delete(mopts, "gid")
+ gid, err := strconv.ParseUint(gidStr, 10, 32)
+ if err != nil {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid gid: %q", gidStr)
+ return nil, nil, syserror.EINVAL
+ }
+ kgid := creds.UserNamespace.MapToKGID(auth.GID(gid))
+ if !kgid.Ok() {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unmapped gid: %d", gid)
+ return nil, nil, syserror.EINVAL
+ }
+ rootKGID = kgid
+ }
+ if len(mopts) != 0 {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unknown options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
devMinor, err := vfsObj.GetAnonBlockDevMinor()
if err != nil {
return nil, nil, err
@@ -127,11 +179,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
var root *dentry
switch rootFileType {
case linux.S_IFREG:
- root = fs.newDentry(fs.newRegularFile(creds, 0777))
+ root = fs.newDentry(fs.newRegularFile(rootKUID, rootKGID, rootMode))
case linux.S_IFLNK:
- root = fs.newDentry(fs.newSymlink(creds, tmpfsOpts.RootSymlinkTarget))
+ root = fs.newDentry(fs.newSymlink(rootKUID, rootKGID, rootMode, tmpfsOpts.RootSymlinkTarget))
case linux.S_IFDIR:
- root = &fs.newDirectory(creds, 01777).dentry
+ root = &fs.newDirectory(rootKUID, rootKGID, rootMode).dentry
default:
fs.vfsfs.DecRef()
return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType)
@@ -201,6 +253,33 @@ func (d *dentry) DecRef() {
d.inode.decRef()
}
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {
+ if d.inode.isDir() {
+ events |= linux.IN_ISDIR
+ }
+
+ // tmpfs never calls VFS.InvalidateDentry(), so d.vfsd.IsDead() indicates
+ // that d was deleted.
+ deleted := d.vfsd.IsDead()
+
+ d.inode.fs.mu.RLock()
+ // The ordering below is important, Linux always notifies the parent first.
+ if d.parent != nil {
+ d.parent.inode.watches.Notify(d.name, events, cookie, et, deleted)
+ }
+ d.inode.watches.Notify("", events, cookie, et, deleted)
+ d.inode.fs.mu.RUnlock()
+}
+
+// Watches implements vfs.DentryImpl.Watches.
+func (d *dentry) Watches() *vfs.Watches {
+ return &d.inode.watches
+}
+
+// OnZeroWatches implements vfs.Dentry.OnZeroWatches.
+func (d *dentry) OnZeroWatches() {}
+
// inode represents a filesystem object.
type inode struct {
// fs is the owning filesystem. fs is immutable.
@@ -209,11 +288,9 @@ type inode struct {
// refs is a reference count. refs is accessed using atomic memory
// operations.
//
- // A reference is held on all inodes that are reachable in the filesystem
- // tree. For non-directories (which may have multiple hard links), this
- // means that a reference is dropped when nlink reaches 0. For directories,
- // nlink never reaches 0 due to the "." entry; instead,
- // filesystem.RmdirAt() drops the reference.
+ // A reference is held on all inodes as long as they are reachable in the
+ // filesystem tree, i.e. nlink is nonzero. This reference is dropped when
+ // nlink reaches 0.
refs int64
// xattrs implements extended attributes.
@@ -235,23 +312,25 @@ type inode struct {
ctime int64 // nanoseconds
mtime int64 // nanoseconds
- // Advisory file locks, which lock at the inode level.
- locks lock.FileLocks
+ locks vfs.FileLocks
+
+ // Inotify watches for this inode.
+ watches vfs.Watches
impl interface{} // immutable
}
const maxLinks = math.MaxUint32
-func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) {
+func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) {
if mode.FileType() == 0 {
panic("file type is required in FileMode")
}
i.fs = fs
i.refs = 1
i.mode = uint32(mode)
- i.uid = uint32(creds.EffectiveKUID)
- i.gid = uint32(creds.EffectiveKGID)
+ i.uid = uint32(kuid)
+ i.gid = uint32(kgid)
i.ino = atomic.AddUint64(&fs.nextInoMinusOne, 1)
// Tmpfs creation sets atime, ctime, and mtime to current time.
now := fs.clock.Now().Nanoseconds()
@@ -276,14 +355,17 @@ func (i *inode) incLinksLocked() {
atomic.AddUint32(&i.nlink, 1)
}
-// decLinksLocked decrements i's link count.
+// decLinksLocked decrements i's link count. If the link count reaches 0, we
+// remove a reference on i as well.
//
// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0.
func (i *inode) decLinksLocked() {
if i.nlink == 0 {
panic("tmpfs.inode.decLinksLocked() called with no existing links")
}
- atomic.AddUint32(&i.nlink, ^uint32(0))
+ if atomic.AddUint32(&i.nlink, ^uint32(0)) == 0 {
+ i.decRef()
+ }
}
func (i *inode) incRef() {
@@ -306,6 +388,7 @@ func (i *inode) tryIncRef() bool {
func (i *inode) decRef() {
if refs := atomic.AddInt64(&i.refs, -1); refs == 0 {
+ i.watches.HandleDeletion()
if regFile, ok := i.impl.(*regularFile); ok {
// Release memory used by regFile to store data. Since regFile is
// no longer usable, we don't need to grab any locks or update any
@@ -369,7 +452,8 @@ func (i *inode) statTo(stat *linux.Statx) {
}
}
-func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx) error {
+func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions) error {
+ stat := &opts.Stat
if stat.Mask == 0 {
return nil
}
@@ -377,7 +461,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu
return syserror.EPERM
}
mode := linux.FileMode(atomic.LoadUint32(&i.mode))
- if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
return err
}
i.mu.Lock()
@@ -455,44 +539,6 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu
return nil
}
-// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular.
-func (i *inode) lockBSD(uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
- switch i.impl.(type) {
- case *regularFile:
- return i.locks.LockBSD(uid, t, block)
- }
- return syserror.EBADF
-}
-
-// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular.
-func (i *inode) unlockBSD(uid fslock.UniqueID) error {
- switch i.impl.(type) {
- case *regularFile:
- i.locks.UnlockBSD(uid)
- return nil
- }
- return syserror.EBADF
-}
-
-// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular.
-func (i *inode) lockPOSIX(uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error {
- switch i.impl.(type) {
- case *regularFile:
- return i.locks.LockPOSIX(uid, t, rng, block)
- }
- return syserror.EBADF
-}
-
-// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular.
-func (i *inode) unlockPOSIX(uid fslock.UniqueID, rng fslock.LockRange) error {
- switch i.impl.(type) {
- case *regularFile:
- i.locks.UnlockPOSIX(uid, rng)
- return nil
- }
- return syserror.EBADF
-}
-
// allocatedBlocksForSize returns the number of 512B blocks needed to
// accommodate the given size in bytes, as appropriate for struct
// stat::st_blocks and struct statx::stx_blocks. (Note that this 512B block
@@ -531,6 +577,9 @@ func (i *inode) isDir() bool {
}
func (i *inode) touchAtime(mnt *vfs.Mount) {
+ if mnt.Flags.NoATime {
+ return
+ }
if err := mnt.CheckBeginWrite(); err != nil {
return
}
@@ -621,14 +670,19 @@ func (i *inode) userXattrSupported() bool {
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
}
func (fd *fileDescription) filesystem() *filesystem {
return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
}
+func (fd *fileDescription) dentry() *dentry {
+ return fd.vfsfd.Dentry().Impl().(*dentry)
+}
+
func (fd *fileDescription) inode() *inode {
- return fd.vfsfd.Dentry().Impl().(*dentry).inode
+ return fd.dentry().inode
}
// Stat implements vfs.FileDescriptionImpl.Stat.
@@ -641,7 +695,15 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
creds := auth.CredentialsFromContext(ctx)
- return fd.inode().setStat(ctx, creds, &opts.Stat)
+ d := fd.dentry()
+ if err := d.inode.setStat(ctx, creds, &opts); err != nil {
+ return err
+ }
+
+ if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ d.InotifyWithParent(ev, 0, vfs.InodeEvent)
+ }
+ return nil
}
// Listxattr implements vfs.FileDescriptionImpl.Listxattr.
@@ -656,12 +718,26 @@ func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOption
// Setxattr implements vfs.FileDescriptionImpl.Setxattr.
func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error {
- return fd.inode().setxattr(auth.CredentialsFromContext(ctx), &opts)
+ d := fd.dentry()
+ if err := d.inode.setxattr(auth.CredentialsFromContext(ctx), &opts); err != nil {
+ return err
+ }
+
+ // Generate inotify events.
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
}
// Removexattr implements vfs.FileDescriptionImpl.Removexattr.
func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
- return fd.inode().removexattr(auth.CredentialsFromContext(ctx), name)
+ d := fd.dentry()
+ if err := d.inode.removexattr(auth.CredentialsFromContext(ctx), name); err != nil {
+ return err
+ }
+
+ // Generate inotify events.
+ d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
}
// NewMemfd creates a new tmpfs regular file and file description that can back
@@ -674,8 +750,7 @@ func NewMemfd(mount *vfs.Mount, creds *auth.Credentials, allowSeals bool, name s
// Per Linux, mm/shmem.c:__shmem_file_setup(), memfd inodes are set up with
// S_IRWXUGO.
- mode := linux.FileMode(0777)
- inode := fs.newRegularFile(creds, mode)
+ inode := fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, 0777)
rf := inode.impl.(*regularFile)
if allowSeals {
rf.seals = 0
@@ -688,9 +763,26 @@ func NewMemfd(mount *vfs.Mount, creds *auth.Credentials, allowSeals bool, name s
// Per Linux, mm/shmem.c:__shmem_file_setup(), memfd files are set up with
// FMODE_READ | FMODE_WRITE.
var fd regularFileFD
+ fd.Init(&inode.locks)
flags := uint32(linux.O_RDWR)
if err := fd.vfsfd.Init(&fd, flags, mount, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return &fd.vfsfd, nil
}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync. It does nothing because all
+// filesystem state is in-memory.
+func (*fileDescription) Sync(context.Context) error {
+ return nil
+}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index a28eab8b8..f6886a758 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -85,6 +85,7 @@ go_library(
name = "kernel",
srcs = [
"abstract_socket_namespace.go",
+ "aio.go",
"context.go",
"fd_table.go",
"fd_table_unsafe.go",
@@ -131,6 +132,7 @@ go_library(
"task_stop.go",
"task_syscall.go",
"task_usermem.go",
+ "task_work.go",
"thread_group.go",
"threads.go",
"timekeeper.go",
@@ -199,6 +201,7 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/state",
"//pkg/state/statefile",
+ "//pkg/state/wire",
"//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
diff --git a/pkg/sentry/kernel/aio.go b/pkg/sentry/kernel/aio.go
new file mode 100644
index 000000000..0ac78c0b8
--- /dev/null
+++ b/pkg/sentry/kernel/aio.go
@@ -0,0 +1,81 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// AIOCallback is an function that does asynchronous I/O on behalf of a task.
+type AIOCallback func(context.Context)
+
+// QueueAIO queues an AIOCallback which will be run asynchronously.
+func (t *Task) QueueAIO(cb AIOCallback) {
+ ctx := taskAsyncContext{t: t}
+ wg := &t.TaskSet().aioGoroutines
+ wg.Add(1)
+ go func() {
+ cb(ctx)
+ wg.Done()
+ }()
+}
+
+type taskAsyncContext struct {
+ context.NoopSleeper
+ t *Task
+}
+
+// Debugf implements log.Logger.Debugf.
+func (ctx taskAsyncContext) Debugf(format string, v ...interface{}) {
+ ctx.t.Debugf(format, v...)
+}
+
+// Infof implements log.Logger.Infof.
+func (ctx taskAsyncContext) Infof(format string, v ...interface{}) {
+ ctx.t.Infof(format, v...)
+}
+
+// Warningf implements log.Logger.Warningf.
+func (ctx taskAsyncContext) Warningf(format string, v ...interface{}) {
+ ctx.t.Warningf(format, v...)
+}
+
+// IsLogging implements log.Logger.IsLogging.
+func (ctx taskAsyncContext) IsLogging(level log.Level) bool {
+ return ctx.t.IsLogging(level)
+}
+
+// Deadline implements context.Context.Deadline.
+func (ctx taskAsyncContext) Deadline() (time.Time, bool) {
+ return ctx.t.Deadline()
+}
+
+// Done implements context.Context.Done.
+func (ctx taskAsyncContext) Done() <-chan struct{} {
+ return ctx.t.Done()
+}
+
+// Err implements context.Context.Err.
+func (ctx taskAsyncContext) Err() error {
+ return ctx.t.Err()
+}
+
+// Value implements context.Context.Value.
+func (ctx taskAsyncContext) Value(key interface{}) interface{} {
+ return ctx.t.Value(key)
+}
diff --git a/pkg/sentry/kernel/auth/credentials.go b/pkg/sentry/kernel/auth/credentials.go
index e057d2c6d..6862f2ef5 100644
--- a/pkg/sentry/kernel/auth/credentials.go
+++ b/pkg/sentry/kernel/auth/credentials.go
@@ -232,3 +232,31 @@ func (c *Credentials) UseGID(gid GID) (KGID, error) {
}
return NoID, syserror.EPERM
}
+
+// SetUID translates the provided uid to the root user namespace and updates c's
+// uids to it. This performs no permissions or capabilities checks, the caller
+// is responsible for ensuring the calling context is permitted to modify c.
+func (c *Credentials) SetUID(uid UID) error {
+ kuid := c.UserNamespace.MapToKUID(uid)
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ c.RealKUID = kuid
+ c.EffectiveKUID = kuid
+ c.SavedKUID = kuid
+ return nil
+}
+
+// SetGID translates the provided gid to the root user namespace and updates c's
+// gids to it. This performs no permissions or capabilities checks, the caller
+// is responsible for ensuring the calling context is permitted to modify c.
+func (c *Credentials) SetGID(gid GID) error {
+ kgid := c.UserNamespace.MapToKGID(gid)
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ c.RealKGID = kgid
+ c.EffectiveKGID = kgid
+ c.SavedKGID = kgid
+ return nil
+}
diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go
index 0c40bf315..dd5f0f5fa 100644
--- a/pkg/sentry/kernel/context.go
+++ b/pkg/sentry/kernel/context.go
@@ -18,7 +18,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/log"
)
// contextID is the kernel package's type for context.Context.Value keys.
@@ -113,55 +112,3 @@ func (*Task) Done() <-chan struct{} {
func (*Task) Err() error {
return nil
}
-
-// AsyncContext returns a context.Context that may be used by goroutines that
-// do work on behalf of t and therefore share its contextual values, but are
-// not t's task goroutine (e.g. asynchronous I/O).
-func (t *Task) AsyncContext() context.Context {
- return taskAsyncContext{t: t}
-}
-
-type taskAsyncContext struct {
- context.NoopSleeper
- t *Task
-}
-
-// Debugf implements log.Logger.Debugf.
-func (ctx taskAsyncContext) Debugf(format string, v ...interface{}) {
- ctx.t.Debugf(format, v...)
-}
-
-// Infof implements log.Logger.Infof.
-func (ctx taskAsyncContext) Infof(format string, v ...interface{}) {
- ctx.t.Infof(format, v...)
-}
-
-// Warningf implements log.Logger.Warningf.
-func (ctx taskAsyncContext) Warningf(format string, v ...interface{}) {
- ctx.t.Warningf(format, v...)
-}
-
-// IsLogging implements log.Logger.IsLogging.
-func (ctx taskAsyncContext) IsLogging(level log.Level) bool {
- return ctx.t.IsLogging(level)
-}
-
-// Deadline implements context.Context.Deadline.
-func (ctx taskAsyncContext) Deadline() (time.Time, bool) {
- return ctx.t.Deadline()
-}
-
-// Done implements context.Context.Done.
-func (ctx taskAsyncContext) Done() <-chan struct{} {
- return ctx.t.Done()
-}
-
-// Err implements context.Context.Err.
-func (ctx taskAsyncContext) Err() error {
- return ctx.t.Err()
-}
-
-// Value implements context.Context.Value.
-func (ctx taskAsyncContext) Value(key interface{}) interface{} {
- return ctx.t.Value(key)
-}
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index 3d78cd48f..4c0f1e41f 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -107,7 +107,7 @@ type EventPoll struct {
// different lock to avoid circular lock acquisition order involving
// the wait queue mutexes and mu. The full order is mu, observed file
// wait queue mutex, then listsMu; this allows listsMu to be acquired
- // when readyCallback is called.
+ // when (*pollEntry).Callback is called.
//
// An entry is always in one of the following lists:
// readyList -- when there's a chance that it's ready to have
@@ -116,7 +116,7 @@ type EventPoll struct {
// readEvents() functions always call the entry's file
// Readiness() function to confirm it's ready.
// waitingList -- when there's no chance that the entry is ready,
- // so it's waiting for the readyCallback to be called
+ // so it's waiting for the (*pollEntry).Callback to be called
// on it before it gets moved to the readyList.
// disabledList -- when the entry is disabled. This happens when
// a one-shot entry gets delivered via readEvents().
@@ -269,21 +269,19 @@ func (e *EventPoll) ReadEvents(max int) []linux.EpollEvent {
return ret
}
-// readyCallback is called when one of the files we're polling becomes ready. It
-// moves said file to the readyList if it's currently in the waiting list.
-type readyCallback struct{}
-
// Callback implements waiter.EntryCallback.Callback.
-func (*readyCallback) Callback(w *waiter.Entry) {
- entry := w.Context.(*pollEntry)
- e := entry.epoll
+//
+// Callback is called when one of the files we're polling becomes ready. It
+// moves said file to the readyList if it's currently in the waiting list.
+func (p *pollEntry) Callback(*waiter.Entry) {
+ e := p.epoll
e.listsMu.Lock()
- if entry.curList == &e.waitingList {
- e.waitingList.Remove(entry)
- e.readyList.PushBack(entry)
- entry.curList = &e.readyList
+ if p.curList == &e.waitingList {
+ e.waitingList.Remove(p)
+ e.readyList.PushBack(p)
+ p.curList = &e.readyList
e.listsMu.Unlock()
e.Notify(waiter.EventIn)
@@ -310,7 +308,7 @@ func (e *EventPoll) initEntryReadiness(entry *pollEntry) {
// Check if the file happens to already be in a ready state.
ready := f.Readiness(entry.mask) & entry.mask
if ready != 0 {
- (*readyCallback).Callback(nil, &entry.waiter)
+ entry.Callback(&entry.waiter)
}
}
@@ -380,10 +378,9 @@ func (e *EventPoll) AddEntry(id FileIdentifier, flags EntryFlags, mask waiter.Ev
userData: data,
epoll: e,
flags: flags,
- waiter: waiter.Entry{Callback: &readyCallback{}},
mask: mask,
}
- entry.waiter.Context = entry
+ entry.waiter.Callback = entry
e.files[id] = entry
entry.file = refs.NewWeakRef(id.File, entry)
@@ -406,7 +403,7 @@ func (e *EventPoll) UpdateEntry(id FileIdentifier, flags EntryFlags, mask waiter
}
// Unregister the old mask and remove entry from the list it's in, so
- // readyCallback is guaranteed to not be called on this entry anymore.
+ // (*pollEntry).Callback is guaranteed to not be called on this entry anymore.
entry.id.File.EventUnregister(&entry.waiter)
// Remove entry from whatever list it's in. This ensure that no other
diff --git a/pkg/sentry/kernel/epoll/epoll_state.go b/pkg/sentry/kernel/epoll/epoll_state.go
index 8e9f200d0..7c61e0258 100644
--- a/pkg/sentry/kernel/epoll/epoll_state.go
+++ b/pkg/sentry/kernel/epoll/epoll_state.go
@@ -21,8 +21,7 @@ import (
// afterLoad is invoked by stateify.
func (p *pollEntry) afterLoad() {
- p.waiter = waiter.Entry{Callback: &readyCallback{}}
- p.waiter.Context = p
+ p.waiter.Callback = p
p.file = refs.NewWeakRef(p.id.File, p)
p.id.File.EventRegister(&p.waiter, p.mask)
}
diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD
index b9126e946..2b3955598 100644
--- a/pkg/sentry/kernel/fasync/BUILD
+++ b/pkg/sentry/kernel/fasync/BUILD
@@ -11,6 +11,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go
index d32c3e90a..153d2cd9b 100644
--- a/pkg/sentry/kernel/fasync/fasync.go
+++ b/pkg/sentry/kernel/fasync/fasync.go
@@ -20,15 +20,21 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
-// New creates a new FileAsync.
+// New creates a new fs.FileAsync.
func New() fs.FileAsync {
return &FileAsync{}
}
+// NewVFS2 creates a new vfs.FileAsync.
+func NewVFS2() vfs.FileAsync {
+ return &FileAsync{}
+}
+
// FileAsync sends signals when the registered file is ready for IO.
//
// +stateify savable
@@ -170,3 +176,13 @@ func (a *FileAsync) SetOwnerProcessGroup(requester *kernel.Task, recipient *kern
a.recipientTG = nil
a.recipientPG = recipient
}
+
+// ClearOwner unsets the current signal recipient.
+func (a *FileAsync) ClearOwner() {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.requester = nil
+ a.recipientT = nil
+ a.recipientTG = nil
+ a.recipientPG = nil
+}
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index ed40b5303..4b7d234a4 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -29,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
)
// FDFlags define flags for an individual descriptor.
@@ -80,9 +81,6 @@ type FDTable struct {
refs.AtomicRefCount
k *Kernel
- // uid is a unique identifier.
- uid uint64
-
// mu protects below.
mu sync.Mutex `state:"nosave"`
@@ -130,7 +128,7 @@ func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) {
// drop drops the table reference.
func (f *FDTable) drop(file *fs.File) {
// Release locks.
- file.Dirent.Inode.LockCtx.Posix.UnlockRegion(lock.UniqueID(f.uid), lock.LockRange{0, lock.LockEOF})
+ file.Dirent.Inode.LockCtx.Posix.UnlockRegion(f, lock.LockRange{0, lock.LockEOF})
// Send inotify events.
d := file.Dirent
@@ -151,24 +149,27 @@ func (f *FDTable) drop(file *fs.File) {
// dropVFS2 drops the table reference.
func (f *FDTable) dropVFS2(file *vfs.FileDescription) {
- // TODO(gvisor.dev/issue/1480): Release locks.
- // TODO(gvisor.dev/issue/1479): Send inotify events.
+ // Release any POSIX lock possibly held by the FDTable. Range {0, 0} means the
+ // entire file.
+ err := file.UnlockPOSIX(context.Background(), f, 0, 0, linux.SEEK_SET)
+ if err != nil && err != syserror.ENOLCK {
+ panic(fmt.Sprintf("UnlockPOSIX failed: %v", err))
+ }
- // Drop the table reference.
- file.DecRef()
-}
+ // Generate inotify events.
+ ev := uint32(linux.IN_CLOSE_NOWRITE)
+ if file.IsWritable() {
+ ev = linux.IN_CLOSE_WRITE
+ }
+ file.Dentry().InotifyWithParent(ev, 0, vfs.PathEvent)
-// ID returns a unique identifier for this FDTable.
-func (f *FDTable) ID() uint64 {
- return f.uid
+ // Drop the table's reference.
+ file.DecRef()
}
// NewFDTable allocates a new FDTable that may be used by tasks in k.
func (k *Kernel) NewFDTable() *FDTable {
- f := &FDTable{
- k: k,
- uid: atomic.AddUint64(&k.fdMapUids, 1),
- }
+ f := &FDTable{k: k}
f.init()
return f
}
@@ -463,6 +464,29 @@ func (f *FDTable) SetFlags(fd int32, flags FDFlags) error {
return nil
}
+// SetFlagsVFS2 sets the flags for the given file descriptor.
+//
+// True is returned iff flags were changed.
+func (f *FDTable) SetFlagsVFS2(fd int32, flags FDFlags) error {
+ if fd < 0 {
+ // Don't accept negative FDs.
+ return syscall.EBADF
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ file, _, _ := f.getVFS2(fd)
+ if file == nil {
+ // No file found.
+ return syscall.EBADF
+ }
+
+ // Update the flags.
+ f.setVFS2(fd, file, flags)
+ return nil
+}
+
// Get returns a reference to the file and the flags for the FD or nil if no
// file is defined for the given fd.
//
diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go
index 732e66da4..bcc1b29a8 100644
--- a/pkg/sentry/kernel/futex/futex.go
+++ b/pkg/sentry/kernel/futex/futex.go
@@ -717,10 +717,10 @@ func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint3
}
}
-// UnlockPI unlock the futex following the Priority-inheritance futex
-// rules. The address provided must contain the caller's TID. If there are
-// waiters, TID of the next waiter (FIFO) is set to the given address, and the
-// waiter woken up. If there are no waiters, 0 is set to the address.
+// UnlockPI unlocks the futex following the Priority-inheritance futex rules.
+// The address provided must contain the caller's TID. If there are waiters,
+// TID of the next waiter (FIFO) is set to the given address, and the waiter
+// woken up. If there are no waiters, 0 is set to the address.
func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool) error {
k, err := getKey(t, addr, private)
if err != nil {
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 5efeb3767..15dae0f5b 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -34,7 +34,6 @@ package kernel
import (
"errors"
"fmt"
- "io"
"path/filepath"
"sync/atomic"
"time"
@@ -73,6 +72,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -81,6 +81,10 @@ import (
// easy access everywhere. To be removed once VFS2 becomes the default.
var VFS2Enabled = false
+// FUSEEnabled is set to true when FUSE is enabled. Added as a global for allow
+// easy access everywhere. To be removed once FUSE is completed.
+var FUSEEnabled = false
+
// Kernel represents an emulated Linux kernel. It must be initialized by calling
// Init() or LoadFrom().
//
@@ -194,11 +198,6 @@ type Kernel struct {
// cpuClockTickerSetting is protected by runningTasksMu.
cpuClockTickerSetting ktime.Setting
- // fdMapUids is an ever-increasing counter for generating FDTable uids.
- //
- // fdMapUids is mutable, and is accessed using atomic memory operations.
- fdMapUids uint64
-
// uniqueID is used to generate unique identifiers.
//
// uniqueID is mutable, and is accessed using atomic memory operations.
@@ -422,7 +421,7 @@ func (k *Kernel) Init(args InitKernelArgs) error {
// SaveTo saves the state of k to w.
//
// Preconditions: The kernel must be paused throughout the call to SaveTo.
-func (k *Kernel) SaveTo(w io.Writer) error {
+func (k *Kernel) SaveTo(w wire.Writer) error {
saveStart := time.Now()
ctx := k.SupervisorContext()
@@ -457,9 +456,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
return err
}
- // Ensure that all pending asynchronous work is complete:
- // - inode and mount release
- // - asynchronuous IO
+ // Ensure that all inode and mount release operations have completed.
fs.AsyncBarrier()
// Once all fs work has completed (flushed references have all been released),
@@ -480,18 +477,18 @@ func (k *Kernel) SaveTo(w io.Writer) error {
//
// N.B. This will also be saved along with the full kernel save below.
cpuidStart := time.Now()
- if err := state.Save(k.SupervisorContext(), w, k.FeatureSet(), nil); err != nil {
+ if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil {
return err
}
log.Infof("CPUID save took [%s].", time.Since(cpuidStart))
// Save the kernel state.
kernelStart := time.Now()
- var stats state.Stats
- if err := state.Save(k.SupervisorContext(), w, k, &stats); err != nil {
+ stats, err := state.Save(k.SupervisorContext(), w, k)
+ if err != nil {
return err
}
- log.Infof("Kernel save stats: %s", &stats)
+ log.Infof("Kernel save stats: %s", stats.String())
log.Infof("Kernel save took [%s].", time.Since(kernelStart))
// Save the memory file's state.
@@ -636,7 +633,7 @@ func (ts *TaskSet) unregisterEpollWaiters() {
}
// LoadFrom returns a new Kernel loaded from args.
-func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
+func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
loadStart := time.Now()
initAppCores := k.applicationCores
@@ -647,7 +644,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// don't need to explicitly install it in the Kernel.
cpuidStart := time.Now()
var features cpuid.FeatureSet
- if err := state.Load(k.SupervisorContext(), r, &features, nil); err != nil {
+ if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil {
return err
}
log.Infof("CPUID load took [%s].", time.Since(cpuidStart))
@@ -662,11 +659,11 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// Load the kernel state.
kernelStart := time.Now()
- var stats state.Stats
- if err := state.Load(k.SupervisorContext(), r, k, &stats); err != nil {
+ stats, err := state.Load(k.SupervisorContext(), r, k)
+ if err != nil {
return err
}
- log.Infof("Kernel load stats: %s", &stats)
+ log.Infof("Kernel load stats: %s", stats.String())
log.Infof("Kernel load took [%s].", time.Since(kernelStart))
// rootNetworkNamespace should be populated after loading the state file.
@@ -897,7 +894,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
if mntnsVFS2 == nil {
// MountNamespaceVFS2 adds a reference to the namespace, which is
// transferred to the new process.
- mntnsVFS2 = k.GlobalInit().Leader().MountNamespaceVFS2()
+ mntnsVFS2 = k.globalInit.Leader().MountNamespaceVFS2()
}
// Get the root directory from the MountNamespace.
root := args.MountNamespaceVFS2.Root()
@@ -1254,13 +1251,15 @@ func (k *Kernel) Kill(es ExitStatus) {
}
// Pause requests that all tasks in k temporarily stop executing, and blocks
-// until all tasks in k have stopped. Multiple calls to Pause nest and require
-// an equal number of calls to Unpause to resume execution.
+// until all tasks and asynchronous I/O operations in k have stopped. Multiple
+// calls to Pause nest and require an equal number of calls to Unpause to
+// resume execution.
func (k *Kernel) Pause() {
k.extMu.Lock()
k.tasks.BeginExternalStop()
k.extMu.Unlock()
k.tasks.runningGoroutines.Wait()
+ k.tasks.aioGoroutines.Wait()
}
// Unpause ends the effect of a previous call to Pause. If Unpause is called
@@ -1470,6 +1469,11 @@ func (k *Kernel) NowMonotonic() int64 {
return now
}
+// AfterFunc implements tcpip.Clock.AfterFunc.
+func (k *Kernel) AfterFunc(d time.Duration, f func()) tcpip.Timer {
+ return ktime.TcpipAfterFunc(k.realtimeClock, d, f)
+}
+
// SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or
// LoadFrom.
func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) {
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index f29dc0472..449643118 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -8,6 +8,7 @@ go_library(
"device.go",
"node.go",
"pipe.go",
+ "pipe_unsafe.go",
"pipe_util.go",
"reader.go",
"reader_writer.go",
@@ -20,10 +21,12 @@ go_library(
"//pkg/amutex",
"//pkg/buffer",
"//pkg/context",
+ "//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/syserror",
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 62c8691f1..79645d7d2 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -207,7 +207,10 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
p.mu.Lock()
defer p.mu.Unlock()
+ return p.readLocked(ctx, ops)
+}
+func (p *Pipe) readLocked(ctx context.Context, ops readOps) (int64, error) {
// Is the pipe empty?
if p.view.Size() == 0 {
if !p.HasWriters() {
@@ -246,7 +249,10 @@ type writeOps struct {
func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
p.mu.Lock()
defer p.mu.Unlock()
+ return p.writeLocked(ctx, ops)
+}
+func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) {
// Can't write to a pipe with no readers.
if !p.HasReaders() {
return 0, syscall.EPIPE
diff --git a/pkg/sentry/kernel/pipe/pipe_unsafe.go b/pkg/sentry/kernel/pipe/pipe_unsafe.go
new file mode 100644
index 000000000..dd60cba24
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/pipe_unsafe.go
@@ -0,0 +1,35 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "unsafe"
+)
+
+// lockTwoPipes locks both x.mu and y.mu in an order that is guaranteed to be
+// consistent for both lockTwoPipes(x, y) and lockTwoPipes(y, x), such that
+// concurrent calls cannot deadlock.
+//
+// Preconditions: x != y.
+func lockTwoPipes(x, y *Pipe) {
+ // Lock the two pipes in order of increasing address.
+ if uintptr(unsafe.Pointer(x)) < uintptr(unsafe.Pointer(y)) {
+ x.mu.Lock()
+ y.mu.Lock()
+ } else {
+ y.mu.Lock()
+ x.mu.Lock()
+ }
+}
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index b54f08a30..45d4c5fc1 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -16,8 +16,11 @@ package pipe
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -59,11 +62,13 @@ func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe {
//
// Preconditions: statusFlags should not contain an open access mode.
func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) {
- return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags)
+ // Connected pipes share the same locks.
+ locks := &vfs.FileLocks{}
+ return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks)
}
// Open opens the pipe represented by vp.
-func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, error) {
+func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) {
vp.mu.Lock()
defer vp.mu.Unlock()
@@ -73,7 +78,7 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s
return nil, syserror.EINVAL
}
- fd := vp.newFD(mnt, vfsd, statusFlags)
+ fd := vp.newFD(mnt, vfsd, statusFlags, locks)
// Named pipes have special blocking semantics during open:
//
@@ -125,10 +130,11 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s
}
// Preconditions: vp.mu must be held.
-func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) *vfs.FileDescription {
+func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) *vfs.FileDescription {
fd := &VFSPipeFD{
pipe: &vp.pipe,
}
+ fd.LockFD.Init(locks)
fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{
DenyPRead: true,
DenyPWrite: true,
@@ -150,11 +156,14 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) *
return &fd.vfsfd
}
-// VFSPipeFD implements vfs.FileDescriptionImpl for pipes.
+// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements
+// non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to
+// other FileDescriptions for splice(2) and tee(2).
type VFSPipeFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
pipe *Pipe
}
@@ -191,6 +200,11 @@ func (fd *VFSPipeFD) Readiness(mask waiter.EventMask) waiter.EventMask {
}
}
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (fd *VFSPipeFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.ESPIPE
+}
+
// EventRegister implements waiter.Waitable.EventRegister.
func (fd *VFSPipeFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
fd.pipe.EventRegister(e, mask)
@@ -229,3 +243,226 @@ func (fd *VFSPipeFD) PipeSize() int64 {
func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) {
return fd.pipe.SetFifoSize(size)
}
+
+// IOSequence returns a useremm.IOSequence that reads up to count bytes from,
+// or writes up to count bytes to, fd.
+func (fd *VFSPipeFD) IOSequence(count int64) usermem.IOSequence {
+ return usermem.IOSequence{
+ IO: fd,
+ Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}),
+ }
+}
+
+// CopyIn implements usermem.IO.CopyIn.
+func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) {
+ origCount := int64(len(dst))
+ n, err := fd.pipe.read(ctx, readOps{
+ left: func() int64 {
+ return int64(len(dst))
+ },
+ limit: func(l int64) {
+ dst = dst[:l]
+ },
+ read: func(view *buffer.View) (int64, error) {
+ n, err := view.ReadAt(dst, 0)
+ view.TrimFront(int64(n))
+ return int64(n), err
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventOut)
+ }
+ if err == nil && n != origCount {
+ return int(n), syserror.ErrWouldBlock
+ }
+ return int(n), err
+}
+
+// CopyOut implements usermem.IO.CopyOut.
+func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) {
+ origCount := int64(len(src))
+ n, err := fd.pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return int64(len(src))
+ },
+ limit: func(l int64) {
+ src = src[:l]
+ },
+ write: func(view *buffer.View) (int64, error) {
+ view.Append(src)
+ return int64(len(src)), nil
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventIn)
+ }
+ if err == nil && n != origCount {
+ return int(n), syserror.ErrWouldBlock
+ }
+ return int(n), err
+}
+
+// ZeroOut implements usermem.IO.ZeroOut.
+func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) {
+ origCount := toZero
+ n, err := fd.pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return toZero
+ },
+ limit: func(l int64) {
+ toZero = l
+ },
+ write: func(view *buffer.View) (int64, error) {
+ view.Grow(view.Size()+toZero, true /* zero */)
+ return toZero, nil
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventIn)
+ }
+ if err == nil && n != origCount {
+ return n, syserror.ErrWouldBlock
+ }
+ return n, err
+}
+
+// CopyInTo implements usermem.IO.CopyInTo.
+func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) {
+ count := ars.NumBytes()
+ if count == 0 {
+ return 0, nil
+ }
+ origCount := count
+ n, err := fd.pipe.read(ctx, readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(view *buffer.View) (int64, error) {
+ n, err := view.ReadToSafememWriter(dst, uint64(count))
+ view.TrimFront(int64(n))
+ return int64(n), err
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventOut)
+ }
+ if err == nil && n != origCount {
+ return n, syserror.ErrWouldBlock
+ }
+ return n, err
+}
+
+// CopyOutFrom implements usermem.IO.CopyOutFrom.
+func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) {
+ count := ars.NumBytes()
+ if count == 0 {
+ return 0, nil
+ }
+ origCount := count
+ n, err := fd.pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(view *buffer.View) (int64, error) {
+ n, err := view.WriteFromSafememReader(src, uint64(count))
+ return int64(n), err
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventIn)
+ }
+ if err == nil && n != origCount {
+ return n, syserror.ErrWouldBlock
+ }
+ return n, err
+}
+
+// SwapUint32 implements usermem.IO.SwapUint32.
+func (fd *VFSPipeFD) SwapUint32(ctx context.Context, addr usermem.Addr, new uint32, opts usermem.IOOpts) (uint32, error) {
+ // How did a pipe get passed as the virtual address space to futex(2)?
+ panic("VFSPipeFD.SwapUint32 called unexpectedly")
+}
+
+// CompareAndSwapUint32 implements usermem.IO.CompareAndSwapUint32.
+func (fd *VFSPipeFD) CompareAndSwapUint32(ctx context.Context, addr usermem.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) {
+ panic("VFSPipeFD.CompareAndSwapUint32 called unexpectedly")
+}
+
+// LoadUint32 implements usermem.IO.LoadUint32.
+func (fd *VFSPipeFD) LoadUint32(ctx context.Context, addr usermem.Addr, opts usermem.IOOpts) (uint32, error) {
+ panic("VFSPipeFD.LoadUint32 called unexpectedly")
+}
+
+// Splice reads up to count bytes from src and writes them to dst. It returns
+// the number of bytes moved.
+//
+// Preconditions: count > 0.
+func Splice(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) {
+ return spliceOrTee(ctx, dst, src, count, true /* removeFromSrc */)
+}
+
+// Tee reads up to count bytes from src and writes them to dst, without
+// removing the read bytes from src. It returns the number of bytes copied.
+//
+// Preconditions: count > 0.
+func Tee(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) {
+ return spliceOrTee(ctx, dst, src, count, false /* removeFromSrc */)
+}
+
+// Preconditions: count > 0.
+func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFromSrc bool) (int64, error) {
+ if dst.pipe == src.pipe {
+ return 0, syserror.EINVAL
+ }
+
+ lockTwoPipes(dst.pipe, src.pipe)
+ defer dst.pipe.mu.Unlock()
+ defer src.pipe.mu.Unlock()
+
+ n, err := dst.pipe.writeLocked(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(dstView *buffer.View) (int64, error) {
+ return src.pipe.readLocked(ctx, readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(srcView *buffer.View) (int64, error) {
+ n, err := srcView.ReadToSafememWriter(dstView, uint64(count))
+ if n > 0 && removeFromSrc {
+ srcView.TrimFront(int64(n))
+ }
+ return int64(n), err
+ },
+ })
+ },
+ })
+ if n > 0 {
+ dst.pipe.Notify(waiter.EventIn)
+ src.pipe.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *VFSPipeFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *VFSPipeFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index bfd779837..c211fc8d0 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -20,7 +20,6 @@ go_library(
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
- "//pkg/sentry/platform",
"//pkg/sentry/usage",
"//pkg/sync",
"//pkg/syserror",
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index f66cfcc7f..55b4c2cdb 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -45,7 +45,6 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -370,7 +369,7 @@ type Shm struct {
// fr is the offset into mfp.MemoryFile() that backs this contents of this
// segment. Immutable.
- fr platform.FileRange
+ fr memmap.FileRange
// mu protects all fields below.
mu sync.Mutex `state:"nosave"`
diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go
index 4607cde2f..a83ce219c 100644
--- a/pkg/sentry/kernel/syslog.go
+++ b/pkg/sentry/kernel/syslog.go
@@ -98,6 +98,15 @@ func (s *syslog) Log() []byte {
s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, selectMessage()))...)
}
+ if VFS2Enabled {
+ time += rand.Float64() / 2
+ s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Setting up VFS2..."))...)
+ if FUSEEnabled {
+ time += rand.Float64() / 2
+ s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Setting up FUSE..."))...)
+ }
+ }
+
time += rand.Float64() / 2
s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Ready!"))...)
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index f48247c94..c4db05bd8 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -68,6 +68,21 @@ type Task struct {
// runState is exclusive to the task goroutine.
runState taskRunState
+ // taskWorkCount represents the current size of the task work queue. It is
+ // used to avoid acquiring taskWorkMu when the queue is empty.
+ //
+ // Must accessed with atomic memory operations.
+ taskWorkCount int32
+
+ // taskWorkMu protects taskWork.
+ taskWorkMu sync.Mutex `state:"nosave"`
+
+ // taskWork is a queue of work to be executed before resuming user execution.
+ // It is similar to the task_work mechanism in Linux.
+ //
+ // taskWork is exclusive to the task goroutine.
+ taskWork []TaskWorker
+
// haveSyscallReturn is true if tc.Arch().Return() represents a value
// returned by a syscall (or set by ptrace after a syscall).
//
@@ -550,6 +565,10 @@ type Task struct {
// futexWaiter is exclusive to the task goroutine.
futexWaiter *futex.Waiter `state:"nosave"`
+ // robustList is a pointer to the head of the tasks's robust futex
+ // list.
+ robustList usermem.Addr
+
// startTime is the real time at which the task started. It is set when
// a Task is created or invokes execve(2).
//
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 00c425cca..7803b98d0 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -198,11 +198,18 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{})
t.tg.pidns.owner.mu.Unlock()
+ oldFDTable := t.fdTable
+ t.fdTable = t.fdTable.Fork()
+ oldFDTable.DecRef()
+
// Remove FDs with the CloseOnExec flag set.
t.fdTable.RemoveIf(func(_ *fs.File, _ *vfs.FileDescription, flags FDFlags) bool {
return flags.CloseOnExec
})
+ // Handle the robust futex list.
+ t.exitRobustList()
+
// NOTE(b/30815691): We currently do not implement privileged
// executables (set-user/group-ID bits and file capabilities). This
// allows us to unconditionally enable user dumpability on the new mm.
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index c4ade6e8e..231ac548a 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -253,6 +253,9 @@ func (*runExitMain) execute(t *Task) taskRunState {
}
}
+ // Handle the robust futex list.
+ t.exitRobustList()
+
// Deactivate the address space and update max RSS before releasing the
// task's MM.
t.Deactivate()
diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go
index a53e77c9f..4b535c949 100644
--- a/pkg/sentry/kernel/task_futex.go
+++ b/pkg/sentry/kernel/task_futex.go
@@ -15,6 +15,7 @@
package kernel
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -52,3 +53,127 @@ func (t *Task) LoadUint32(addr usermem.Addr) (uint32, error) {
func (t *Task) GetSharedKey(addr usermem.Addr) (futex.Key, error) {
return t.MemoryManager().GetSharedFutexKey(t, addr)
}
+
+// GetRobustList sets the robust futex list for the task.
+func (t *Task) GetRobustList() usermem.Addr {
+ t.mu.Lock()
+ addr := t.robustList
+ t.mu.Unlock()
+ return addr
+}
+
+// SetRobustList sets the robust futex list for the task.
+func (t *Task) SetRobustList(addr usermem.Addr) {
+ t.mu.Lock()
+ t.robustList = addr
+ t.mu.Unlock()
+}
+
+// exitRobustList walks the robust futex list, marking locks dead and notifying
+// wakers. It corresponds to Linux's exit_robust_list(). Following Linux,
+// errors are silently ignored.
+func (t *Task) exitRobustList() {
+ t.mu.Lock()
+ addr := t.robustList
+ t.robustList = 0
+ t.mu.Unlock()
+
+ if addr == 0 {
+ return
+ }
+
+ var rl linux.RobustListHead
+ if _, err := rl.CopyIn(t, usermem.Addr(addr)); err != nil {
+ return
+ }
+
+ next := rl.List
+ done := 0
+ var pendingLockAddr usermem.Addr
+ if rl.ListOpPending != 0 {
+ pendingLockAddr = usermem.Addr(rl.ListOpPending + rl.FutexOffset)
+ }
+
+ // Wake up normal elements.
+ for usermem.Addr(next) != addr {
+ // We traverse to the next element of the list before we
+ // actually wake anything. This prevents the race where waking
+ // this futex causes a modification of the list.
+ thisLockAddr := usermem.Addr(next + rl.FutexOffset)
+
+ // Try to decode the next element in the list before waking the
+ // current futex. But don't check the error until after we've
+ // woken the current futex. Linux does it in this order too
+ _, nextErr := t.CopyIn(usermem.Addr(next), &next)
+
+ // Wakeup the current futex if it's not pending.
+ if thisLockAddr != pendingLockAddr {
+ t.wakeRobustListOne(thisLockAddr)
+ }
+
+ // If there was an error copying the next futex, we must bail.
+ if nextErr != nil {
+ break
+ }
+
+ // This is a user structure, so it could be a massive list, or
+ // even contain a loop if they are trying to mess with us. We
+ // cap traversal to prevent that.
+ done++
+ if done >= linux.ROBUST_LIST_LIMIT {
+ break
+ }
+ }
+
+ // Is there a pending entry to wake?
+ if pendingLockAddr != 0 {
+ t.wakeRobustListOne(pendingLockAddr)
+ }
+}
+
+// wakeRobustListOne wakes a single futex from the robust list.
+func (t *Task) wakeRobustListOne(addr usermem.Addr) {
+ // Bit 0 in address signals PI futex.
+ pi := addr&1 == 1
+ addr = addr &^ 1
+
+ // Load the futex.
+ f, err := t.LoadUint32(addr)
+ if err != nil {
+ // Can't read this single value? Ignore the problem.
+ // We can wake the other futexes in the list.
+ return
+ }
+
+ tid := uint32(t.ThreadID())
+ for {
+ // Is this held by someone else?
+ if f&linux.FUTEX_TID_MASK != tid {
+ return
+ }
+
+ // This thread is dying and it's holding this futex. We need to
+ // set the owner died bit and wake up any waiters.
+ newF := (f & linux.FUTEX_WAITERS) | linux.FUTEX_OWNER_DIED
+ if curF, err := t.CompareAndSwapUint32(addr, f, newF); err != nil {
+ return
+ } else if curF != f {
+ // Futex changed out from under us. Try again...
+ f = curF
+ continue
+ }
+
+ // Wake waiters if there are any.
+ if f&linux.FUTEX_WAITERS != 0 {
+ private := f&linux.FUTEX_PRIVATE_FLAG != 0
+ if pi {
+ t.Futex().UnlockPI(t, addr, tid, private)
+ return
+ }
+ t.Futex().Wake(t, addr, private, linux.FUTEX_BITSET_MATCH_ANY, 1)
+ }
+
+ // Done.
+ return
+ }
+}
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index d654dd997..7d4f44caf 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -167,7 +167,22 @@ func (app *runApp) execute(t *Task) taskRunState {
return (*runInterrupt)(nil)
}
- // We're about to switch to the application again. If there's still a
+ // Execute any task work callbacks before returning to user space.
+ if atomic.LoadInt32(&t.taskWorkCount) > 0 {
+ t.taskWorkMu.Lock()
+ queue := t.taskWork
+ t.taskWork = nil
+ atomic.StoreInt32(&t.taskWorkCount, 0)
+ t.taskWorkMu.Unlock()
+
+ // Do not hold taskWorkMu while executing task work, which may register
+ // more work.
+ for _, work := range queue {
+ work.TaskWork(t)
+ }
+ }
+
+ // We're about to switch to the application again. If there's still an
// unhandled SyscallRestartErrno that wasn't translated to an EINTR,
// restart the syscall that was interrupted. If there's a saved signal
// mask, restore it. (Note that restoring the saved signal mask may unblock
diff --git a/pkg/sentry/kernel/task_work.go b/pkg/sentry/kernel/task_work.go
new file mode 100644
index 000000000..dda5a433a
--- /dev/null
+++ b/pkg/sentry/kernel/task_work.go
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import "sync/atomic"
+
+// TaskWorker is a deferred task.
+//
+// This must be savable.
+type TaskWorker interface {
+ // TaskWork will be executed prior to returning to user space. Note that
+ // TaskWork may call RegisterWork again, but this will not be executed until
+ // the next return to user space, unlike in Linux. This effectively allows
+ // registration of indefinite user return hooks, but not by default.
+ TaskWork(t *Task)
+}
+
+// RegisterWork can be used to register additional task work that will be
+// performed prior to returning to user space. See TaskWorker.TaskWork for
+// semantics regarding registration.
+func (t *Task) RegisterWork(work TaskWorker) {
+ t.taskWorkMu.Lock()
+ defer t.taskWorkMu.Unlock()
+ atomic.AddInt32(&t.taskWorkCount, 1)
+ t.taskWork = append(t.taskWork, work)
+}
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 52849f5b3..4dfd2c990 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -366,7 +366,8 @@ func (tg *ThreadGroup) SetControllingTTY(tty *TTY, arg int32) error {
// terminal is stolen, and all processes that had it as controlling
// terminal lose it." - tty_ioctl(4)
if tty.tg != nil && tg.processGroup.session != tty.tg.processGroup.session {
- if !auth.CredentialsFromContext(tg.leader).HasCapability(linux.CAP_SYS_ADMIN) || arg != 1 {
+ // Stealing requires CAP_SYS_ADMIN in the root user namespace.
+ if creds := auth.CredentialsFromContext(tg.leader); !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) || arg != 1 {
return syserror.EPERM
}
// Steal the TTY away. Unlike TIOCNOTTY, don't send signals.
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index bf2dabb6e..872e1a82d 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -87,6 +87,13 @@ type TaskSet struct {
// at time of save (but note that this is not necessarily the same thing as
// sync.WaitGroup's zero value).
runningGoroutines sync.WaitGroup `state:"nosave"`
+
+ // aioGoroutines is the number of goroutines running async I/O
+ // callbacks.
+ //
+ // aioGoroutines is not saved but is required to be zero at the time of
+ // save.
+ aioGoroutines sync.WaitGroup `state:"nosave"`
}
// newTaskSet returns a new, empty TaskSet.
diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD
index 7ba7dc50c..2817aa3ba 100644
--- a/pkg/sentry/kernel/time/BUILD
+++ b/pkg/sentry/kernel/time/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "time",
srcs = [
"context.go",
+ "tcpip.go",
"time.go",
],
visibility = ["//pkg/sentry:internal"],
diff --git a/pkg/sentry/kernel/time/tcpip.go b/pkg/sentry/kernel/time/tcpip.go
new file mode 100644
index 000000000..c4474c0cf
--- /dev/null
+++ b/pkg/sentry/kernel/time/tcpip.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "sync"
+ "time"
+)
+
+// TcpipAfterFunc waits for duration to elapse according to clock then runs fn.
+// The timer is started immediately and will fire exactly once.
+func TcpipAfterFunc(clock Clock, duration time.Duration, fn func()) *TcpipTimer {
+ timer := &TcpipTimer{
+ clock: clock,
+ }
+ timer.notifier = functionNotifier{
+ fn: func() {
+ // tcpip.Timer.Stop() explicitly states that the function is called in a
+ // separate goroutine that Stop() does not synchronize with.
+ // Timer.Destroy() synchronizes with calls to TimerListener.Notify().
+ // This is semantically meaningful because, in the former case, it's
+ // legal to call tcpip.Timer.Stop() while holding locks that may also be
+ // taken by the function, but this isn't so in the latter case. Most
+ // immediately, Timer calls TimerListener.Notify() while holding
+ // Timer.mu. A deadlock occurs without spawning a goroutine:
+ // T1: (Timer expires)
+ // => Timer.Tick() <- Timer.mu.Lock() called
+ // => TimerListener.Notify()
+ // => Timer.Stop()
+ // => Timer.Destroy() <- Timer.mu.Lock() called, deadlock!
+ //
+ // Spawning a goroutine avoids the deadlock:
+ // T1: (Timer expires)
+ // => Timer.Tick() <- Timer.mu.Lock() called
+ // => TimerListener.Notify() <- Launches T2
+ // T2:
+ // => Timer.Stop()
+ // => Timer.Destroy() <- Timer.mu.Lock() called, blocks
+ // T1:
+ // => (returns) <- Timer.mu.Unlock() called
+ // T2:
+ // => (continues) <- No deadlock!
+ go func() {
+ timer.Stop()
+ fn()
+ }()
+ },
+ }
+ timer.Reset(duration)
+ return timer
+}
+
+// TcpipTimer is a resettable timer with variable duration expirations.
+// Implements tcpip.Timer, which does not define a Destroy method; instead, all
+// resources are released after timer expiration and calls to Timer.Stop.
+//
+// Must be created by AfterFunc.
+type TcpipTimer struct {
+ // clock is the time source. clock is immutable.
+ clock Clock
+
+ // notifier is called when the Timer expires. notifier is immutable.
+ notifier functionNotifier
+
+ // mu protects t.
+ mu sync.Mutex
+
+ // t stores the latest running Timer. This is replaced whenever Reset is
+ // called since Timer cannot be restarted once it has been Destroyed by Stop.
+ //
+ // This field is nil iff Stop has been called.
+ t *Timer
+}
+
+// Stop implements tcpip.Timer.Stop.
+func (r *TcpipTimer) Stop() bool {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.t == nil {
+ return false
+ }
+ _, lastSetting := r.t.Swap(Setting{})
+ r.t.Destroy()
+ r.t = nil
+ return lastSetting.Enabled
+}
+
+// Reset implements tcpip.Timer.Reset.
+func (r *TcpipTimer) Reset(d time.Duration) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.t == nil {
+ r.t = NewTimer(r.clock, &r.notifier)
+ }
+
+ r.t.Swap(Setting{
+ Enabled: true,
+ Period: 0,
+ Next: r.clock.Now().Add(d),
+ })
+}
+
+// functionNotifier is a TimerListener that runs a function.
+//
+// functionNotifier cannot be saved or loaded.
+type functionNotifier struct {
+ fn func()
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (f *functionNotifier) Notify(uint64, Setting) (Setting, bool) {
+ f.fn()
+ return Setting{}, false
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (f *functionNotifier) Destroy() {}
diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go
index da0ea7bb5..7c4fefb16 100644
--- a/pkg/sentry/kernel/timekeeper.go
+++ b/pkg/sentry/kernel/timekeeper.go
@@ -21,8 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/log"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -90,7 +90,7 @@ type Timekeeper struct {
// NewTimekeeper does not take ownership of paramPage.
//
// SetClocks must be called on the returned Timekeeper before it is usable.
-func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage platform.FileRange) (*Timekeeper, error) {
+func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) (*Timekeeper, error) {
return &Timekeeper{
params: NewVDSOParamPage(mfp, paramPage),
}, nil
@@ -186,6 +186,7 @@ func (t *Timekeeper) startUpdater() {
timer := time.NewTicker(sentrytime.ApproxUpdateInterval)
t.wg.Add(1)
go func() { // S/R-SAFE: stopped during save.
+ defer t.wg.Done()
for {
// Start with an update immediately, so the clocks are
// ready ASAP.
@@ -209,9 +210,6 @@ func (t *Timekeeper) startUpdater() {
p.realtimeBaseRef = int64(realtimeParams.BaseRef)
p.realtimeFrequency = realtimeParams.Frequency
}
-
- log.Debugf("Updating VDSO parameters: %+v", p)
-
return p
}); err != nil {
log.Warningf("Unable to update VDSO parameter page: %v", err)
@@ -220,7 +218,6 @@ func (t *Timekeeper) startUpdater() {
select {
case <-timer.C:
case <-t.stop:
- t.wg.Done()
return
}
}
diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go
index f1b3c212c..290c32466 100644
--- a/pkg/sentry/kernel/vdso.go
+++ b/pkg/sentry/kernel/vdso.go
@@ -19,8 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -58,7 +58,7 @@ type vdsoParams struct {
type VDSOParamPage struct {
// The parameter page is fr, allocated from mfp.MemoryFile().
mfp pgalloc.MemoryFileProvider
- fr platform.FileRange
+ fr memmap.FileRange
// seq is the current sequence count written to the page.
//
@@ -81,7 +81,7 @@ type VDSOParamPage struct {
// * VDSOParamPage must be the only writer to fr.
//
// * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block.
-func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *VDSOParamPage {
+func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage {
return &VDSOParamPage{mfp: mfp, fr: fr}
}
diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD
index c6aa65f28..34bdb0b69 100644
--- a/pkg/sentry/loader/BUILD
+++ b/pkg/sentry/loader/BUILD
@@ -30,9 +30,6 @@ go_library(
"//pkg/rand",
"//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/fs",
- "//pkg/sentry/fs/anon",
- "//pkg/sentry/fs/fsutil",
"//pkg/sentry/fsbridge",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/limits",
@@ -45,6 +42,5 @@ go_library(
"//pkg/syserr",
"//pkg/syserror",
"//pkg/usermem",
- "//pkg/waiter",
],
)
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index 616fafa2c..ddeaff3db 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -90,14 +90,23 @@ type elfInfo struct {
sharedObject bool
}
+// fullReader interface extracts the ReadFull method from fsbridge.File so that
+// client code does not need to define an entire fsbridge.File when only read
+// functionality is needed.
+//
+// TODO(gvisor.dev/issue/1035): Once VFS2 ships, rewrite this to wrap
+// vfs.FileDescription's PRead/Read instead.
+type fullReader interface {
+ // ReadFull is the same as fsbridge.File.ReadFull.
+ ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error)
+}
+
// parseHeader parse the ELF header, verifying that this is a supported ELF
// file and returning the ELF program headers.
//
// This is similar to elf.NewFile, except that it is more strict about what it
// accepts from the ELF, and it doesn't parse unnecessary parts of the file.
-//
-// ctx may be nil if f does not need it.
-func parseHeader(ctx context.Context, f fsbridge.File) (elfInfo, error) {
+func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
// Check ident first; it will tell us the endianness of the rest of the
// structs.
var ident [elf.EI_NIDENT]byte
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index 88449fe95..986c7fb4d 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -27,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
@@ -80,22 +79,6 @@ type LoadArgs struct {
Features *cpuid.FeatureSet
}
-// readFull behaves like io.ReadFull for an *fs.File.
-func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
- var total int64
- for dst.NumBytes() > 0 {
- n, err := f.Preadv(ctx, dst, offset+total)
- total += n
- if err == io.EOF && total != 0 {
- return total, io.ErrUnexpectedEOF
- } else if err != nil {
- return total, err
- }
- dst = dst.DropFirst64(n)
- }
- return total, nil
-}
-
// openPath opens args.Filename and checks that it is valid for loading.
//
// openPath returns an *fs.Dirent and *fs.File for args.Filename, which is not
@@ -238,14 +221,14 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V
// Load the executable itself.
loaded, ac, file, newArgv, err := loadExecutable(ctx, args)
if err != nil {
- return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux())
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux())
}
defer file.DecRef()
// Load the VDSO.
vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded)
if err != nil {
- return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Error loading VDSO: %v", err), syserr.FromError(err).ToLinux())
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("error loading VDSO: %v", err), syserr.FromError(err).ToLinux())
}
// Setup the heap. brk starts at the next page after the end of the
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
index 00977fc08..05a294fe6 100644
--- a/pkg/sentry/loader/vdso.go
+++ b/pkg/sentry/loader/vdso.go
@@ -26,10 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/fs/anon"
- "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -37,7 +33,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/pkg/waiter"
)
const vdsoPrelink = 0xffffffffff700000
@@ -55,52 +50,11 @@ func (f *fileContext) Value(key interface{}) interface{} {
}
}
-// byteReader implements fs.FileOperations for reading from a []byte source.
-type byteReader struct {
- fsutil.FileNoFsync `state:"nosave"`
- fsutil.FileNoIoctl `state:"nosave"`
- fsutil.FileNoMMap `state:"nosave"`
- fsutil.FileNoSplice `state:"nosave"`
- fsutil.FileNoopFlush `state:"nosave"`
- fsutil.FileNoopRelease `state:"nosave"`
- fsutil.FileNotDirReaddir `state:"nosave"`
- fsutil.FilePipeSeek `state:"nosave"`
- fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- waiter.AlwaysReady `state:"nosave"`
-
+type byteFullReader struct {
data []byte
}
-var _ fs.FileOperations = (*byteReader)(nil)
-
-// newByteReaderFile creates a fake file to read data from.
-//
-// TODO(gvisor.dev/issue/1623): Convert to VFS2.
-func newByteReaderFile(ctx context.Context, data []byte) *fs.File {
- // Create a fake inode.
- inode := fs.NewInode(
- ctx,
- &fsutil.SimpleFileInode{},
- fs.NewPseudoMountSource(ctx),
- fs.StableAttr{
- Type: fs.Anonymous,
- DeviceID: anon.PseudoDevice.DeviceID(),
- InodeID: anon.PseudoDevice.NextIno(),
- BlockSize: usermem.PageSize,
- })
-
- // Use the fake inode to create a fake dirent.
- dirent := fs.NewTransientDirent(inode)
- defer dirent.DecRef()
-
- // Use the fake dirent to make a fake file.
- flags := fs.FileFlags{Read: true, Pread: true}
- return fs.NewFile(&fileContext{Context: context.Background()}, dirent, flags, &byteReader{
- data: data,
- })
-}
-
-func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+func (b *byteFullReader) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
if offset < 0 {
return 0, syserror.EINVAL
}
@@ -111,10 +65,6 @@ func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequ
return int64(n), err
}
-func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
- panic("Write not supported")
-}
-
// validateVDSO checks that the VDSO can be loaded by loadVDSO.
//
// VDSOs are special (see below). Since we are going to map the VDSO directly
@@ -130,7 +80,7 @@ func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSeq
// * PT_LOAD segments don't extend beyond the end of the file.
//
// ctx may be nil if f does not need it.
-func validateVDSO(ctx context.Context, f fsbridge.File, size uint64) (elfInfo, error) {
+func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, error) {
info, err := parseHeader(ctx, f)
if err != nil {
log.Infof("Unable to parse VDSO header: %v", err)
@@ -248,13 +198,12 @@ func getSymbolValueFromVDSO(symbol string) (uint64, error) {
// PrepareVDSO validates the system VDSO and returns a VDSO, containing the
// param page for updating by the kernel.
-func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
- vdsoFile := fsbridge.NewFSFile(newByteReaderFile(ctx, vdsoBin))
+func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
+ vdsoFile := &byteFullReader{data: vdsoBin}
// First make sure the VDSO is valid. vdsoFile does not use ctx, so a
// nil context can be passed.
info, err := validateVDSO(nil, vdsoFile, uint64(len(vdsoBin)))
- vdsoFile.DecRef()
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD
index a98b66de1..2c95669cd 100644
--- a/pkg/sentry/memmap/BUILD
+++ b/pkg/sentry/memmap/BUILD
@@ -28,9 +28,21 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "file_range",
+ out = "file_range.go",
+ package = "memmap",
+ prefix = "File",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "uint64",
+ },
+)
+
go_library(
name = "memmap",
srcs = [
+ "file_range.go",
"mappable_range.go",
"mapping_set.go",
"mapping_set_impl.go",
@@ -40,7 +52,7 @@ go_library(
deps = [
"//pkg/context",
"//pkg/log",
- "//pkg/sentry/platform",
+ "//pkg/safemem",
"//pkg/syserror",
"//pkg/usermem",
],
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index c6db9fc8f..c188f6c29 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -19,12 +19,12 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/usermem"
)
// Mappable represents a memory-mappable object, a mutable mapping from uint64
-// offsets to (platform.File, uint64 File offset) pairs.
+// offsets to (File, uint64 File offset) pairs.
//
// See mm/mm.go for Mappable's place in the lock order.
//
@@ -74,7 +74,7 @@ type Mappable interface {
// Translations are valid until invalidated by a callback to
// MappingSpace.Invalidate or until the caller removes its mapping of the
// translated range. Mappable implementations must ensure that at least one
- // reference is held on all pages in a platform.File that may be the result
+ // reference is held on all pages in a File that may be the result
// of a valid Translation.
//
// Preconditions: required.Length() > 0. optional.IsSupersetOf(required).
@@ -100,7 +100,7 @@ type Translation struct {
Source MappableRange
// File is the mapped file.
- File platform.File
+ File File
// Offset is the offset into File at which this Translation begins.
Offset uint64
@@ -110,9 +110,9 @@ type Translation struct {
Perms usermem.AccessType
}
-// FileRange returns the platform.FileRange represented by t.
-func (t Translation) FileRange() platform.FileRange {
- return platform.FileRange{t.Offset, t.Offset + t.Source.Length()}
+// FileRange returns the FileRange represented by t.
+func (t Translation) FileRange() FileRange {
+ return FileRange{t.Offset, t.Offset + t.Source.Length()}
}
// CheckTranslateResult returns an error if (ts, terr) does not satisfy all
@@ -361,3 +361,49 @@ type MMapOpts struct {
// TODO(jamieliu): Replace entirely with MappingIdentity?
Hint string
}
+
+// File represents a host file that may be mapped into an platform.AddressSpace.
+type File interface {
+ // All pages in a File are reference-counted.
+
+ // IncRef increments the reference count on all pages in fr.
+ //
+ // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
+ // 0. At least one reference must be held on all pages in fr. (The File
+ // interface does not provide a way to acquire an initial reference;
+ // implementors may define mechanisms for doing so.)
+ IncRef(fr FileRange)
+
+ // DecRef decrements the reference count on all pages in fr.
+ //
+ // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
+ // 0. At least one reference must be held on all pages in fr.
+ DecRef(fr FileRange)
+
+ // MapInternal returns a mapping of the given file offsets in the invoking
+ // process' address space for reading and writing.
+ //
+ // Note that fr.Start and fr.End need not be page-aligned.
+ //
+ // Preconditions: fr.Length() > 0. At least one reference must be held on
+ // all pages in fr.
+ //
+ // Postconditions: The returned mapping is valid as long as at least one
+ // reference is held on the mapped pages.
+ MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error)
+
+ // FD returns the file descriptor represented by the File.
+ //
+ // The only permitted operation on the returned file descriptor is to map
+ // pages from it consistent with the requirements of AddressSpace.MapFile.
+ FD() int
+}
+
+// FileRange represents a range of uint64 offsets into a File.
+//
+// type FileRange <generated using go_generics>
+
+// String implements fmt.Stringer.String.
+func (fr FileRange) String() string {
+ return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End)
+}
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index a036ce53c..f9d0837a1 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -7,14 +7,14 @@ go_template_instance(
name = "file_refcount_set",
out = "file_refcount_set.go",
imports = {
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
},
package = "mm",
prefix = "fileRefcount",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
- "Range": "platform.FileRange",
+ "Range": "memmap.FileRange",
"Value": "int32",
"Functions": "fileRefcountSetFunctions",
},
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index 379148903..1999ec706 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -243,7 +242,7 @@ type aioMappable struct {
refs.AtomicRefCount
mfp pgalloc.MemoryFileProvider
- fr platform.FileRange
+ fr memmap.FileRange
}
var aioRingBufferSize = uint64(usermem.Addr(linux.AIORingSize).MustRoundUp())
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 6db7c3d40..3e85964e4 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -25,7 +25,7 @@
// Locks taken by memmap.Mappable.Translate
// mm.privateRefs.mu
// platform.AddressSpace locks
-// platform.File locks
+// memmap.File locks
// mm.aioManager.mu
// mm.AIOContext.mu
//
@@ -396,7 +396,7 @@ type pma struct {
// file is the file mapped by this pma. Only pmas for which file ==
// MemoryManager.mfp.MemoryFile() may be saved. pmas hold a reference to
// the corresponding file range while they exist.
- file platform.File `state:"nosave"`
+ file memmap.File `state:"nosave"`
// off is the offset into file at which this pma begins.
//
@@ -436,7 +436,7 @@ type pma struct {
private bool
// If internalMappings is not empty, it is the cached return value of
- // file.MapInternal for the platform.FileRange mapped by this pma.
+ // file.MapInternal for the memmap.FileRange mapped by this pma.
internalMappings safemem.BlockSeq `state:"nosave"`
}
@@ -469,10 +469,10 @@ func (fileRefcountSetFunctions) MaxKey() uint64 {
func (fileRefcountSetFunctions) ClearValue(_ *int32) {
}
-func (fileRefcountSetFunctions) Merge(_ platform.FileRange, rc1 int32, _ platform.FileRange, rc2 int32) (int32, bool) {
+func (fileRefcountSetFunctions) Merge(_ memmap.FileRange, rc1 int32, _ memmap.FileRange, rc2 int32) (int32, bool) {
return rc1, rc1 == rc2
}
-func (fileRefcountSetFunctions) Split(_ platform.FileRange, rc int32, _ uint64) (int32, int32) {
+func (fileRefcountSetFunctions) Split(_ memmap.FileRange, rc int32, _ uint64) (int32, int32) {
return rc, rc
}
diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go
index 62e4c20af..930ec895f 100644
--- a/pkg/sentry/mm/pma.go
+++ b/pkg/sentry/mm/pma.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/safecopy"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -604,7 +603,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat
}
}
-// Pin returns the platform.File ranges currently mapped by addresses in ar in
+// Pin returns the memmap.File ranges currently mapped by addresses in ar in
// mm, acquiring a reference on the returned ranges which the caller must
// release by calling Unpin. If not all addresses are mapped, Pin returns a
// non-nil error. Note that Pin may return both a non-empty slice of
@@ -674,15 +673,15 @@ type PinnedRange struct {
Source usermem.AddrRange
// File is the mapped file.
- File platform.File
+ File memmap.File
// Offset is the offset into File at which this PinnedRange begins.
Offset uint64
}
-// FileRange returns the platform.File offsets mapped by pr.
-func (pr PinnedRange) FileRange() platform.FileRange {
- return platform.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())}
+// FileRange returns the memmap.File offsets mapped by pr.
+func (pr PinnedRange) FileRange() memmap.FileRange {
+ return memmap.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())}
}
// Unpin releases the reference held by prs.
@@ -857,7 +856,7 @@ func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) saf
}
// incPrivateRef acquires a reference on private pages in fr.
-func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) {
+func (mm *MemoryManager) incPrivateRef(fr memmap.FileRange) {
mm.privateRefs.mu.Lock()
defer mm.privateRefs.mu.Unlock()
refSet := &mm.privateRefs.refs
@@ -878,8 +877,8 @@ func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) {
}
// decPrivateRef releases a reference on private pages in fr.
-func (mm *MemoryManager) decPrivateRef(fr platform.FileRange) {
- var freed []platform.FileRange
+func (mm *MemoryManager) decPrivateRef(fr memmap.FileRange) {
+ var freed []memmap.FileRange
mm.privateRefs.mu.Lock()
refSet := &mm.privateRefs.refs
@@ -951,7 +950,7 @@ func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRa
// Discard internal mappings instead of trying to merge them, since merging
// them requires an allocation and getting them again from the
- // platform.File might not.
+ // memmap.File might not.
pma1.internalMappings = safemem.BlockSeq{}
return pma1, true
}
@@ -1012,12 +1011,12 @@ func (pseg pmaIterator) getInternalMappingsLocked() error {
return nil
}
-func (pseg pmaIterator) fileRange() platform.FileRange {
+func (pseg pmaIterator) fileRange() memmap.FileRange {
return pseg.fileRangeOf(pseg.Range())
}
// Preconditions: pseg.Range().IsSupersetOf(ar). ar.Length != 0.
-func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange {
+func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange {
if checkInvariants {
if !pseg.Ok() {
panic("terminal pma iterator")
@@ -1032,5 +1031,5 @@ func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange {
pma := pseg.ValuePtr()
pstart := pseg.Start()
- return platform.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)}
+ return memmap.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)}
}
diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go
index 9ad52082d..0e142fb11 100644
--- a/pkg/sentry/mm/special_mappable.go
+++ b/pkg/sentry/mm/special_mappable.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -35,7 +34,7 @@ type SpecialMappable struct {
refs.AtomicRefCount
mfp pgalloc.MemoryFileProvider
- fr platform.FileRange
+ fr memmap.FileRange
name string
}
@@ -44,7 +43,7 @@ type SpecialMappable struct {
// SpecialMappable will use the given name in /proc/[pid]/maps.
//
// Preconditions: fr.Length() != 0.
-func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *SpecialMappable {
+func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *SpecialMappable {
m := SpecialMappable{mfp: mfp, fr: fr, name: name}
m.EnableLeakCheck("mm.SpecialMappable")
return &m
@@ -126,7 +125,7 @@ func (m *SpecialMappable) MemoryFileProvider() pgalloc.MemoryFileProvider {
// FileRange returns the offsets into MemoryFileProvider().MemoryFile() that
// store the SpecialMappable's contents.
-func (m *SpecialMappable) FileRange() platform.FileRange {
+func (m *SpecialMappable) FileRange() memmap.FileRange {
return m.fr
}
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index 1eeb9f317..7a3311a70 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -33,21 +33,42 @@ go_template_instance(
out = "usage_set.go",
consts = {
"minDegree": "10",
+ "trackGaps": "1",
},
imports = {
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
},
package = "pgalloc",
prefix = "usage",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
- "Range": "platform.FileRange",
+ "Range": "memmap.FileRange",
"Value": "usageInfo",
"Functions": "usageSetFunctions",
},
)
+go_template_instance(
+ name = "reclaim_set",
+ out = "reclaim_set.go",
+ consts = {
+ "minDegree": "10",
+ },
+ imports = {
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
+ },
+ package = "pgalloc",
+ prefix = "reclaim",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uint64",
+ "Range": "memmap.FileRange",
+ "Value": "reclaimSetValue",
+ "Functions": "reclaimSetFunctions",
+ },
+)
+
go_library(
name = "pgalloc",
srcs = [
@@ -56,6 +77,7 @@ go_library(
"evictable_range_set.go",
"pgalloc.go",
"pgalloc_unsafe.go",
+ "reclaim_set.go",
"save_restore.go",
"usage_set.go",
],
@@ -67,9 +89,10 @@ go_library(
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/hostmm",
- "//pkg/sentry/platform",
+ "//pkg/sentry/memmap",
"//pkg/sentry/usage",
"//pkg/state",
+ "//pkg/state/wire",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index 2b11ea4ae..3243d7214 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -33,14 +33,14 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/hostmm"
- "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
-// MemoryFile is a platform.File whose pages may be allocated to arbitrary
+// MemoryFile is a memmap.File whose pages may be allocated to arbitrary
// users.
type MemoryFile struct {
// opts holds options passed to NewMemoryFile. opts is immutable.
@@ -108,12 +108,6 @@ type MemoryFile struct {
usageSwapped uint64
usageLast time.Time
- // minUnallocatedPage is the minimum page that may be unallocated.
- // i.e., there are no unallocated pages below minUnallocatedPage.
- //
- // minUnallocatedPage is protected by mu.
- minUnallocatedPage uint64
-
// fileSize is the size of the backing memory file in bytes. fileSize is
// always a power-of-two multiple of chunkSize.
//
@@ -146,11 +140,9 @@ type MemoryFile struct {
// is protected by mu.
reclaimable bool
- // minReclaimablePage is the minimum page that may be reclaimable.
- // i.e., all reclaimable pages are >= minReclaimablePage.
- //
- // minReclaimablePage is protected by mu.
- minReclaimablePage uint64
+ // relcaim is the collection of regions for reclaim. relcaim is protected
+ // by mu.
+ reclaim reclaimSet
// reclaimCond is signaled (with mu locked) when reclaimable or destroyed
// transitions from false to true.
@@ -273,12 +265,10 @@ type evictableMemoryUserInfo struct {
}
const (
- chunkShift = 24
- chunkSize = 1 << chunkShift // 16 MB
+ chunkShift = 30
+ chunkSize = 1 << chunkShift // 1 GB
chunkMask = chunkSize - 1
- initialSize = chunkSize
-
// maxPage is the highest 64-bit page.
maxPage = math.MaxUint64 &^ (usermem.PageSize - 1)
)
@@ -302,19 +292,12 @@ func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) {
if err := file.Truncate(0); err != nil {
return nil, err
}
- if err := file.Truncate(initialSize); err != nil {
- return nil, err
- }
f := &MemoryFile{
- opts: opts,
- fileSize: initialSize,
- file: file,
- // No pages are reclaimable. DecRef will always be able to
- // decrease minReclaimablePage from this point.
- minReclaimablePage: maxPage,
- evictable: make(map[EvictableMemoryUser]*evictableMemoryUserInfo),
+ opts: opts,
+ file: file,
+ evictable: make(map[EvictableMemoryUser]*evictableMemoryUserInfo),
}
- f.mappings.Store(make([]uintptr, initialSize/chunkSize))
+ f.mappings.Store(make([]uintptr, 0))
f.reclaimCond.L = &f.mu
if f.opts.DelayedEviction == DelayedEvictionEnabled && f.opts.UseHostMemcgPressure {
@@ -389,7 +372,7 @@ func (f *MemoryFile) Destroy() {
// to Allocate.
//
// Preconditions: length must be page-aligned and non-zero.
-func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.FileRange, error) {
+func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.FileRange, error) {
if length == 0 || length%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid allocation length: %#x", length))
}
@@ -404,46 +387,36 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
alignment = usermem.HugePageSize
}
- start, minUnallocatedPage := findUnallocatedRange(&f.usage, f.minUnallocatedPage, length, alignment)
- end := start + length
- // File offsets are int64s. Since length must be strictly positive, end
- // cannot legitimately be 0.
- if end < start || int64(end) <= 0 {
- return platform.FileRange{}, syserror.ENOMEM
+ // Find a range in the underlying file.
+ fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment)
+ if !ok {
+ return memmap.FileRange{}, syserror.ENOMEM
}
- // Expand the file if needed. Double the file size on each expansion;
- // uncommitted pages have effectively no cost.
- fileSize := f.fileSize
- for int64(end) > fileSize {
- if fileSize >= 2*fileSize {
- // fileSize overflow.
- return platform.FileRange{}, syserror.ENOMEM
+ // Expand the file if needed.
+ if int64(fr.End) > f.fileSize {
+ // Round the new file size up to be chunk-aligned.
+ newFileSize := (int64(fr.End) + chunkMask) &^ chunkMask
+ if err := f.file.Truncate(newFileSize); err != nil {
+ return memmap.FileRange{}, err
}
- fileSize *= 2
- }
- if fileSize > f.fileSize {
- if err := f.file.Truncate(fileSize); err != nil {
- return platform.FileRange{}, err
- }
- f.fileSize = fileSize
+ f.fileSize = newFileSize
f.mappingsMu.Lock()
oldMappings := f.mappings.Load().([]uintptr)
- newMappings := make([]uintptr, fileSize>>chunkShift)
+ newMappings := make([]uintptr, newFileSize>>chunkShift)
copy(newMappings, oldMappings)
f.mappings.Store(newMappings)
f.mappingsMu.Unlock()
}
// Mark selected pages as in use.
- fr := platform.FileRange{start, end}
if f.opts.ManualZeroing {
if err := f.forEachMappingSlice(fr, func(bs []byte) {
for i := range bs {
bs[i] = 0
}
}); err != nil {
- return platform.FileRange{}, err
+ return memmap.FileRange{}, err
}
}
if !f.usage.Add(fr, usageInfo{
@@ -453,49 +426,79 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
panic(fmt.Sprintf("allocating %v: failed to insert into usage set:\n%v", fr, &f.usage))
}
- if minUnallocatedPage < start {
- f.minUnallocatedPage = minUnallocatedPage
- } else {
- // start was the first unallocated page. The next must be
- // somewhere beyond end.
- f.minUnallocatedPage = end
- }
-
return fr, nil
}
-// findUnallocatedRange returns the first unallocated page in usage of the
-// specified length and alignment beginning at page start and the first single
-// unallocated page.
-func findUnallocatedRange(usage *usageSet, start, length, alignment uint64) (uint64, uint64) {
- // Only searched until the first page is found.
- firstPage := start
- foundFirstPage := false
- alignMask := alignment - 1
- for seg := usage.LowerBoundSegment(start); seg.Ok(); seg = seg.NextSegment() {
- r := seg.Range()
-
- if !foundFirstPage && r.Start > firstPage {
- foundFirstPage = true
+// findAvailableRange returns an available range in the usageSet.
+//
+// Note that scanning for available slots takes place from end first backwards,
+// then forwards. This heuristic has important consequence for how sequential
+// mappings can be merged in the host VMAs, given that addresses for both
+// application and sentry mappings are allocated top-down (from higher to
+// lower addresses). The file is also grown expoentially in order to create
+// space for mappings to be allocated downwards.
+//
+// Precondition: alignment must be a power of 2.
+func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (memmap.FileRange, bool) {
+ alignmentMask := alignment - 1
+
+ // Search for space in existing gaps, starting at the current end of the
+ // file and working backward.
+ lastGap := usage.LastGap()
+ gap := lastGap
+ for {
+ end := gap.End()
+ if end > uint64(fileSize) {
+ end = uint64(fileSize)
}
- if start >= r.End {
- // start was rounded up to an alignment boundary from the end
- // of a previous segment and is now beyond r.End.
- continue
+ // Try to allocate from the end of this gap, with the start of the
+ // allocated range aligned down to alignment.
+ unalignedStart := end - length
+ if unalignedStart > end {
+ // Negative overflow: this and all preceding gaps are too small to
+ // accommodate length.
+ break
}
- // This segment represents allocated or reclaimable pages; only the
- // range from start to the segment's beginning is allocatable, and the
- // next allocatable range begins after the segment.
- if r.Start > start && r.Start-start >= length {
+ if start := unalignedStart &^ alignmentMask; start >= gap.Start() {
+ return memmap.FileRange{start, start + length}, true
+ }
+
+ gap = gap.PrevLargeEnoughGap(length)
+ if !gap.Ok() {
break
}
- start = (r.End + alignMask) &^ alignMask
- if !foundFirstPage {
- firstPage = r.End
+ }
+
+ // Check that it's possible to fit this allocation at the end of a file of any size.
+ min := lastGap.Start()
+ min = (min + alignmentMask) &^ alignmentMask
+ if min+length < min {
+ // Overflow: allocation would exceed the range of uint64.
+ return memmap.FileRange{}, false
+ }
+
+ // Determine the minimum file size required to fit this allocation at its end.
+ for {
+ newFileSize := 2 * fileSize
+ if newFileSize <= fileSize {
+ if fileSize != 0 {
+ // Overflow: allocation would exceed the range of int64.
+ return memmap.FileRange{}, false
+ }
+ newFileSize = chunkSize
+ }
+ fileSize = newFileSize
+
+ unalignedStart := uint64(fileSize) - length
+ if unalignedStart > uint64(fileSize) {
+ // Negative overflow: fileSize is still inadequate.
+ continue
+ }
+ if start := unalignedStart &^ alignmentMask; start >= min {
+ return memmap.FileRange{start, start + length}, true
}
}
- return start, firstPage
}
// AllocateAndFill allocates memory of the given kind and fills it by calling
@@ -505,22 +508,22 @@ func findUnallocatedRange(usage *usageSet, start, length, alignment uint64) (uin
// by r.ReadToBlocks(), it returns that error.
//
// Preconditions: length > 0. length must be page-aligned.
-func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (platform.FileRange, error) {
+func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (memmap.FileRange, error) {
fr, err := f.Allocate(length, kind)
if err != nil {
- return platform.FileRange{}, err
+ return memmap.FileRange{}, err
}
dsts, err := f.MapInternal(fr, usermem.Write)
if err != nil {
f.DecRef(fr)
- return platform.FileRange{}, err
+ return memmap.FileRange{}, err
}
n, err := safemem.ReadFullToBlocks(r, dsts)
un := uint64(usermem.Addr(n).RoundDown())
if un < length {
// Free unused memory and update fr to contain only the memory that is
// still allocated.
- f.DecRef(platform.FileRange{fr.Start + un, fr.End})
+ f.DecRef(memmap.FileRange{fr.Start + un, fr.End})
fr.End = fr.Start + un
}
return fr, err
@@ -537,7 +540,7 @@ const (
// will read zeroes.
//
// Preconditions: fr.Length() > 0.
-func (f *MemoryFile) Decommit(fr platform.FileRange) error {
+func (f *MemoryFile) Decommit(fr memmap.FileRange) error {
if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -557,7 +560,7 @@ func (f *MemoryFile) Decommit(fr platform.FileRange) error {
return nil
}
-func (f *MemoryFile) markDecommitted(fr platform.FileRange) {
+func (f *MemoryFile) markDecommitted(fr memmap.FileRange) {
f.mu.Lock()
defer f.mu.Unlock()
// Since we're changing the knownCommitted attribute, we need to merge
@@ -578,8 +581,8 @@ func (f *MemoryFile) markDecommitted(fr platform.FileRange) {
f.usage.MergeRange(fr)
}
-// IncRef implements platform.File.IncRef.
-func (f *MemoryFile) IncRef(fr platform.FileRange) {
+// IncRef implements memmap.File.IncRef.
+func (f *MemoryFile) IncRef(fr memmap.FileRange) {
if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -597,8 +600,8 @@ func (f *MemoryFile) IncRef(fr platform.FileRange) {
f.usage.MergeAdjacent(fr)
}
-// DecRef implements platform.File.DecRef.
-func (f *MemoryFile) DecRef(fr platform.FileRange) {
+// DecRef implements memmap.File.DecRef.
+func (f *MemoryFile) DecRef(fr memmap.FileRange) {
if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -616,6 +619,7 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) {
}
val.refs--
if val.refs == 0 {
+ f.reclaim.Add(seg.Range(), reclaimSetValue{})
freed = true
// Reclassify memory as System, until it's freed by the reclaim
// goroutine.
@@ -628,17 +632,13 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) {
f.usage.MergeAdjacent(fr)
if freed {
- if fr.Start < f.minReclaimablePage {
- // We've freed at least one lower page.
- f.minReclaimablePage = fr.Start
- }
f.reclaimable = true
f.reclaimCond.Signal()
}
}
-// MapInternal implements platform.File.MapInternal.
-func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// MapInternal implements memmap.File.MapInternal.
+func (f *MemoryFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
if !fr.WellFormed() || fr.Length() == 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -664,7 +664,7 @@ func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (
// forEachMappingSlice invokes fn on a sequence of byte slices that
// collectively map all bytes in fr.
-func (f *MemoryFile) forEachMappingSlice(fr platform.FileRange, fn func([]byte)) error {
+func (f *MemoryFile) forEachMappingSlice(fr memmap.FileRange, fn func([]byte)) error {
mappings := f.mappings.Load().([]uintptr)
for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize {
chunk := int(chunkStart >> chunkShift)
@@ -944,7 +944,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(
continue
case !populated && populatedRun:
// Finish the run by changing this segment.
- runRange := platform.FileRange{
+ runRange := memmap.FileRange{
Start: r.Start + uint64(populatedRunStart*usermem.PageSize),
End: r.Start + uint64(i*usermem.PageSize),
}
@@ -1009,7 +1009,7 @@ func (f *MemoryFile) File() *os.File {
return f.file
}
-// FD implements platform.File.FD.
+// FD implements memmap.File.FD.
func (f *MemoryFile) FD() int {
return int(f.file.Fd())
}
@@ -1030,6 +1030,7 @@ func (f *MemoryFile) String() string {
// for allocation.
func (f *MemoryFile) runReclaim() {
for {
+ // N.B. We must call f.markReclaimed on the returned FrameRange.
fr, ok := f.findReclaimable()
if !ok {
break
@@ -1085,13 +1086,17 @@ func (f *MemoryFile) runReclaim() {
}
}
-func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
+// findReclaimable finds memory that has been marked for reclaim.
+//
+// Note that there returned range will be removed from tracking. It
+// must be reclaimed (removed from f.usage) at this point.
+func (f *MemoryFile) findReclaimable() (memmap.FileRange, bool) {
f.mu.Lock()
defer f.mu.Unlock()
for {
for {
if f.destroyed {
- return platform.FileRange{}, false
+ return memmap.FileRange{}, false
}
if f.reclaimable {
break
@@ -1103,27 +1108,24 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
}
f.reclaimCond.Wait()
}
- // Allocate returns the first usable range in offset order and is
- // currently a linear scan, so reclaiming from the beginning of the
- // file minimizes the expected latency of Allocate.
- for seg := f.usage.LowerBoundSegment(f.minReclaimablePage); seg.Ok(); seg = seg.NextSegment() {
- if seg.ValuePtr().refs == 0 {
- f.minReclaimablePage = seg.End()
- return seg.Range(), true
- }
+ // Allocate works from the back of the file inwards, so reclaim
+ // preserves this order to minimize the cost of the search.
+ if seg := f.reclaim.LastSegment(); seg.Ok() {
+ fr := seg.Range()
+ f.reclaim.Remove(seg)
+ return fr, true
}
- // No pages are reclaimable.
+ // Nothing is reclaimable.
f.reclaimable = false
- f.minReclaimablePage = maxPage
}
}
-func (f *MemoryFile) markReclaimed(fr platform.FileRange) {
+func (f *MemoryFile) markReclaimed(fr memmap.FileRange) {
f.mu.Lock()
defer f.mu.Unlock()
seg := f.usage.FindSegment(fr.Start)
- // All of fr should be mapped to a single uncommitted reclaimable segment
- // accounted to System.
+ // All of fr should be mapped to a single uncommitted reclaimable
+ // segment accounted to System.
if !seg.Ok() {
panic(fmt.Sprintf("reclaimed pages %v include unreferenced pages:\n%v", fr, &f.usage))
}
@@ -1137,14 +1139,10 @@ func (f *MemoryFile) markReclaimed(fr platform.FileRange) {
}); got != want {
panic(fmt.Sprintf("reclaimed pages %v in segment %v has incorrect state %v, wanted %v:\n%v", fr, seg.Range(), got, want, &f.usage))
}
- // Deallocate reclaimed pages. Even though all of seg is reclaimable, the
- // caller of markReclaimed may not have decommitted it, so we can only mark
- // fr as reclaimed.
+ // Deallocate reclaimed pages. Even though all of seg is reclaimable,
+ // the caller of markReclaimed may not have decommitted it, so we can
+ // only mark fr as reclaimed.
f.usage.Remove(f.usage.Isolate(seg, fr))
- if fr.Start < f.minUnallocatedPage {
- // We've deallocated at least one lower page.
- f.minUnallocatedPage = fr.Start
- }
}
// StartEvictions requests that f evict all evictable allocations. It does not
@@ -1224,11 +1222,11 @@ func (usageSetFunctions) MaxKey() uint64 {
func (usageSetFunctions) ClearValue(val *usageInfo) {
}
-func (usageSetFunctions) Merge(_ platform.FileRange, val1 usageInfo, _ platform.FileRange, val2 usageInfo) (usageInfo, bool) {
+func (usageSetFunctions) Merge(_ memmap.FileRange, val1 usageInfo, _ memmap.FileRange, val2 usageInfo) (usageInfo, bool) {
return val1, val1 == val2
}
-func (usageSetFunctions) Split(_ platform.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) {
+func (usageSetFunctions) Split(_ memmap.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) {
return val, val
}
@@ -1255,3 +1253,27 @@ func (evictableRangeSetFunctions) Merge(_ EvictableRange, _ evictableRangeSetVal
func (evictableRangeSetFunctions) Split(_ EvictableRange, _ evictableRangeSetValue, _ uint64) (evictableRangeSetValue, evictableRangeSetValue) {
return evictableRangeSetValue{}, evictableRangeSetValue{}
}
+
+// reclaimSetValue is the value type of reclaimSet.
+type reclaimSetValue struct{}
+
+type reclaimSetFunctions struct{}
+
+func (reclaimSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+func (reclaimSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+func (reclaimSetFunctions) ClearValue(val *reclaimSetValue) {
+}
+
+func (reclaimSetFunctions) Merge(_ memmap.FileRange, _ reclaimSetValue, _ memmap.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) {
+ return reclaimSetValue{}, true
+}
+
+func (reclaimSetFunctions) Split(_ memmap.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) {
+ return reclaimSetValue{}, reclaimSetValue{}
+}
diff --git a/pkg/sentry/pgalloc/pgalloc_test.go b/pkg/sentry/pgalloc/pgalloc_test.go
index 293f22c6b..405db141f 100644
--- a/pkg/sentry/pgalloc/pgalloc_test.go
+++ b/pkg/sentry/pgalloc/pgalloc_test.go
@@ -23,39 +23,49 @@ import (
const (
page = usermem.PageSize
hugepage = usermem.HugePageSize
+ topPage = (1 << 63) - page
)
func TestFindUnallocatedRange(t *testing.T) {
for _, test := range []struct {
- desc string
- usage *usageSegmentDataSlices
- start uint64
- length uint64
- alignment uint64
- unallocated uint64
- minUnallocated uint64
+ desc string
+ usage *usageSegmentDataSlices
+ fileSize int64
+ length uint64
+ alignment uint64
+ start uint64
+ expectFail bool
}{
{
- desc: "Initial allocation succeeds",
- usage: &usageSegmentDataSlices{},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 0,
- minUnallocated: 0,
+ desc: "Initial allocation succeeds",
+ usage: &usageSegmentDataSlices{},
+ length: page,
+ alignment: page,
+ start: chunkSize - page, // Grows by chunkSize, allocate down.
},
{
- desc: "Allocation begins at start of file",
+ desc: "Allocation finds empty space at start of file",
usage: &usageSegmentDataSlices{
Start: []uint64{page},
End: []uint64{2 * page},
Values: []usageInfo{{refs: 1}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 0,
- minUnallocated: 0,
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: 0,
+ },
+ {
+ desc: "Allocation finds empty space at end of file",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0},
+ End: []uint64{page},
+ Values: []usageInfo{{refs: 1}},
+ },
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: page,
},
{
desc: "In-use frames are not allocatable",
@@ -64,11 +74,10 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 2 * page},
Values: []usageInfo{{refs: 1}, {refs: 2}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 2 * page,
- minUnallocated: 2 * page,
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: 3 * page, // Double fileSize, allocate top-down.
},
{
desc: "Reclaimable frames are not allocatable",
@@ -77,11 +86,10 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 2 * page, 3 * page},
Values: []usageInfo{{refs: 1}, {refs: 0}, {refs: 1}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 3 * page,
- minUnallocated: 3 * page,
+ fileSize: 3 * page,
+ length: page,
+ alignment: page,
+ start: 5 * page, // Double fileSize, grow down.
},
{
desc: "Gaps between in-use frames are allocatable",
@@ -90,11 +98,10 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 3 * page},
Values: []usageInfo{{refs: 1}, {refs: 1}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: page,
- minUnallocated: page,
+ fileSize: 3 * page,
+ length: page,
+ alignment: page,
+ start: page,
},
{
desc: "Inadequately-sized gaps are rejected",
@@ -103,14 +110,13 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 3 * page},
Values: []usageInfo{{refs: 1}, {refs: 1}},
},
- start: 0,
- length: 2 * page,
- alignment: page,
- unallocated: 3 * page,
- minUnallocated: page,
+ fileSize: 3 * page,
+ length: 2 * page,
+ alignment: page,
+ start: 4 * page, // Double fileSize, grow down.
},
{
- desc: "Hugepage alignment is honored",
+ desc: "Alignment is honored at end of file",
usage: &usageSegmentDataSlices{
Start: []uint64{0, hugepage + page},
// Hugepage-sized gap here that shouldn't be allocated from
@@ -118,37 +124,103 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, hugepage + 2*page},
Values: []usageInfo{{refs: 1}, {refs: 1}},
},
- start: 0,
- length: hugepage,
- alignment: hugepage,
- unallocated: 2 * hugepage,
- minUnallocated: page,
+ fileSize: hugepage + 2*page,
+ length: hugepage,
+ alignment: hugepage,
+ start: 3 * hugepage, // Double fileSize until alignment is satisfied, grow down.
},
{
- desc: "Pages before start ignored",
+ desc: "Alignment is honored before end of file",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0, 2*hugepage + page},
+ // Page will need to be shifted down from top.
+ End: []uint64{page, 2*hugepage + 2*page},
+ Values: []usageInfo{{refs: 1}, {refs: 1}},
+ },
+ fileSize: 2*hugepage + 2*page,
+ length: hugepage,
+ alignment: hugepage,
+ start: hugepage,
+ },
+ {
+ desc: "Allocation doubles file size more than once if necessary",
+ usage: &usageSegmentDataSlices{},
+ fileSize: page,
+ length: 4 * page,
+ alignment: page,
+ start: 0,
+ },
+ {
+ desc: "Allocations are compact if possible",
usage: &usageSegmentDataSlices{
Start: []uint64{page, 3 * page},
End: []uint64{2 * page, 4 * page},
Values: []usageInfo{{refs: 1}, {refs: 2}},
},
- start: page,
- length: page,
- alignment: page,
- unallocated: 2 * page,
- minUnallocated: 2 * page,
+ fileSize: 4 * page,
+ length: page,
+ alignment: page,
+ start: 2 * page,
+ },
+ {
+ desc: "Top-down allocation within one gap",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, 4 * page, 7 * page},
+ End: []uint64{2 * page, 5 * page, 8 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}},
+ },
+ fileSize: 8 * page,
+ length: page,
+ alignment: page,
+ start: 6 * page,
+ },
+ {
+ desc: "Top-down allocation between multiple gaps",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, 3 * page, 5 * page},
+ End: []uint64{2 * page, 4 * page, 6 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}},
+ },
+ fileSize: 6 * page,
+ length: page,
+ alignment: page,
+ start: 4 * page,
},
{
- desc: "start may be in the middle of segment",
+ desc: "Top-down allocation with large top gap",
usage: &usageSegmentDataSlices{
- Start: []uint64{0, 3 * page},
+ Start: []uint64{page, 3 * page},
End: []uint64{2 * page, 4 * page},
Values: []usageInfo{{refs: 1}, {refs: 2}},
},
- start: page,
- length: page,
- alignment: page,
- unallocated: 2 * page,
- minUnallocated: 2 * page,
+ fileSize: 8 * page,
+ length: page,
+ alignment: page,
+ start: 7 * page,
+ },
+ {
+ desc: "Gaps found with possible overflow",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, topPage - page},
+ End: []uint64{2 * page, topPage},
+ Values: []usageInfo{{refs: 1}, {refs: 1}},
+ },
+ fileSize: topPage,
+ length: page,
+ alignment: page,
+ start: topPage - 2*page,
+ },
+ {
+ desc: "Overflow detected",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page},
+ End: []uint64{topPage},
+ Values: []usageInfo{{refs: 1}},
+ },
+ fileSize: topPage,
+ length: 2 * page,
+ alignment: page,
+ expectFail: true,
},
} {
t.Run(test.desc, func(t *testing.T) {
@@ -156,12 +228,18 @@ func TestFindUnallocatedRange(t *testing.T) {
if err := usage.ImportSortedSlices(test.usage); err != nil {
t.Fatalf("Failed to initialize usage from %v: %v", test.usage, err)
}
- unallocated, minUnallocated := findUnallocatedRange(&usage, test.start, test.length, test.alignment)
- if unallocated != test.unallocated {
- t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got unallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, unallocated, test.unallocated)
+ fr, ok := findAvailableRange(&usage, test.fileSize, test.length, test.alignment)
+ if !test.expectFail && !ok {
+ t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, false wanted %x, true", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
+ }
+ if test.expectFail && ok {
+ t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, true wanted %x, false", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
+ }
+ if ok && fr.Start != test.start {
+ t.Errorf("findAvailableRange(%v, %x, %x, %x): got start=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
}
- if minUnallocated != test.minUnallocated {
- t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got minUnallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, minUnallocated, test.minUnallocated)
+ if ok && fr.End != test.start+test.length {
+ t.Errorf("findAvailableRange(%v, %x, %x, %x): got end=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.End, test.start+test.length)
}
})
}
diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go
index f8385c146..78317fa35 100644
--- a/pkg/sentry/pgalloc/save_restore.go
+++ b/pkg/sentry/pgalloc/save_restore.go
@@ -26,11 +26,12 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
"gvisor.dev/gvisor/pkg/usermem"
)
// SaveTo writes f's state to the given stream.
-func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error {
+func (f *MemoryFile) SaveTo(ctx context.Context, w wire.Writer) error {
// Wait for reclaim.
f.mu.Lock()
defer f.mu.Unlock()
@@ -79,10 +80,10 @@ func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error {
}
// Save metadata.
- if err := state.Save(ctx, w, &f.fileSize, nil); err != nil {
+ if _, err := state.Save(ctx, w, &f.fileSize); err != nil {
return err
}
- if err := state.Save(ctx, w, &f.usage, nil); err != nil {
+ if _, err := state.Save(ctx, w, &f.usage); err != nil {
return err
}
@@ -115,9 +116,9 @@ func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error {
}
// LoadFrom loads MemoryFile state from the given stream.
-func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error {
+func (f *MemoryFile) LoadFrom(ctx context.Context, r wire.Reader) error {
// Load metadata.
- if err := state.Load(ctx, r, &f.fileSize, nil); err != nil {
+ if _, err := state.Load(ctx, r, &f.fileSize); err != nil {
return err
}
if err := f.file.Truncate(f.fileSize); err != nil {
@@ -125,7 +126,7 @@ func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error {
}
newMappings := make([]uintptr, f.fileSize>>chunkShift)
f.mappings.Store(newMappings)
- if err := state.Load(ctx, r, &f.usage, nil); err != nil {
+ if _, err := state.Load(ctx, r, &f.usage); err != nil {
return err
}
diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD
index 453241eca..209b28053 100644
--- a/pkg/sentry/platform/BUILD
+++ b/pkg/sentry/platform/BUILD
@@ -1,39 +1,21 @@
load("//tools:defs.bzl", "go_library")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-go_template_instance(
- name = "file_range",
- out = "file_range.go",
- package = "platform",
- prefix = "File",
- template = "//pkg/segment:generic_range",
- types = {
- "T": "uint64",
- },
-)
-
go_library(
name = "platform",
srcs = [
"context.go",
- "file_range.go",
"mmap_min_addr.go",
"platform.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/atomicbitops",
"//pkg/context",
- "//pkg/log",
- "//pkg/safecopy",
- "//pkg/safemem",
"//pkg/seccomp",
"//pkg/sentry/arch",
- "//pkg/sentry/usage",
- "//pkg/syserror",
+ "//pkg/sentry/memmap",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 159f7eafd..b5d27a72a 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -6,8 +6,8 @@ go_library(
name = "kvm",
srcs = [
"address_space.go",
- "allocator.go",
"bluepill.go",
+ "bluepill_allocator.go",
"bluepill_amd64.go",
"bluepill_amd64.s",
"bluepill_amd64_unsafe.go",
@@ -47,6 +47,7 @@ go_library(
"//pkg/safecopy",
"//pkg/seccomp",
"//pkg/sentry/arch",
+ "//pkg/sentry/memmap",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
"//pkg/sentry/platform/ring0",
@@ -60,6 +61,7 @@ go_library(
go_test(
name = "kvm_test",
srcs = [
+ "kvm_amd64_test.go",
"kvm_test.go",
"virtual_map_test.go",
],
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
index be213bfe8..98a3e539d 100644
--- a/pkg/sentry/platform/kvm/address_space.go
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -18,6 +18,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/atomicbitops"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sync"
@@ -26,16 +27,15 @@ import (
// dirtySet tracks vCPUs for invalidation.
type dirtySet struct {
- vCPUs []uint64
+ vCPUMasks []uint64
}
// forEach iterates over all CPUs in the dirty set.
+//
+//go:nosplit
func (ds *dirtySet) forEach(m *machine, fn func(c *vCPU)) {
- m.mu.RLock()
- defer m.mu.RUnlock()
-
- for index := range ds.vCPUs {
- mask := atomic.SwapUint64(&ds.vCPUs[index], 0)
+ for index := range ds.vCPUMasks {
+ mask := atomic.SwapUint64(&ds.vCPUMasks[index], 0)
if mask != 0 {
for bit := 0; bit < 64; bit++ {
if mask&(1<<uint64(bit)) == 0 {
@@ -54,7 +54,7 @@ func (ds *dirtySet) mark(c *vCPU) bool {
index := uint64(c.id) / 64
bit := uint64(1) << uint(c.id%64)
- oldValue := atomic.LoadUint64(&ds.vCPUs[index])
+ oldValue := atomic.LoadUint64(&ds.vCPUMasks[index])
if oldValue&bit != 0 {
return false // Not clean.
}
@@ -62,7 +62,7 @@ func (ds *dirtySet) mark(c *vCPU) bool {
// Set the bit unilaterally, and ensure that a flush takes place. Note
// that it's possible for races to occur here, but since the flush is
// taking place long after these lines there's no race in practice.
- atomicbitops.OrUint64(&ds.vCPUs[index], bit)
+ atomicbitops.OrUint64(&ds.vCPUMasks[index], bit)
return true // Previously clean.
}
@@ -113,7 +113,12 @@ type hostMapEntry struct {
length uintptr
}
-func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) {
+// mapLocked maps the given host entry.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) {
for m.length > 0 {
physical, length, ok := translateToPhysical(m.addr)
if !ok {
@@ -133,18 +138,10 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac
// important; if the pagetable mappings were installed before
// ensuring the physical pages were available, then some other
// thread could theoretically access them.
- //
- // Due to the way KVM's shadow paging implementation works,
- // modifications to the page tables while in host mode may not
- // be trapped, leading to the shadow pages being out of sync.
- // Therefore, we need to ensure that we are in guest mode for
- // page table modifications. See the call to bluepill, below.
- as.machine.retryInGuest(func() {
- inv = as.pageTables.Map(addr, length, pagetables.MapOpts{
- AccessType: at,
- User: true,
- }, physical) || inv
- })
+ inv = as.pageTables.Map(addr, length, pagetables.MapOpts{
+ AccessType: at,
+ User: true,
+ }, physical) || inv
m.addr += length
m.length -= length
addr += usermem.Addr(length)
@@ -154,7 +151,7 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac
}
// MapFile implements platform.AddressSpace.MapFile.
-func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error {
+func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error {
as.mu.Lock()
defer as.mu.Unlock()
@@ -176,6 +173,10 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.
return err
}
+ // See block in mapLocked.
+ as.pageTables.Allocator.(*allocator).cpu = as.machine.Get()
+ defer as.machine.Put(as.pageTables.Allocator.(*allocator).cpu)
+
// Map the mappings in the sentry's address space (guest physical memory)
// into the application's address space (guest virtual memory).
inv := false
@@ -190,7 +191,12 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.
_ = s[i] // Touch to commit.
}
}
- prev := as.mapHost(addr, hostMapEntry{
+
+ // See bluepill_allocator.go.
+ bluepill(as.pageTables.Allocator.(*allocator).cpu)
+
+ // Perform the mapping.
+ prev := as.mapLocked(addr, hostMapEntry{
addr: b.Addr(),
length: uintptr(b.Len()),
}, at)
@@ -204,17 +210,27 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.
return nil
}
+// unmapLocked is an escape-checked wrapped around Unmap.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func (as *addressSpace) unmapLocked(addr usermem.Addr, length uint64) bool {
+ return as.pageTables.Unmap(addr, uintptr(length))
+}
+
// Unmap unmaps the given range by calling pagetables.PageTables.Unmap.
func (as *addressSpace) Unmap(addr usermem.Addr, length uint64) {
as.mu.Lock()
defer as.mu.Unlock()
- // See above re: retryInGuest.
- var prev bool
- as.machine.retryInGuest(func() {
- prev = as.pageTables.Unmap(addr, uintptr(length)) || prev
- })
- if prev {
+ // See above & bluepill_allocator.go.
+ as.pageTables.Allocator.(*allocator).cpu = as.machine.Get()
+ defer as.machine.Put(as.pageTables.Allocator.(*allocator).cpu)
+ bluepill(as.pageTables.Allocator.(*allocator).cpu)
+
+ if prev := as.unmapLocked(addr, length); prev {
+ // Invalidate all active vCPUs.
as.invalidate()
// Recycle any freed intermediate pages.
@@ -227,7 +243,7 @@ func (as *addressSpace) Release() {
as.Unmap(0, ^uint64(0))
// Free all pages from the allocator.
- as.pageTables.Allocator.(allocator).base.Drain()
+ as.pageTables.Allocator.(*allocator).base.Drain()
// Drop all cached machine references.
as.machine.dropPageTables(as.pageTables)
diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/bluepill_allocator.go
index 3f35414bb..9485e1301 100644
--- a/pkg/sentry/platform/kvm/allocator.go
+++ b/pkg/sentry/platform/kvm/bluepill_allocator.go
@@ -21,56 +21,80 @@ import (
)
type allocator struct {
- base *pagetables.RuntimeAllocator
+ base pagetables.RuntimeAllocator
+
+ // cpu must be set prior to any pagetable operation.
+ //
+ // Due to the way KVM's shadow paging implementation works,
+ // modifications to the page tables while in host mode may not be
+ // trapped, leading to the shadow pages being out of sync. Therefore,
+ // we need to ensure that we are in guest mode for page table
+ // modifications. See the call to bluepill, below.
+ cpu *vCPU
}
// newAllocator is used to define the allocator.
-func newAllocator() allocator {
- return allocator{
- base: pagetables.NewRuntimeAllocator(),
- }
+func newAllocator() *allocator {
+ a := new(allocator)
+ a.base.Init()
+ return a
}
// NewPTEs implements pagetables.Allocator.NewPTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) NewPTEs() *pagetables.PTEs {
- return a.base.NewPTEs()
+func (a *allocator) NewPTEs() *pagetables.PTEs {
+ ptes := a.base.NewPTEs() // escapes: bluepill below.
+ if a.cpu != nil {
+ bluepill(a.cpu)
+ }
+ return ptes
}
// PhysicalFor returns the physical address for a set of PTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr {
+func (a *allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr {
virtual := a.base.PhysicalFor(ptes)
physical, _, ok := translateToPhysical(virtual)
if !ok {
- panic(fmt.Sprintf("PhysicalFor failed for %p", ptes))
+ panic(fmt.Sprintf("PhysicalFor failed for %p", ptes)) // escapes: panic.
}
return physical
}
// LookupPTEs implements pagetables.Allocator.LookupPTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs {
+func (a *allocator) LookupPTEs(physical uintptr) *pagetables.PTEs {
virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions)
if !ok {
- panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical))
+ panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) // escapes: panic.
}
return a.base.LookupPTEs(virtualStart + (physical - physicalStart))
}
// FreePTEs implements pagetables.Allocator.FreePTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) FreePTEs(ptes *pagetables.PTEs) {
- a.base.FreePTEs(ptes)
+func (a *allocator) FreePTEs(ptes *pagetables.PTEs) {
+ a.base.FreePTEs(ptes) // escapes: bluepill below.
+ if a.cpu != nil {
+ bluepill(a.cpu)
+ }
}
// Recycle implements pagetables.Allocator.Recycle.
//
//go:nosplit
-func (a allocator) Recycle() {
+func (a *allocator) Recycle() {
a.base.Recycle()
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go
index 133c2203d..ddc1554d5 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.go
@@ -63,6 +63,8 @@ func bluepillArchEnter(context *arch.SignalContext64) *vCPU {
// KernelSyscall handles kernel syscalls.
//
+// +checkescape:all
+//
//go:nosplit
func (c *vCPU) KernelSyscall() {
regs := c.Registers()
@@ -72,13 +74,15 @@ func (c *vCPU) KernelSyscall() {
// We only trigger a bluepill entry in the bluepill function, and can
// therefore be guaranteed that there is no floating point state to be
// loaded on resuming from halt. We only worry about saving on exit.
- ring0.SaveFloatingPoint((*byte)(c.floatingPointState))
+ ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no.
ring0.Halt()
- ring0.WriteFS(uintptr(regs.Fs_base)) // Reload host segment.
+ ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no, reload host segment.
}
// KernelException handles kernel exceptions.
//
+// +checkescape:all
+//
//go:nosplit
func (c *vCPU) KernelException(vector ring0.Vector) {
regs := c.Registers()
@@ -89,9 +93,9 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
regs.Rip = 0
}
// See above.
- ring0.SaveFloatingPoint((*byte)(c.floatingPointState))
+ ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no.
ring0.Halt()
- ring0.WriteFS(uintptr(regs.Fs_base)) // Reload host segment.
+ ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no; reload host segment.
}
// bluepillArchExit is called during bluepillEnter.
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index 99cac665d..03a98512e 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -17,6 +17,7 @@
package kvm
import (
+ "syscall"
"unsafe"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -53,3 +54,34 @@ func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
context.Rbx = uint64(uintptr(unsafe.Pointer(c)))
context.Rip = uint64(dieTrampolineAddr)
}
+
+// getHypercallID returns hypercall ID.
+//
+//go:nosplit
+func getHypercallID(addr uintptr) int {
+ return _KVM_HYPERCALL_MAX
+}
+
+// bluepillStopGuest is reponsible for injecting interrupt.
+//
+//go:nosplit
+func bluepillStopGuest(c *vCPU) {
+ // Interrupt: we must have requested an interrupt
+ // window; set the interrupt line.
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_INTERRUPT,
+ uintptr(unsafe.Pointer(&bounce))); errno != 0 {
+ throw("interrupt injection failed")
+ }
+ // Clear previous injection request.
+ c.runData.requestInterruptWindow = 0
+}
+
+// bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection.
+//
+//go:nosplit
+func bluepillReadyStopGuest(c *vCPU) bool {
+ return c.runData.readyForInterruptInjection != 0
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index c215d443c..dba563160 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -26,6 +26,17 @@ import (
var (
// The action for bluepillSignal is changed by sigaction().
bluepillSignal = syscall.SIGILL
+
+ // vcpuSErr is the event of system error.
+ vcpuSErr = kvmVcpuEvents{
+ exception: exception{
+ sErrPending: 1,
+ sErrHasEsr: 0,
+ pad: [6]uint8{0, 0, 0, 0, 0, 0},
+ sErrEsr: 1,
+ },
+ rsvd: [12]uint32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ }
)
// bluepillArchEnter is called during bluepillEnter.
@@ -66,6 +77,8 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
// KernelSyscall handles kernel syscalls.
//
+// +checkescape:all
+//
//go:nosplit
func (c *vCPU) KernelSyscall() {
regs := c.Registers()
@@ -88,6 +101,8 @@ func (c *vCPU) KernelSyscall() {
// KernelException handles kernel exceptions.
//
+// +checkescape:all
+//
//go:nosplit
func (c *vCPU) KernelException(vector ring0.Vector) {
regs := c.Registers()
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
index 4ca2b7717..8b64f3a1e 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -17,6 +17,7 @@
package kvm
import (
+ "syscall"
"unsafe"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -61,3 +62,36 @@ func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
func bluepillArchFpContext(context unsafe.Pointer) *arch.FpsimdContext {
return &((*arch.SignalContext64)(context).Fpsimd64)
}
+
+// getHypercallID returns hypercall ID.
+//
+// On Arm64, the MMIO address should be 64-bit aligned.
+//
+//go:nosplit
+func getHypercallID(addr uintptr) int {
+ if addr < arm64HypercallMMIOBase || addr >= (arm64HypercallMMIOBase+_AARCH64_HYPERCALL_MMIO_SIZE) {
+ return _KVM_HYPERCALL_MAX
+ } else {
+ return int(((addr) - arm64HypercallMMIOBase) >> 3)
+ }
+}
+
+// bluepillStopGuest is reponsible for injecting sError.
+//
+//go:nosplit
+func bluepillStopGuest(c *vCPU) {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_VCPU_EVENTS,
+ uintptr(unsafe.Pointer(&vcpuSErr))); errno != 0 {
+ throw("sErr injection failed")
+ }
+}
+
+// bluepillReadyStopGuest checks whether the current vCPU is ready for sError injection.
+//
+//go:nosplit
+func bluepillReadyStopGuest(c *vCPU) bool {
+ return true
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 9add7c944..bf357de1a 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.12
-// +build !go1.15
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
@@ -58,12 +58,32 @@ func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 {
return &((*arch.UContext64)(context).MContext)
}
+// bluepillHandleHlt is reponsible for handling VM-Exit.
+//
+//go:nosplit
+func bluepillGuestExit(c *vCPU, context unsafe.Pointer) {
+ // Copy out registers.
+ bluepillArchExit(c, bluepillArchContext(context))
+
+ // Return to the vCPUReady state; notify any waiters.
+ user := atomic.LoadUint32(&c.state) & vCPUUser
+ switch atomic.SwapUint32(&c.state, user) {
+ case user | vCPUGuest: // Expected case.
+ case user | vCPUGuest | vCPUWaiter:
+ c.notify()
+ default:
+ throw("invalid state")
+ }
+}
+
// bluepillHandler is called from the signal stub.
//
// The world may be stopped while this is executing, and it executes on the
// signal stack. It should only execute raw system calls and functions that are
// explicitly marked go:nosplit.
//
+// +checkescape:all
+//
//go:nosplit
func bluepillHandler(context unsafe.Pointer) {
// Sanitize the registers; interrupts must always be disabled.
@@ -82,7 +102,8 @@ func bluepillHandler(context unsafe.Pointer) {
}
for {
- switch _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0); errno {
+ _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0) // escapes: no.
+ switch errno {
case 0: // Expected case.
case syscall.EINTR:
// First, we process whatever pending signal
@@ -90,7 +111,7 @@ func bluepillHandler(context unsafe.Pointer) {
// currently, all signals are masked and the signal
// must have been delivered directly to this thread.
timeout := syscall.Timespec{}
- sig, _, errno := syscall.RawSyscall6(
+ sig, _, errno := syscall.RawSyscall6( // escapes: no.
syscall.SYS_RT_SIGTIMEDWAIT,
uintptr(unsafe.Pointer(&bounceSignalMask)),
0, // siginfo.
@@ -112,12 +133,12 @@ func bluepillHandler(context unsafe.Pointer) {
// PIC, we can't inject an interrupt while they are
// masked. We need to request a window if it's not
// ready.
- if c.runData.readyForInterruptInjection == 0 {
- c.runData.requestInterruptWindow = 1
- continue // Rerun vCPU.
- } else {
+ if bluepillReadyStopGuest(c) {
// Force injection below; the vCPU is ready.
c.runData.exitReason = _KVM_EXIT_IRQ_WINDOW_OPEN
+ } else {
+ c.runData.requestInterruptWindow = 1
+ continue // Rerun vCPU.
}
case syscall.EFAULT:
// If a fault is not serviceable due to the host
@@ -125,7 +146,7 @@ func bluepillHandler(context unsafe.Pointer) {
// MMIO exit we receive EFAULT from the run ioctl. We
// always inject an NMI here since we may be in kernel
// mode and have interrupts disabled.
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_NMI, 0); errno != 0 {
@@ -156,25 +177,20 @@ func bluepillHandler(context unsafe.Pointer) {
c.die(bluepillArchContext(context), "debug")
return
case _KVM_EXIT_HLT:
- // Copy out registers.
- bluepillArchExit(c, bluepillArchContext(context))
-
- // Return to the vCPUReady state; notify any waiters.
- user := atomic.LoadUint32(&c.state) & vCPUUser
- switch atomic.SwapUint32(&c.state, user) {
- case user | vCPUGuest: // Expected case.
- case user | vCPUGuest | vCPUWaiter:
- c.notify()
- default:
- throw("invalid state")
- }
+ bluepillGuestExit(c, context)
return
case _KVM_EXIT_MMIO:
+ physical := uintptr(c.runData.data[0])
+ if getHypercallID(physical) == _KVM_HYPERCALL_VMEXIT {
+ bluepillGuestExit(c, context)
+ return
+ }
+
// Increment the fault count.
atomic.AddUint32(&c.faults, 1)
// For MMIO, the physical address is the first data item.
- physical := uintptr(c.runData.data[0])
+ physical = uintptr(c.runData.data[0])
virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE)
if !ok {
c.die(bluepillArchContext(context), "invalid physical address")
@@ -201,17 +217,7 @@ func bluepillHandler(context unsafe.Pointer) {
}
}
case _KVM_EXIT_IRQ_WINDOW_OPEN:
- // Interrupt: we must have requested an interrupt
- // window; set the interrupt line.
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_INTERRUPT,
- uintptr(unsafe.Pointer(&bounce))); errno != 0 {
- throw("interrupt injection failed")
- }
- // Clear previous injection request.
- c.runData.requestInterruptWindow = 0
+ bluepillStopGuest(c)
case _KVM_EXIT_SHUTDOWN:
c.die(bluepillArchContext(context), "shutdown")
return
diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go
new file mode 100644
index 000000000..c0b4fd374
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go
@@ -0,0 +1,51 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+)
+
+func TestSegments(t *testing.T) {
+ applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ testutil.SetTestSegments(regs)
+ for {
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ FullRestore: true,
+ }, &si); err == platform.ErrContextInterrupt {
+ continue // Retry.
+ } else if err != nil {
+ t.Errorf("application segment check with full restore got unexpected error: %v", err)
+ }
+ if err := testutil.CheckTestSegments(regs); err != nil {
+ t.Errorf("application segment check with full restore failed: %v", err)
+ }
+ break // Done.
+ }
+ return false
+ })
+}
diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go
index 29d457a7e..0b06a923a 100644
--- a/pkg/sentry/platform/kvm/kvm_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_arm64.go
@@ -26,6 +26,9 @@ type kvmOneReg struct {
addr uint64
}
+// arm64HypercallMMIOBase is MMIO base address used to dispatch hypercalls.
+var arm64HypercallMMIOBase uintptr
+
const KVM_NR_SPSR = 5
type userFpsimdState struct {
@@ -43,6 +46,18 @@ type userRegs struct {
fpRegs userFpsimdState
}
+type exception struct {
+ sErrPending uint8
+ sErrHasEsr uint8
+ pad [6]uint8
+ sErrEsr uint64
+}
+
+type kvmVcpuEvents struct {
+ exception
+ rsvd [12]uint32
+}
+
// updateGlobalOnce does global initialization. It has to be called only once.
func updateGlobalOnce(fd int) error {
physicalInit()
diff --git a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go
index 6531bae1d..48ccf8474 100644
--- a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go
@@ -22,7 +22,8 @@ import (
)
var (
- runDataSize int
+ runDataSize int
+ hasGuestPCID bool
)
func updateSystemValues(fd int) error {
@@ -33,6 +34,7 @@ func updateSystemValues(fd int) error {
}
// Save the data.
runDataSize = int(sz)
+ hasGuestPCID = true
// Success.
return nil
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
index 1d5c77ff4..3bf918446 100644
--- a/pkg/sentry/platform/kvm/kvm_const.go
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -35,6 +35,8 @@ const (
_KVM_GET_SUPPORTED_CPUID = 0xc008ae05
_KVM_SET_CPUID2 = 0x4008ae90
_KVM_SET_SIGNAL_MASK = 0x4004ae8b
+ _KVM_GET_VCPU_EVENTS = 0x8040ae9f
+ _KVM_SET_VCPU_EVENTS = 0x4040aea0
)
// KVM exit reasons.
@@ -54,8 +56,10 @@ const (
// KVM capability options.
const (
- _KVM_CAP_MAX_VCPUS = 0x42
- _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5
+ _KVM_CAP_MAX_VCPUS = 0x42
+ _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5
+ _KVM_CAP_VCPU_EVENTS = 0x29
+ _KVM_CAP_ARM_INJECT_SERROR_ESR = 0x9e
)
// KVM limits.
@@ -71,3 +75,13 @@ const (
_KVM_MEM_READONLY = uint32(1) << 1
_KVM_MEM_FLAGS_NONE = 0
)
+
+// KVM hypercall list.
+// Canonical list of hypercalls supported.
+const (
+ // On amd64, it uses 'HLT' to leave the guest.
+ // Unlike amd64, arm64 can only uses mmio_exit/psci to leave the guest.
+ // _KVM_HYPERCALL_VMEXIT is only used on Arm64 for now.
+ _KVM_HYPERCALL_VMEXIT int = iota
+ _KVM_HYPERCALL_MAX
+)
diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go
index 531ae8b1e..fdc599477 100644
--- a/pkg/sentry/platform/kvm/kvm_const_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go
@@ -116,6 +116,17 @@ const (
// Arm64: Exception Syndrome Register EL1.
const (
+ _ESR_ELx_EC_SHIFT = 26
+ _ESR_ELx_EC_MASK = 0x3F << _ESR_ELx_EC_SHIFT
+
+ _ESR_ELx_EC_IMP_DEF = 0x1f
+ _ESR_ELx_EC_IABT_LOW = 0x20
+ _ESR_ELx_EC_IABT_CUR = 0x21
+ _ESR_ELx_EC_PC_ALIGN = 0x22
+
+ _ESR_ELx_CM = 1 << 8
+ _ESR_ELx_WNR = 1 << 6
+
_ESR_ELx_FSC = 0x3F
_ESR_SEGV_MAPERR_L0 = 0x4
@@ -131,3 +142,10 @@ const (
_ESR_SEGV_PEMERR_L2 = 0xe
_ESR_SEGV_PEMERR_L3 = 0xf
)
+
+// Arm64: MMIO base address used to dispatch hypercalls.
+const (
+ // on Arm64, the MMIO address must be 64-bit aligned.
+ // Currently, we only need 1 hypercall: hypercall_vmexit.
+ _AARCH64_HYPERCALL_MMIO_SIZE = 1 << 3
+)
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
index 6c8f4fa28..45b3180f1 100644
--- a/pkg/sentry/platform/kvm/kvm_test.go
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -262,30 +262,6 @@ func TestRegistersFault(t *testing.T) {
})
}
-func TestSegments(t *testing.T) {
- applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- testutil.SetTestSegments(regs)
- for {
- var si arch.SignalInfo
- if _, err := c.SwitchToUser(ring0.SwitchOpts{
- Registers: regs,
- FloatingPointState: dummyFPState,
- PageTables: pt,
- FullRestore: true,
- }, &si); err == platform.ErrContextInterrupt {
- continue // Retry.
- } else if err != nil {
- t.Errorf("application segment check with full restore got unexpected error: %v", err)
- }
- if err := testutil.CheckTestSegments(regs); err != nil {
- t.Errorf("application segment check with full restore failed: %v", err)
- }
- break // Done.
- }
- return false
- })
-}
-
func TestBounce(t *testing.T) {
applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
go func() {
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index f1afc74dc..6c54712d1 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -52,16 +52,19 @@ type machine struct {
// available is notified when vCPUs are available.
available sync.Cond
- // vCPUs are the machine vCPUs.
+ // vCPUsByTID are the machine vCPUs.
//
// These are populated dynamically.
- vCPUs map[uint64]*vCPU
+ vCPUsByTID map[uint64]*vCPU
// vCPUsByID are the machine vCPUs, can be indexed by the vCPU's ID.
- vCPUsByID map[int]*vCPU
+ vCPUsByID []*vCPU
// maxVCPUs is the maximum number of vCPUs supported by the machine.
maxVCPUs int
+
+ // nextID is the next vCPU ID.
+ nextID uint32
}
const (
@@ -137,9 +140,8 @@ type dieState struct {
//
// Precondition: mu must be held.
func (m *machine) newVCPU() *vCPU {
- id := len(m.vCPUs)
-
// Create the vCPU.
+ id := int(atomic.AddUint32(&m.nextID, 1) - 1)
fd, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CREATE_VCPU, uintptr(id))
if errno != 0 {
panic(fmt.Sprintf("error creating new vCPU: %v", errno))
@@ -176,11 +178,7 @@ func (m *machine) newVCPU() *vCPU {
// newMachine returns a new VM context.
func newMachine(vm int) (*machine, error) {
// Create the machine.
- m := &machine{
- fd: vm,
- vCPUs: make(map[uint64]*vCPU),
- vCPUsByID: make(map[int]*vCPU),
- }
+ m := &machine{fd: vm}
m.available.L = &m.mu
m.kernel.Init(ring0.KernelOpts{
PageTables: pagetables.New(newAllocator()),
@@ -194,6 +192,10 @@ func newMachine(vm int) (*machine, error) {
}
log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs)
+ // Create the vCPUs map/slices.
+ m.vCPUsByTID = make(map[uint64]*vCPU)
+ m.vCPUsByID = make([]*vCPU, m.maxVCPUs)
+
// Apply the physical mappings. Note that these mappings may point to
// guest physical addresses that are not actually available. These
// physical pages are mapped on demand, see kernel_unsafe.go.
@@ -274,6 +276,8 @@ func newMachine(vm int) (*machine, error) {
// not available. This attempts to be efficient for calls in the hot path.
//
// This panics on error.
+//
+//go:nosplit
func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) {
for end := physical + length; physical < end; {
_, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions)
@@ -304,7 +308,11 @@ func (m *machine) Destroy() {
runtime.SetFinalizer(m, nil)
// Destroy vCPUs.
- for _, c := range m.vCPUs {
+ for _, c := range m.vCPUsByID {
+ if c == nil {
+ continue
+ }
+
// Ensure the vCPU is not still running in guest mode. This is
// possible iff teardown has been done by other threads, and
// somehow a single thread has not executed any system calls.
@@ -337,7 +345,7 @@ func (m *machine) Get() *vCPU {
tid := procid.Current()
// Check for an exact match.
- if c := m.vCPUs[tid]; c != nil {
+ if c := m.vCPUsByTID[tid]; c != nil {
c.lock()
m.mu.RUnlock()
return c
@@ -356,7 +364,7 @@ func (m *machine) Get() *vCPU {
tid = procid.Current()
// Recheck for an exact match.
- if c := m.vCPUs[tid]; c != nil {
+ if c := m.vCPUsByTID[tid]; c != nil {
c.lock()
m.mu.Unlock()
return c
@@ -364,10 +372,10 @@ func (m *machine) Get() *vCPU {
for {
// Scan for an available vCPU.
- for origTID, c := range m.vCPUs {
+ for origTID, c := range m.vCPUsByTID {
if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) {
- delete(m.vCPUs, origTID)
- m.vCPUs[tid] = c
+ delete(m.vCPUsByTID, origTID)
+ m.vCPUsByTID[tid] = c
m.mu.Unlock()
c.loadSegments(tid)
return c
@@ -375,17 +383,17 @@ func (m *machine) Get() *vCPU {
}
// Create a new vCPU (maybe).
- if len(m.vCPUs) < m.maxVCPUs {
+ if int(m.nextID) < m.maxVCPUs {
c := m.newVCPU()
c.lock()
- m.vCPUs[tid] = c
+ m.vCPUsByTID[tid] = c
m.mu.Unlock()
c.loadSegments(tid)
return c
}
// Scan for something not in user mode.
- for origTID, c := range m.vCPUs {
+ for origTID, c := range m.vCPUsByTID {
if !atomic.CompareAndSwapUint32(&c.state, vCPUGuest, vCPUGuest|vCPUWaiter) {
continue
}
@@ -403,8 +411,8 @@ func (m *machine) Get() *vCPU {
}
// Steal the vCPU.
- delete(m.vCPUs, origTID)
- m.vCPUs[tid] = c
+ delete(m.vCPUsByTID, origTID)
+ m.vCPUsByTID[tid] = c
m.mu.Unlock()
c.loadSegments(tid)
return c
@@ -431,7 +439,7 @@ func (m *machine) Put(c *vCPU) {
// newDirtySet returns a new dirty set.
func (m *machine) newDirtySet() *dirtySet {
return &dirtySet{
- vCPUs: make([]uint64, (m.maxVCPUs+63)/64, (m.maxVCPUs+63)/64),
+ vCPUMasks: make([]uint64, (m.maxVCPUs+63)/64, (m.maxVCPUs+63)/64),
}
}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index 923ce3909..acc823ba6 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -51,9 +51,10 @@ func (m *machine) initArchState() error {
recover()
debug.SetPanicOnFault(old)
}()
- m.retryInGuest(func() {
- ring0.SetCPUIDFaulting(true)
- })
+ c := m.Get()
+ defer m.Put(c)
+ bluepill(c)
+ ring0.SetCPUIDFaulting(true)
return nil
}
@@ -89,8 +90,8 @@ func (m *machine) dropPageTables(pt *pagetables.PageTables) {
defer m.mu.Unlock()
// Clear from all PCIDs.
- for _, c := range m.vCPUs {
- if c.PCIDs != nil {
+ for _, c := range m.vCPUsByID {
+ if c != nil && c.PCIDs != nil {
c.PCIDs.Drop(pt)
}
}
@@ -335,29 +336,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
}
}
-// retryInGuest runs the given function in guest mode.
-//
-// If the function does not complete in guest mode (due to execution of a
-// system call due to a GC stall, for example), then it will be retried. The
-// given function must be idempotent as a result of the retry mechanism.
-func (m *machine) retryInGuest(fn func()) {
- c := m.Get()
- defer m.Put(c)
- for {
- c.ClearErrorCode() // See below.
- bluepill(c) // Force guest mode.
- fn() // Execute the given function.
- _, user := c.ErrorCode()
- if user {
- // If user is set, then we haven't bailed back to host
- // mode via a kernel exception or system call. We
- // consider the full function to have executed in guest
- // mode and we can return.
- break
- }
- }
-}
-
// On x86 platform, the flags for "setMemoryRegion" can always be set as 0.
// There is no need to return read-only physicalRegions.
func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
index 7156c245f..290f035dd 100644
--- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -154,7 +154,7 @@ func (c *vCPU) setUserRegisters(uregs *userRegs) error {
//
//go:nosplit
func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_GET_REGS,
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index e42505542..9db171af9 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -60,6 +60,12 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
if !vr.accessType.Write && vr.accessType.Read {
rdonlyRegions = append(rdonlyRegions, vr.region)
}
+
+ // TODO(gvisor.dev/issue/2686): PROT_NONE should be specially treated.
+ // Workaround: treated as rdonly temporarily.
+ if !vr.accessType.Write && !vr.accessType.Read && !vr.accessType.Execute {
+ rdonlyRegions = append(rdonlyRegions, vr.region)
+ }
})
for _, r := range rdonlyRegions {
@@ -100,7 +106,7 @@ func (m *machine) dropPageTables(pt *pagetables.PageTables) {
defer m.mu.Unlock()
// Clear from all PCIDs.
- for _, c := range m.vCPUs {
+ for _, c := range m.vCPUsByID {
if c.PCIDs != nil {
c.PCIDs.Drop(pt)
}
@@ -119,71 +125,59 @@ func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (usermem.Acc
return usermem.NoAccess, platform.ErrContextSignal
}
+// isInstructionAbort returns true if it is an instruction abort.
+//
+//go:nosplit
+func isInstructionAbort(code uint64) bool {
+ value := (code & _ESR_ELx_EC_MASK) >> _ESR_ELx_EC_SHIFT
+ return value == _ESR_ELx_EC_IABT_LOW
+}
+
+// isWriteFault returns whether it is a write fault.
+//
+//go:nosplit
+func isWriteFault(code uint64) bool {
+ if isInstructionAbort(code) {
+ return false
+ }
+
+ return (code & _ESR_ELx_WNR) != 0
+}
+
// fault generates an appropriate fault return.
//
//go:nosplit
func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+ bluepill(c) // Probably no-op, but may not be.
faultAddr := c.GetFaultAddr()
code, user := c.ErrorCode()
+ if !user {
+ // The last fault serviced by this CPU was not a user
+ // fault, so we can't reliably trust the faultAddr or
+ // the code provided here. We need to re-execute.
+ return usermem.NoAccess, platform.ErrContextInterrupt
+ }
+
// Reset the pointed SignalInfo.
*info = arch.SignalInfo{Signo: signal}
info.SetAddr(uint64(faultAddr))
- read := true
- write := false
- execute := true
-
ret := code & _ESR_ELx_FSC
switch ret {
case _ESR_SEGV_MAPERR_L0, _ESR_SEGV_MAPERR_L1, _ESR_SEGV_MAPERR_L2, _ESR_SEGV_MAPERR_L3:
info.Code = 1 //SEGV_MAPERR
- read = false
- write = true
- execute = false
case _ESR_SEGV_ACCERR_L1, _ESR_SEGV_ACCERR_L2, _ESR_SEGV_ACCERR_L3, _ESR_SEGV_PEMERR_L1, _ESR_SEGV_PEMERR_L2, _ESR_SEGV_PEMERR_L3:
info.Code = 2 // SEGV_ACCERR.
- read = true
- write = false
- execute = false
default:
info.Code = 2
}
- if !user {
- read = true
- write = false
- execute = true
-
- }
accessType := usermem.AccessType{
- Read: read,
- Write: write,
- Execute: execute,
+ Read: !isWriteFault(uint64(code)),
+ Write: isWriteFault(uint64(code)),
+ Execute: isInstructionAbort(uint64(code)),
}
return accessType, platform.ErrContextSignal
}
-
-// retryInGuest runs the given function in guest mode.
-//
-// If the function does not complete in guest mode (due to execution of a
-// system call due to a GC stall, for example), then it will be retried. The
-// given function must be idempotent as a result of the retry mechanism.
-func (m *machine) retryInGuest(fn func()) {
- c := m.Get()
- defer m.Put(c)
- for {
- c.ClearErrorCode() // See below.
- bluepill(c) // Force guest mode.
- fn() // Execute the given function.
- _, user := c.ErrorCode()
- if user {
- // If user is set, then we haven't bailed back to host
- // mode via a kernel exception or system call. We
- // consider the full function to have executed in guest
- // mode and we can return.
- break
- }
- }
-}
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 3c02cef7c..ff8c068c0 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -78,19 +79,6 @@ func (c *vCPU) initArchState() error {
return err
}
- // sctlr_el1
- regGet.id = _KVM_ARM64_REGS_SCTLR_EL1
- if err := c.getOneRegister(&regGet); err != nil {
- return err
- }
-
- dataGet |= (_SCTLR_M | _SCTLR_C | _SCTLR_I)
- data = dataGet
- reg.id = _KVM_ARM64_REGS_SCTLR_EL1
- if err := c.setOneRegister(&reg); err != nil {
- return err
- }
-
// tcr_el1
data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS
reg.id = _KVM_ARM64_REGS_TCR_EL1
@@ -159,12 +147,24 @@ func (c *vCPU) initArchState() error {
return err
}
+ // Use the address of the exception vector table as
+ // the MMIO address base.
+ arm64HypercallMMIOBase = toLocation
+
data = ring0.PsrDefaultSet | ring0.KernelFlagsSet
reg.id = _KVM_ARM64_REGS_PSTATE
if err := c.setOneRegister(&reg); err != nil {
return err
}
+ // Initialize the PCID database.
+ if hasGuestPCID {
+ // Note that NewPCIDs may return a nil table here, in which
+ // case we simply don't use PCID support (see below). In
+ // practice, this should not happen, however.
+ c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs)
+ }
+
c.floatingPointState = arch.NewFloatingPointData()
return nil
}
@@ -243,6 +243,13 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info)
}
+ // Assign PCIDs.
+ if c.PCIDs != nil {
+ var requireFlushPCID bool // Force a flush?
+ switchOpts.UserASID, requireFlushPCID = c.PCIDs.Assign(switchOpts.PageTables)
+ switchOpts.Flush = switchOpts.Flush || requireFlushPCID
+ }
+
var vector ring0.Vector
ttbr0App := switchOpts.PageTables.TTBR0_EL1(false, 0)
c.SetTtbr0App(uintptr(ttbr0App))
@@ -269,8 +276,8 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
case ring0.PageFault:
return c.fault(int32(syscall.SIGSEGV), info)
- case 0xaa:
- return usermem.NoAccess, nil
+ case ring0.Vector(bounce): // ring0.VirtualizationException
+ return usermem.NoAccess, platform.ErrContextInterrupt
default:
return usermem.NoAccess, platform.ErrContextSignal
}
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index f04be2ab5..9f86f6a7a 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.12
-// +build !go1.15
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
@@ -115,7 +115,7 @@ func (a *atomicAddressSpace) get() *addressSpace {
//
//go:nosplit
func (c *vCPU) notify() {
- _, _, errno := syscall.RawSyscall6(
+ _, _, errno := syscall.RawSyscall6( // escapes: no.
syscall.SYS_FUTEX,
uintptr(unsafe.Pointer(&c.state)),
linux.FUTEX_WAKE|linux.FUTEX_PRIVATE_FLAG,
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
index ca902c8c1..4dad877ba 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
@@ -56,5 +56,9 @@ func CheckTestRegs(regs *arch.Registers, full bool) (err error) {
err = addRegisterMismatch(err, fmt.Sprintf("R%d", i), regs.Regs[i], need)
}
}
+ // Check tls.
+ if need := ^uint64(11); regs.TPIDR_EL0 != need {
+ err = addRegisterMismatch(err, "tpdir_el0", regs.TPIDR_EL0, need)
+ }
return
}
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
index 0bebee852..6caf7282d 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
@@ -52,6 +52,8 @@ start:
TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8
NO_LOCAL_POINTERS
+ // gc will touch fpsimd, so we should test it.
+ // such as in <runtime.deductSweepCredit>.
FMOVD $(9.9), F0
MOVD $SYS_GETPID, R8 // getpid
SVC
@@ -102,5 +104,15 @@ isNaN:
TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0
TWIDDLE_REGS()
+ MSR R10, TPIDR_EL0
+ // Trapped in el0_svc.
SVC
RET // never reached
+
+TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0
+ TWIDDLE_REGS()
+ MSR R10, TPIDR_EL0
+ // Trapped in el0_ia.
+ // Branch to Register branches unconditionally to an address in <Rn>.
+ JMP (R6) // <=> br x6, must fault
+ RET // never reached
diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go
index 171513f3f..4b13eec30 100644
--- a/pkg/sentry/platform/platform.go
+++ b/pkg/sentry/platform/platform.go
@@ -22,9 +22,9 @@ import (
"os"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -207,7 +207,7 @@ type AddressSpace interface {
// Preconditions: addr and fr must be page-aligned. fr.Length() > 0.
// at.Any() == true. At least one reference must be held on all pages in
// fr, and must continue to be held as long as pages are mapped.
- MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, precommit bool) error
+ MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error
// Unmap unmaps the given range.
//
@@ -310,52 +310,6 @@ func (f SegmentationFault) Error() string {
return fmt.Sprintf("segmentation fault at %#x", f.Addr)
}
-// File represents a host file that may be mapped into an AddressSpace.
-type File interface {
- // All pages in a File are reference-counted.
-
- // IncRef increments the reference count on all pages in fr.
- //
- // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
- // 0. At least one reference must be held on all pages in fr. (The File
- // interface does not provide a way to acquire an initial reference;
- // implementors may define mechanisms for doing so.)
- IncRef(fr FileRange)
-
- // DecRef decrements the reference count on all pages in fr.
- //
- // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
- // 0. At least one reference must be held on all pages in fr.
- DecRef(fr FileRange)
-
- // MapInternal returns a mapping of the given file offsets in the invoking
- // process' address space for reading and writing.
- //
- // Note that fr.Start and fr.End need not be page-aligned.
- //
- // Preconditions: fr.Length() > 0. At least one reference must be held on
- // all pages in fr.
- //
- // Postconditions: The returned mapping is valid as long as at least one
- // reference is held on the mapped pages.
- MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error)
-
- // FD returns the file descriptor represented by the File.
- //
- // The only permitted operation on the returned file descriptor is to map
- // pages from it consistent with the requirements of AddressSpace.MapFile.
- FD() int
-}
-
-// FileRange represents a range of uint64 offsets into a File.
-//
-// type FileRange <generated using go_generics>
-
-// String implements fmt.Stringer.String.
-func (fr FileRange) String() string {
- return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End)
-}
-
// Requirements is used to specify platform specific requirements.
type Requirements struct {
// RequiresCurrentPIDNS indicates that the sandbox has to be started in the
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index 30402c2df..29fd23cc3 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/hostcpu",
+ "//pkg/sentry/memmap",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
"//pkg/sync",
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 2389423b0..c990f3454 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
@@ -616,7 +617,7 @@ func (s *subprocess) syscall(sysno uintptr, args ...arch.SyscallArgument) (uintp
}
// MapFile implements platform.AddressSpace.MapFile.
-func (s *subprocess) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error {
+func (s *subprocess) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error {
var flags int
if precommit {
flags |= syscall.MAP_POPULATE
diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go
index 2ae6b9f9d..0bee995e4 100644
--- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go
+++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.12
-// +build !go1.15
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index db6465663..9fd02d628 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -40,6 +40,20 @@
#define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT)
+// sctlr_el1: system control register el1.
+#define SCTLR_M 1 << 0
+#define SCTLR_C 1 << 2
+#define SCTLR_I 1 << 12
+#define SCTLR_UCT 1 << 15
+
+#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT)
+
+// cntkctl_el1: counter-timer kernel control register el1.
+#define CNTKCTL_EL0PCTEN 1 << 0
+#define CNTKCTL_EL0VCTEN 1 << 1
+
+#define CNTKCTL_EL1_DEFAULT (CNTKCTL_EL0PCTEN | CNTKCTL_EL0VCTEN)
+
// Saves a register set.
//
// This is a macro because it may need to executed in contents where a stack is
@@ -362,9 +376,17 @@ mmio_exit:
MOVD R1, CPU_LAZY_VFP(RSV_REG)
VFP_DISABLE
- // MMIO_EXIT.
- MOVD $0, R9
- MOVD R0, 0xffff000000001000(R9)
+ // Trigger MMIO_EXIT/_KVM_HYPERCALL_VMEXIT.
+ //
+ // To keep it simple, I used the address of exception table as the
+ // MMIO base address, so that I can trigger a MMIO-EXIT by forcibly writing
+ // a read-only space.
+ // Also, the length is engough to match a sufficient number of hypercall ID.
+ // Then, in host user space, I can calculate this address to find out
+ // which hypercall.
+ MRS VBAR_EL1, R9
+ MOVD R0, 0x0(R9)
+
RET
// HaltAndResume halts execution and point the pointer to the resume function.
@@ -488,6 +510,14 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
// Start is the CPU entrypoint.
TEXT ·Start(SB),NOSPLIT,$0
IRQ_DISABLE
+
+ // Init.
+ MOVD $SCTLR_EL1_DEFAULT, R1
+ MSR R1, SCTLR_EL1
+
+ MOVD $CNTKCTL_EL1_DEFAULT, R1
+ MSR R1, CNTKCTL_EL1
+
MOVD R8, RSV_REG
ORR $0xffff000000000000, RSV_REG, RSV_REG
WORD $0xd518d092 //MSR R18, TPIDR_EL1
diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go
index 900c0bba7..021693791 100644
--- a/pkg/sentry/platform/ring0/kernel.go
+++ b/pkg/sentry/platform/ring0/kernel.go
@@ -31,23 +31,39 @@ type defaultHooks struct{}
// KernelSyscall implements Hooks.KernelSyscall.
//
+// +checkescape:all
+//
//go:nosplit
-func (defaultHooks) KernelSyscall() { Halt() }
+func (defaultHooks) KernelSyscall() {
+ Halt()
+}
// KernelException implements Hooks.KernelException.
//
+// +checkescape:all
+//
//go:nosplit
-func (defaultHooks) KernelException(Vector) { Halt() }
+func (defaultHooks) KernelException(Vector) {
+ Halt()
+}
// kernelSyscall is a trampoline.
//
+// +checkescape:hard,stack
+//
//go:nosplit
-func kernelSyscall(c *CPU) { c.hooks.KernelSyscall() }
+func kernelSyscall(c *CPU) {
+ c.hooks.KernelSyscall()
+}
// kernelException is a trampoline.
//
+// +checkescape:hard,stack
+//
//go:nosplit
-func kernelException(c *CPU, vector Vector) { c.hooks.KernelException(vector) }
+func kernelException(c *CPU, vector Vector) {
+ c.hooks.KernelException(vector)
+}
// Init initializes a new CPU.
//
diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go
index 0feff8778..d37981dbf 100644
--- a/pkg/sentry/platform/ring0/kernel_amd64.go
+++ b/pkg/sentry/platform/ring0/kernel_amd64.go
@@ -178,6 +178,8 @@ func IsCanonical(addr uint64) bool {
//
// Precondition: the Rip, Rsp, Fs and Gs registers must be canonical.
//
+// +checkescape:all
+//
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID)
@@ -192,9 +194,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
// Perform the switch.
swapgs() // GS will be swapped on return.
- WriteFS(uintptr(regs.Fs_base)) // Set application FS.
- WriteGS(uintptr(regs.Gs_base)) // Set application GS.
- LoadFloatingPoint(switchOpts.FloatingPointState) // Copy in floating point.
+ WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS.
+ WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
+ LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point.
jumpToKernel() // Switch to upper half.
writeCR3(uintptr(userCR3)) // Change to user address space.
if switchOpts.FullRestore {
@@ -204,8 +206,8 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
}
writeCR3(uintptr(kernelCR3)) // Return to kernel address space.
jumpToUser() // Return to lower half.
- SaveFloatingPoint(switchOpts.FloatingPointState) // Copy out floating point.
- WriteFS(uintptr(c.registers.Fs_base)) // Restore kernel FS.
+ SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point.
+ WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
return
}
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index ccacaea6b..d483ff03c 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -58,7 +58,15 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
regs.Pstate &= ^uint64(UserFlagsClear)
regs.Pstate |= UserFlagsSet
+
+ LoadFloatingPoint(switchOpts.FloatingPointState)
+ SetTLS(regs.TPIDR_EL0)
+
kernelExitToEl0()
+
+ regs.TPIDR_EL0 = GetTLS()
+ SaveFloatingPoint(switchOpts.FloatingPointState)
+
vector = c.vecCode
// Perform the switch.
diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go
index a6345010d..00e52c8af 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.go
+++ b/pkg/sentry/platform/ring0/lib_arm64.go
@@ -38,6 +38,12 @@ func SaveVRegs(*byte)
// LoadVRegs loads V0-V31 registers.
func LoadVRegs(*byte)
+// LoadFloatingPoint loads floating point state.
+func LoadFloatingPoint(*byte)
+
+// SaveFloatingPoint saves floating point state.
+func SaveFloatingPoint(*byte)
+
// GetTLS returns the value of TPIDR_EL0 register.
func GetTLS() (value uint64)
diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s
index b63e14b41..86bfbe46f 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.s
+++ b/pkg/sentry/platform/ring0/lib_arm64.s
@@ -129,3 +129,89 @@ TEXT ·LoadVRegs(SB),NOSPLIT,$0-8
ISB $15
RET
+
+TEXT ·LoadFloatingPoint(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R0
+
+ MOVD 0(R0), R1
+ MOVD R1, FPSR
+ MOVD 8(R0), R1
+ MOVD R1, NZCV
+
+ FMOVD 16*1(R0), F0
+ FMOVD 16*2(R0), F1
+ FMOVD 16*3(R0), F2
+ FMOVD 16*4(R0), F3
+ FMOVD 16*5(R0), F4
+ FMOVD 16*6(R0), F5
+ FMOVD 16*7(R0), F6
+ FMOVD 16*8(R0), F7
+ FMOVD 16*9(R0), F8
+ FMOVD 16*10(R0), F9
+ FMOVD 16*11(R0), F10
+ FMOVD 16*12(R0), F11
+ FMOVD 16*13(R0), F12
+ FMOVD 16*14(R0), F13
+ FMOVD 16*15(R0), F14
+ FMOVD 16*16(R0), F15
+ FMOVD 16*17(R0), F16
+ FMOVD 16*18(R0), F17
+ FMOVD 16*19(R0), F18
+ FMOVD 16*20(R0), F19
+ FMOVD 16*21(R0), F20
+ FMOVD 16*22(R0), F21
+ FMOVD 16*23(R0), F22
+ FMOVD 16*24(R0), F23
+ FMOVD 16*25(R0), F24
+ FMOVD 16*26(R0), F25
+ FMOVD 16*27(R0), F26
+ FMOVD 16*28(R0), F27
+ FMOVD 16*29(R0), F28
+ FMOVD 16*30(R0), F29
+ FMOVD 16*31(R0), F30
+ FMOVD 16*32(R0), F31
+
+ RET
+
+TEXT ·SaveFloatingPoint(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R0
+
+ MOVD FPSR, R1
+ MOVD R1, 0(R0)
+ MOVD NZCV, R1
+ MOVD R1, 8(R0)
+
+ FMOVD F0, 16*1(R0)
+ FMOVD F1, 16*2(R0)
+ FMOVD F2, 16*3(R0)
+ FMOVD F3, 16*4(R0)
+ FMOVD F4, 16*5(R0)
+ FMOVD F5, 16*6(R0)
+ FMOVD F6, 16*7(R0)
+ FMOVD F7, 16*8(R0)
+ FMOVD F8, 16*9(R0)
+ FMOVD F9, 16*10(R0)
+ FMOVD F10, 16*11(R0)
+ FMOVD F11, 16*12(R0)
+ FMOVD F12, 16*13(R0)
+ FMOVD F13, 16*14(R0)
+ FMOVD F14, 16*15(R0)
+ FMOVD F15, 16*16(R0)
+ FMOVD F16, 16*17(R0)
+ FMOVD F17, 16*18(R0)
+ FMOVD F18, 16*19(R0)
+ FMOVD F19, 16*20(R0)
+ FMOVD F20, 16*21(R0)
+ FMOVD F21, 16*22(R0)
+ FMOVD F22, 16*23(R0)
+ FMOVD F23, 16*24(R0)
+ FMOVD F24, 16*25(R0)
+ FMOVD F25, 16*26(R0)
+ FMOVD F26, 16*27(R0)
+ FMOVD F27, 16*28(R0)
+ FMOVD F28, 16*29(R0)
+ FMOVD F29, 16*30(R0)
+ FMOVD F30, 16*31(R0)
+ FMOVD F31, 16*32(R0)
+
+ RET
diff --git a/pkg/sentry/platform/ring0/pagetables/allocator.go b/pkg/sentry/platform/ring0/pagetables/allocator.go
index 23fd5c352..8d75b7599 100644
--- a/pkg/sentry/platform/ring0/pagetables/allocator.go
+++ b/pkg/sentry/platform/ring0/pagetables/allocator.go
@@ -53,9 +53,14 @@ type RuntimeAllocator struct {
// NewRuntimeAllocator returns an allocator that uses runtime allocation.
func NewRuntimeAllocator() *RuntimeAllocator {
- return &RuntimeAllocator{
- used: make(map[*PTEs]struct{}),
- }
+ r := new(RuntimeAllocator)
+ r.Init()
+ return r
+}
+
+// Init initializes a RuntimeAllocator.
+func (r *RuntimeAllocator) Init() {
+ r.used = make(map[*PTEs]struct{})
}
// Recycle returns freed pages to the pool.
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go
index 87e88e97d..7f18ac296 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go
@@ -86,6 +86,8 @@ func (*mapVisitor) requiresSplit() bool { return true }
//
// Precondition: addr & length must be page-aligned, their sum must not overflow.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool {
if !opts.AccessType.Any() {
@@ -128,6 +130,8 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) {
//
// Precondition: addr & length must be page-aligned.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool {
w := unmapWalker{
@@ -162,6 +166,8 @@ func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) {
//
// Precondition: addr & length must be page-aligned.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) IsEmpty(addr usermem.Addr, length uintptr) bool {
w := emptyWalker{
@@ -197,6 +203,8 @@ func (*lookupVisitor) requiresSplit() bool { return false }
// Lookup returns the physical address for the given virtual address.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts) {
mask := uintptr(usermem.PageSize - 1)
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index c40c6d673..c0fd3425b 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -20,5 +20,6 @@ go_library(
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/usermem",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index e82d6cd1e..e76e498de 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -26,6 +26,7 @@ go_library(
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/hostfd",
"//pkg/sentry/inet",
@@ -39,6 +40,8 @@ go_library(
"//pkg/tcpip/stack",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index c11e82c10..532a1ea5d 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -36,6 +36,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const (
@@ -319,12 +321,12 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
if outLen < 0 {
return nil, syserr.ErrInvalidArgument
}
- // Whitelist options and constrain option length.
+ // Only allow known and safe options.
optlen := getSockOptLen(t, level, name)
switch level {
case linux.SOL_IP:
@@ -364,12 +366,13 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
if err != nil {
return nil, syserr.FromError(err)
}
- return opt, nil
+ optP := primitive.ByteSlice(opt)
+ return &optP, nil
}
// SetSockOpt implements socket.Socket.SetSockOpt.
func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
- // Whitelist options and constrain option length.
+ // Only allow known and safe options.
optlen := setSockOptLen(t, level, name)
switch level {
case linux.SOL_IP:
@@ -415,7 +418,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
// RecvMsg implements socket.Socket.RecvMsg.
func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
- // Whitelist flags.
+ // Only allow known and safe flags.
//
// FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
// messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the
@@ -537,7 +540,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
// SendMsg implements socket.Socket.SendMsg.
func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
- // Whitelist flags.
+ // Only allow known and safe flags.
if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
return 0, syserr.ErrInvalidArgument
}
@@ -708,6 +711,6 @@ func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int
func init() {
for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
socket.RegisterProvider(family, &socketProvider{family})
- socket.RegisterProviderVFS2(family, &socketProviderVFS2{})
+ socket.RegisterProviderVFS2(family, &socketProviderVFS2{family})
}
}
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
index 677743113..8a1d52ebf 100644
--- a/pkg/sentry/socket/hostinet/socket_vfs2.go
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
"gvisor.dev/gvisor/pkg/sentry/hostfd"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -35,6 +36,7 @@ import (
type socketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
// We store metadata for hostinet sockets internally. Technically, we should
// access metadata (e.g. through stat, chmod) on the host for correctness,
@@ -59,6 +61,7 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in
fd: fd,
},
}
+ s.LockFD.Init(&vfs.FileLocks{})
if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
return nil, syserr.FromError(err)
}
@@ -68,6 +71,7 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in
DenyPWrite: true,
UseDentryMetadata: true,
}); err != nil {
+ fdnotifier.RemoveFD(int32(s.fd))
return nil, syserr.FromError(err)
}
return vfsfd, nil
@@ -93,7 +97,12 @@ func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal
return ioctl(ctx, s.fd, uio, args)
}
-// PRead implements vfs.FileDescriptionImpl.
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.ENODEV
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
return 0, syserror.ESPIPE
}
@@ -131,6 +140,16 @@ func (s *socketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
return int64(n), err
}
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *socketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *socketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
+
type socketProviderVFS2 struct {
family int
}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 789bb94c8..a9f0604ae 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -41,19 +41,6 @@ const errorTargetName = "ERROR"
// change the destination port/destination IP for packets.
const redirectTargetName = "REDIRECT"
-// Metadata is used to verify that we are correctly serializing and
-// deserializing iptables into structs consumable by the iptables tool. We save
-// a metadata struct when the tables are written, and when they are read out we
-// verify that certain fields are the same.
-//
-// metadata is used by this serialization/deserializing code, not netstack.
-type metadata struct {
- HookEntry [linux.NF_INET_NUMHOOKS]uint32
- Underflow [linux.NF_INET_NUMHOOKS]uint32
- NumEntries uint32
- Size uint32
-}
-
// enableLogging controls whether to log the (de)serialization of netfilter
// structs between userspace and netstack. These logs are useful when
// developing iptables, but can pollute sentry logs otherwise.
@@ -64,6 +51,8 @@ const enableLogging = false
var emptyFilter = stack.IPHeaderFilter{
Dst: "\x00\x00\x00\x00",
DstMask: "\x00\x00\x00\x00",
+ Src: "\x00\x00\x00\x00",
+ SrcMask: "\x00\x00\x00\x00",
}
// nflog logs messages related to the writing and reading of iptables.
@@ -77,33 +66,17 @@ func nflog(format string, args ...interface{}) {
func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) {
// Read in the struct and table name.
var info linux.IPTGetinfo
- if _, err := t.CopyIn(outPtr, &info); err != nil {
+ if _, err := info.CopyIn(t, outPtr); err != nil {
return linux.IPTGetinfo{}, syserr.FromError(err)
}
- // Find the appropriate table.
- table, err := findTable(stack, info.Name)
+ _, info, err := convertNetstackToBinary(stack, info.Name)
if err != nil {
- nflog("%v", err)
+ nflog("couldn't convert iptables: %v", err)
return linux.IPTGetinfo{}, syserr.ErrInvalidArgument
}
- // Get the hooks that apply to this table.
- info.ValidHooks = table.ValidHooks()
-
- // Grab the metadata struct, which is used to store information (e.g.
- // the number of entries) that applies to the user's encoding of
- // iptables, but not netstack's.
- metadata := table.Metadata().(metadata)
-
- // Set values from metadata.
- info.HookEntry = metadata.HookEntry
- info.Underflow = metadata.Underflow
- info.NumEntries = metadata.NumEntries
- info.Size = metadata.Size
-
nflog("returning info: %+v", info)
-
return info, nil
}
@@ -111,28 +84,18 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT
func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
// Read in the struct and table name.
var userEntries linux.IPTGetEntries
- if _, err := t.CopyIn(outPtr, &userEntries); err != nil {
+ if _, err := userEntries.CopyIn(t, outPtr); err != nil {
nflog("couldn't copy in entries %q", userEntries.Name)
return linux.KernelIPTGetEntries{}, syserr.FromError(err)
}
- // Find the appropriate table.
- table, err := findTable(stack, userEntries.Name)
- if err != nil {
- nflog("%v", err)
- return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
- }
-
// Convert netstack's iptables rules to something that the iptables
// tool can understand.
- entries, meta, err := convertNetstackToBinary(userEntries.Name.String(), table)
+ entries, _, err := convertNetstackToBinary(stack, userEntries.Name)
if err != nil {
nflog("couldn't read entries: %v", err)
return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
}
- if meta != table.Metadata().(metadata) {
- panic(fmt.Sprintf("Table %q metadata changed between writing and reading. Was saved as %+v, but is now %+v", userEntries.Name.String(), table.Metadata().(metadata), meta))
- }
if binary.Size(entries) > uintptr(outLen) {
nflog("insufficient GetEntries output size: %d", uintptr(outLen))
return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
@@ -141,48 +104,26 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen
return entries, nil
}
-func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) {
- ipt := stk.IPTables()
- table, ok := ipt.Tables[tablename.String()]
- if !ok {
- return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename)
- }
- return table, nil
-}
-
-// FillDefaultIPTables sets stack's IPTables to the default tables and
-// populates them with metadata.
-func FillDefaultIPTables(stk *stack.Stack) {
- ipt := stack.DefaultTables()
-
- // In order to fill in the metadata, we have to translate ipt from its
- // netstack format to Linux's giant-binary-blob format.
- for name, table := range ipt.Tables {
- _, metadata, err := convertNetstackToBinary(name, table)
- if err != nil {
- panic(fmt.Errorf("Unable to set default IP tables: %v", err))
- }
- table.SetMetadata(metadata)
- ipt.Tables[name] = table
- }
-
- stk.SetIPTables(ipt)
-}
-
// convertNetstackToBinary converts the iptables as stored in netstack to the
// format expected by the iptables tool. Linux stores each table as a binary
// blob that can only be traversed by parsing a bit, reading some offsets,
// jumping to those offsets, parsing again, etc.
-func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelIPTGetEntries, metadata, error) {
- // Return values.
+func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) {
+ table, ok := stack.IPTables().GetTable(tablename.String())
+ if !ok {
+ return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
+ }
+
var entries linux.KernelIPTGetEntries
- var meta metadata
+ var info linux.IPTGetinfo
+ info.ValidHooks = table.ValidHooks()
// The table name has to fit in the struct.
if linux.XT_TABLE_MAXNAMELEN < len(tablename) {
- return linux.KernelIPTGetEntries{}, metadata{}, fmt.Errorf("table name %q too long.", tablename)
+ return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
}
- copy(entries.Name[:], tablename)
+ copy(info.Name[:], tablename[:])
+ copy(entries.Name[:], tablename[:])
for ruleIdx, rule := range table.Rules {
nflog("convert to binary: current offset: %d", entries.Size)
@@ -191,20 +132,20 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI
for hook, hookRuleIdx := range table.BuiltinChains {
if hookRuleIdx == ruleIdx {
nflog("convert to binary: found hook %d at offset %d", hook, entries.Size)
- meta.HookEntry[hook] = entries.Size
+ info.HookEntry[hook] = entries.Size
}
}
// Is this a chain underflow point?
for underflow, underflowRuleIdx := range table.Underflows {
if underflowRuleIdx == ruleIdx {
nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size)
- meta.Underflow[underflow] = entries.Size
+ info.Underflow[underflow] = entries.Size
}
}
// Each rule corresponds to an entry.
entry := linux.KernelIPTEntry{
- IPTEntry: linux.IPTEntry{
+ Entry: linux.IPTEntry{
IP: linux.IPTIP{
Protocol: uint16(rule.Filter.Protocol),
},
@@ -212,15 +153,20 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI
TargetOffset: linux.SizeOfIPTEntry,
},
}
- copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst)
- copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask)
- copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface)
- copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
+ copy(entry.Entry.IP.Dst[:], rule.Filter.Dst)
+ copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask)
+ copy(entry.Entry.IP.Src[:], rule.Filter.Src)
+ copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask)
+ copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface)
+ copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
if rule.Filter.DstInvert {
- entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP
+ }
+ if rule.Filter.SrcInvert {
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP
}
if rule.Filter.OutputInterfaceInvert {
- entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
}
for _, matcher := range rule.Matchers {
@@ -232,8 +178,8 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI
panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher))
}
entry.Elems = append(entry.Elems, serialized...)
- entry.NextOffset += uint16(len(serialized))
- entry.TargetOffset += uint16(len(serialized))
+ entry.Entry.NextOffset += uint16(len(serialized))
+ entry.Entry.TargetOffset += uint16(len(serialized))
}
// Serialize and append the target.
@@ -242,18 +188,18 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI
panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target))
}
entry.Elems = append(entry.Elems, serialized...)
- entry.NextOffset += uint16(len(serialized))
+ entry.Entry.NextOffset += uint16(len(serialized))
nflog("convert to binary: adding entry: %+v", entry)
- entries.Size += uint32(entry.NextOffset)
+ entries.Size += uint32(entry.Entry.NextOffset)
entries.Entrytable = append(entries.Entrytable, entry)
- meta.NumEntries++
+ info.NumEntries++
}
- nflog("convert to binary: finished with an marshalled size of %d", meta.Size)
- meta.Size = entries.Size
- return entries, meta, nil
+ nflog("convert to binary: finished with an marshalled size of %d", info.Size)
+ info.Size = entries.Size
+ return entries, info, nil
}
func marshalTarget(target stack.Target) []byte {
@@ -396,10 +342,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// TODO(gvisor.dev/issue/170): Support other tables.
var table stack.Table
switch replace.Name.String() {
- case stack.TablenameFilter:
+ case stack.FilterTable:
table = stack.EmptyFilterTable()
- case stack.TablenameNat:
- table = stack.EmptyNatTable()
+ case stack.NATTable:
+ table = stack.EmptyNATTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
return syserr.ErrInvalidArgument
@@ -485,6 +431,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
for hook, _ := range replace.HookEntry {
if table.ValidHooks()&(1<<hook) != 0 {
hk := hookFromLinux(hook)
+ table.BuiltinChains[hk] = stack.HookUnset
+ table.Underflows[hk] = stack.HookUnset
for offset, ruleIdx := range offsets {
if offset == replace.HookEntry[hook] {
table.BuiltinChains[hk] = ruleIdx
@@ -510,8 +458,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// Add the user chains.
for ruleIdx, rule := range table.Rules {
- target, ok := rule.Target.(stack.UserChainTarget)
- if !ok {
+ if _, ok := rule.Target.(stack.UserChainTarget); !ok {
continue
}
@@ -527,7 +474,6 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
nflog("user chain's first node must have no matchers")
return syserr.ErrInvalidArgument
}
- table.UserChains[target.Name] = ruleIdx + 1
}
// Set each jump to point to the appropriate rule. Right now they hold byte
@@ -553,7 +499,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
// make sure all other chains point to ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
- if hook == stack.Forward || hook == stack.Postrouting {
+ if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting {
+ if ruleIdx == stack.HookUnset {
+ continue
+ }
if !isUnconditionalAccept(table.Rules[ruleIdx]) {
nflog("hook %d is unsupported.", hook)
return syserr.ErrInvalidArgument
@@ -566,17 +515,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- ipt := stk.IPTables()
- table.SetMetadata(metadata{
- HookEntry: replace.HookEntry,
- Underflow: replace.Underflow,
- NumEntries: replace.NumEntries,
- Size: replace.Size,
- })
- ipt.Tables[replace.Name.String()] = table
- stk.SetIPTables(ipt)
-
- return nil
+ return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table))
}
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
@@ -737,6 +676,9 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize {
return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask))
}
+ if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask))
+ }
n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0)
if n == -1 {
@@ -755,6 +697,9 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
Dst: tcpip.Address(iptip.Dst[:]),
DstMask: tcpip.Address(iptip.DstMask[:]),
DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0,
+ Src: tcpip.Address(iptip.Src[:]),
+ SrcMask: tcpip.Address(iptip.SrcMask[:]),
+ SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0,
OutputInterface: ifname,
OutputInterfaceMask: ifnameMask,
OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0,
@@ -765,15 +710,13 @@ func containsUnsupportedFields(iptip linux.IPTIP) bool {
// The following features are supported:
// - Protocol
// - Dst and DstMask
+ // - Src and SrcMask
// - The inverse destination IP check flag
// - OutputInterface, OutputInterfaceMask and its inverse.
- var emptyInetAddr = linux.InetAddr{}
var emptyInterface = [linux.IFNAMSIZ]byte{}
// Disable any supported inverse flags.
- inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_VIA_OUT)
- return iptip.Src != emptyInetAddr ||
- iptip.SrcMask != emptyInetAddr ||
- iptip.InputInterface != emptyInterface ||
+ inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT)
+ return iptip.InputInterface != emptyInterface ||
iptip.InputInterfaceMask != emptyInterface ||
iptip.Flags != 0 ||
iptip.InverseFlags&^inverseMask != 0
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
index 3863293c7..1b4e0ad79 100644
--- a/pkg/sentry/socket/netfilter/owner_matcher.go
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -111,7 +111,7 @@ func (*OwnerMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (om *OwnerMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
+func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
// Support only for OUTPUT chain.
// TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also.
if hook != stack.Output {
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index 84abe8d29..b91ba3ab3 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -30,6 +30,6 @@ type JumpTarget struct {
}
// Action implements stack.Target.Action.
-func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrackTable, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
+func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
return stack.RuleJump, jt.RuleNum
}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 57a1e1c12..4f98ee2d5 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -96,7 +96,7 @@ func (*TCPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
+func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
netHeader := header.IPv4(pkt.NetworkHeader)
if netHeader.TransportProtocol() != header.TCPProtocolNumber {
@@ -111,36 +111,10 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa
return false, false
}
- // Now we need the transport header. However, this may not have been set
- // yet.
- // TODO(gvisor.dev/issue/170): Parsing the transport header should
- // ultimately be moved into the stack.Check codepath as matchers are
- // added.
- var tcpHeader header.TCP
- if pkt.TransportHeader != nil {
- tcpHeader = header.TCP(pkt.TransportHeader)
- } else {
- var length int
- if hook == stack.Prerouting {
- // The network header hasn't been parsed yet. We have to do it here.
- hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- // There's no valid TCP header here, so we hotdrop the
- // packet.
- return false, true
- }
- h := header.IPv4(hdr)
- pkt.NetworkHeader = hdr
- length = int(h.HeaderLength())
- }
- // The TCP header hasn't been parsed yet. We have to do it here.
- hdr, ok := pkt.Data.PullUp(length + header.TCPMinimumSize)
- if !ok {
- // There's no valid TCP header here, so we hotdrop the
- // packet.
- return false, true
- }
- tcpHeader = header.TCP(hdr[length:])
+ tcpHeader := header.TCP(pkt.TransportHeader)
+ if len(tcpHeader) < header.TCPMinimumSize {
+ // There's no valid TCP header here, so we drop the packet immediately.
+ return false, true
}
// Check whether the source and destination ports are within the
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index cfa9e621d..3f20fc891 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -93,7 +93,7 @@ func (*UDPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
+func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
netHeader := header.IPv4(pkt.NetworkHeader)
// TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
@@ -110,36 +110,10 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa
return false, false
}
- // Now we need the transport header. However, this may not have been set
- // yet.
- // TODO(gvisor.dev/issue/170): Parsing the transport header should
- // ultimately be moved into the stack.Check codepath as matchers are
- // added.
- var udpHeader header.UDP
- if pkt.TransportHeader != nil {
- udpHeader = header.UDP(pkt.TransportHeader)
- } else {
- var length int
- if hook == stack.Prerouting {
- // The network header hasn't been parsed yet. We have to do it here.
- hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- // There's no valid UDP header here, so we hotdrop the
- // packet.
- return false, true
- }
- h := header.IPv4(hdr)
- pkt.NetworkHeader = hdr
- length = int(h.HeaderLength())
- }
- // The UDP header hasn't been parsed yet. We have to do it here.
- hdr, ok := pkt.Data.PullUp(length + header.UDPMinimumSize)
- if !ok {
- // There's no valid UDP header here, so we hotdrop the
- // packet.
- return false, true
- }
- udpHeader = header.UDP(hdr[length:])
+ udpHeader := header.UDP(pkt.TransportHeader)
+ if len(udpHeader) < header.UDPMinimumSize {
+ // There's no valid UDP header here, so we drop the packet immediately.
+ return false, true
}
// Check whether the source and destination ports are within the
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 7212d8644..0546801bf 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -20,6 +20,7 @@ go_library(
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
@@ -35,6 +36,8 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 81f34c5a2..98ca7add0 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -38,6 +38,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const sizeOfInt32 int = 4
@@ -330,7 +332,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
switch name {
@@ -340,24 +342,26 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
}
s.mu.Lock()
defer s.mu.Unlock()
- return int32(s.sendBufferSize), nil
+ sendBufferSizeP := primitive.Int32(s.sendBufferSize)
+ return &sendBufferSizeP, nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
// We don't have limit on receiving size.
- return int32(math.MaxInt32), nil
+ recvBufferSizeP := primitive.Int32(math.MaxInt32)
+ return &recvBufferSizeP, nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var passcred int32
+ var passcred primitive.Int32
if s.Passcred() {
passcred = 1
}
- return passcred, nil
+ return &passcred, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go
index b854bf990..dbcd8b49a 100644
--- a/pkg/sentry/socket/netlink/socket_vfs2.go
+++ b/pkg/sentry/socket/netlink/socket_vfs2.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
@@ -40,6 +41,7 @@ type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
socketOpsCommon
}
@@ -66,7 +68,7 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV
return nil, err
}
- return &SocketVFS2{
+ fd := &SocketVFS2{
socketOpsCommon: socketOpsCommon{
ports: t.Kernel().NetlinkPorts(),
protocol: protocol,
@@ -75,7 +77,9 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV
connection: connection,
sendBufferSize: defaultSendBufferSize,
},
- }, nil
+ }
+ fd.LockFD.Init(&vfs.FileLocks{})
+ return fd, nil
}
// Readiness implements waiter.Waitable.Readiness.
@@ -136,3 +140,13 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{})
return int64(n), err.ToError()
}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 333e0042e..1fb777a6c 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -28,6 +28,7 @@ go_library(
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
@@ -50,5 +51,8 @@ go_library(
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 60df51dae..31a168f7e 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -26,6 +26,7 @@ package netstack
import (
"bytes"
+ "fmt"
"io"
"math"
"reflect"
@@ -33,6 +34,7 @@ import (
"syscall"
"time"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/binary"
@@ -60,6 +62,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
func mustCreateMetric(name, description string) *tcpip.StatCounter {
@@ -190,6 +194,8 @@ var Metrics = tcpip.Stats{
MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."),
PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."),
PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."),
+ ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."),
+ InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."),
},
}
@@ -294,8 +300,9 @@ type socketOpsCommon struct {
readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
- readCM tcpip.ControlMessages
- sender tcpip.FullAddress
+ readCM tcpip.ControlMessages
+ sender tcpip.FullAddress
+ linkPacketInfo tcpip.LinkPacketInfo
// sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps
// of returned messages can be returned via control messages. When
@@ -416,7 +423,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- // TODO(b/129292371): Return protocol too.
+ // TODO(gvisor.dev/issue/173): Return protocol too.
return tcpip.FullAddress{
NIC: tcpip.NICID(a.InterfaceIndex),
Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
@@ -444,8 +451,21 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error {
}
s.readView = nil
s.sender = tcpip.FullAddress{}
+ s.linkPacketInfo = tcpip.LinkPacketInfo{}
- v, cms, err := s.Endpoint.Read(&s.sender)
+ var v buffer.View
+ var cms tcpip.ControlMessages
+ var err *tcpip.Error
+
+ switch e := s.Endpoint.(type) {
+ // The ordering of these interfaces matters. The most specific
+ // interfaces must be specified before the more generic Endpoint
+ // interface.
+ case tcpip.PacketEndpoint:
+ v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo)
+ case tcpip.Endpoint:
+ v, cms, err = e.Read(&s.sender)
+ }
if err != nil {
atomic.StoreUint32(&s.readViewHasData, 0)
return syserr.TranslateNetstackError(err)
@@ -719,6 +739,14 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
defer s.EventUnregister(&e)
if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting {
+ if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM {
+ // TCP unlike UDP returns EADDRNOTAVAIL when it can't
+ // find an available local ephemeral port.
+ if err == tcpip.ErrNoPortAvailable {
+ return syserr.ErrAddressNotAvailable
+ }
+ }
+
return syserr.TranslateNetstackError(err)
}
@@ -884,7 +912,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketOperations rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -894,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptTimestamp {
val = 1
}
- return val, nil
+ return &val, nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptInq {
val = 1
}
- return val, nil
+ return &val, nil
}
if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
@@ -930,7 +958,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if err != nil {
return nil, err
}
- return info, nil
+ return &info, nil
case linux.IPT_SO_GET_ENTRIES:
if outLen < linux.SizeOfIPTGetEntries {
@@ -945,7 +973,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if err != nil {
return nil, err
}
- return entries, nil
+ return &entries, nil
}
}
@@ -955,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -988,7 +1016,7 @@ func boolToInt32(v bool) int32 {
}
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -999,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
// Get the last error and convert it.
err := ep.GetSockOpt(tcpip.ErrorOption{})
if err == nil {
- return int32(0), nil
+ optP := primitive.Int32(0)
+ return &optP, nil
}
- return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil
+
+ optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number())
+ return &optP, nil
case linux.SO_PEERCRED:
if family != linux.AF_UNIX || outLen < syscall.SizeofUcred {
@@ -1009,11 +1040,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
tcred := t.Credentials()
- return syscall.Ucred{
- Pid: int32(t.ThreadGroup().ID()),
- Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
- Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
- }, nil
+ creds := linux.ControlMessageCredentials{
+ PID: int32(t.ThreadGroup().ID()),
+ UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
+ GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
+ }
+ return &creds, nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
@@ -1024,7 +1056,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_SNDBUF:
if outLen < sizeOfInt32 {
@@ -1040,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
size = math.MaxInt32
}
- return int32(size), nil
+ sizeP := primitive.Int32(size)
+ return &sizeP, nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
@@ -1056,7 +1091,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
size = math.MaxInt32
}
- return int32(size), nil
+ sizeP := primitive.Int32(size)
+ return &sizeP, nil
case linux.SO_REUSEADDR:
if outLen < sizeOfInt32 {
@@ -1067,7 +1103,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
@@ -1078,7 +1115,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
@@ -1086,7 +1125,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.TranslateNetstackError(err)
}
if v == 0 {
- return []byte{}, nil
+ var b primitive.ByteSlice
+ return &b, nil
}
if outLen < linux.IFNAMSIZ {
return nil, syserr.ErrInvalidArgument
@@ -1101,7 +1141,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
// interface was removed.
return nil, syserr.ErrUnknownDevice
}
- return append([]byte(nic.Name), 0), nil
+
+ name := primitive.ByteSlice(append([]byte(nic.Name), 0))
+ return &name, nil
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
@@ -1112,7 +1154,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_KEEPALIVE:
if outLen < sizeOfInt32 {
@@ -1123,13 +1167,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_LINGER:
if outLen < linux.SizeOfLinger {
return nil, syserr.ErrInvalidArgument
}
- return linux.Linger{}, nil
+
+ linger := linux.Linger{}
+ return &linger, nil
case linux.SO_SNDTIMEO:
// TODO(igudger): Linux allows shorter lengths for partial results.
@@ -1137,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- return linux.NsecToTimeval(s.SendTimeout()), nil
+ sendTimeout := linux.NsecToTimeval(s.SendTimeout())
+ return &sendTimeout, nil
case linux.SO_RCVTIMEO:
// TODO(igudger): Linux allows shorter lengths for partial results.
@@ -1145,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- return linux.NsecToTimeval(s.RecvTimeout()), nil
+ recvTimeout := linux.NsecToTimeval(s.RecvTimeout())
+ return &recvTimeout, nil
case linux.SO_OOBINLINE:
if outLen < sizeOfInt32 {
@@ -1157,7 +1207,20 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
+
+ case linux.SO_NO_CHECK:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.NoChecksumOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
@@ -1166,7 +1229,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
-func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.TCP_NODELAY:
if outLen < sizeOfInt32 {
@@ -1177,7 +1240,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(!v), nil
+
+ vP := primitive.Int32(boolToInt32(!v))
+ return &vP, nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
@@ -1188,7 +1253,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
@@ -1199,7 +1266,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
@@ -1210,8 +1279,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_KEEPIDLE:
if outLen < sizeOfInt32 {
@@ -1222,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Second), nil
+ keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second)
+ return &keepAliveIdle, nil
case linux.TCP_KEEPINTVL:
if outLen < sizeOfInt32 {
@@ -1234,8 +1303,20 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
+ keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second)
+ return &keepAliveInterval, nil
- return int32(time.Duration(v) / time.Second), nil
+ case linux.TCP_KEEPCNT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.KeepaliveCountOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_USER_TIMEOUT:
if outLen < sizeOfInt32 {
@@ -1246,8 +1327,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Millisecond), nil
+ tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond)
+ return &tcpUserTimeout, nil
case linux.TCP_INFO:
var v tcpip.TCPInfoOption
@@ -1260,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
info := linux.TCPInfo{}
// Linux truncates the output binary to outLen.
- ib := binary.Marshal(nil, usermem.ByteOrder, &info)
- if len(ib) > outLen {
- ib = ib[:outLen]
+ buf := t.CopyScratchBuffer(info.SizeBytes())
+ info.MarshalUnsafe(buf)
+ if len(buf) > outLen {
+ buf = buf[:outLen]
}
-
- return ib, nil
+ bufP := primitive.ByteSlice(buf)
+ return &bufP, nil
case linux.TCP_CC_INFO,
linux.TCP_NOTSENT_LOWAT,
@@ -1295,7 +1377,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
}
b := make([]byte, toCopy)
copy(b, v)
- return b, nil
+
+ bP := primitive.ByteSlice(b)
+ return &bP, nil
case linux.TCP_LINGER2:
if outLen < sizeOfInt32 {
@@ -1307,7 +1391,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.TranslateNetstackError(err)
}
- return int32(time.Duration(v) / time.Second), nil
+ lingerTimeout := primitive.Int32(time.Duration(v) / time.Second)
+ return &lingerTimeout, nil
case linux.TCP_DEFER_ACCEPT:
if outLen < sizeOfInt32 {
@@ -1319,7 +1404,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.TranslateNetstackError(err)
}
- return int32(time.Duration(v) / time.Second), nil
+ tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second)
+ return &tcpDeferAccept, nil
case linux.TCP_SYNCNT:
if outLen < sizeOfInt32 {
@@ -1330,8 +1416,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_WINDOW_CLAMP:
if outLen < sizeOfInt32 {
@@ -1342,8 +1428,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
default:
emitUnimplementedEventTCP(t, name)
}
@@ -1351,7 +1437,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
}
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
-func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
@@ -1362,7 +1448,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1370,21 +1458,24 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
case linux.IPV6_TCLASS:
// Length handling for parity with Linux.
if outLen == 0 {
- return make([]byte, 0), nil
+ var b primitive.ByteSlice
+ return &b, nil
}
v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- uintv := uint32(v)
+ uintv := primitive.Uint32(v)
// Linux truncates the output binary to outLen.
- ib := binary.Marshal(nil, usermem.ByteOrder, &uintv)
+ ib := t.CopyScratchBuffer(uintv.SizeBytes())
+ uintv.MarshalUnsafe(ib)
// Handle cases where outLen is lesser than sizeOfInt32.
if len(ib) > outLen {
ib = ib[:outLen]
}
- return ib, nil
+ ibP := primitive.ByteSlice(ib)
+ return &ibP, nil
case linux.IPV6_RECVTCLASS:
if outLen < sizeOfInt32 {
@@ -1395,7 +1486,13 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
+
+ case linux.SO_ORIGINAL_DST:
+ // TODO(gvisor.dev/issue/170): ip6tables.
+ return nil, syserr.ErrInvalidArgument
default:
emitUnimplementedEventIPv6(t, name)
@@ -1404,7 +1501,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IP_TTL:
if outLen < sizeOfInt32 {
@@ -1417,11 +1514,12 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
}
// Fill in the default value, if needed.
- if v == 0 {
- v = DefaultTTL
+ vP := primitive.Int32(v)
+ if vP == 0 {
+ vP = DefaultTTL
}
- return int32(v), nil
+ return &vP, nil
case linux.IP_MULTICAST_TTL:
if outLen < sizeOfInt32 {
@@ -1433,7 +1531,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.IP_MULTICAST_IF:
if outLen < len(linux.InetAddr{}) {
@@ -1447,7 +1546,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
- return a.(*linux.SockAddrInet).Addr, nil
+ return &a.(*linux.SockAddrInet).Addr, nil
case linux.IP_MULTICAST_LOOP:
if outLen < sizeOfInt32 {
@@ -1458,21 +1557,26 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IP_TOS:
// Length handling for parity with Linux.
if outLen == 0 {
- return []byte(nil), nil
+ var b primitive.ByteSlice
+ return &b, nil
}
v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
if outLen < sizeOfInt32 {
- return uint8(v), nil
+ vP := primitive.Uint8(v)
+ return &vP, nil
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.IP_RECVTOS:
if outLen < sizeOfInt32 {
@@ -1483,7 +1587,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
@@ -1494,7 +1600,22 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
+
+ case linux.SO_ORIGINAL_DST:
+ if outLen < int(binary.Size(linux.SockAddrInet{})) {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.OriginalDestinationOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v))
+ return a.(*linux.SockAddrInet), nil
default:
emitUnimplementedEventIP(t, name)
@@ -1698,6 +1819,14 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v)))
+ case linux.SO_NO_CHECK:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0))
+
case linux.SO_LINGER:
if len(optVal) < linux.SizeOfLinger {
return syserr.ErrInvalidArgument
@@ -1712,6 +1841,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return nil
+ case linux.SO_DETACH_FILTER:
+ // optval is ignored.
+ var v tcpip.SocketDetachFilterOption
+ return syserr.TranslateNetstackError(ep.SetSockOpt(v))
+
default:
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
@@ -1777,6 +1911,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v))))
+ case linux.TCP_KEEPCNT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ if v < 1 || v > linux.MAX_TCP_KEEPCNT {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.KeepaliveCountOption, int(v)))
+
case linux.TCP_USER_TIMEOUT:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2060,13 +2205,22 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
}
return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+ case linux.IP_HDRINCL:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0))
+
case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
linux.IP_BLOCK_SOURCE,
linux.IP_CHECKSUM,
linux.IP_DROP_SOURCE_MEMBERSHIP,
linux.IP_FREEBIND,
- linux.IP_HDRINCL,
linux.IP_IPSEC_POLICY,
linux.IP_MINTTL,
linux.IP_MSFILTER,
@@ -2106,30 +2260,20 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) {
switch name {
case linux.TCP_CONGESTION,
linux.TCP_CORK,
- linux.TCP_DEFER_ACCEPT,
linux.TCP_FASTOPEN,
linux.TCP_FASTOPEN_CONNECT,
linux.TCP_FASTOPEN_KEY,
linux.TCP_FASTOPEN_NO_COOKIE,
- linux.TCP_KEEPCNT,
- linux.TCP_KEEPIDLE,
- linux.TCP_KEEPINTVL,
- linux.TCP_LINGER2,
- linux.TCP_MAXSEG,
linux.TCP_QUEUE_SEQ,
- linux.TCP_QUICKACK,
linux.TCP_REPAIR,
linux.TCP_REPAIR_QUEUE,
linux.TCP_REPAIR_WINDOW,
linux.TCP_SAVED_SYN,
linux.TCP_SAVE_SYN,
- linux.TCP_SYNCNT,
linux.TCP_THIN_DUPACK,
linux.TCP_THIN_LINEAR_TIMEOUTS,
linux.TCP_TIMESTAMP,
- linux.TCP_ULP,
- linux.TCP_USER_TIMEOUT,
- linux.TCP_WINDOW_CLAMP:
+ linux.TCP_ULP:
t.Kernel().EmitUnimplementedEvent(t)
}
@@ -2291,7 +2435,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
return &out, uint32(sockAddrInet6Size)
case linux.AF_PACKET:
- // TODO(b/129292371): Return protocol too.
+ // TODO(gvisor.dev/issue/173): Return protocol too.
var out linux.SockAddrLink
out.Family = linux.AF_PACKET
out.InterfaceIndex = int32(addr.NIC)
@@ -2397,6 +2541,23 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
}
+func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
+ switch pktType {
+ case tcpip.PacketHost:
+ return linux.PACKET_HOST
+ case tcpip.PacketOtherHost:
+ return linux.PACKET_OTHERHOST
+ case tcpip.PacketOutgoing:
+ return linux.PACKET_OUTGOING
+ case tcpip.PacketBroadcast:
+ return linux.PACKET_BROADCAST
+ case tcpip.PacketMulticast:
+ return linux.PACKET_MULTICAST
+ default:
+ panic(fmt.Sprintf("unknown packet type: %d", pktType))
+ }
+}
+
// nonBlockingRead issues a non-blocking read.
//
// TODO(b/78348848): Support timestamps for stream sockets.
@@ -2452,6 +2613,11 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
var addrLen uint32
if isPacket && senderRequested {
addr, addrLen = ConvertAddress(s.family, s.sender)
+ switch v := addr.(type) {
+ case *linux.SockAddrLink:
+ v.Protocol = htons(uint16(s.linkPacketInfo.Protocol))
+ v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
+ }
}
if peek {
@@ -2686,11 +2852,16 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
}
func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("ioctl(2) may only be called from a task goroutine")
+ }
+
// SIOCGSTAMP is implemented by netstack rather than all commonEndpoint
// sockets.
// TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP.
switch args[1].Int() {
- case syscall.SIOCGSTAMP:
+ case linux.SIOCGSTAMP:
s.readMu.Lock()
defer s.readMu.Unlock()
if !s.timestampValid {
@@ -2698,9 +2869,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
}
tv := linux.NsecToTimeval(s.timestampNS)
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := tv.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCINQ:
@@ -2719,9 +2888,8 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
}
// Copy result to userspace.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ vP := primitive.Int32(v)
+ _, err := vP.CopyOut(t, args[2].Pointer())
return 0, err
}
@@ -2730,52 +2898,49 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
// Ioctl performs a socket ioctl.
func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("ioctl(2) may only be called from a task goroutine")
+ }
+
switch arg := int(args[1].Int()); arg {
- case syscall.SIOCGIFFLAGS,
- syscall.SIOCGIFADDR,
- syscall.SIOCGIFBRDADDR,
- syscall.SIOCGIFDSTADDR,
- syscall.SIOCGIFHWADDR,
- syscall.SIOCGIFINDEX,
- syscall.SIOCGIFMAP,
- syscall.SIOCGIFMETRIC,
- syscall.SIOCGIFMTU,
- syscall.SIOCGIFNAME,
- syscall.SIOCGIFNETMASK,
- syscall.SIOCGIFTXQLEN:
+ case linux.SIOCGIFFLAGS,
+ linux.SIOCGIFADDR,
+ linux.SIOCGIFBRDADDR,
+ linux.SIOCGIFDSTADDR,
+ linux.SIOCGIFHWADDR,
+ linux.SIOCGIFINDEX,
+ linux.SIOCGIFMAP,
+ linux.SIOCGIFMETRIC,
+ linux.SIOCGIFMTU,
+ linux.SIOCGIFNAME,
+ linux.SIOCGIFNETMASK,
+ linux.SIOCGIFTXQLEN,
+ linux.SIOCETHTOOL:
var ifr linux.IFReq
- if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
+ if _, err := ifr.CopyIn(t, args[2].Pointer()); err != nil {
return 0, err
}
if err := interfaceIoctl(ctx, io, arg, &ifr); err != nil {
return 0, err.ToError()
}
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := ifr.CopyOut(t, args[2].Pointer())
return 0, err
- case syscall.SIOCGIFCONF:
+ case linux.SIOCGIFCONF:
// Return a list of interface addresses or the buffer size
// necessary to hold the list.
var ifc linux.IFConf
- if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
+ if _, err := ifc.CopyIn(t, args[2].Pointer()); err != nil {
return 0, err
}
- if err := ifconfIoctl(ctx, io, &ifc); err != nil {
+ if err := ifconfIoctl(ctx, t, io, &ifc); err != nil {
return 0, err
}
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{
- AddressSpaceActive: true,
- })
-
+ _, err := ifc.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCINQ:
@@ -2788,9 +2953,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
v = math.MaxInt32
}
// Copy result to userspace.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ vP := primitive.Int32(v)
+ _, err := vP.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCOUTQ:
@@ -2804,9 +2968,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
}
// Copy result to userspace.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ vP := primitive.Int32(v)
+ _, err := vP.CopyOut(t, args[2].Pointer())
return 0, err
case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG:
@@ -2832,7 +2995,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
// SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to
// identify a device.
- if arg == syscall.SIOCGIFNAME {
+ if arg == linux.SIOCGIFNAME {
// Gets the name of the interface given the interface index
// stored in ifr_ifindex.
index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4]))
@@ -2855,21 +3018,28 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
switch arg {
- case syscall.SIOCGIFINDEX:
+ case linux.SIOCGIFINDEX:
// Copy out the index to the data.
usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index))
- case syscall.SIOCGIFHWADDR:
+ case linux.SIOCGIFHWADDR:
// Copy the hardware address out.
- ifr.Data[0] = 6 // IEEE802.2 arp type.
- ifr.Data[1] = 0
+ //
+ // Refer: https://linux.die.net/man/7/netdevice
+ // SIOCGIFHWADDR, SIOCSIFHWADDR
+ //
+ // Get or set the hardware address of a device using
+ // ifr_hwaddr. The hardware address is specified in a struct
+ // sockaddr. sa_family contains the ARPHRD_* device type,
+ // sa_data the L2 hardware address starting from byte 0. Setting
+ // the hardware address is a privileged operation.
+ usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType)
n := copy(ifr.Data[2:], iface.Addr)
for i := 2 + n; i < len(ifr.Data); i++ {
ifr.Data[i] = 0 // Clear padding.
}
- usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n))
- case syscall.SIOCGIFFLAGS:
+ case linux.SIOCGIFFLAGS:
f, err := interfaceStatusFlags(stack, iface.Name)
if err != nil {
return err
@@ -2878,7 +3048,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
// matches Linux behavior.
usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f))
- case syscall.SIOCGIFADDR:
+ case linux.SIOCGIFADDR:
// Copy the IPv4 address out.
for _, addr := range stack.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
@@ -2889,32 +3059,32 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
break
}
- case syscall.SIOCGIFMETRIC:
+ case linux.SIOCGIFMETRIC:
// Gets the metric of the device. As per netdevice(7), this
// always just sets ifr_metric to 0.
usermem.ByteOrder.PutUint32(ifr.Data[:4], 0)
- case syscall.SIOCGIFMTU:
+ case linux.SIOCGIFMTU:
// Gets the MTU of the device.
usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU)
- case syscall.SIOCGIFMAP:
+ case linux.SIOCGIFMAP:
// Gets the hardware parameters of the device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFTXQLEN:
+ case linux.SIOCGIFTXQLEN:
// Gets the transmit queue length of the device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFDSTADDR:
+ case linux.SIOCGIFDSTADDR:
// Gets the destination address of a point-to-point device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFBRDADDR:
+ case linux.SIOCGIFBRDADDR:
// Gets the broadcast address of a device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFNETMASK:
+ case linux.SIOCGIFNETMASK:
// Gets the network mask of a device.
for _, addr := range stack.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
@@ -2931,6 +3101,14 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
break
}
+ case linux.SIOCETHTOOL:
+ // Stubbed out for now, Ideally we should implement the required
+ // sub-commands for ETHTOOL
+ //
+ // See:
+ // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/net/core/dev_ioctl.c
+ return syserr.ErrEndpointOperation
+
default:
// Not a valid call.
return syserr.ErrInvalidArgument
@@ -2940,7 +3118,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
// ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl.
-func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
+func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error {
// If Ptr is NULL, return the necessary buffer size via Len.
// Otherwise, write up to Len bytes starting at Ptr containing ifreq
// structs.
@@ -2977,9 +3155,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
// Copy the ifr to userspace.
dst := uintptr(ifc.Ptr) + uintptr(ifc.Len)
ifc.Len += int32(linux.SizeOfIFReq)
- if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
+ if _, err := ifr.CopyOut(t, usermem.Addr(dst)); err != nil {
return err
}
}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index fcd8013c0..a9025b0ec 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -30,6 +31,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// SocketVFS2 encapsulates all the state needed to represent a network stack
@@ -38,6 +41,7 @@ type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
socketOpsCommon
}
@@ -64,6 +68,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu
protocol: protocol,
},
}
+ s.LockFD.Init(&vfs.FileLocks{})
vfsfd := &s.vfsfd
if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{
DenyPRead: true,
@@ -197,7 +202,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketVFS2 rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -207,25 +212,25 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptTimestamp {
val = 1
}
- return val, nil
+ return &val, nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptInq {
val = 1
}
- return val, nil
+ return &val, nil
}
if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
@@ -243,7 +248,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if err != nil {
return nil, err
}
- return info, nil
+ return &info, nil
case linux.IPT_SO_GET_ENTRIES:
if outLen < linux.SizeOfIPTGetEntries {
@@ -258,7 +263,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if err != nil {
return nil, err
}
- return entries, nil
+ return &entries, nil
}
}
@@ -315,3 +320,13 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by
return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index f5fa18136..67737ae87 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -15,10 +15,11 @@
package netstack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/inet"
- "gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -41,19 +42,29 @@ func (s *Stack) SupportsIPv6() bool {
return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber)
}
+// Converts Netstack's ARPHardwareType to equivalent linux constants.
+func toLinuxARPHardwareType(t header.ARPHardwareType) uint16 {
+ switch t {
+ case header.ARPHardwareNone:
+ return linux.ARPHRD_NONE
+ case header.ARPHardwareLoopback:
+ return linux.ARPHRD_LOOPBACK
+ case header.ARPHardwareEther:
+ return linux.ARPHRD_ETHER
+ default:
+ panic(fmt.Sprintf("unknown ARPHRD type: %d", t))
+ }
+}
+
// Interfaces implements inet.Stack.Interfaces.
func (s *Stack) Interfaces() map[int32]inet.Interface {
is := make(map[int32]inet.Interface)
for id, ni := range s.Stack.NICInfo() {
- var devType uint16
- if ni.Flags.Loopback {
- devType = linux.ARPHRD_LOOPBACK
- }
is[int32(id)] = inet.Interface{
Name: ni.Name,
Addr: []byte(ni.LinkAddress),
Flags: uint32(nicStateFlagsToLinux(ni.Flags)),
- DeviceType: devType,
+ DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType),
MTU: ni.MTU,
}
}
@@ -314,7 +325,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
udp.PacketsSent.Value(), // OutDatagrams.
udp.ReceiveBufferErrors.Value(), // RcvbufErrors.
0, // Udp/SndbufErrors.
- 0, // Udp/InCsumErrors.
+ udp.ChecksumErrors.Value(), // Udp/InCsumErrors.
0, // Udp/IgnoredMulti.
}
default:
@@ -362,16 +373,10 @@ func (s *Stack) RouteTable() []inet.Route {
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() (stack.IPTables, error) {
+func (s *Stack) IPTables() (*stack.IPTables, error) {
return s.Stack.IPTables(), nil
}
-// FillDefaultIPTables sets the stack's iptables to the default tables, which
-// allow and do not modify all traffic.
-func (s *Stack) FillDefaultIPTables() {
- netfilter.FillDefaultIPTables(s.Stack)
-}
-
// Resume implements inet.Stack.Resume.
func (s *Stack) Resume() {
s.Stack.Resume()
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 6580bd6e9..d112757fb 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// ControlMessages represents the union of unix control messages and tcpip
@@ -86,7 +87,7 @@ type SocketOps interface {
Shutdown(t *kernel.Task, how int) *syserr.Error
// GetSockOpt implements the getsockopt(2) linux syscall.
- GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error)
+ GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error)
// SetSockOpt implements the setsockopt(2) linux syscall.
SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error
@@ -407,7 +408,6 @@ func emitUnimplementedEvent(t *kernel.Task, name int) {
linux.SO_MARK,
linux.SO_MAX_PACING_RATE,
linux.SO_NOFCS,
- linux.SO_NO_CHECK,
linux.SO_OOBINLINE,
linux.SO_PASSCRED,
linux.SO_PASSSEC,
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index de2cc4bdf..061a689a9 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -21,6 +21,7 @@ go_library(
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
@@ -34,5 +35,6 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index ce5b94ee7..a1e49cc57 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -252,7 +252,7 @@ func (e *connectionedEndpoint) Close() {
// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
if ce.Type() != e.stype {
- return syserr.ErrConnectionRefused
+ return syserr.ErrWrongProtocolForSocket
}
// Check if ce is e to avoid a deadlock.
@@ -476,6 +476,9 @@ func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask
// State implements socket.Socket.State.
func (e *connectionedEndpoint) State() uint32 {
+ e.Lock()
+ defer e.Unlock()
+
if e.Connected() {
return linux.SS_CONNECTED
}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 5b29e9d7f..0482d33cf 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -40,6 +40,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketOperations is a Unix socket. It is similar to a netstack socket,
@@ -184,7 +185,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
@@ -417,7 +418,18 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
defer ep.Release()
// Connect the server endpoint.
- return s.ep.Connect(t, ep)
+ err = s.ep.Connect(t, ep)
+
+ if err == syserr.ErrWrongProtocolForSocket {
+ // Linux for abstract sockets returns ErrConnectionRefused
+ // instead of ErrWrongProtocolForSocket.
+ path, _ := extractPath(sockaddr)
+ if len(path) > 0 && path[0] == 0 {
+ err = syserr.ErrConnectionRefused
+ }
+ }
+
+ return err
}
// Write implements fs.FileOperations.Write.
@@ -448,15 +460,25 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
To: nil,
}
if len(to) > 0 {
- ep, err := extractEndpoint(t, to)
- if err != nil {
- return 0, err
- }
- defer ep.Release()
- w.To = ep
+ switch s.stype {
+ case linux.SOCK_SEQPACKET:
+ to = nil
+ case linux.SOCK_STREAM:
+ if s.State() == linux.SS_CONNECTED {
+ return 0, syserr.ErrAlreadyConnected
+ }
+ return 0, syserr.ErrNotSupported
+ default:
+ ep, err := extractEndpoint(t, to)
+ if err != nil {
+ return 0, err
+ }
+ defer ep.Release()
+ w.To = ep
- if ep.Passcred() && w.Control.Credentials == nil {
- w.Control.Credentials = control.MakeCreds(t)
+ if ep.Passcred() && w.Control.Credentials == nil {
+ w.Control.Credentials = control.MakeCreds(t)
+ }
}
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 45e109361..05c16fcfe 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket"
@@ -31,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketVFS2 implements socket.SocketVFS2 (and by extension,
@@ -39,6 +41,7 @@ type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
socketOpsCommon
}
@@ -51,7 +54,7 @@ func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType)
mnt := t.Kernel().SocketMount()
d := sockfs.NewDentry(t.Credentials(), mnt)
- fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d)
+ fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{})
if err != nil {
return nil, syserr.FromError(err)
}
@@ -60,7 +63,7 @@ func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType)
// NewFileDescription creates and returns a socket file description
// corresponding to the given mount and dentry.
-func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry) (*vfs.FileDescription, error) {
+func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry, locks *vfs.FileLocks) (*vfs.FileDescription, error) {
// You can create AF_UNIX, SOCK_RAW sockets. They're the same as
// SOCK_DGRAM and don't require CAP_NET_RAW.
if stype == linux.SOCK_RAW {
@@ -73,6 +76,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
stype: stype,
},
}
+ sock.LockFD.Init(locks)
vfsfd := &sock.vfsfd
if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
DenyPRead: true,
@@ -86,7 +90,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
@@ -297,6 +301,16 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by
return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
}
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
+
// providerVFS2 is a unix domain socket provider for VFS2.
type providerVFS2 struct{}
diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go
index a6e48b836..5d51a7792 100644
--- a/pkg/sentry/strace/epoll.go
+++ b/pkg/sentry/strace/epoll.go
@@ -50,10 +50,10 @@ func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes ui
sb.WriteString("...")
break
}
- if _, ok := addr.AddLength(uint64(linux.SizeOfEpollEvent)); !ok {
- fmt.Fprintf(&sb, "{error reading event at %#x: EFAULT}", addr)
- continue
- }
+ // Allowing addr to overflow is consistent with Linux, and harmless; if
+ // this isn't the last iteration of the loop, the next call to CopyIn
+ // will just fail with EFAULT.
+ addr, _ = addr.AddLength(uint64(linux.SizeOfEpollEvent))
}
sb.WriteString("}")
return sb.String()
@@ -75,7 +75,7 @@ var epollEventEvents = abi.FlagSet{
{Flag: linux.EPOLLPRI, Name: "EPOLLPRI"},
{Flag: linux.EPOLLOUT, Name: "EPOLLOUT"},
{Flag: linux.EPOLLERR, Name: "EPOLLERR"},
- {Flag: linux.EPOLLHUP, Name: "EPULLHUP"},
+ {Flag: linux.EPOLLHUP, Name: "EPOLLHUP"},
{Flag: linux.EPOLLRDNORM, Name: "EPOLLRDNORM"},
{Flag: linux.EPOLLRDBAND, Name: "EPOLLRDBAND"},
{Flag: linux.EPOLLWRNORM, Name: "EPOLLWRNORM"},
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index c0512de89..b51c4c941 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -521,6 +521,7 @@ var sockOptNames = map[uint64]abi.ValueSet{
linux.IP_ROUTER_ALERT: "IP_ROUTER_ALERT",
linux.IP_PKTOPTIONS: "IP_PKTOPTIONS",
linux.IP_MTU: "IP_MTU",
+ linux.SO_ORIGINAL_DST: "SO_ORIGINAL_DST",
},
linux.SOL_SOCKET: {
linux.SO_ERROR: "SO_ERROR",
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 217fcfef2..4a9b04fd0 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -99,5 +99,7 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index ea4f9b1a7..80c65164a 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -325,8 +325,8 @@ var AMD64 = &kernel.SyscallTable{
270: syscalls.Supported("pselect", Pselect),
271: syscalls.Supported("ppoll", Ppoll),
272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
- 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
- 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 273: syscalls.Supported("set_robust_list", SetRobustList),
+ 274: syscalls.Supported("get_robust_list", GetRobustList),
275: syscalls.Supported("splice", Splice),
276: syscalls.Supported("tee", Tee),
277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
index d781d6a04..ba2557c52 100644
--- a/pkg/sentry/syscalls/linux/sys_aio.go
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -15,8 +15,8 @@
package linux
import (
- "encoding/binary"
-
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -27,59 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// I/O commands.
-const (
- _IOCB_CMD_PREAD = 0
- _IOCB_CMD_PWRITE = 1
- _IOCB_CMD_FSYNC = 2
- _IOCB_CMD_FDSYNC = 3
- _IOCB_CMD_NOOP = 6
- _IOCB_CMD_PREADV = 7
- _IOCB_CMD_PWRITEV = 8
-)
-
-// I/O flags.
-const (
- _IOCB_FLAG_RESFD = 1
-)
-
-// ioCallback describes an I/O request.
-//
-// The priority field is currently ignored in the implementation below. Also
-// note that the IOCB_FLAG_RESFD feature is not supported.
-type ioCallback struct {
- Data uint64
- Key uint32
- Reserved1 uint32
-
- OpCode uint16
- ReqPrio int16
- FD int32
-
- Buf uint64
- Bytes uint64
- Offset int64
-
- Reserved2 uint64
- Flags uint32
-
- // eventfd to signal if IOCB_FLAG_RESFD is set in flags.
- ResFD int32
-}
-
-// ioEvent describes an I/O result.
-//
-// +stateify savable
-type ioEvent struct {
- Data uint64
- Obj uint64
- Result int64
- Result2 int64
-}
-
-// ioEventSize is the size of an ioEvent encoded.
-var ioEventSize = binary.Size(ioEvent{})
-
// IoSetup implements linux syscall io_setup(2).
func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
nrEvents := args[0].Int()
@@ -192,7 +139,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
}
}
- ev := v.(*ioEvent)
+ ev := v.(*linux.IOEvent)
// Copy out the result.
if _, err := t.CopyOut(eventsAddr, ev); err != nil {
@@ -204,7 +151,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
}
// Keep rolling.
- eventsAddr += usermem.Addr(ioEventSize)
+ eventsAddr += usermem.Addr(linux.IOEventSize)
}
// Everything finished.
@@ -231,7 +178,7 @@ func waitForRequest(ctx *mm.AIOContext, t *kernel.Task, haveDeadline bool, deadl
}
// memoryFor returns appropriate memory for the given callback.
-func memoryFor(t *kernel.Task, cb *ioCallback) (usermem.IOSequence, error) {
+func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) {
bytes := int(cb.Bytes)
if bytes < 0 {
// Linux also requires that this field fit in ssize_t.
@@ -242,17 +189,17 @@ func memoryFor(t *kernel.Task, cb *ioCallback) (usermem.IOSequence, error) {
// we have no guarantee that t's AddressSpace will be active during the
// I/O.
switch cb.OpCode {
- case _IOCB_CMD_PREAD, _IOCB_CMD_PWRITE:
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PWRITE:
return t.SingleIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
AddressSpaceActive: false,
})
- case _IOCB_CMD_PREADV, _IOCB_CMD_PWRITEV:
+ case linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITEV:
return t.IovecsIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
AddressSpaceActive: false,
})
- case _IOCB_CMD_FSYNC, _IOCB_CMD_FDSYNC, _IOCB_CMD_NOOP:
+ case linux.IOCB_CMD_FSYNC, linux.IOCB_CMD_FDSYNC, linux.IOCB_CMD_NOOP:
return usermem.IOSequence{}, nil
default:
@@ -261,54 +208,62 @@ func memoryFor(t *kernel.Task, cb *ioCallback) (usermem.IOSequence, error) {
}
}
-func performCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *ioCallback, ioseq usermem.IOSequence, ctx *mm.AIOContext, eventFile *fs.File) {
- if ctx.Dead() {
- ctx.CancelPendingRequest()
- return
- }
- ev := &ioEvent{
- Data: cb.Data,
- Obj: uint64(cbAddr),
- }
+// IoCancel implements linux syscall io_cancel(2).
+//
+// It is not presently supported (ENOSYS indicates no support on this
+// architecture).
+func IoCancel(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, syserror.ENOSYS
+}
- // Construct a context.Context that will not be interrupted if t is
- // interrupted.
- c := t.AsyncContext()
+// LINT.IfChange
- var err error
- switch cb.OpCode {
- case _IOCB_CMD_PREAD, _IOCB_CMD_PREADV:
- ev.Result, err = file.Preadv(c, ioseq, cb.Offset)
- case _IOCB_CMD_PWRITE, _IOCB_CMD_PWRITEV:
- ev.Result, err = file.Pwritev(c, ioseq, cb.Offset)
- case _IOCB_CMD_FSYNC:
- err = file.Fsync(c, 0, fs.FileMaxOffset, fs.SyncAll)
- case _IOCB_CMD_FDSYNC:
- err = file.Fsync(c, 0, fs.FileMaxOffset, fs.SyncData)
- }
+func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, actx *mm.AIOContext, eventFile *fs.File) kernel.AIOCallback {
+ return func(ctx context.Context) {
+ if actx.Dead() {
+ actx.CancelPendingRequest()
+ return
+ }
+ ev := &linux.IOEvent{
+ Data: cb.Data,
+ Obj: uint64(cbAddr),
+ }
- // Update the result.
- if err != nil {
- err = handleIOError(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", file)
- ev.Result = -int64(kernel.ExtractErrno(err, 0))
- }
+ var err error
+ switch cb.OpCode {
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV:
+ ev.Result, err = file.Preadv(ctx, ioseq, cb.Offset)
+ case linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV:
+ ev.Result, err = file.Pwritev(ctx, ioseq, cb.Offset)
+ case linux.IOCB_CMD_FSYNC:
+ err = file.Fsync(ctx, 0, fs.FileMaxOffset, fs.SyncAll)
+ case linux.IOCB_CMD_FDSYNC:
+ err = file.Fsync(ctx, 0, fs.FileMaxOffset, fs.SyncData)
+ }
+
+ // Update the result.
+ if err != nil {
+ err = handleIOError(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", file)
+ ev.Result = -int64(kernel.ExtractErrno(err, 0))
+ }
- file.DecRef()
+ file.DecRef()
- // Queue the result for delivery.
- ctx.FinishRequest(ev)
+ // Queue the result for delivery.
+ actx.FinishRequest(ev)
- // Notify the event file if one was specified. This needs to happen
- // *after* queueing the result to avoid racing with the thread we may
- // wake up.
- if eventFile != nil {
- eventFile.FileOperations.(*eventfd.EventOperations).Signal(1)
- eventFile.DecRef()
+ // Notify the event file if one was specified. This needs to happen
+ // *after* queueing the result to avoid racing with the thread we may
+ // wake up.
+ if eventFile != nil {
+ eventFile.FileOperations.(*eventfd.EventOperations).Signal(1)
+ eventFile.DecRef()
+ }
}
}
// submitCallback processes a single callback.
-func submitCallback(t *kernel.Task, id uint64, cb *ioCallback, cbAddr usermem.Addr) error {
+func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr usermem.Addr) error {
file := t.GetFile(cb.FD)
if file == nil {
// File not found.
@@ -318,7 +273,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *ioCallback, cbAddr usermem.Ad
// Was there an eventFD? Extract it.
var eventFile *fs.File
- if cb.Flags&_IOCB_FLAG_RESFD != 0 {
+ if cb.Flags&linux.IOCB_FLAG_RESFD != 0 {
eventFile = t.GetFile(cb.ResFD)
if eventFile == nil {
// Bad FD.
@@ -340,7 +295,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *ioCallback, cbAddr usermem.Ad
// Check offset for reads/writes.
switch cb.OpCode {
- case _IOCB_CMD_PREAD, _IOCB_CMD_PREADV, _IOCB_CMD_PWRITE, _IOCB_CMD_PWRITEV:
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV:
if cb.Offset < 0 {
return syserror.EINVAL
}
@@ -366,7 +321,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *ioCallback, cbAddr usermem.Ad
// Perform the request asynchronously.
file.IncRef()
- fs.Async(func() { performCallback(t, file, cbAddr, cb, ioseq, ctx, eventFile) })
+ t.QueueAIO(getAIOCallback(t, file, cbAddr, cb, ioseq, ctx, eventFile))
// All set.
return nil
@@ -395,7 +350,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
// Copy in this callback.
- var cb ioCallback
+ var cb linux.IOCallback
cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative))
if _, err := t.CopyIn(cbAddr, &cb); err != nil {
@@ -424,10 +379,4 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return uintptr(nrEvents), nil, nil
}
-// IoCancel implements linux syscall io_cancel(2).
-//
-// It is not presently supported (ENOSYS indicates no support on this
-// architecture).
-func IoCancel(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- return 0, nil, syserror.ENOSYS
-}
+// LINT.ThenChange(vfs2/aio.go)
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 35a98212a..8cf6401e7 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -900,14 +900,20 @@ func fGetOwn(t *kernel.Task, file *fs.File) int32 {
//
// If who is positive, it represents a PID. If negative, it represents a PGID.
// If the PID or PGID is invalid, the owner is silently unset.
-func fSetOwn(t *kernel.Task, file *fs.File, who int32) {
+func fSetOwn(t *kernel.Task, file *fs.File, who int32) error {
a := file.Async(fasync.New).(*fasync.FileAsync)
if who < 0 {
+ // Check for overflow before flipping the sign.
+ if who-1 > who {
+ return syserror.EINVAL
+ }
pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(-who))
a.SetOwnerProcessGroup(t, pg)
+ } else {
+ tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(who))
+ a.SetOwnerThreadGroup(t, tg)
}
- tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(who))
- a.SetOwnerThreadGroup(t, tg)
+ return nil
}
// Fcntl implements linux syscall fcntl(2).
@@ -935,10 +941,10 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return uintptr(flags.ToLinuxFDFlags()), nil, nil
case linux.F_SETFD:
flags := args[2].Uint()
- t.FDTable().SetFlags(fd, kernel.FDFlags{
+ err := t.FDTable().SetFlags(fd, kernel.FDFlags{
CloseOnExec: flags&linux.FD_CLOEXEC != 0,
})
- return 0, nil, nil
+ return 0, nil, err
case linux.F_GETFL:
return uintptr(file.Flags().ToLinux()), nil, nil
case linux.F_SETFL:
@@ -998,9 +1004,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
}
- // The lock uid is that of the Task's FDTable.
- lockUniqueID := lock.UniqueID(t.FDTable().ID())
-
// These locks don't block; execute the non-blocking operation using the inode's lock
// context directly.
switch flock.Type {
@@ -1010,12 +1013,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
if cmd == linux.F_SETLK {
// Non-blocking lock, provide a nil lock.Blocker.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.ReadLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, nil) {
return 0, nil, syserror.EAGAIN
}
} else {
// Blocking lock, pass in the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.ReadLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
@@ -1026,18 +1029,18 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
if cmd == linux.F_SETLK {
// Non-blocking lock, provide a nil lock.Blocker.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.WriteLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, nil) {
return 0, nil, syserror.EAGAIN
}
} else {
// Blocking lock, pass in the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.WriteLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
return 0, nil, nil
case linux.F_UNLCK:
- file.Dirent.Inode.LockCtx.Posix.UnlockRegion(lockUniqueID, rng)
+ file.Dirent.Inode.LockCtx.Posix.UnlockRegion(t.FDTable(), rng)
return 0, nil, nil
default:
return 0, nil, syserror.EINVAL
@@ -1045,8 +1048,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.F_GETOWN:
return uintptr(fGetOwn(t, file)), nil, nil
case linux.F_SETOWN:
- fSetOwn(t, file, args[2].Int())
- return 0, nil, nil
+ return 0, nil, fSetOwn(t, file, args[2].Int())
case linux.F_GETOWN_EX:
addr := args[2].Pointer()
owner := fGetOwnEx(t, file)
@@ -1055,7 +1057,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.F_SETOWN_EX:
addr := args[2].Pointer()
var owner linux.FOwnerEx
- n, err := t.CopyIn(addr, &owner)
+ _, err := t.CopyIn(addr, &owner)
if err != nil {
return 0, nil, err
}
@@ -1067,21 +1069,21 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, syserror.ESRCH
}
a.SetOwnerTask(t, task)
- return uintptr(n), nil, nil
+ return 0, nil, nil
case linux.F_OWNER_PID:
tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(owner.PID))
if tg == nil {
return 0, nil, syserror.ESRCH
}
a.SetOwnerThreadGroup(t, tg)
- return uintptr(n), nil, nil
+ return 0, nil, nil
case linux.F_OWNER_PGRP:
pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(owner.PID))
if pg == nil {
return 0, nil, syserror.ESRCH
}
a.SetOwnerProcessGroup(t, pg)
- return uintptr(n), nil, nil
+ return 0, nil, nil
default:
return 0, nil, syserror.EINVAL
}
@@ -1114,17 +1116,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
}
-// LINT.ThenChange(vfs2/fd.go)
-
-const (
- _FADV_NORMAL = 0
- _FADV_RANDOM = 1
- _FADV_SEQUENTIAL = 2
- _FADV_WILLNEED = 3
- _FADV_DONTNEED = 4
- _FADV_NOREUSE = 5
-)
-
// Fadvise64 implements linux syscall fadvise64(2).
// This implementation currently ignores the provided advice.
func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
@@ -1149,12 +1140,12 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
switch advice {
- case _FADV_NORMAL:
- case _FADV_RANDOM:
- case _FADV_SEQUENTIAL:
- case _FADV_WILLNEED:
- case _FADV_DONTNEED:
- case _FADV_NOREUSE:
+ case linux.POSIX_FADV_NORMAL:
+ case linux.POSIX_FADV_RANDOM:
+ case linux.POSIX_FADV_SEQUENTIAL:
+ case linux.POSIX_FADV_WILLNEED:
+ case linux.POSIX_FADV_DONTNEED:
+ case linux.POSIX_FADV_NOREUSE:
default:
return 0, nil, syserror.EINVAL
}
@@ -1163,8 +1154,6 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, nil
}
-// LINT.IfChange
-
func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error {
path, _, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -2157,22 +2146,6 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
nonblocking := operation&linux.LOCK_NB != 0
operation &^= linux.LOCK_NB
- // flock(2):
- // Locks created by flock() are associated with an open file table entry. This means that
- // duplicate file descriptors (created by, for example, fork(2) or dup(2)) refer to the
- // same lock, and this lock may be modified or released using any of these descriptors. Furthermore,
- // the lock is released either by an explicit LOCK_UN operation on any of these duplicate
- // descriptors, or when all such descriptors have been closed.
- //
- // If a process uses open(2) (or similar) to obtain more than one descriptor for the same file,
- // these descriptors are treated independently by flock(). An attempt to lock the file using
- // one of these file descriptors may be denied by a lock that the calling process has already placed via
- // another descriptor.
- //
- // We use the File UniqueID as the lock UniqueID because it needs to reference the same lock across dup(2)
- // and fork(2).
- lockUniqueID := lock.UniqueID(file.UniqueID)
-
// A BSD style lock spans the entire file.
rng := lock.LockRange{
Start: 0,
@@ -2183,29 +2156,29 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.LOCK_EX:
if nonblocking {
// Since we're nonblocking we pass a nil lock.Blocker implementation.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.WriteLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, nil) {
return 0, nil, syserror.EWOULDBLOCK
}
} else {
// Because we're blocking we will pass the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.WriteLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
case linux.LOCK_SH:
if nonblocking {
// Since we're nonblocking we pass a nil lock.Blocker implementation.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.ReadLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, nil) {
return 0, nil, syserror.EWOULDBLOCK
}
} else {
// Because we're blocking we will pass the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.ReadLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
case linux.LOCK_UN:
- file.Dirent.Inode.LockCtx.BSD.UnlockRegion(lockUniqueID, rng)
+ file.Dirent.Inode.LockCtx.BSD.UnlockRegion(file, rng)
default:
// flock(2): EINVAL operation is invalid.
return 0, nil, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
index b68261f72..f04d78856 100644
--- a/pkg/sentry/syscalls/linux/sys_futex.go
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -198,7 +198,7 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
switch cmd {
case linux.FUTEX_WAIT:
// WAIT uses a relative timeout.
- mask = ^uint32(0)
+ mask = linux.FUTEX_BITSET_MATCH_ANY
var timeoutDur time.Duration
if !forever {
timeoutDur = time.Duration(timespec.ToNsecCapped()) * time.Nanosecond
@@ -286,3 +286,49 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, syserror.ENOSYS
}
}
+
+// SetRobustList implements linux syscall set_robust_list(2).
+func SetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // Despite the syscall using the name 'pid' for this variable, it is
+ // very much a tid.
+ head := args[0].Pointer()
+ length := args[1].SizeT()
+
+ if length != uint(linux.SizeOfRobustListHead) {
+ return 0, nil, syserror.EINVAL
+ }
+ t.SetRobustList(head)
+ return 0, nil, nil
+}
+
+// GetRobustList implements linux syscall get_robust_list(2).
+func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // Despite the syscall using the name 'pid' for this variable, it is
+ // very much a tid.
+ tid := args[0].Int()
+ head := args[1].Pointer()
+ size := args[2].Pointer()
+
+ if tid < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ot := t
+ if tid != 0 {
+ if ot = t.PIDNamespace().TaskWithID(kernel.ThreadID(tid)); ot == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ }
+
+ // Copy out head pointer.
+ if _, err := t.CopyOut(head, uint64(ot.GetRobustList())); err != nil {
+ return 0, nil, err
+ }
+
+ // Copy out size, which is a constant.
+ if _, err := t.CopyOut(size, uint64(linux.SizeOfRobustListHead)); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 0760af77b..414fce8e3 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -29,6 +29,8 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// LINT.IfChange
@@ -474,7 +476,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
if v != nil {
- if _, err := t.CopyOut(optValAddr, v); err != nil {
+ if _, err := v.CopyOut(t, optValAddr); err != nil {
return 0, nil, err
}
}
@@ -484,7 +486,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getSockOpt tries to handle common socket options, or dispatches to a specific
// socket implementation.
-func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) {
if level == linux.SOL_SOCKET {
switch name {
case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
@@ -496,13 +498,16 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr use
switch name {
case linux.SO_TYPE:
_, skType, _ := s.Type()
- return int32(skType), nil
+ v := primitive.Int32(skType)
+ return &v, nil
case linux.SO_DOMAIN:
family, _, _ := s.Type()
- return int32(family), nil
+ v := primitive.Int32(family)
+ return &v, nil
case linux.SO_PROTOCOL:
_, _, protocol := s.Type()
- return int32(protocol), nil
+ v := primitive.Int32(protocol)
+ return &v, nil
}
}
@@ -539,7 +544,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.EINVAL
}
buf := t.CopyScratchBuffer(int(optLen))
- if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index 39f2b79ec..77c78889d 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -80,6 +80,12 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
}
}
+ if total > 0 {
+ // On Linux, inotify behavior is not very consistent with splice(2). We try
+ // our best to emulate Linux for very basic calls to splice, where for some
+ // reason, events are generated for output files, but not input files.
+ outFile.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+ }
return total, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go
index 2de5e3422..c24946160 100644
--- a/pkg/sentry/syscalls/linux/sys_xattr.go
+++ b/pkg/sentry/syscalls/linux/sys_xattr.go
@@ -207,7 +207,11 @@ func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, si
return syserror.EOPNOTSUPP
}
- return d.Inode.SetXattr(t, d, name, value, flags)
+ if err := d.Inode.SetXattr(t, d, name, value, flags); err != nil {
+ return err
+ }
+ d.InotifyEvent(linux.IN_ATTRIB, 0)
+ return nil
}
func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
@@ -418,7 +422,11 @@ func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr) error {
return syserror.EOPNOTSUPP
}
- return d.Inode.RemoveXattr(t, d, name)
+ if err := d.Inode.RemoveXattr(t, d, name); err != nil {
+ return err
+ }
+ d.InotifyEvent(linux.IN_ATTRIB, 0)
+ return nil
}
// LINT.ThenChange(vfs2/xattr.go)
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index f882ef840..64696b438 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -5,6 +5,7 @@ package(licenses = ["notice"])
go_library(
name = "vfs2",
srcs = [
+ "aio.go",
"epoll.go",
"eventfd.go",
"execve.go",
@@ -12,9 +13,12 @@ go_library(
"filesystem.go",
"fscontext.go",
"getdents.go",
+ "inotify.go",
"ioctl.go",
+ "lock.go",
"memfd.go",
"mmap.go",
+ "mount.go",
"path.go",
"pipe.go",
"poll.go",
@@ -22,6 +26,7 @@ go_library(
"setstat.go",
"signal.go",
"socket.go",
+ "splice.go",
"stat.go",
"stat_amd64.go",
"stat_arm64.go",
@@ -36,9 +41,11 @@ go_library(
"//pkg/abi/linux",
"//pkg/binary",
"//pkg/bits",
+ "//pkg/context",
"//pkg/fspath",
"//pkg/gohacks",
"//pkg/sentry/arch",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/fsbridge",
"//pkg/sentry/fsimpl/eventfd",
"//pkg/sentry/fsimpl/pipefs",
@@ -47,11 +54,13 @@ go_library(
"//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/fasync",
"//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
"//pkg/sentry/loader",
"//pkg/sentry/memmap",
+ "//pkg/sentry/mm",
"//pkg/sentry/socket",
"//pkg/sentry/socket/control",
"//pkg/sentry/socket/unix/transport",
@@ -63,5 +72,7 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go
new file mode 100644
index 000000000..e5cdefc50
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/aio.go
@@ -0,0 +1,216 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// IoSubmit implements linux syscall io_submit(2).
+func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := args[0].Uint64()
+ nrEvents := args[1].Int()
+ addr := args[2].Pointer()
+
+ if nrEvents < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ for i := int32(0); i < nrEvents; i++ {
+ // Copy in the address.
+ cbAddrNative := t.Arch().Native(0)
+ if _, err := t.CopyIn(addr, cbAddrNative); err != nil {
+ if i > 0 {
+ // Some successful.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Copy in this callback.
+ var cb linux.IOCallback
+ cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative))
+ if _, err := t.CopyIn(cbAddr, &cb); err != nil {
+ if i > 0 {
+ // Some have been successful.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Process this callback.
+ if err := submitCallback(t, id, &cb, cbAddr); err != nil {
+ if i > 0 {
+ // Partial success.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Advance to the next one.
+ addr += usermem.Addr(t.Arch().Width())
+ }
+
+ return uintptr(nrEvents), nil, nil
+}
+
+// submitCallback processes a single callback.
+func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr usermem.Addr) error {
+ if cb.Reserved2 != 0 {
+ return syserror.EINVAL
+ }
+
+ fd := t.GetFileVFS2(cb.FD)
+ if fd == nil {
+ return syserror.EBADF
+ }
+ defer fd.DecRef()
+
+ // Was there an eventFD? Extract it.
+ var eventFD *vfs.FileDescription
+ if cb.Flags&linux.IOCB_FLAG_RESFD != 0 {
+ eventFD = t.GetFileVFS2(cb.ResFD)
+ if eventFD == nil {
+ return syserror.EBADF
+ }
+ defer eventFD.DecRef()
+
+ // Check that it is an eventfd.
+ if _, ok := eventFD.Impl().(*eventfd.EventFileDescription); !ok {
+ return syserror.EINVAL
+ }
+ }
+
+ ioseq, err := memoryFor(t, cb)
+ if err != nil {
+ return err
+ }
+
+ // Check offset for reads/writes.
+ switch cb.OpCode {
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV:
+ if cb.Offset < 0 {
+ return syserror.EINVAL
+ }
+ }
+
+ // Prepare the request.
+ aioCtx, ok := t.MemoryManager().LookupAIOContext(t, id)
+ if !ok {
+ return syserror.EINVAL
+ }
+ if ready := aioCtx.Prepare(); !ready {
+ // Context is busy.
+ return syserror.EAGAIN
+ }
+
+ if eventFD != nil {
+ // The request is set. Make sure there's a ref on the file.
+ //
+ // This is necessary when the callback executes on completion,
+ // which is also what will release this reference.
+ eventFD.IncRef()
+ }
+
+ // Perform the request asynchronously.
+ fd.IncRef()
+ t.QueueAIO(getAIOCallback(t, fd, eventFD, cbAddr, cb, ioseq, aioCtx))
+ return nil
+}
+
+func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, aioCtx *mm.AIOContext) kernel.AIOCallback {
+ return func(ctx context.Context) {
+ if aioCtx.Dead() {
+ aioCtx.CancelPendingRequest()
+ return
+ }
+ ev := &linux.IOEvent{
+ Data: cb.Data,
+ Obj: uint64(cbAddr),
+ }
+
+ var err error
+ switch cb.OpCode {
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV:
+ ev.Result, err = fd.PRead(ctx, ioseq, cb.Offset, vfs.ReadOptions{})
+ case linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV:
+ ev.Result, err = fd.PWrite(ctx, ioseq, cb.Offset, vfs.WriteOptions{})
+ case linux.IOCB_CMD_FSYNC, linux.IOCB_CMD_FDSYNC:
+ err = fd.Sync(ctx)
+ }
+
+ // Update the result.
+ if err != nil {
+ err = slinux.HandleIOErrorVFS2(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", fd)
+ ev.Result = -int64(kernel.ExtractErrno(err, 0))
+ }
+
+ fd.DecRef()
+
+ // Queue the result for delivery.
+ aioCtx.FinishRequest(ev)
+
+ // Notify the event file if one was specified. This needs to happen
+ // *after* queueing the result to avoid racing with the thread we may
+ // wake up.
+ if eventFD != nil {
+ eventFD.Impl().(*eventfd.EventFileDescription).Signal(1)
+ eventFD.DecRef()
+ }
+ }
+}
+
+// memoryFor returns appropriate memory for the given callback.
+func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) {
+ bytes := int(cb.Bytes)
+ if bytes < 0 {
+ // Linux also requires that this field fit in ssize_t.
+ return usermem.IOSequence{}, syserror.EINVAL
+ }
+
+ // Since this I/O will be asynchronous with respect to t's task goroutine,
+ // we have no guarantee that t's AddressSpace will be active during the
+ // I/O.
+ switch cb.OpCode {
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PWRITE:
+ return t.SingleIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
+ AddressSpaceActive: false,
+ })
+
+ case linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITEV:
+ return t.IovecsIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
+ AddressSpaceActive: false,
+ })
+
+ case linux.IOCB_CMD_FSYNC, linux.IOCB_CMD_FDSYNC, linux.IOCB_CMD_NOOP:
+ return usermem.IOSequence{}, nil
+
+ default:
+ // Not a supported command.
+ return usermem.IOSequence{}, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index ca0f7fd1e..67f191551 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -17,10 +17,13 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/fasync"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -134,10 +137,10 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return uintptr(flags.ToLinuxFDFlags()), nil, nil
case linux.F_SETFD:
flags := args[2].Uint()
- t.FDTable().SetFlags(fd, kernel.FDFlags{
+ err := t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{
CloseOnExec: flags&linux.FD_CLOEXEC != 0,
})
- return 0, nil, nil
+ return 0, nil, err
case linux.F_GETFL:
return uintptr(file.StatusFlags()), nil, nil
case linux.F_SETFL:
@@ -152,6 +155,41 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
}
return uintptr(n), nil, nil
+ case linux.F_GETOWN:
+ owner, hasOwner := getAsyncOwner(t, file)
+ if !hasOwner {
+ return 0, nil, nil
+ }
+ if owner.Type == linux.F_OWNER_PGRP {
+ return uintptr(-owner.PID), nil, nil
+ }
+ return uintptr(owner.PID), nil, nil
+ case linux.F_SETOWN:
+ who := args[2].Int()
+ ownerType := int32(linux.F_OWNER_PID)
+ if who < 0 {
+ // Check for overflow before flipping the sign.
+ if who-1 > who {
+ return 0, nil, syserror.EINVAL
+ }
+ ownerType = linux.F_OWNER_PGRP
+ who = -who
+ }
+ return 0, nil, setAsyncOwner(t, file, ownerType, who)
+ case linux.F_GETOWN_EX:
+ owner, hasOwner := getAsyncOwner(t, file)
+ if !hasOwner {
+ return 0, nil, nil
+ }
+ _, err := t.CopyOut(args[2].Pointer(), &owner)
+ return 0, nil, err
+ case linux.F_SETOWN_EX:
+ var owner linux.FOwnerEx
+ _, err := t.CopyIn(args[2].Pointer(), &owner)
+ if err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID)
case linux.F_GETPIPE_SZ:
pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
if !ok {
@@ -167,8 +205,151 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
err := tmpfs.AddSeals(file, args[2].Uint())
return 0, nil, err
+ case linux.F_SETLK, linux.F_SETLKW:
+ return 0, nil, posixLock(t, args, file, cmd)
+ default:
+ // TODO(gvisor.dev/issue/2920): Everything else is not yet supported.
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+func getAsyncOwner(t *kernel.Task, fd *vfs.FileDescription) (ownerEx linux.FOwnerEx, hasOwner bool) {
+ a := fd.AsyncHandler()
+ if a == nil {
+ return linux.FOwnerEx{}, false
+ }
+
+ ot, otg, opg := a.(*fasync.FileAsync).Owner()
+ switch {
+ case ot != nil:
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_TID,
+ PID: int32(t.PIDNamespace().IDOfTask(ot)),
+ }, true
+ case otg != nil:
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_PID,
+ PID: int32(t.PIDNamespace().IDOfThreadGroup(otg)),
+ }, true
+ case opg != nil:
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_PGRP,
+ PID: int32(t.PIDNamespace().IDOfProcessGroup(opg)),
+ }, true
+ default:
+ return linux.FOwnerEx{}, true
+ }
+}
+
+func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32) error {
+ switch ownerType {
+ case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP:
+ // Acceptable type.
+ default:
+ return syserror.EINVAL
+ }
+
+ a := fd.SetAsyncHandler(fasync.NewVFS2).(*fasync.FileAsync)
+ if pid == 0 {
+ a.ClearOwner()
+ return nil
+ }
+
+ switch ownerType {
+ case linux.F_OWNER_TID:
+ task := t.PIDNamespace().TaskWithID(kernel.ThreadID(pid))
+ if task == nil {
+ return syserror.ESRCH
+ }
+ a.SetOwnerTask(t, task)
+ return nil
+ case linux.F_OWNER_PID:
+ tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(pid))
+ if tg == nil {
+ return syserror.ESRCH
+ }
+ a.SetOwnerThreadGroup(t, tg)
+ return nil
+ case linux.F_OWNER_PGRP:
+ pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(pid))
+ if pg == nil {
+ return syserror.ESRCH
+ }
+ a.SetOwnerProcessGroup(t, pg)
+ return nil
+ default:
+ return syserror.EINVAL
+ }
+}
+
+func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription, cmd int32) error {
+ // Copy in the lock request.
+ flockAddr := args[2].Pointer()
+ var flock linux.Flock
+ if _, err := t.CopyIn(flockAddr, &flock); err != nil {
+ return err
+ }
+
+ var blocker lock.Blocker
+ if cmd == linux.F_SETLKW {
+ blocker = t
+ }
+
+ switch flock.Type {
+ case linux.F_RDLCK:
+ if !file.IsReadable() {
+ return syserror.EBADF
+ }
+ return file.LockPOSIX(t, t.FDTable(), lock.ReadLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker)
+
+ case linux.F_WRLCK:
+ if !file.IsWritable() {
+ return syserror.EBADF
+ }
+ return file.LockPOSIX(t, t.FDTable(), lock.WriteLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker)
+
+ case linux.F_UNLCK:
+ return file.UnlockPOSIX(t, t.FDTable(), uint64(flock.Start), uint64(flock.Len), flock.Whence)
+
+ default:
+ return syserror.EINVAL
+ }
+}
+
+// Fadvise64 implements fadvise64(2).
+// This implementation currently ignores the provided advice.
+func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ length := args[2].Int64()
+ advice := args[3].Int()
+
+ // Note: offset is allowed to be negative.
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // If the FD refers to a pipe or FIFO, return error.
+ if _, isPipe := file.Impl().(*pipe.VFSPipeFD); isPipe {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ switch advice {
+ case linux.POSIX_FADV_NORMAL:
+ case linux.POSIX_FADV_RANDOM:
+ case linux.POSIX_FADV_SEQUENTIAL:
+ case linux.POSIX_FADV_WILLNEED:
+ case linux.POSIX_FADV_DONTNEED:
+ case linux.POSIX_FADV_NOREUSE:
default:
- // TODO(gvisor.dev/issue/1623): Everything else is not yet supported.
return 0, nil, syserror.EINVAL
}
+
+ // Sure, whatever.
+ return 0, nil, nil
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
index 46d3e189c..b6d2ddd65 100644
--- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -106,7 +106,7 @@ func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
addr := args[0].Pointer()
mode := args[1].ModeT()
dev := args[2].Uint()
- return 0, nil, mknodat(t, linux.AT_FDCWD, addr, mode, dev)
+ return 0, nil, mknodat(t, linux.AT_FDCWD, addr, linux.FileMode(mode), dev)
}
// Mknodat implements Linux syscall mknodat(2).
@@ -115,10 +115,10 @@ func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
addr := args[1].Pointer()
mode := args[2].ModeT()
dev := args[3].Uint()
- return 0, nil, mknodat(t, dirfd, addr, mode, dev)
+ return 0, nil, mknodat(t, dirfd, addr, linux.FileMode(mode), dev)
}
-func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint32) error {
+func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode linux.FileMode, dev uint32) error {
path, err := copyInPath(t, addr)
if err != nil {
return err
@@ -128,9 +128,14 @@ func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint
return err
}
defer tpop.Release()
+
+ // "Zero file type is equivalent to type S_IFREG." - mknod(2)
+ if mode.FileType() == 0 {
+ mode |= linux.ModeRegular
+ }
major, minor := linux.DecodeDeviceID(dev)
return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{
- Mode: linux.FileMode(mode &^ t.FSContext().Umask()),
+ Mode: mode &^ linux.FileMode(t.FSContext().Umask()),
DevMajor: uint32(major),
DevMinor: minor,
})
@@ -313,6 +318,9 @@ func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpath
if err != nil {
return err
}
+ if len(target) == 0 {
+ return syserror.ENOENT
+ }
linkpath, err := copyInPath(t, linkpathAddr)
if err != nil {
return err
diff --git a/pkg/sentry/syscalls/linux/vfs2/inotify.go b/pkg/sentry/syscalls/linux/vfs2/inotify.go
new file mode 100644
index 000000000..5d98134a5
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/inotify.go
@@ -0,0 +1,137 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const allFlags = linux.IN_NONBLOCK | linux.IN_CLOEXEC
+
+// InotifyInit1 implements the inotify_init1() syscalls.
+func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+ if flags&^allFlags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ino, err := vfs.NewInotifyFD(t, t.Kernel().VFS(), uint32(flags))
+ if err != nil {
+ return 0, nil, err
+ }
+ defer ino.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, ino, kernel.FDFlags{
+ CloseOnExec: flags&linux.IN_CLOEXEC != 0,
+ })
+
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// InotifyInit implements the inotify_init() syscalls.
+func InotifyInit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ args[0].Value = 0
+ return InotifyInit1(t, args)
+}
+
+// fdToInotify resolves an fd to an inotify object. If successful, the file will
+// have an extra ref and the caller is responsible for releasing the ref.
+func fdToInotify(t *kernel.Task, fd int32) (*vfs.Inotify, *vfs.FileDescription, error) {
+ f := t.GetFileVFS2(fd)
+ if f == nil {
+ // Invalid fd.
+ return nil, nil, syserror.EBADF
+ }
+
+ ino, ok := f.Impl().(*vfs.Inotify)
+ if !ok {
+ // Not an inotify fd.
+ f.DecRef()
+ return nil, nil, syserror.EINVAL
+ }
+
+ return ino, f, nil
+}
+
+// InotifyAddWatch implements the inotify_add_watch() syscall.
+func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ mask := args[2].Uint()
+
+ // "EINVAL: The given event mask contains no valid events."
+ // -- inotify_add_watch(2)
+ if mask&linux.ALL_INOTIFY_BITS == 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // "IN_DONT_FOLLOW: Don't dereference pathname if it is a symbolic link."
+ // -- inotify(7)
+ follow := followFinalSymlink
+ if mask&linux.IN_DONT_FOLLOW == 0 {
+ follow = nofollowFinalSymlink
+ }
+
+ ino, f, err := fdToInotify(t, fd)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer f.DecRef()
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ if mask&linux.IN_ONLYDIR != 0 {
+ path.Dir = true
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, follow)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+ d, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{})
+ if err != nil {
+ return 0, nil, err
+ }
+ defer d.DecRef()
+
+ fd, err = ino.AddWatch(d.Dentry(), mask)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// InotifyRmWatch implements the inotify_rm_watch() syscall.
+func InotifyRmWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ wd := args[1].Int()
+
+ ino, f, err := fdToInotify(t, fd)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer f.DecRef()
+ return 0, nil, ino.RmWatch(wd)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
index 5a2418da9..fd6ab94b2 100644
--- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -15,6 +15,7 @@
package vfs2
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
@@ -30,6 +31,77 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
defer file.DecRef()
+ // Handle ioctls that apply to all FDs.
+ switch args[1].Int() {
+ case linux.FIONCLEX:
+ t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{
+ CloseOnExec: false,
+ })
+ return 0, nil, nil
+
+ case linux.FIOCLEX:
+ t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{
+ CloseOnExec: true,
+ })
+ return 0, nil, nil
+
+ case linux.FIONBIO:
+ var set int32
+ if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ return 0, nil, err
+ }
+ flags := file.StatusFlags()
+ if set != 0 {
+ flags |= linux.O_NONBLOCK
+ } else {
+ flags &^= linux.O_NONBLOCK
+ }
+ return 0, nil, file.SetStatusFlags(t, t.Credentials(), flags)
+
+ case linux.FIOASYNC:
+ var set int32
+ if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ return 0, nil, err
+ }
+ flags := file.StatusFlags()
+ if set != 0 {
+ flags |= linux.O_ASYNC
+ } else {
+ flags &^= linux.O_ASYNC
+ }
+ file.SetStatusFlags(t, t.Credentials(), flags)
+ return 0, nil, nil
+
+ case linux.FIOGETOWN, linux.SIOCGPGRP:
+ var who int32
+ owner, hasOwner := getAsyncOwner(t, file)
+ if hasOwner {
+ if owner.Type == linux.F_OWNER_PGRP {
+ who = -owner.PID
+ } else {
+ who = owner.PID
+ }
+ }
+ _, err := t.CopyOut(args[2].Pointer(), &who)
+ return 0, nil, err
+
+ case linux.FIOSETOWN, linux.SIOCSPGRP:
+ var who int32
+ if _, err := t.CopyIn(args[2].Pointer(), &who); err != nil {
+ return 0, nil, err
+ }
+ ownerType := int32(linux.F_OWNER_PID)
+ if who < 0 {
+ // Check for overflow before flipping the sign.
+ if who-1 > who {
+ return 0, nil, syserror.EINVAL
+ }
+ ownerType = linux.F_OWNER_PGRP
+ who = -who
+ }
+ return 0, nil, setAsyncOwner(t, file, ownerType, who)
+ }
+
ret, err := file.Ioctl(t, t.MemoryManager(), args)
return ret, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/lock.go b/pkg/sentry/syscalls/linux/vfs2/lock.go
new file mode 100644
index 000000000..bf19028c4
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/lock.go
@@ -0,0 +1,64 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Flock implements linux syscall flock(2).
+func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ operation := args[1].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ // flock(2): EBADF fd is not an open file descriptor.
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ nonblocking := operation&linux.LOCK_NB != 0
+ operation &^= linux.LOCK_NB
+
+ var blocker lock.Blocker
+ if !nonblocking {
+ blocker = t
+ }
+
+ switch operation {
+ case linux.LOCK_EX:
+ if err := file.LockBSD(t, lock.WriteLock, blocker); err != nil {
+ return 0, nil, err
+ }
+ case linux.LOCK_SH:
+ if err := file.LockBSD(t, lock.ReadLock, blocker); err != nil {
+ return 0, nil, err
+ }
+ case linux.LOCK_UN:
+ if err := file.UnlockBSD(t); err != nil {
+ return 0, nil, err
+ }
+ default:
+ // flock(2): EINVAL operation is invalid.
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go
new file mode 100644
index 000000000..ea337de7c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/mount.go
@@ -0,0 +1,150 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Mount implements Linux syscall mount(2).
+func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ sourceAddr := args[0].Pointer()
+ targetAddr := args[1].Pointer()
+ typeAddr := args[2].Pointer()
+ flags := args[3].Uint64()
+ dataAddr := args[4].Pointer()
+
+ // For null-terminated strings related to mount(2), Linux copies in at most
+ // a page worth of data. See fs/namespace.c:copy_mount_string().
+ fsType, err := t.CopyInString(typeAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ source, err := t.CopyInString(sourceAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ targetPath, err := copyInPath(t, targetAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ data := ""
+ if dataAddr != 0 {
+ // In Linux, a full page is always copied in regardless of null
+ // character placement, and the address is passed to each file system.
+ // Most file systems always treat this data as a string, though, and so
+ // do all of the ones we implement.
+ data, err = t.CopyInString(dataAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ // Ignore magic value that was required before Linux 2.4.
+ if flags&linux.MS_MGC_MSK == linux.MS_MGC_VAL {
+ flags = flags &^ linux.MS_MGC_MSK
+ }
+
+ // Must have CAP_SYS_ADMIN in the current mount namespace's associated user
+ // namespace.
+ creds := t.Credentials()
+ if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespaceVFS2().Owner) {
+ return 0, nil, syserror.EPERM
+ }
+
+ const unsupportedOps = linux.MS_REMOUNT | linux.MS_BIND |
+ linux.MS_SHARED | linux.MS_PRIVATE | linux.MS_SLAVE |
+ linux.MS_UNBINDABLE | linux.MS_MOVE
+
+ // Silently allow MS_NOSUID, since we don't implement set-id bits
+ // anyway.
+ const unsupportedFlags = linux.MS_NODIRATIME | linux.MS_STRICTATIME
+
+ // Linux just allows passing any flags to mount(2) - it won't fail when
+ // unknown or unsupported flags are passed. Since we don't implement
+ // everything, we fail explicitly on flags that are unimplemented.
+ if flags&(unsupportedOps|unsupportedFlags) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var opts vfs.MountOptions
+ if flags&linux.MS_NOATIME == linux.MS_NOATIME {
+ opts.Flags.NoATime = true
+ }
+ if flags&linux.MS_NOEXEC == linux.MS_NOEXEC {
+ opts.Flags.NoExec = true
+ }
+ if flags&linux.MS_NODEV == linux.MS_NODEV {
+ opts.Flags.NoDev = true
+ }
+ if flags&linux.MS_NOSUID == linux.MS_NOSUID {
+ opts.Flags.NoSUID = true
+ }
+ if flags&linux.MS_RDONLY == linux.MS_RDONLY {
+ opts.ReadOnly = true
+ }
+ opts.GetFilesystemOptions.Data = data
+
+ target, err := getTaskPathOperation(t, linux.AT_FDCWD, targetPath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer target.Release()
+
+ return 0, nil, t.Kernel().VFS().MountAt(t, creds, source, &target.pop, fsType, &opts)
+}
+
+// Umount2 implements Linux syscall umount2(2).
+func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Int()
+
+ // Must have CAP_SYS_ADMIN in the mount namespace's associated user
+ // namespace.
+ //
+ // Currently, this is always the init task's user namespace.
+ creds := t.Credentials()
+ if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespaceVFS2().Owner) {
+ return 0, nil, syserror.EPERM
+ }
+
+ const unsupported = linux.MNT_FORCE | linux.MNT_EXPIRE
+ if flags&unsupported != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ opts := vfs.UmountOptions{
+ Flags: uint32(flags),
+ }
+
+ return 0, nil, t.Kernel().VFS().UmountAt(t, creds, &tpop.pop, &opts)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go
index 3a7ef24f5..cd25597a7 100644
--- a/pkg/sentry/syscalls/linux/vfs2/read_write.go
+++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go
@@ -93,11 +93,17 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
n, err := file.Read(t, dst, opts)
if err != syserror.ErrWouldBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
return n, err
}
allowBlock, deadline, hasDeadline := blockPolicy(t, file)
if !allowBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
return n, err
}
@@ -128,6 +134,9 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt
}
file.EventUnregister(&w)
+ if total > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
return total, err
}
@@ -248,11 +257,17 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
n, err := file.PRead(t, dst, offset, opts)
if err != syserror.ErrWouldBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
return n, err
}
allowBlock, deadline, hasDeadline := blockPolicy(t, file)
if !allowBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
return n, err
}
@@ -283,6 +298,9 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of
}
file.EventUnregister(&w)
+ if total > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
return total, err
}
@@ -345,11 +363,17 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
n, err := file.Write(t, src, opts)
if err != syserror.ErrWouldBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ }
return n, err
}
allowBlock, deadline, hasDeadline := blockPolicy(t, file)
if !allowBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ }
return n, err
}
@@ -380,6 +404,9 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op
}
file.EventUnregister(&w)
+ if total > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ }
return total, err
}
@@ -500,11 +527,17 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
n, err := file.PWrite(t, src, offset, opts)
if err != syserror.ErrWouldBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ }
return n, err
}
allowBlock, deadline, hasDeadline := blockPolicy(t, file)
if !allowBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
return n, err
}
@@ -535,6 +568,9 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o
}
file.EventUnregister(&w)
+ if total > 0 {
+ file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
return total, err
}
@@ -570,3 +606,36 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
newoff, err := file.Seek(t, offset, whence)
return uintptr(newoff), nil, err
}
+
+// Readahead implements readahead(2).
+func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is readable.
+ if !file.IsReadable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is valid.
+ if int(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Return EINVAL; if the underlying file type does not support readahead,
+ // then Linux will return EINVAL to indicate as much. In the future, we
+ // may extend this function to actually support readahead hints.
+ return 0, nil, syserror.EINVAL
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index 09ecfed26..25cdb7a55 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -178,6 +179,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
Mask: linux.STATX_SIZE,
Size: uint64(length),
},
+ NeedWritePerm: true,
})
return 0, nil, handleSetSizeError(t, err)
}
@@ -197,6 +199,10 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
defer file.DecRef()
+ if !file.IsWritable() {
+ return 0, nil, syserror.EINVAL
+ }
+
err := file.SetStat(t, vfs.SetStatOptions{
Stat: linux.Statx{
Mask: linux.STATX_SIZE,
@@ -206,6 +212,56 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, handleSetSizeError(t, err)
}
+// Fallocate implements linux system call fallocate(2).
+func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ mode := args[1].Uint64()
+ offset := args[2].Int64()
+ length := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ if !file.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ if mode != 0 {
+ return 0, nil, syserror.ENOTSUP
+ }
+
+ if offset < 0 || length <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ size := offset + length
+
+ if size < 0 {
+ return 0, nil, syserror.EFBIG
+ }
+
+ limit := limits.FromContext(t).Get(limits.FileSize).Cur
+
+ if uint64(size) >= limit {
+ t.SendSignal(&arch.SignalInfo{
+ Signo: int32(linux.SIGXFSZ),
+ Code: arch.SignalInfoUser,
+ })
+ return 0, nil, syserror.EFBIG
+ }
+
+ if err := file.Allocate(t, mode, uint64(offset), uint64(length)); err != nil {
+ return 0, nil, err
+ }
+
+ file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ return 0, nil, nil
+}
+
// Utime implements Linux syscall utime(2).
func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
pathAddr := args[0].Pointer()
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 10b668477..8096a8f9c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -30,6 +30,8 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// minListenBacklog is the minimum reasonable backlog for listening sockets.
@@ -477,7 +479,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
if v != nil {
- if _, err := t.CopyOut(optValAddr, v); err != nil {
+ if _, err := v.CopyOut(t, optValAddr); err != nil {
return 0, nil, err
}
}
@@ -487,7 +489,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getSockOpt tries to handle common socket options, or dispatches to a specific
// socket implementation.
-func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) {
if level == linux.SOL_SOCKET {
switch name {
case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
@@ -499,13 +501,16 @@ func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr
switch name {
case linux.SO_TYPE:
_, skType, _ := s.Type()
- return int32(skType), nil
+ v := primitive.Int32(skType)
+ return &v, nil
case linux.SO_DOMAIN:
family, _, _ := s.Type()
- return int32(family), nil
+ v := primitive.Int32(family)
+ return &v, nil
case linux.SO_PROTOCOL:
_, _, protocol := s.Type()
- return int32(protocol), nil
+ v := primitive.Int32(protocol)
+ return &v, nil
}
}
@@ -542,7 +547,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.EINVAL
}
buf := t.CopyScratchBuffer(int(optLen))
- if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
new file mode 100644
index 000000000..63ab11f8c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -0,0 +1,486 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Splice implements Linux syscall splice(2).
+func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ inFD := args[0].Int()
+ inOffsetPtr := args[1].Pointer()
+ outFD := args[2].Int()
+ outOffsetPtr := args[3].Pointer()
+ count := int64(args[4].SizeT())
+ flags := args[5].Int()
+
+ if count == 0 {
+ return 0, nil, nil
+ }
+ if count > int64(kernel.MAX_RW_COUNT) {
+ count = int64(kernel.MAX_RW_COUNT)
+ }
+
+ // Check for invalid flags.
+ if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get file descriptions.
+ inFile := t.GetFileVFS2(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+ outFile := t.GetFileVFS2(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+
+ // Check that both files support the required directionality.
+ if !inFile.IsReadable() || !outFile.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // The operation is non-blocking if anything is non-blocking.
+ //
+ // N.B. This is a rather simplistic heuristic that avoids some
+ // poor edge case behavior since the exact semantics here are
+ // underspecified and vary between versions of Linux itself.
+ nonBlock := ((inFile.StatusFlags()|outFile.StatusFlags())&linux.O_NONBLOCK != 0) || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
+ // At least one file description must represent a pipe.
+ inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD)
+ outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
+ if !inIsPipe && !outIsPipe {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy in offsets.
+ inOffset := int64(-1)
+ if inOffsetPtr != 0 {
+ if inIsPipe {
+ return 0, nil, syserror.ESPIPE
+ }
+ if inFile.Options().DenyPRead {
+ return 0, nil, syserror.EINVAL
+ }
+ if _, err := t.CopyIn(inOffsetPtr, &inOffset); err != nil {
+ return 0, nil, err
+ }
+ if inOffset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+ outOffset := int64(-1)
+ if outOffsetPtr != 0 {
+ if outIsPipe {
+ return 0, nil, syserror.ESPIPE
+ }
+ if outFile.Options().DenyPWrite {
+ return 0, nil, syserror.EINVAL
+ }
+ if _, err := t.CopyIn(outOffsetPtr, &outOffset); err != nil {
+ return 0, nil, err
+ }
+ if outOffset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+
+ // Move data.
+ var (
+ n int64
+ err error
+ )
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
+ for {
+ // If both input and output are pipes, delegate to the pipe
+ // implementation. Otherwise, exactly one end is a pipe, which
+ // we ensure is consistently ordered after the non-pipe FD's
+ // locks by passing the pipe FD as usermem.IO to the non-pipe
+ // end.
+ switch {
+ case inIsPipe && outIsPipe:
+ n, err = pipe.Splice(t, outPipeFD, inPipeFD, count)
+ case inIsPipe:
+ if outOffset != -1 {
+ n, err = outFile.PWrite(t, inPipeFD.IOSequence(count), outOffset, vfs.WriteOptions{})
+ outOffset += n
+ } else {
+ n, err = outFile.Write(t, inPipeFD.IOSequence(count), vfs.WriteOptions{})
+ }
+ case outIsPipe:
+ if inOffset != -1 {
+ n, err = inFile.PRead(t, outPipeFD.IOSequence(count), inOffset, vfs.ReadOptions{})
+ inOffset += n
+ } else {
+ n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{})
+ }
+ default:
+ panic("not possible")
+ }
+
+ if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
+ break
+ }
+ if err = dw.waitForBoth(t); err != nil {
+ break
+ }
+ }
+
+ // Copy updated offsets out.
+ if inOffsetPtr != 0 {
+ if _, err := t.CopyOut(inOffsetPtr, &inOffset); err != nil {
+ return 0, nil, err
+ }
+ }
+ if outOffsetPtr != 0 {
+ if _, err := t.CopyOut(outOffsetPtr, &outOffset); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ if n == 0 {
+ return 0, nil, err
+ }
+
+ // On Linux, inotify behavior is not very consistent with splice(2). We try
+ // our best to emulate Linux for very basic calls to splice, where for some
+ // reason, events are generated for output files, but not input files.
+ outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ return uintptr(n), nil, nil
+}
+
+// Tee implements Linux syscall tee(2).
+func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ inFD := args[0].Int()
+ outFD := args[1].Int()
+ count := int64(args[2].SizeT())
+ flags := args[3].Int()
+
+ if count == 0 {
+ return 0, nil, nil
+ }
+ if count > int64(kernel.MAX_RW_COUNT) {
+ count = int64(kernel.MAX_RW_COUNT)
+ }
+
+ // Check for invalid flags.
+ if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get file descriptions.
+ inFile := t.GetFileVFS2(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+ outFile := t.GetFileVFS2(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+
+ // Check that both files support the required directionality.
+ if !inFile.IsReadable() || !outFile.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // The operation is non-blocking if anything is non-blocking.
+ //
+ // N.B. This is a rather simplistic heuristic that avoids some
+ // poor edge case behavior since the exact semantics here are
+ // underspecified and vary between versions of Linux itself.
+ nonBlock := ((inFile.StatusFlags()|outFile.StatusFlags())&linux.O_NONBLOCK != 0) || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
+ // Both file descriptions must represent pipes.
+ inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD)
+ outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
+ if !inIsPipe || !outIsPipe {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy data.
+ var (
+ n int64
+ err error
+ )
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
+ for {
+ n, err = pipe.Tee(t, outPipeFD, inPipeFD, count)
+ if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
+ break
+ }
+ if err = dw.waitForBoth(t); err != nil {
+ break
+ }
+ }
+ if n == 0 {
+ return 0, nil, err
+ }
+ outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ return uintptr(n), nil, nil
+}
+
+// Sendfile implements linux system call sendfile(2).
+func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ outFD := args[0].Int()
+ inFD := args[1].Int()
+ offsetAddr := args[2].Pointer()
+ count := int64(args[3].SizeT())
+
+ inFile := t.GetFileVFS2(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+ if !inFile.IsReadable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ outFile := t.GetFileVFS2(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+ if !outFile.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Verify that the outFile Append flag is not set.
+ if outFile.StatusFlags()&linux.O_APPEND != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Verify that inFile is a regular file or block device. This is a
+ // requirement; the same check appears in Linux
+ // (fs/splice.c:splice_direct_to_actor).
+ if stat, err := inFile.Stat(t, vfs.StatOptions{Mask: linux.STATX_TYPE}); err != nil {
+ return 0, nil, err
+ } else if stat.Mask&linux.STATX_TYPE == 0 ||
+ (stat.Mode&linux.S_IFMT != linux.S_IFREG && stat.Mode&linux.S_IFMT != linux.S_IFBLK) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy offset if it exists.
+ offset := int64(-1)
+ if offsetAddr != 0 {
+ if inFile.Options().DenyPRead {
+ return 0, nil, syserror.ESPIPE
+ }
+ if _, err := t.CopyIn(offsetAddr, &offset); err != nil {
+ return 0, nil, err
+ }
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if offset+count < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+
+ // Validate count. This must come after offset checks.
+ if count < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if count == 0 {
+ return 0, nil, nil
+ }
+ if count > int64(kernel.MAX_RW_COUNT) {
+ count = int64(kernel.MAX_RW_COUNT)
+ }
+
+ // Copy data.
+ var (
+ n int64
+ err error
+ )
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
+ outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
+ // Reading from input file should never block, since it is regular or
+ // block device. We only need to check if writing to the output file
+ // can block.
+ nonBlock := outFile.StatusFlags()&linux.O_NONBLOCK != 0
+ if outIsPipe {
+ for n < count {
+ var spliceN int64
+ if offset != -1 {
+ spliceN, err = inFile.PRead(t, outPipeFD.IOSequence(count), offset, vfs.ReadOptions{})
+ offset += spliceN
+ } else {
+ spliceN, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{})
+ }
+ n += spliceN
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForBoth(t)
+ }
+ if err != nil {
+ break
+ }
+ }
+ } else {
+ // Read inFile to buffer, then write the contents to outFile.
+ buf := make([]byte, count)
+ for n < count {
+ var readN int64
+ if offset != -1 {
+ readN, err = inFile.PRead(t, usermem.BytesIOSequence(buf), offset, vfs.ReadOptions{})
+ offset += readN
+ } else {
+ readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ }
+ if readN == 0 && err == io.EOF {
+ // We reached the end of the file. Eat the
+ // error and exit the loop.
+ err = nil
+ break
+ }
+ n += readN
+ if err != nil {
+ break
+ }
+
+ // Write all of the bytes that we read. This may need
+ // multiple write calls to complete.
+ wbuf := buf[:n]
+ for len(wbuf) > 0 {
+ var writeN int64
+ writeN, err = outFile.Write(t, usermem.BytesIOSequence(wbuf), vfs.WriteOptions{})
+ wbuf = wbuf[writeN:]
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForOut(t)
+ }
+ if err != nil {
+ // We didn't complete the write. Only
+ // report the bytes that were actually
+ // written, and rewind the offset.
+ notWritten := int64(len(wbuf))
+ n -= notWritten
+ if offset != -1 {
+ offset -= notWritten
+ }
+ break
+ }
+ }
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForBoth(t)
+ }
+ if err != nil {
+ break
+ }
+ }
+ }
+
+ if offsetAddr != 0 {
+ // Copy out the new offset.
+ if _, err := t.CopyOut(offsetAddr, offset); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ if n == 0 {
+ return 0, nil, err
+ }
+
+ inFile.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ return uintptr(n), nil, nil
+}
+
+// dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not
+// thread-safe, and does not take a reference on the vfs.FileDescriptions.
+//
+// Users must call destroy() when finished.
+type dualWaiter struct {
+ inFile *vfs.FileDescription
+ outFile *vfs.FileDescription
+
+ inW waiter.Entry
+ inCh chan struct{}
+ outW waiter.Entry
+ outCh chan struct{}
+}
+
+// waitForBoth waits for both dw.inFile and dw.outFile to be ready.
+func (dw *dualWaiter) waitForBoth(t *kernel.Task) error {
+ if dw.inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
+ if dw.inCh == nil {
+ dw.inW, dw.inCh = waiter.NewChannelEntry(nil)
+ dw.inFile.EventRegister(&dw.inW, eventMaskRead)
+ // We might be ready now. Try again before blocking.
+ return nil
+ }
+ if err := t.Block(dw.inCh); err != nil {
+ return err
+ }
+ }
+ return dw.waitForOut(t)
+}
+
+// waitForOut waits for dw.outfile to be read.
+func (dw *dualWaiter) waitForOut(t *kernel.Task) error {
+ if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
+ if dw.outCh == nil {
+ dw.outW, dw.outCh = waiter.NewChannelEntry(nil)
+ dw.outFile.EventRegister(&dw.outW, eventMaskWrite)
+ // We might be ready now. Try again before blocking.
+ return nil
+ }
+ if err := t.Block(dw.outCh); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// destroy cleans up resources help by dw. No more calls to wait* can occur
+// after destroy is called.
+func (dw *dualWaiter) destroy() {
+ if dw.inCh != nil {
+ dw.inFile.EventUnregister(&dw.inW)
+ dw.inCh = nil
+ }
+ if dw.outCh != nil {
+ dw.outFile.EventUnregister(&dw.outW)
+ dw.outCh = nil
+ }
+ dw.inFile = nil
+ dw.outFile = nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go
index 365250b0b..0d0ebf46a 100644
--- a/pkg/sentry/syscalls/linux/vfs2/sync.go
+++ b/pkg/sentry/syscalls/linux/vfs2/sync.go
@@ -65,10 +65,8 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
nbytes := args[2].Int64()
flags := args[3].Uint()
- if offset < 0 {
- return 0, nil, syserror.EINVAL
- }
- if nbytes < 0 {
+ // Check for negative values and overflow.
+ if offset < 0 || offset+nbytes < 0 {
return 0, nil, syserror.EINVAL
}
if flags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|linux.SYNC_FILE_RANGE_WRITE|linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 {
@@ -81,7 +79,37 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
}
defer file.DecRef()
- // TODO(gvisor.dev/issue/1897): Avoid writeback of data ranges outside of
- // [offset, offset+nbytes).
- return 0, nil, file.Sync(t)
+ // TODO(gvisor.dev/issue/1897): Currently, the only file syncing we support
+ // is a full-file sync, i.e. fsync(2). As a result, there are severe
+ // limitations on how much we support sync_file_range:
+ // - In Linux, sync_file_range(2) doesn't write out the file's metadata, even
+ // if the file size is changed. We do.
+ // - We always sync the entire file instead of [offset, offset+nbytes).
+ // - We do not support the use of WAIT_BEFORE without WAIT_AFTER. For
+ // correctness, we would have to perform a write-out every time WAIT_BEFORE
+ // was used, but this would be much more expensive than expected if there
+ // were no write-out operations in progress.
+ // - Whenever WAIT_AFTER is used, we sync the file.
+ // - Ignore WRITE. If this flag is used with WAIT_AFTER, then the file will
+ // be synced anyway. If this flag is used without WAIT_AFTER, then it is
+ // safe (and less expensive) to do nothing, because the syscall will not
+ // wait for the write-out to complete--we only need to make sure that the
+ // next time WAIT_BEFORE or WAIT_AFTER are used, the write-out completes.
+ // - According to fs/sync.c, WAIT_BEFORE|WAIT_AFTER "will detect any I/O
+ // errors or ENOSPC conditions and will return those to the caller, after
+ // clearing the EIO and ENOSPC flags in the address_space." We don't do
+ // this.
+
+ if flags&linux.SYNC_FILE_RANGE_WAIT_BEFORE != 0 &&
+ flags&linux.SYNC_FILE_RANGE_WAIT_AFTER == 0 {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ }
+
+ if flags&linux.SYNC_FILE_RANGE_WAIT_AFTER != 0 {
+ if err := file.Sync(t); err != nil {
+ return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+ }
+ }
+ return 0, nil, nil
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
index a332d01bd..c576d9475 100644
--- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -44,7 +44,7 @@ func Override() {
s.Table[23] = syscalls.Supported("select", Select)
s.Table[32] = syscalls.Supported("dup", Dup)
s.Table[33] = syscalls.Supported("dup2", Dup2)
- delete(s.Table, 40) // sendfile
+ s.Table[40] = syscalls.Supported("sendfile", Sendfile)
s.Table[41] = syscalls.Supported("socket", Socket)
s.Table[42] = syscalls.Supported("connect", Connect)
s.Table[43] = syscalls.Supported("accept", Accept)
@@ -62,7 +62,7 @@ func Override() {
s.Table[55] = syscalls.Supported("getsockopt", GetSockOpt)
s.Table[59] = syscalls.Supported("execve", Execve)
s.Table[72] = syscalls.Supported("fcntl", Fcntl)
- delete(s.Table, 73) // flock
+ s.Table[73] = syscalls.Supported("flock", Flock)
s.Table[74] = syscalls.Supported("fsync", Fsync)
s.Table[75] = syscalls.Supported("fdatasync", Fdatasync)
s.Table[76] = syscalls.Supported("truncate", Truncate)
@@ -90,9 +90,9 @@ func Override() {
s.Table[138] = syscalls.Supported("fstatfs", Fstatfs)
s.Table[161] = syscalls.Supported("chroot", Chroot)
s.Table[162] = syscalls.Supported("sync", Sync)
- delete(s.Table, 165) // mount
- delete(s.Table, 166) // umount2
- delete(s.Table, 187) // readahead
+ s.Table[165] = syscalls.Supported("mount", Mount)
+ s.Table[166] = syscalls.Supported("umount2", Umount2)
+ s.Table[187] = syscalls.Supported("readahead", Readahead)
s.Table[188] = syscalls.Supported("setxattr", Setxattr)
s.Table[189] = syscalls.Supported("lsetxattr", Lsetxattr)
s.Table[190] = syscalls.Supported("fsetxattr", Fsetxattr)
@@ -105,20 +105,16 @@ func Override() {
s.Table[197] = syscalls.Supported("removexattr", Removexattr)
s.Table[198] = syscalls.Supported("lremovexattr", Lremovexattr)
s.Table[199] = syscalls.Supported("fremovexattr", Fremovexattr)
- delete(s.Table, 206) // io_setup
- delete(s.Table, 207) // io_destroy
- delete(s.Table, 208) // io_getevents
- delete(s.Table, 209) // io_submit
- delete(s.Table, 210) // io_cancel
+ s.Table[209] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"})
s.Table[213] = syscalls.Supported("epoll_create", EpollCreate)
s.Table[217] = syscalls.Supported("getdents64", Getdents64)
- delete(s.Table, 221) // fdavise64
+ s.Table[221] = syscalls.PartiallySupported("fadvise64", Fadvise64, "The syscall is 'supported', but ignores all provided advice.", nil)
s.Table[232] = syscalls.Supported("epoll_wait", EpollWait)
s.Table[233] = syscalls.Supported("epoll_ctl", EpollCtl)
s.Table[235] = syscalls.Supported("utimes", Utimes)
- delete(s.Table, 253) // inotify_init
- delete(s.Table, 254) // inotify_add_watch
- delete(s.Table, 255) // inotify_rm_watch
+ s.Table[253] = syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil)
+ s.Table[254] = syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil)
+ s.Table[255] = syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil)
s.Table[257] = syscalls.Supported("openat", Openat)
s.Table[258] = syscalls.Supported("mkdirat", Mkdirat)
s.Table[259] = syscalls.Supported("mknodat", Mknodat)
@@ -134,15 +130,15 @@ func Override() {
s.Table[269] = syscalls.Supported("faccessat", Faccessat)
s.Table[270] = syscalls.Supported("pselect", Pselect)
s.Table[271] = syscalls.Supported("ppoll", Ppoll)
- delete(s.Table, 275) // splice
- delete(s.Table, 276) // tee
+ s.Table[275] = syscalls.Supported("splice", Splice)
+ s.Table[276] = syscalls.Supported("tee", Tee)
s.Table[277] = syscalls.Supported("sync_file_range", SyncFileRange)
s.Table[280] = syscalls.Supported("utimensat", Utimensat)
s.Table[281] = syscalls.Supported("epoll_pwait", EpollPwait)
s.Table[282] = syscalls.Supported("signalfd", Signalfd)
s.Table[283] = syscalls.Supported("timerfd_create", TimerfdCreate)
s.Table[284] = syscalls.Supported("eventfd", Eventfd)
- delete(s.Table, 285) // fallocate
+ s.Table[285] = syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil)
s.Table[286] = syscalls.Supported("timerfd_settime", TimerfdSettime)
s.Table[287] = syscalls.Supported("timerfd_gettime", TimerfdGettime)
s.Table[288] = syscalls.Supported("accept4", Accept4)
@@ -151,7 +147,7 @@ func Override() {
s.Table[291] = syscalls.Supported("epoll_create1", EpollCreate1)
s.Table[292] = syscalls.Supported("dup3", Dup3)
s.Table[293] = syscalls.Supported("pipe2", Pipe2)
- delete(s.Table, 294) // inotify_init1
+ s.Table[294] = syscalls.PartiallySupported("inotify_init1", InotifyInit1, "inotify events are only available inside the sandbox.", nil)
s.Table[295] = syscalls.Supported("preadv", Preadv)
s.Table[296] = syscalls.Supported("pwritev", Pwritev)
s.Table[299] = syscalls.Supported("recvmmsg", RecvMMsg)
@@ -167,6 +163,106 @@ func Override() {
// Override ARM64.
s = linux.ARM64
+ s.Table[5] = syscalls.Supported("setxattr", Setxattr)
+ s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr)
+ s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr)
+ s.Table[8] = syscalls.Supported("getxattr", Getxattr)
+ s.Table[9] = syscalls.Supported("lgetxattr", Lgetxattr)
+ s.Table[10] = syscalls.Supported("fgetxattr", Fgetxattr)
+ s.Table[11] = syscalls.Supported("listxattr", Listxattr)
+ s.Table[12] = syscalls.Supported("llistxattr", Llistxattr)
+ s.Table[13] = syscalls.Supported("flistxattr", Flistxattr)
+ s.Table[14] = syscalls.Supported("removexattr", Removexattr)
+ s.Table[15] = syscalls.Supported("lremovexattr", Lremovexattr)
+ s.Table[16] = syscalls.Supported("fremovexattr", Fremovexattr)
+ s.Table[17] = syscalls.Supported("getcwd", Getcwd)
+ s.Table[19] = syscalls.Supported("eventfd2", Eventfd2)
+ s.Table[20] = syscalls.Supported("epoll_create1", EpollCreate1)
+ s.Table[21] = syscalls.Supported("epoll_ctl", EpollCtl)
+ s.Table[22] = syscalls.Supported("epoll_pwait", EpollPwait)
+ s.Table[23] = syscalls.Supported("dup", Dup)
+ s.Table[24] = syscalls.Supported("dup3", Dup3)
+ s.Table[25] = syscalls.Supported("fcntl", Fcntl)
+ s.Table[26] = syscalls.PartiallySupported("inotify_init1", InotifyInit1, "inotify events are only available inside the sandbox.", nil)
+ s.Table[27] = syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil)
+ s.Table[28] = syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil)
+ s.Table[29] = syscalls.Supported("ioctl", Ioctl)
+ s.Table[32] = syscalls.Supported("flock", Flock)
+ s.Table[33] = syscalls.Supported("mknodat", Mknodat)
+ s.Table[34] = syscalls.Supported("mkdirat", Mkdirat)
+ s.Table[35] = syscalls.Supported("unlinkat", Unlinkat)
+ s.Table[36] = syscalls.Supported("symlinkat", Symlinkat)
+ s.Table[37] = syscalls.Supported("linkat", Linkat)
+ s.Table[38] = syscalls.Supported("renameat", Renameat)
+ s.Table[39] = syscalls.Supported("umount2", Umount2)
+ s.Table[40] = syscalls.Supported("mount", Mount)
+ s.Table[43] = syscalls.Supported("statfs", Statfs)
+ s.Table[44] = syscalls.Supported("fstatfs", Fstatfs)
+ s.Table[45] = syscalls.Supported("truncate", Truncate)
+ s.Table[46] = syscalls.Supported("ftruncate", Ftruncate)
+ s.Table[48] = syscalls.Supported("faccessat", Faccessat)
+ s.Table[49] = syscalls.Supported("chdir", Chdir)
+ s.Table[50] = syscalls.Supported("fchdir", Fchdir)
+ s.Table[51] = syscalls.Supported("chroot", Chroot)
+ s.Table[52] = syscalls.Supported("fchmod", Fchmod)
+ s.Table[53] = syscalls.Supported("fchmodat", Fchmodat)
+ s.Table[54] = syscalls.Supported("fchownat", Fchownat)
+ s.Table[55] = syscalls.Supported("fchown", Fchown)
+ s.Table[56] = syscalls.Supported("openat", Openat)
+ s.Table[57] = syscalls.Supported("close", Close)
+ s.Table[59] = syscalls.Supported("pipe2", Pipe2)
+ s.Table[61] = syscalls.Supported("getdents64", Getdents64)
+ s.Table[62] = syscalls.Supported("lseek", Lseek)
s.Table[63] = syscalls.Supported("read", Read)
+ s.Table[64] = syscalls.Supported("write", Write)
+ s.Table[65] = syscalls.Supported("readv", Readv)
+ s.Table[66] = syscalls.Supported("writev", Writev)
+ s.Table[67] = syscalls.Supported("pread64", Pread64)
+ s.Table[68] = syscalls.Supported("pwrite64", Pwrite64)
+ s.Table[69] = syscalls.Supported("preadv", Preadv)
+ s.Table[70] = syscalls.Supported("pwritev", Pwritev)
+ s.Table[72] = syscalls.Supported("pselect", Pselect)
+ s.Table[73] = syscalls.Supported("ppoll", Ppoll)
+ s.Table[74] = syscalls.Supported("signalfd4", Signalfd4)
+ s.Table[76] = syscalls.Supported("splice", Splice)
+ s.Table[77] = syscalls.Supported("tee", Tee)
+ s.Table[78] = syscalls.Supported("readlinkat", Readlinkat)
+ s.Table[80] = syscalls.Supported("fstat", Fstat)
+ s.Table[81] = syscalls.Supported("sync", Sync)
+ s.Table[82] = syscalls.Supported("fsync", Fsync)
+ s.Table[83] = syscalls.Supported("fdatasync", Fdatasync)
+ s.Table[84] = syscalls.Supported("sync_file_range", SyncFileRange)
+ s.Table[85] = syscalls.Supported("timerfd_create", TimerfdCreate)
+ s.Table[86] = syscalls.Supported("timerfd_settime", TimerfdSettime)
+ s.Table[87] = syscalls.Supported("timerfd_gettime", TimerfdGettime)
+ s.Table[88] = syscalls.Supported("utimensat", Utimensat)
+ s.Table[198] = syscalls.Supported("socket", Socket)
+ s.Table[199] = syscalls.Supported("socketpair", SocketPair)
+ s.Table[200] = syscalls.Supported("bind", Bind)
+ s.Table[201] = syscalls.Supported("listen", Listen)
+ s.Table[202] = syscalls.Supported("accept", Accept)
+ s.Table[203] = syscalls.Supported("connect", Connect)
+ s.Table[204] = syscalls.Supported("getsockname", GetSockName)
+ s.Table[205] = syscalls.Supported("getpeername", GetPeerName)
+ s.Table[206] = syscalls.Supported("sendto", SendTo)
+ s.Table[207] = syscalls.Supported("recvfrom", RecvFrom)
+ s.Table[208] = syscalls.Supported("setsockopt", SetSockOpt)
+ s.Table[209] = syscalls.Supported("getsockopt", GetSockOpt)
+ s.Table[210] = syscalls.Supported("shutdown", Shutdown)
+ s.Table[211] = syscalls.Supported("sendmsg", SendMsg)
+ s.Table[212] = syscalls.Supported("recvmsg", RecvMsg)
+ s.Table[221] = syscalls.Supported("execve", Execve)
+ s.Table[222] = syscalls.Supported("mmap", Mmap)
+ s.Table[242] = syscalls.Supported("accept4", Accept4)
+ s.Table[243] = syscalls.Supported("recvmmsg", RecvMMsg)
+ s.Table[267] = syscalls.Supported("syncfs", Syncfs)
+ s.Table[269] = syscalls.Supported("sendmmsg", SendMMsg)
+ s.Table[276] = syscalls.Supported("renameat2", Renameat2)
+ s.Table[279] = syscalls.Supported("memfd_create", MemfdCreate)
+ s.Table[281] = syscalls.Supported("execveat", Execveat)
+ s.Table[286] = syscalls.Supported("preadv2", Preadv2)
+ s.Table[287] = syscalls.Supported("pwritev2", Pwritev2)
+ s.Table[291] = syscalls.Supported("statx", Statx)
+
s.Init()
}
diff --git a/pkg/sentry/time/muldiv_arm64.s b/pkg/sentry/time/muldiv_arm64.s
index 5ad57a8a3..8afc62d53 100644
--- a/pkg/sentry/time/muldiv_arm64.s
+++ b/pkg/sentry/time/muldiv_arm64.s
@@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "funcdata.h"
#include "textflag.h"
// Documentation is available in parameters.go.
//
// func muldiv64(value, multiplier, divisor uint64) (uint64, bool)
TEXT ·muldiv64(SB),NOSPLIT,$40-33
+ GO_ARGS
+ NO_LOCAL_POINTERS
MOVD value+0(FP), R0
MOVD multiplier+8(FP), R1
MOVD divisor+16(FP), R2
diff --git a/pkg/sentry/time/parameters.go b/pkg/sentry/time/parameters.go
index 65868cb26..cd1b95117 100644
--- a/pkg/sentry/time/parameters.go
+++ b/pkg/sentry/time/parameters.go
@@ -228,11 +228,15 @@ func errorAdjust(prevParams Parameters, newParams Parameters, now TSCValue) (Par
//
// The log level is determined by the error severity.
func logErrorAdjustment(clock ClockID, errorNS ReferenceNS, orig, adjusted Parameters) {
- fn := log.Debugf
- if int64(errorNS.Magnitude()) > time.Millisecond.Nanoseconds() {
+ magNS := int64(errorNS.Magnitude())
+ if magNS <= 10*time.Microsecond.Nanoseconds() {
+ // Don't log small errors.
+ return
+ }
+ fn := log.Infof
+ if magNS > time.Millisecond.Nanoseconds() {
+ // Upgrade large errors to warning.
fn = log.Warningf
- } else if int64(errorNS.Magnitude()) > 10*time.Microsecond.Nanoseconds() {
- fn = log.Infof
}
fn("Clock(%v): error: %v ns, adjusted frequency from %v Hz to %v Hz", clock, errorNS, orig.Frequency, adjusted.Frequency)
diff --git a/pkg/sentry/time/parameters_test.go b/pkg/sentry/time/parameters_test.go
index e1b9084ac..0ce1257f6 100644
--- a/pkg/sentry/time/parameters_test.go
+++ b/pkg/sentry/time/parameters_test.go
@@ -484,3 +484,18 @@ func TestMulDivOverflow(t *testing.T) {
})
}
}
+
+func BenchmarkMuldiv64(b *testing.B) {
+ var v uint64 = math.MaxUint64
+ for i := uint64(1); i <= 1000000; i++ {
+ mult := uint64(1000000000)
+ div := i * mult
+ res, ok := muldiv64(v, mult, div)
+ if !ok {
+ b.Errorf("Result of %v * %v / %v ok got false want true", v, mult, div)
+ }
+ if want := v / i; res != want {
+ b.Errorf("Bad result of %v * %v / %v: got %v, want %v", v, mult, div, res, want)
+ }
+ }
+}
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 94d69c1cc..642769e7c 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -15,6 +15,18 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "event_list",
+ out = "event_list.go",
+ package = "vfs",
+ prefix = "event",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Event",
+ "Linker": "*Event",
+ },
+)
+
go_library(
name = "vfs",
srcs = [
@@ -25,11 +37,14 @@ go_library(
"device.go",
"epoll.go",
"epoll_interest_list.go",
+ "event_list.go",
"file_description.go",
"file_description_impl_util.go",
"filesystem.go",
"filesystem_impl_util.go",
"filesystem_type.go",
+ "inotify.go",
+ "lock.go",
"mount.go",
"mount_unsafe.go",
"options.go",
@@ -57,6 +72,7 @@ go_library(
"//pkg/sentry/limits",
"//pkg/sentry/memmap",
"//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/uniqueid",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md
index 9aa133bcb..4b9faf2ea 100644
--- a/pkg/sentry/vfs/README.md
+++ b/pkg/sentry/vfs/README.md
@@ -39,8 +39,8 @@ Mount references are held by:
- Mount: Each referenced Mount holds a reference on its parent, which is the
mount containing its mount point.
-- VirtualFilesystem: A reference is held on each Mount that has not been
- umounted.
+- VirtualFilesystem: A reference is held on each Mount that has been connected
+ to a mount point, but not yet umounted.
MountNamespace and FileDescription references are held by users of VFS. The
expectation is that each `kernel.Task` holds a reference on its corresponding
@@ -169,8 +169,6 @@ This construction, which is essentially a type-safe analogue to Linux's
- binder, which is similarly far too incomplete to use.
- - whitelistfs, which we are already actively attempting to remove.
-
- Save/restore. For instance, it is unclear if the current implementation of
the `state` package supports the inheritance pattern described above.
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index caf770fd5..641e3e502 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -297,3 +297,18 @@ func (d *anonDentry) TryIncRef() bool {
func (d *anonDentry) DecRef() {
// no-op
}
+
+// InotifyWithParent implements DentryImpl.InotifyWithParent.
+//
+// Although Linux technically supports inotify on pseudo filesystems (inotify
+// is implemented at the vfs layer), it is not particularly useful. It is left
+// unimplemented until someone actually needs it.
+func (d *anonDentry) InotifyWithParent(events, cookie uint32, et EventType) {}
+
+// Watches implements DentryImpl.Watches.
+func (d *anonDentry) Watches() *Watches {
+ return nil
+}
+
+// OnZeroWatches implements Dentry.OnZeroWatches.
+func (d *anonDentry) OnZeroWatches() {}
diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go
index 8624dbd5d..cea3e6955 100644
--- a/pkg/sentry/vfs/dentry.go
+++ b/pkg/sentry/vfs/dentry.go
@@ -103,6 +103,39 @@ type DentryImpl interface {
// DecRef decrements the Dentry's reference count.
DecRef()
+
+ // InotifyWithParent notifies all watches on the targets represented by this
+ // dentry and its parent. The parent's watches are notified first, followed
+ // by this dentry's.
+ //
+ // InotifyWithParent automatically adds the IN_ISDIR flag for dentries
+ // representing directories.
+ //
+ // Note that the events may not actually propagate up to the user, depending
+ // on the event masks.
+ InotifyWithParent(events, cookie uint32, et EventType)
+
+ // Watches returns the set of inotify watches for the file corresponding to
+ // the Dentry. Dentries that are hard links to the same underlying file
+ // share the same watches.
+ //
+ // Watches may return nil if the dentry belongs to a FilesystemImpl that
+ // does not support inotify. If an implementation returns a non-nil watch
+ // set, it must always return a non-nil watch set. Likewise, if an
+ // implementation returns a nil watch set, it must always return a nil watch
+ // set.
+ //
+ // The caller does not need to hold a reference on the dentry.
+ Watches() *Watches
+
+ // OnZeroWatches is called whenever the number of watches on a dentry drops
+ // to zero. This is needed by some FilesystemImpls (e.g. gofer) to manage
+ // dentry lifetime.
+ //
+ // The caller does not need to hold a reference on the dentry. OnZeroWatches
+ // may acquire inotify locks, so to prevent deadlock, no inotify locks should
+ // be held by the caller.
+ OnZeroWatches()
}
// IncRef increments d's reference count.
@@ -133,6 +166,26 @@ func (d *Dentry) isMounted() bool {
return atomic.LoadUint32(&d.mounts) != 0
}
+// InotifyWithParent notifies all watches on the targets represented by d and
+// its parent of events.
+func (d *Dentry) InotifyWithParent(events, cookie uint32, et EventType) {
+ d.impl.InotifyWithParent(events, cookie, et)
+}
+
+// Watches returns the set of inotify watches associated with d.
+//
+// Watches will return nil if d belongs to a FilesystemImpl that does not
+// support inotify.
+func (d *Dentry) Watches() *Watches {
+ return d.impl.Watches()
+}
+
+// OnZeroWatches performs cleanup tasks whenever the number of watches on a
+// dentry drops to zero.
+func (d *Dentry) OnZeroWatches() {
+ d.impl.OnZeroWatches()
+}
+
// The following functions are exported so that filesystem implementations can
// use them. The vfs package, and users of VFS, should not call these
// functions.
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 8297f964b..5b009b928 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -31,6 +31,7 @@ type EpollInstance struct {
vfsfd FileDescription
FileDescriptionDefaultImpl
DentryMetadataFileDescriptionImpl
+ NoLockFD
// q holds waiters on this EpollInstance.
q waiter.Queue
@@ -185,7 +186,7 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin
}
// Register interest in file.
- mask := event.Events | linux.EPOLLERR | linux.EPOLLRDHUP
+ mask := event.Events | linux.EPOLLERR | linux.EPOLLHUP
epi := &epollInterest{
epoll: ep,
key: key,
@@ -256,7 +257,7 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event
}
// Update epi for the next call to ep.ReadEvents().
- mask := event.Events | linux.EPOLLERR | linux.EPOLLRDHUP
+ mask := event.Events | linux.EPOLLERR | linux.EPOLLHUP
ep.mu.Lock()
epi.mask = mask
epi.userData = event.Data
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index cfabd936c..93861fb4a 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -42,11 +42,20 @@ type FileDescription struct {
// operations.
refs int64
+ // flagsMu protects statusFlags and asyncHandler below.
+ flagsMu sync.Mutex
+
// statusFlags contains status flags, "initialized by open(2) and possibly
- // modified by fcntl()" - fcntl(2). statusFlags is accessed using atomic
- // memory operations.
+ // modified by fcntl()" - fcntl(2). statusFlags can be read using atomic
+ // memory operations when it does not need to be synchronized with an
+ // access to asyncHandler.
statusFlags uint32
+ // asyncHandler handles O_ASYNC signal generation. It is set with the
+ // F_SETOWN or F_SETOWN_EX fcntls. For asyncHandler to be used, O_ASYNC must
+ // also be set by fcntl(2).
+ asyncHandler FileAsync
+
// epolls is the set of epollInterests registered for this FileDescription.
// epolls is protected by epollMu.
epollMu sync.Mutex
@@ -73,6 +82,8 @@ type FileDescription struct {
// writable is analogous to Linux's FMODE_WRITE.
writable bool
+ usedLockBSD uint32
+
// impl is the FileDescriptionImpl associated with this Filesystem. impl is
// immutable. This should be the last field in FileDescription.
impl FileDescriptionImpl
@@ -80,8 +91,7 @@ type FileDescription struct {
// FileDescriptionOptions contains options to FileDescription.Init().
type FileDescriptionOptions struct {
- // If AllowDirectIO is true, allow O_DIRECT to be set on the file. This is
- // usually only the case if O_DIRECT would actually have an effect.
+ // If AllowDirectIO is true, allow O_DIRECT to be set on the file.
AllowDirectIO bool
// If DenyPRead is true, calls to FileDescription.PRead() return ESPIPE.
@@ -106,6 +116,10 @@ type FileDescriptionOptions struct {
UseDentryMetadata bool
}
+// FileCreationFlags are the set of flags passed to FileDescription.Init() but
+// omitted from FileDescription.StatusFlags().
+const FileCreationFlags = linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC
+
// Init must be called before first use of fd. If it succeeds, it takes
// references on mnt and d. flags is the initial file description flags, which
// is usually the full set of flags passed to open(2).
@@ -120,8 +134,8 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, flags uint32, mnt *Mou
fd.refs = 1
// Remove "file creation flags" to mirror the behavior from file.f_flags in
- // fs/open.c:do_dentry_open
- fd.statusFlags = flags &^ (linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC)
+ // fs/open.c:do_dentry_open.
+ fd.statusFlags = flags &^ FileCreationFlags
fd.vd = VirtualDentry{
mount: mnt,
dentry: d,
@@ -175,12 +189,25 @@ func (fd *FileDescription) DecRef() {
}
ep.interestMu.Unlock()
}
+
+ // If BSD locks were used, release any lock that it may have acquired.
+ if atomic.LoadUint32(&fd.usedLockBSD) != 0 {
+ fd.impl.UnlockBSD(context.Background(), fd)
+ }
+
// Release implementation resources.
fd.impl.Release()
if fd.writable {
fd.vd.mount.EndWrite()
}
fd.vd.DecRef()
+ fd.flagsMu.Lock()
+ // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1.
+ if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ fd.asyncHandler.Unregister(fd)
+ }
+ fd.asyncHandler = nil
+ fd.flagsMu.Unlock()
} else if refs < 0 {
panic("FileDescription.DecRef() called without holding a reference")
}
@@ -210,6 +237,11 @@ func (fd *FileDescription) VirtualDentry() VirtualDentry {
return fd.vd
}
+// Options returns the options passed to fd.Init().
+func (fd *FileDescription) Options() FileDescriptionOptions {
+ return fd.opts
+}
+
// StatusFlags returns file description status flags, as for fcntl(F_GETFL).
func (fd *FileDescription) StatusFlags() uint32 {
return atomic.LoadUint32(&fd.statusFlags)
@@ -259,7 +291,18 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Crede
}
// TODO(jamieliu): FileDescriptionImpl.SetOAsync()?
const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK
- atomic.StoreUint32(&fd.statusFlags, (oldFlags&^settableFlags)|(flags&settableFlags))
+ fd.flagsMu.Lock()
+ if fd.asyncHandler != nil {
+ // Use fd.statusFlags instead of oldFlags, which may have become outdated,
+ // to avoid double registering/unregistering.
+ if fd.statusFlags&linux.O_ASYNC == 0 && flags&linux.O_ASYNC != 0 {
+ fd.asyncHandler.Register(fd)
+ } else if fd.statusFlags&linux.O_ASYNC != 0 && flags&linux.O_ASYNC == 0 {
+ fd.asyncHandler.Unregister(fd)
+ }
+ }
+ fd.statusFlags = (oldFlags &^ settableFlags) | (flags & settableFlags)
+ fd.flagsMu.Unlock()
return nil
}
@@ -311,6 +354,10 @@ type FileDescriptionImpl interface {
// represented by the FileDescription.
StatFS(ctx context.Context) (linux.Statfs, error)
+ // Allocate grows the file to offset + length bytes.
+ // Only mode == 0 is supported currently.
+ Allocate(ctx context.Context, mode, offset, length uint64) error
+
// waiter.Waitable methods may be used to poll for I/O events.
waiter.Waitable
@@ -415,24 +462,16 @@ type FileDescriptionImpl interface {
Removexattr(ctx context.Context, name string) error
// LockBSD tries to acquire a BSD-style advisory file lock.
- //
- // TODO(gvisor.dev/issue/1480): BSD-style file locking
LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error
- // LockBSD releases a BSD-style advisory file lock.
- //
- // TODO(gvisor.dev/issue/1480): BSD-style file locking
+ // UnlockBSD releases a BSD-style advisory file lock.
UnlockBSD(ctx context.Context, uid lock.UniqueID) error
// LockPOSIX tries to acquire a POSIX-style advisory file lock.
- //
- // TODO(gvisor.dev/issue/1480): POSIX-style file locking
- LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, rng lock.LockRange, block lock.Blocker) error
+ LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, length uint64, whence int16, block lock.Blocker) error
// UnlockPOSIX releases a POSIX-style advisory file lock.
- //
- // TODO(gvisor.dev/issue/1480): POSIX-style file locking
- UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng lock.LockRange) error
+ UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, length uint64, whence int16) error
}
// Dirent holds the information contained in struct linux_dirent64.
@@ -462,6 +501,15 @@ type IterDirentsCallback interface {
Handle(dirent Dirent) error
}
+// IterDirentsCallbackFunc implements IterDirentsCallback for a function with
+// the semantics of IterDirentsCallback.Handle.
+type IterDirentsCallbackFunc func(dirent Dirent) error
+
+// Handle implements IterDirentsCallback.Handle.
+func (f IterDirentsCallbackFunc) Handle(dirent Dirent) error {
+ return f(dirent)
+}
+
// OnClose is called when a file descriptor representing the FileDescription is
// closed. Returning a non-nil error should not prevent the file descriptor
// from being closed.
@@ -515,17 +563,28 @@ func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
return fd.impl.StatFS(ctx)
}
-// Readiness returns fd's I/O readiness.
+// Allocate grows file represented by FileDescription to offset + length bytes.
+func (fd *FileDescription) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return fd.impl.Allocate(ctx, mode, offset, length)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+//
+// It returns fd's I/O readiness.
func (fd *FileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
return fd.impl.Readiness(mask)
}
-// EventRegister registers e for I/O readiness events in mask.
+// EventRegister implements waiter.Waitable.EventRegister.
+//
+// It registers e for I/O readiness events in mask.
func (fd *FileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
fd.impl.EventRegister(e, mask)
}
-// EventUnregister unregisters e for I/O readiness events.
+// EventUnregister implements waiter.Waitable.EventUnregister.
+//
+// It unregisters e for I/O readiness events.
func (fd *FileDescription) EventUnregister(e *waiter.Entry) {
fd.impl.EventUnregister(e)
}
@@ -731,3 +790,53 @@ func (fd *FileDescription) InodeID() uint64 {
func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) error {
return fd.Sync(ctx)
}
+
+// LockBSD tries to acquire a BSD-style advisory file lock.
+func (fd *FileDescription) LockBSD(ctx context.Context, lockType lock.LockType, blocker lock.Blocker) error {
+ atomic.StoreUint32(&fd.usedLockBSD, 1)
+ return fd.impl.LockBSD(ctx, fd, lockType, blocker)
+}
+
+// UnlockBSD releases a BSD-style advisory file lock.
+func (fd *FileDescription) UnlockBSD(ctx context.Context) error {
+ return fd.impl.UnlockBSD(ctx, fd)
+}
+
+// LockPOSIX locks a POSIX-style file range lock.
+func (fd *FileDescription) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, end uint64, whence int16, block lock.Blocker) error {
+ return fd.impl.LockPOSIX(ctx, uid, t, start, end, whence, block)
+}
+
+// UnlockPOSIX unlocks a POSIX-style file range lock.
+func (fd *FileDescription) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, end uint64, whence int16) error {
+ return fd.impl.UnlockPOSIX(ctx, uid, start, end, whence)
+}
+
+// A FileAsync sends signals to its owner when w is ready for IO. This is only
+// implemented by pkg/sentry/fasync:FileAsync, but we unfortunately need this
+// interface to avoid circular dependencies.
+type FileAsync interface {
+ Register(w waiter.Waitable)
+ Unregister(w waiter.Waitable)
+}
+
+// AsyncHandler returns the FileAsync for fd.
+func (fd *FileDescription) AsyncHandler() FileAsync {
+ fd.flagsMu.Lock()
+ defer fd.flagsMu.Unlock()
+ return fd.asyncHandler
+}
+
+// SetAsyncHandler sets fd.asyncHandler if it has not been set before and
+// returns it.
+func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsync {
+ fd.flagsMu.Lock()
+ defer fd.flagsMu.Unlock()
+ if fd.asyncHandler == nil {
+ fd.asyncHandler = newHandler()
+ if fd.statusFlags&linux.O_ASYNC != 0 {
+ fd.asyncHandler.Register(fd)
+ }
+ }
+ return fd.asyncHandler
+}
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index f4c111926..6b8b4ad49 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -21,7 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -56,6 +56,12 @@ func (FileDescriptionDefaultImpl) StatFS(ctx context.Context) (linux.Statfs, err
return linux.Statfs{}, syserror.ENOSYS
}
+// Allocate implements FileDescriptionImpl.Allocate analogously to
+// fallocate called on regular file, directory or FIFO in Linux.
+func (FileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.ENODEV
+}
+
// Readiness implements waiter.Waitable.Readiness analogously to
// file_operations::poll == NULL in Linux.
func (FileDescriptionDefaultImpl) Readiness(mask waiter.EventMask) waiter.EventMask {
@@ -153,31 +159,16 @@ func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string)
return syserror.ENOTSUP
}
-// LockBSD implements FileDescriptionImpl.LockBSD.
-func (FileDescriptionDefaultImpl) LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error {
- return syserror.EBADF
-}
-
-// UnlockBSD implements FileDescriptionImpl.UnlockBSD.
-func (FileDescriptionDefaultImpl) UnlockBSD(ctx context.Context, uid lock.UniqueID) error {
- return syserror.EBADF
-}
-
-// LockPOSIX implements FileDescriptionImpl.LockPOSIX.
-func (FileDescriptionDefaultImpl) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, rng lock.LockRange, block lock.Blocker) error {
- return syserror.EBADF
-}
-
-// UnlockPOSIX implements FileDescriptionImpl.UnlockPOSIX.
-func (FileDescriptionDefaultImpl) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng lock.LockRange) error {
- return syserror.EBADF
-}
-
// DirectoryFileDescriptionDefaultImpl may be embedded by implementations of
// FileDescriptionImpl that always represent directories to obtain
// implementations of non-directory I/O methods that return EISDIR.
type DirectoryFileDescriptionDefaultImpl struct{}
+// Allocate implements DirectoryFileDescriptionDefaultImpl.Allocate.
+func (DirectoryFileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.EISDIR
+}
+
// PRead implements FileDescriptionImpl.PRead.
func (DirectoryFileDescriptionDefaultImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
return 0, syserror.EISDIR
@@ -347,7 +338,7 @@ func (fd *DynamicBytesFileDescriptionImpl) pwriteLocked(ctx context.Context, src
writable, ok := fd.data.(WritableDynamicBytesSource)
if !ok {
- return 0, syserror.EINVAL
+ return 0, syserror.EIO
}
n, err := writable.Write(ctx, src, offset)
if err != nil {
@@ -384,3 +375,54 @@ func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.M
fd.IncRef()
return nil
}
+
+// LockFD may be used by most implementations of FileDescriptionImpl.Lock*
+// functions. Caller must call Init().
+type LockFD struct {
+ locks *FileLocks
+}
+
+// Init initializes fd with FileLocks to use.
+func (fd *LockFD) Init(locks *FileLocks) {
+ fd.locks = locks
+}
+
+// Locks returns the locks associated with this file.
+func (fd *LockFD) Locks() *FileLocks {
+ return fd.locks
+}
+
+// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ return fd.locks.LockBSD(uid, t, block)
+}
+
+// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
+func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
+ fd.locks.UnlockBSD(uid)
+ return nil
+}
+
+// NoLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface
+// returning ENOLCK.
+type NoLockFD struct{}
+
+// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+func (NoLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ return syserror.ENOLCK
+}
+
+// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
+func (NoLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
+ return syserror.ENOLCK
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return syserror.ENOLCK
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return syserror.ENOLCK
+}
diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go
index 3a75d4d62..3b7e1c273 100644
--- a/pkg/sentry/vfs/file_description_impl_util_test.go
+++ b/pkg/sentry/vfs/file_description_impl_util_test.go
@@ -33,6 +33,7 @@ import (
type fileDescription struct {
vfsfd FileDescription
FileDescriptionDefaultImpl
+ NoLockFD
}
// genCount contains the number of times its DynamicBytesSource.Generate()
@@ -154,11 +155,11 @@ func TestGenCountFD(t *testing.T) {
}
// Write and PWrite fails.
- if _, err := fd.Write(ctx, ioseq, WriteOptions{}); err != syserror.EINVAL {
- t.Errorf("Write: got err %v, wanted %v", err, syserror.EINVAL)
+ if _, err := fd.Write(ctx, ioseq, WriteOptions{}); err != syserror.EIO {
+ t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO)
}
- if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); err != syserror.EINVAL {
- t.Errorf("Write: got err %v, wanted %v", err, syserror.EINVAL)
+ if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); err != syserror.EIO {
+ t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO)
}
}
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 1edd584c9..6bb9ca180 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -524,8 +524,6 @@ type FilesystemImpl interface {
//
// Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl.
PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error
-
- // TODO(gvisor.dev/issue/1479): inotify_add_watch()
}
// PrependPathAtVFSRootError is returned by implementations of
diff --git a/pkg/sentry/vfs/g3doc/inotify.md b/pkg/sentry/vfs/g3doc/inotify.md
new file mode 100644
index 000000000..e7da49faa
--- /dev/null
+++ b/pkg/sentry/vfs/g3doc/inotify.md
@@ -0,0 +1,210 @@
+# Inotify
+
+Inotify is a mechanism for monitoring filesystem events in Linux--see
+inotify(7). An inotify instance can be used to monitor files and directories for
+modifications, creation/deletion, etc. The inotify API consists of system calls
+that create inotify instances (inotify_init/inotify_init1) and add/remove
+watches on files to an instance (inotify_add_watch/inotify_rm_watch). Events are
+generated from various places in the sentry, including the syscall layer, the
+vfs layer, the process fd table, and within each filesystem implementation. This
+document outlines the implementation details of inotify in VFS2.
+
+## Inotify Objects
+
+Inotify data structures are implemented in the vfs package.
+
+### vfs.Inotify
+
+Inotify instances are represented by vfs.Inotify objects, which implement
+vfs.FileDescriptionImpl. As in Linux, inotify fds are backed by a
+pseudo-filesystem (anonfs). Each inotify instance receives events from a set of
+vfs.Watch objects, which can be modified with inotify_add_watch(2) and
+inotify_rm_watch(2). An application can retrieve events by reading the inotify
+fd.
+
+### vfs.Watches
+
+The set of all watches held on a single file (i.e., the watch target) is stored
+in vfs.Watches. Each watch will belong to a different inotify instance (an
+instance can only have one watch on any watch target). The watches are stored in
+a map indexed by their vfs.Inotify owner’s id. Hard links and file descriptions
+to a single file will all share the same vfs.Watches. Activity on the target
+causes its vfs.Watches to generate notifications on its watches’ inotify
+instances.
+
+### vfs.Watch
+
+A single watch, owned by one inotify instance and applied to one watch target.
+Both the vfs.Inotify owner and vfs.Watches on the target will hold a vfs.Watch,
+which leads to some complicated locking behavior (see Lock Ordering). Whenever a
+watch is notified of an event on its target, it will queue events to its inotify
+instance for delivery to the user.
+
+### vfs.Event
+
+vfs.Event is a simple struct encapsulating all the fields for an inotify event.
+It is generated by vfs.Watches and forwarded to the watches' owners. It is
+serialized to the user during read(2) syscalls on the associated fs.Inotify's
+fd.
+
+## Lock Ordering
+
+There are three locks related to the inotify implementation:
+
+Inotify.mu: the inotify instance lock. Inotify.evMu: the inotify event queue
+lock. Watches.mu: the watch set lock, used to protect the collection of watches
+on a target.
+
+The correct lock ordering for inotify code is:
+
+Inotify.mu -> Watches.mu -> Inotify.evMu.
+
+Note that we use a distinct lock to protect the inotify event queue. If we
+simply used Inotify.mu, we could simultaneously have locks being acquired in the
+order of Inotify.mu -> Watches.mu and Watches.mu -> Inotify.mu, which would
+cause deadlocks. For instance, adding a watch to an inotify instance would
+require locking Inotify.mu, and then adding the same watch to the target would
+cause Watches.mu to be held. At the same time, generating an event on the target
+would require Watches.mu to be held before iterating through each watch, and
+then notifying the owner of each watch would cause Inotify.mu to be held.
+
+See the vfs package comment to understand how inotify locks fit into the overall
+ordering of filesystem locks.
+
+## Watch Targets in Different Filesystem Implementations
+
+In Linux, watches reside on inodes at the virtual filesystem layer. As a result,
+all hard links and file descriptions on a single file will all share the same
+watch set. In VFS2, there is no common inode structure across filesystem types
+(some may not even have inodes), so we have to plumb inotify support through
+each specific filesystem implementation. Some of the technical considerations
+are outlined below.
+
+### Tmpfs
+
+For filesystems with inodes, like tmpfs, the design is quite similar to that of
+Linux, where watches reside on the inode.
+
+### Pseudo-filesystems
+
+Technically, because inotify is implemented at the vfs layer in Linux,
+pseudo-filesystems on top of kernfs support inotify passively. However, watches
+can only track explicit filesystem operations like read/write, open/close,
+mknod, etc., so watches on a target like /proc/self/fd will not generate events
+every time a new fd is added or removed. As of this writing, we leave inotify
+unimplemented in kernfs and anonfs; it does not seem particularly useful.
+
+### Gofer Filesystem (fsimpl/gofer)
+
+The gofer filesystem has several traits that make it difficult to support
+inotify:
+
+* **There are no inodes.** A file is represented as a dentry that holds an
+ unopened p9 file (and possibly an open FID), through which the Sentry
+ interacts with the gofer.
+ * *Solution:* Because there is no inode structure stored in the sandbox,
+ inotify watches must be held on the dentry. This would be an issue in
+ the presence of hard links, where multiple dentries would need to share
+ the same set of watches, but in VFS2, we do not support the internal
+ creation of hard links on gofer fs. As a result, we make the assumption
+ that every dentry corresponds to a unique inode. However, the next point
+ raises an issue with this assumption:
+* **The Sentry cannot always be aware of hard links on the remote
+ filesystem.** There is no way for us to confirm whether two files on the
+ remote filesystem are actually links to the same inode. QIDs and inodes are
+ not always 1:1. The assumption that dentries and inodes are 1:1 is
+ inevitably broken if there are remote hard links that we cannot detect.
+ * *Solution:* this is an issue with gofer fs in general, not only inotify,
+ and we will have to live with it.
+* **Dentries can be cached, and then evicted.** Dentry lifetime does not
+ correspond to file lifetime. Because gofer fs is not entirely in-memory, the
+ absence of a dentry does not mean that the corresponding file does not
+ exist, nor does a dentry reaching zero references mean that the
+ corresponding file no longer exists. When a dentry reaches zero references,
+ it will be cached, in case the file at that path is needed again in the
+ future. However, the dentry may be evicted from the cache, which will cause
+ a new dentry to be created next time the same file path is used. The
+ existing watches will be lost.
+ * *Solution:* When a dentry reaches zero references, do not cache it if it
+ has any watches, so we can avoid eviction/destruction. Note that if the
+ dentry was deleted or invalidated (d.vfsd.IsDead()), we should still
+ destroy it along with its watches. Additionally, when a dentry’s last
+ watch is removed, we cache it if it also has zero references. This way,
+ the dentry can eventually be evicted from memory if it is no longer
+ needed.
+* **Dentries can be invalidated.** Another issue with dentry lifetime is that
+ the remote file at the file path represented may change from underneath the
+ dentry. In this case, the next time that the dentry is used, it will be
+ invalidated and a new dentry will replace it. In this case, it is not clear
+ what should be done with the watches on the old dentry.
+ * *Solution:* Silently destroy the watches when invalidation occurs. We
+ have no way of knowing exactly what happened, when it happens. Inotify
+ instances on NFS files in Linux probably behave in a similar fashion,
+ since inotify is implemented at the vfs layer and is not aware of the
+ complexities of remote file systems.
+ * An alternative would be to issue some kind of event upon invalidation,
+ e.g. a delete event, but this has several issues:
+ * We cannot discern whether the remote file was invalidated because it was
+ moved, deleted, etc. This information is crucial, because these cases
+ should result in different events. Furthermore, the watches should only
+ be destroyed if the file has been deleted.
+ * Moreover, the mechanism for detecting whether the underlying file has
+ changed is to check whether a new QID is given by the gofer. This may
+ result in false positives, e.g. suppose that the server closed and
+ re-opened the same file, which may result in a new QID.
+ * Finally, the time of the event may be completely different from the time
+ of the file modification, since a dentry is not immediately notified
+ when the underlying file has changed. It would be quite unexpected to
+ receive the notification when invalidation was triggered, i.e. the next
+ time the file was accessed within the sandbox, because then the
+ read/write/etc. operation on the file would not result in the expected
+ event.
+ * Another point in favor of the first solution: inotify in Linux can
+ already be lossy on local filesystems (one of the sacrifices made so
+ that filesystem performance isn’t killed), and it is lossy on NFS for
+ similar reasons to gofer fs. Therefore, it is better for inotify to be
+ silent than to emit incorrect notifications.
+* **There may be external users of the remote filesystem.** We can only track
+ operations performed on the file within the sandbox. This is sufficient
+ under InteropModeExclusive, but whenever there are external users, the set
+ of actions we are aware of is incomplete.
+ * *Solution:* We could either return an error or just issue a warning when
+ inotify is used without InteropModeExclusive. Although faulty, VFS1
+ allows it when the filesystem is shared, and Linux does the same for
+ remote filesystems (as mentioned above, inotify sits at the vfs level).
+
+## Dentry Interface
+
+For events that must be generated above the vfs layer, we provide the following
+DentryImpl methods to allow interactions with targets on any FilesystemImpl:
+
+* **InotifyWithParent()** generates events on the dentry’s watches as well as
+ its parent’s.
+* **Watches()** retrieves the watch set of the target represented by the
+ dentry. This is used to access and modify watches on a target.
+* **OnZeroWatches()** performs cleanup tasks after the last watch is removed
+ from a dentry. This is needed by gofer fs, which must allow a watched dentry
+ to be cached once it has no more watches. Most implementations can just do
+ nothing. Note that OnZeroWatches() must be called after all inotify locks
+ are released to preserve lock ordering, since it may acquire
+ FilesystemImpl-specific locks.
+
+## IN_EXCL_UNLINK
+
+There are several options that can be set for a watch, specified as part of the
+mask in inotify_add_watch(2). In particular, IN_EXCL_UNLINK requires some
+additional support in each filesystem.
+
+A watch with IN_EXCL_UNLINK will not generate events for its target if it
+corresponds to a path that was unlinked. For instance, if an fd is opened on
+“foo/bar” and “foo/bar” is subsequently unlinked, any reads/writes/etc. on the
+fd will be ignored by watches on “foo” or “foo/bar” with IN_EXCL_UNLINK. This
+requires each DentryImpl to keep track of whether it has been unlinked, in order
+to determine whether events should be sent to watches with IN_EXCL_UNLINK.
+
+## IN_ONESHOT
+
+One-shot watches expire after generating a single event. When an event occurs,
+all one-shot watches on the target that successfully generated an event are
+removed. Lock ordering can cause the management of one-shot watches to be quite
+expensive; see Watches.Notify() for more information.
diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go
index 286510195..8882fa84a 100644
--- a/pkg/sentry/vfs/genericfstree/genericfstree.go
+++ b/pkg/sentry/vfs/genericfstree/genericfstree.go
@@ -43,7 +43,7 @@ type Dentry struct {
// IsAncestorDentry returns true if d is an ancestor of d2; that is, d is
// either d2's parent or an ancestor of d2's parent.
func IsAncestorDentry(d, d2 *Dentry) bool {
- for {
+ for d2 != nil { // Stop at root, where d2.parent == nil.
if d2.parent == d {
return true
}
@@ -52,6 +52,7 @@ func IsAncestorDentry(d, d2 *Dentry) bool {
}
d2 = d2.parent
}
+ return false
}
// ParentOrSelf returns d.parent. If d.parent is nil, ParentOrSelf returns d.
diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go
new file mode 100644
index 000000000..167b731ac
--- /dev/null
+++ b/pkg/sentry/vfs/inotify.go
@@ -0,0 +1,774 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "bytes"
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// inotifyEventBaseSize is the base size of linux's struct inotify_event. This
+// must be a power 2 for rounding below.
+const inotifyEventBaseSize = 16
+
+// EventType defines different kinds of inotfiy events.
+//
+// The way events are labelled appears somewhat arbitrary, but they must match
+// Linux so that IN_EXCL_UNLINK behaves as it does in Linux.
+type EventType uint8
+
+// PathEvent and InodeEvent correspond to FSNOTIFY_EVENT_PATH and
+// FSNOTIFY_EVENT_INODE in Linux.
+const (
+ PathEvent EventType = iota
+ InodeEvent EventType = iota
+)
+
+// Inotify represents an inotify instance created by inotify_init(2) or
+// inotify_init1(2). Inotify implements FileDescriptionImpl.
+//
+// +stateify savable
+type Inotify struct {
+ vfsfd FileDescription
+ FileDescriptionDefaultImpl
+ DentryMetadataFileDescriptionImpl
+ NoLockFD
+
+ // Unique identifier for this inotify instance. We don't just reuse the
+ // inotify fd because fds can be duped. These should not be exposed to the
+ // user, since we may aggressively reuse an id on S/R.
+ id uint64
+
+ // queue is used to notify interested parties when the inotify instance
+ // becomes readable or writable.
+ queue waiter.Queue `state:"nosave"`
+
+ // evMu *only* protects the events list. We need a separate lock while
+ // queuing events: using mu may violate lock ordering, since at that point
+ // the calling goroutine may already hold Watches.mu.
+ evMu sync.Mutex `state:"nosave"`
+
+ // A list of pending events for this inotify instance. Protected by evMu.
+ events eventList
+
+ // A scratch buffer, used to serialize inotify events. Allocate this
+ // ahead of time for the sake of performance. Protected by evMu.
+ scratch []byte
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // nextWatchMinusOne is used to allocate watch descriptors on this Inotify
+ // instance. Note that Linux starts numbering watch descriptors from 1.
+ nextWatchMinusOne int32
+
+ // Map from watch descriptors to watch objects.
+ watches map[int32]*Watch
+}
+
+var _ FileDescriptionImpl = (*Inotify)(nil)
+
+// NewInotifyFD constructs a new Inotify instance.
+func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32) (*FileDescription, error) {
+ // O_CLOEXEC affects file descriptors, so it must be handled outside of vfs.
+ flags &^= linux.O_CLOEXEC
+ if flags&^linux.O_NONBLOCK != 0 {
+ return nil, syserror.EINVAL
+ }
+
+ id := uniqueid.GlobalFromContext(ctx)
+ vd := vfsObj.NewAnonVirtualDentry(fmt.Sprintf("[inotifyfd:%d]", id))
+ defer vd.DecRef()
+ fd := &Inotify{
+ id: id,
+ scratch: make([]byte, inotifyEventBaseSize),
+ watches: make(map[int32]*Watch),
+ }
+ if err := fd.vfsfd.Init(fd, flags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{
+ UseDentryMetadata: true,
+ DenyPRead: true,
+ DenyPWrite: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// Release implements FileDescriptionImpl.Release. Release removes all
+// watches and frees all resources for an inotify instance.
+func (i *Inotify) Release() {
+ var ds []*Dentry
+
+ // We need to hold i.mu to avoid a race with concurrent calls to
+ // Inotify.handleDeletion from Watches. There's no risk of Watches
+ // accessing this Inotify after the destructor ends, because we remove all
+ // references to it below.
+ i.mu.Lock()
+ for _, w := range i.watches {
+ // Remove references to the watch from the watches set on the target. We
+ // don't need to worry about the references from i.watches, since this
+ // file description is about to be destroyed.
+ d := w.target
+ ws := d.Watches()
+ // Watchable dentries should never return a nil watch set.
+ if ws == nil {
+ panic("Cannot remove watch from an unwatchable dentry")
+ }
+ ws.Remove(i.id)
+ if ws.Size() == 0 {
+ ds = append(ds, d)
+ }
+ }
+ i.mu.Unlock()
+
+ for _, d := range ds {
+ d.OnZeroWatches()
+ }
+}
+
+// Allocate implements FileDescription.Allocate.
+func (i *Inotify) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ panic("Allocate should not be called on read-only inotify fds")
+}
+
+// EventRegister implements waiter.Waitable.
+func (i *Inotify) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ i.queue.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.
+func (i *Inotify) EventUnregister(e *waiter.Entry) {
+ i.queue.EventUnregister(e)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+//
+// Readiness indicates whether there are pending events for an inotify instance.
+func (i *Inotify) Readiness(mask waiter.EventMask) waiter.EventMask {
+ ready := waiter.EventMask(0)
+
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+
+ if !i.events.Empty() {
+ ready |= waiter.EventIn
+ }
+
+ return mask & ready
+}
+
+// PRead implements FileDescriptionImpl.PRead.
+func (*Inotify) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// PWrite implements FileDescriptionImpl.PWrite.
+func (*Inotify) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements FileDescriptionImpl.Write.
+func (*Inotify) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// Read implements FileDescriptionImpl.Read.
+func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
+ if dst.NumBytes() < inotifyEventBaseSize {
+ return 0, syserror.EINVAL
+ }
+
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+
+ if i.events.Empty() {
+ // Nothing to read yet, tell caller to block.
+ return 0, syserror.ErrWouldBlock
+ }
+
+ var writeLen int64
+ for it := i.events.Front(); it != nil; {
+ // Advance `it` before the element is removed from the list, or else
+ // it.Next() will always be nil.
+ event := it
+ it = it.Next()
+
+ // Does the buffer have enough remaining space to hold the event we're
+ // about to write out?
+ if dst.NumBytes() < int64(event.sizeOf()) {
+ if writeLen > 0 {
+ // Buffer wasn't big enough for all pending events, but we did
+ // write some events out.
+ return writeLen, nil
+ }
+ return 0, syserror.EINVAL
+ }
+
+ // Linux always dequeues an available event as long as there's enough
+ // buffer space to copy it out, even if the copy below fails. Emulate
+ // this behaviour.
+ i.events.Remove(event)
+
+ // Buffer has enough space, copy event to the read buffer.
+ n, err := event.CopyTo(ctx, i.scratch, dst)
+ if err != nil {
+ return 0, err
+ }
+
+ writeLen += n
+ dst = dst.DropFirst64(n)
+ }
+ return writeLen, nil
+}
+
+// Ioctl implements FileDescriptionImpl.Ioctl.
+func (i *Inotify) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch args[1].Int() {
+ case linux.FIONREAD:
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+ var n uint32
+ for e := i.events.Front(); e != nil; e = e.Next() {
+ n += uint32(e.sizeOf())
+ }
+ var buf [4]byte
+ usermem.ByteOrder.PutUint32(buf[:], n)
+ _, err := uio.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{})
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+func (i *Inotify) queueEvent(ev *Event) {
+ i.evMu.Lock()
+
+ // Check if we should coalesce the event we're about to queue with the last
+ // one currently in the queue. Events are coalesced if they are identical.
+ if last := i.events.Back(); last != nil {
+ if ev.equals(last) {
+ // "Coalesce" the two events by simply not queuing the new one. We
+ // don't need to raise a waiter.EventIn notification because no new
+ // data is available for reading.
+ i.evMu.Unlock()
+ return
+ }
+ }
+
+ i.events.PushBack(ev)
+
+ // Release mutex before notifying waiters because we don't control what they
+ // can do.
+ i.evMu.Unlock()
+
+ i.queue.Notify(waiter.EventIn)
+}
+
+// newWatchLocked creates and adds a new watch to target.
+//
+// Precondition: i.mu must be locked. ws must be the watch set for target d.
+func (i *Inotify) newWatchLocked(d *Dentry, ws *Watches, mask uint32) *Watch {
+ w := &Watch{
+ owner: i,
+ wd: i.nextWatchIDLocked(),
+ target: d,
+ mask: mask,
+ }
+
+ // Hold the watch in this inotify instance as well as the watch set on the
+ // target.
+ i.watches[w.wd] = w
+ ws.Add(w)
+ return w
+}
+
+// newWatchIDLocked allocates and returns a new watch descriptor.
+//
+// Precondition: i.mu must be locked.
+func (i *Inotify) nextWatchIDLocked() int32 {
+ i.nextWatchMinusOne++
+ return i.nextWatchMinusOne
+}
+
+// AddWatch constructs a new inotify watch and adds it to the target. It
+// returns the watch descriptor returned by inotify_add_watch(2).
+//
+// The caller must hold a reference on target.
+func (i *Inotify) AddWatch(target *Dentry, mask uint32) (int32, error) {
+ // Note: Locking this inotify instance protects the result returned by
+ // Lookup() below. With the lock held, we know for sure the lookup result
+ // won't become stale because it's impossible for *this* instance to
+ // add/remove watches on target.
+ i.mu.Lock()
+ defer i.mu.Unlock()
+
+ ws := target.Watches()
+ if ws == nil {
+ // While Linux supports inotify watches on all filesystem types, watches on
+ // filesystems like kernfs are not generally useful, so we do not.
+ return 0, syserror.EPERM
+ }
+ // Does the target already have a watch from this inotify instance?
+ if existing := ws.Lookup(i.id); existing != nil {
+ newmask := mask
+ if mask&linux.IN_MASK_ADD != 0 {
+ // "Add (OR) events to watch mask for this pathname if it already
+ // exists (instead of replacing mask)." -- inotify(7)
+ newmask |= atomic.LoadUint32(&existing.mask)
+ }
+ atomic.StoreUint32(&existing.mask, newmask)
+ return existing.wd, nil
+ }
+
+ // No existing watch, create a new watch.
+ w := i.newWatchLocked(target, ws, mask)
+ return w.wd, nil
+}
+
+// RmWatch looks up an inotify watch for the given 'wd' and configures the
+// target to stop sending events to this inotify instance.
+func (i *Inotify) RmWatch(wd int32) error {
+ i.mu.Lock()
+
+ // Find the watch we were asked to removed.
+ w, ok := i.watches[wd]
+ if !ok {
+ i.mu.Unlock()
+ return syserror.EINVAL
+ }
+
+ // Remove the watch from this instance.
+ delete(i.watches, wd)
+
+ // Remove the watch from the watch target.
+ ws := w.target.Watches()
+ // AddWatch ensures that w.target has a non-nil watch set.
+ if ws == nil {
+ panic("Watched dentry cannot have nil watch set")
+ }
+ ws.Remove(w.OwnerID())
+ remaining := ws.Size()
+ i.mu.Unlock()
+
+ if remaining == 0 {
+ w.target.OnZeroWatches()
+ }
+
+ // Generate the event for the removal.
+ i.queueEvent(newEvent(wd, "", linux.IN_IGNORED, 0))
+
+ return nil
+}
+
+// Watches is the collection of all inotify watches on a single file.
+//
+// +stateify savable
+type Watches struct {
+ // mu protects the fields below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // ws is the map of active watches in this collection, keyed by the inotify
+ // instance id of the owner.
+ ws map[uint64]*Watch
+}
+
+// Size returns the number of watches held by w.
+func (w *Watches) Size() int {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return len(w.ws)
+}
+
+// Lookup returns the watch owned by an inotify instance with the given id.
+// Returns nil if no such watch exists.
+//
+// Precondition: the inotify instance with the given id must be locked to
+// prevent the returned watch from being concurrently modified or replaced in
+// Inotify.watches.
+func (w *Watches) Lookup(id uint64) *Watch {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return w.ws[id]
+}
+
+// Add adds watch into this set of watches.
+//
+// Precondition: the inotify instance with the given id must be locked.
+func (w *Watches) Add(watch *Watch) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ owner := watch.OwnerID()
+ // Sanity check, we should never have two watches for one owner on the
+ // same target.
+ if _, exists := w.ws[owner]; exists {
+ panic(fmt.Sprintf("Watch collision with ID %+v", owner))
+ }
+ if w.ws == nil {
+ w.ws = make(map[uint64]*Watch)
+ }
+ w.ws[owner] = watch
+}
+
+// Remove removes a watch with the given id from this set of watches and
+// releases it. The caller is responsible for generating any watch removal
+// event, as appropriate. The provided id must match an existing watch in this
+// collection.
+//
+// Precondition: the inotify instance with the given id must be locked.
+func (w *Watches) Remove(id uint64) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ if w.ws == nil {
+ // This watch set is being destroyed. The thread executing the
+ // destructor is already in the process of deleting all our watches. We
+ // got here with no references on the target because we raced with the
+ // destructor notifying all the watch owners of destruction. See the
+ // comment in Watches.HandleDeletion for why this race exists.
+ return
+ }
+
+ // It is possible for w.Remove() to be called for the same watch multiple
+ // times. See the treatment of one-shot watches in Watches.Notify().
+ if _, ok := w.ws[id]; ok {
+ delete(w.ws, id)
+ }
+}
+
+// Notify queues a new event with watches in this set. Watches with
+// IN_EXCL_UNLINK are skipped if the event is coming from a child that has been
+// unlinked.
+func (w *Watches) Notify(name string, events, cookie uint32, et EventType, unlinked bool) {
+ var hasExpired bool
+ w.mu.RLock()
+ for _, watch := range w.ws {
+ if unlinked && watch.ExcludeUnlinked() && et == PathEvent {
+ continue
+ }
+ if watch.Notify(name, events, cookie) {
+ hasExpired = true
+ }
+ }
+ w.mu.RUnlock()
+
+ if hasExpired {
+ w.cleanupExpiredWatches()
+ }
+}
+
+// This function is relatively expensive and should only be called where there
+// are expired watches.
+func (w *Watches) cleanupExpiredWatches() {
+ // Because of lock ordering, we cannot acquire Inotify.mu for each watch
+ // owner while holding w.mu. As a result, store expired watches locally
+ // before removing.
+ var toRemove []*Watch
+ w.mu.RLock()
+ for _, watch := range w.ws {
+ if atomic.LoadInt32(&watch.expired) == 1 {
+ toRemove = append(toRemove, watch)
+ }
+ }
+ w.mu.RUnlock()
+ for _, watch := range toRemove {
+ watch.owner.RmWatch(watch.wd)
+ }
+}
+
+// HandleDeletion is called when the watch target is destroyed. Clear the
+// watch set, detach watches from the inotify instances they belong to, and
+// generate the appropriate events.
+func (w *Watches) HandleDeletion() {
+ w.Notify("", linux.IN_DELETE_SELF, 0, InodeEvent, true /* unlinked */)
+
+ // As in Watches.Notify, we can't hold w.mu while acquiring Inotify.mu for
+ // the owner of each watch being deleted. Instead, atomically store the
+ // watches map in a local variable and set it to nil so we can iterate over
+ // it with the assurance that there will be no concurrent accesses.
+ var ws map[uint64]*Watch
+ w.mu.Lock()
+ ws = w.ws
+ w.ws = nil
+ w.mu.Unlock()
+
+ // Remove each watch from its owner's watch set, and generate a corresponding
+ // watch removal event.
+ for _, watch := range ws {
+ i := watch.owner
+ i.mu.Lock()
+ _, found := i.watches[watch.wd]
+ delete(i.watches, watch.wd)
+
+ // Release mutex before notifying waiters because we don't control what
+ // they can do.
+ i.mu.Unlock()
+
+ // If watch was not found, it was removed from the inotify instance before
+ // we could get to it, in which case we should not generate an event.
+ if found {
+ i.queueEvent(newEvent(watch.wd, "", linux.IN_IGNORED, 0))
+ }
+ }
+}
+
+// Watch represent a particular inotify watch created by inotify_add_watch.
+//
+// +stateify savable
+type Watch struct {
+ // Inotify instance which owns this watch.
+ //
+ // This field is immutable after creation.
+ owner *Inotify
+
+ // Descriptor for this watch. This is unique across an inotify instance.
+ //
+ // This field is immutable after creation.
+ wd int32
+
+ // target is a dentry representing the watch target. Its watch set contains this watch.
+ //
+ // This field is immutable after creation.
+ target *Dentry
+
+ // Events being monitored via this watch. Must be accessed with atomic
+ // memory operations.
+ mask uint32
+
+ // expired is set to 1 to indicate that this watch is a one-shot that has
+ // already sent a notification and therefore can be removed. Must be accessed
+ // with atomic memory operations.
+ expired int32
+}
+
+// OwnerID returns the id of the inotify instance that owns this watch.
+func (w *Watch) OwnerID() uint64 {
+ return w.owner.id
+}
+
+// ExcludeUnlinked indicates whether the watched object should continue to be
+// notified of events originating from a path that has been unlinked.
+//
+// For example, if "foo/bar" is opened and then unlinked, operations on the
+// open fd may be ignored by watches on "foo" and "foo/bar" with IN_EXCL_UNLINK.
+func (w *Watch) ExcludeUnlinked() bool {
+ return atomic.LoadUint32(&w.mask)&linux.IN_EXCL_UNLINK != 0
+}
+
+// Notify queues a new event on this watch. Returns true if this is a one-shot
+// watch that should be deleted, after this event was successfully queued.
+func (w *Watch) Notify(name string, events uint32, cookie uint32) bool {
+ if atomic.LoadInt32(&w.expired) == 1 {
+ // This is a one-shot watch that is already in the process of being
+ // removed. This may happen if a second event reaches the watch target
+ // before this watch has been removed.
+ return false
+ }
+
+ mask := atomic.LoadUint32(&w.mask)
+ if mask&events == 0 {
+ // We weren't watching for this event.
+ return false
+ }
+
+ // Event mask should include bits matched from the watch plus all control
+ // event bits.
+ unmaskableBits := ^uint32(0) &^ linux.IN_ALL_EVENTS
+ effectiveMask := unmaskableBits | mask
+ matchedEvents := effectiveMask & events
+ w.owner.queueEvent(newEvent(w.wd, name, matchedEvents, cookie))
+ if mask&linux.IN_ONESHOT != 0 {
+ atomic.StoreInt32(&w.expired, 1)
+ return true
+ }
+ return false
+}
+
+// Event represents a struct inotify_event from linux.
+//
+// +stateify savable
+type Event struct {
+ eventEntry
+
+ wd int32
+ mask uint32
+ cookie uint32
+
+ // len is computed based on the name field is set automatically by
+ // Event.setName. It should be 0 when no name is set; otherwise it is the
+ // length of the name slice.
+ len uint32
+
+ // The name field has special padding requirements and should only be set by
+ // calling Event.setName.
+ name []byte
+}
+
+func newEvent(wd int32, name string, events, cookie uint32) *Event {
+ e := &Event{
+ wd: wd,
+ mask: events,
+ cookie: cookie,
+ }
+ if name != "" {
+ e.setName(name)
+ }
+ return e
+}
+
+// paddedBytes converts a go string to a null-terminated c-string, padded with
+// null bytes to a total size of 'l'. 'l' must be large enough for all the bytes
+// in the 's' plus at least one null byte.
+func paddedBytes(s string, l uint32) []byte {
+ if l < uint32(len(s)+1) {
+ panic("Converting string to byte array results in truncation, this can lead to buffer-overflow due to the missing null-byte!")
+ }
+ b := make([]byte, l)
+ copy(b, s)
+
+ // b was zero-value initialized during make(), so the rest of the slice is
+ // already filled with null bytes.
+
+ return b
+}
+
+// setName sets the optional name for this event.
+func (e *Event) setName(name string) {
+ // We need to pad the name such that the entire event length ends up a
+ // multiple of inotifyEventBaseSize.
+ unpaddedLen := len(name) + 1
+ // Round up to nearest multiple of inotifyEventBaseSize.
+ e.len = uint32((unpaddedLen + inotifyEventBaseSize - 1) & ^(inotifyEventBaseSize - 1))
+ // Make sure we haven't overflowed and wrapped around when rounding.
+ if unpaddedLen > int(e.len) {
+ panic("Overflow when rounding inotify event size, the 'name' field was too big.")
+ }
+ e.name = paddedBytes(name, e.len)
+}
+
+func (e *Event) sizeOf() int {
+ s := inotifyEventBaseSize + int(e.len)
+ if s < inotifyEventBaseSize {
+ panic("Overflowed event size")
+ }
+ return s
+}
+
+// CopyTo serializes this event to dst. buf is used as a scratch buffer to
+// construct the output. We use a buffer allocated ahead of time for
+// performance. buf must be at least inotifyEventBaseSize bytes.
+func (e *Event) CopyTo(ctx context.Context, buf []byte, dst usermem.IOSequence) (int64, error) {
+ usermem.ByteOrder.PutUint32(buf[0:], uint32(e.wd))
+ usermem.ByteOrder.PutUint32(buf[4:], e.mask)
+ usermem.ByteOrder.PutUint32(buf[8:], e.cookie)
+ usermem.ByteOrder.PutUint32(buf[12:], e.len)
+
+ writeLen := 0
+
+ n, err := dst.CopyOut(ctx, buf)
+ if err != nil {
+ return 0, err
+ }
+ writeLen += n
+ dst = dst.DropFirst(n)
+
+ if e.len > 0 {
+ n, err = dst.CopyOut(ctx, e.name)
+ if err != nil {
+ return 0, err
+ }
+ writeLen += n
+ }
+
+ // Santiy check.
+ if writeLen != e.sizeOf() {
+ panic(fmt.Sprintf("Serialized unexpected amount of data for an event, expected %d, wrote %d.", e.sizeOf(), writeLen))
+ }
+
+ return int64(writeLen), nil
+}
+
+func (e *Event) equals(other *Event) bool {
+ return e.wd == other.wd &&
+ e.mask == other.mask &&
+ e.cookie == other.cookie &&
+ e.len == other.len &&
+ bytes.Equal(e.name, other.name)
+}
+
+// InotifyEventFromStatMask generates the appropriate events for an operation
+// that set the stats specified in mask.
+func InotifyEventFromStatMask(mask uint32) uint32 {
+ var ev uint32
+ if mask&(linux.STATX_UID|linux.STATX_GID|linux.STATX_MODE) != 0 {
+ ev |= linux.IN_ATTRIB
+ }
+ if mask&linux.STATX_SIZE != 0 {
+ ev |= linux.IN_MODIFY
+ }
+
+ if (mask & (linux.STATX_ATIME | linux.STATX_MTIME)) == (linux.STATX_ATIME | linux.STATX_MTIME) {
+ // Both times indicates a utime(s) call.
+ ev |= linux.IN_ATTRIB
+ } else if mask&linux.STATX_ATIME != 0 {
+ ev |= linux.IN_ACCESS
+ } else if mask&linux.STATX_MTIME != 0 {
+ mask |= linux.IN_MODIFY
+ }
+ return ev
+}
+
+// InotifyRemoveChild sends the appriopriate notifications to the watch sets of
+// the child being removed and its parent. Note that unlike most pairs of
+// parent/child notifications, the child is notified first in this case.
+func InotifyRemoveChild(self, parent *Watches, name string) {
+ if self != nil {
+ self.Notify("", linux.IN_ATTRIB, 0, InodeEvent, true /* unlinked */)
+ }
+ if parent != nil {
+ parent.Notify(name, linux.IN_DELETE, 0, InodeEvent, true /* unlinked */)
+ }
+}
+
+// InotifyRename sends the appriopriate notifications to the watch sets of the
+// file being renamed and its old/new parents.
+func InotifyRename(ctx context.Context, renamed, oldParent, newParent *Watches, oldName, newName string, isDir bool) {
+ var dirEv uint32
+ if isDir {
+ dirEv = linux.IN_ISDIR
+ }
+ cookie := uniqueid.InotifyCookie(ctx)
+ if oldParent != nil {
+ oldParent.Notify(oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent, false /* unlinked */)
+ }
+ if newParent != nil {
+ newParent.Notify(newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent, false /* unlinked */)
+ }
+ // Somewhat surprisingly, self move events do not have a cookie.
+ if renamed != nil {
+ renamed.Notify("", linux.IN_MOVE_SELF, 0, InodeEvent, false /* unlinked */)
+ }
+}
diff --git a/pkg/sentry/vfs/lock/lock.go b/pkg/sentry/vfs/lock.go
index 724dfe743..6c7583a81 100644
--- a/pkg/sentry/vfs/lock/lock.go
+++ b/pkg/sentry/vfs/lock.go
@@ -17,9 +17,11 @@
//
// The actual implementations can be found in the lock package under
// sentry/fs/lock.
-package lock
+package vfs
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -56,7 +58,11 @@ func (fl *FileLocks) UnlockBSD(uid fslock.UniqueID) {
}
// LockPOSIX tries to acquire a POSIX-style lock on a file region.
-func (fl *FileLocks) LockPOSIX(uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error {
+func (fl *FileLocks) LockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ rng, err := computeRange(ctx, fd, start, length, whence)
+ if err != nil {
+ return err
+ }
if fl.posix.LockRegion(uid, t, rng, block) {
return nil
}
@@ -67,6 +73,37 @@ func (fl *FileLocks) LockPOSIX(uid fslock.UniqueID, t fslock.LockType, rng fsloc
//
// This operation is always successful, even if there did not exist a lock on
// the requested region held by uid in the first place.
-func (fl *FileLocks) UnlockPOSIX(uid fslock.UniqueID, rng fslock.LockRange) {
+func (fl *FileLocks) UnlockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ rng, err := computeRange(ctx, fd, start, length, whence)
+ if err != nil {
+ return err
+ }
fl.posix.UnlockRegion(uid, rng)
+ return nil
+}
+
+func computeRange(ctx context.Context, fd *FileDescription, start uint64, length uint64, whence int16) (fslock.LockRange, error) {
+ var off int64
+ switch whence {
+ case linux.SEEK_SET:
+ off = 0
+ case linux.SEEK_CUR:
+ // Note that Linux does not hold any mutexes while retrieving the file
+ // offset, see fs/locks.c:flock_to_posix_lock and fs/locks.c:fcntl_setlk.
+ curOff, err := fd.Seek(ctx, 0, linux.SEEK_CUR)
+ if err != nil {
+ return fslock.LockRange{}, err
+ }
+ off = curOff
+ case linux.SEEK_END:
+ stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_SIZE})
+ if err != nil {
+ return fslock.LockRange{}, err
+ }
+ off = int64(stat.Size)
+ default:
+ return fslock.LockRange{}, syserror.EINVAL
+ }
+
+ return fslock.ComputeRange(int64(start), int64(length), off)
}
diff --git a/pkg/sentry/vfs/lock/BUILD b/pkg/sentry/vfs/lock/BUILD
deleted file mode 100644
index d9ab063b7..000000000
--- a/pkg/sentry/vfs/lock/BUILD
+++ /dev/null
@@ -1,13 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "lock",
- srcs = ["lock.go"],
- visibility = ["//pkg/sentry:internal"],
- deps = [
- "//pkg/sentry/fs/lock",
- "//pkg/syserror",
- ],
-)
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 02850b65c..32f901bd8 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -28,9 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-// lastMountID is used to allocate mount ids. Must be accessed atomically.
-var lastMountID uint64
-
// A Mount is a replacement of a Dentry (Mount.key.point) from one Filesystem
// (Mount.key.parent.fs) with a Dentry (Mount.root) from another Filesystem
// (Mount.fs), which applies to path resolution in the context of a particular
@@ -58,6 +55,10 @@ type Mount struct {
// ID is the immutable mount ID.
ID uint64
+ // Flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except
+ // for MS_RDONLY which is tracked in "writers". Immutable.
+ Flags MountFlags
+
// key is protected by VirtualFilesystem.mountMu and
// VirtualFilesystem.mounts.seq, and may be nil. References are held on
// key.parent and key.point if they are not nil.
@@ -84,10 +85,6 @@ type Mount struct {
// umounted is true. umounted is protected by VirtualFilesystem.mountMu.
umounted bool
- // flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except
- // for MS_RDONLY which is tracked in "writers".
- flags MountFlags
-
// The lower 63 bits of writers is the number of calls to
// Mount.CheckBeginWrite() that have not yet been paired with a call to
// Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect.
@@ -97,11 +94,11 @@ type Mount struct {
func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *MountNamespace, opts *MountOptions) *Mount {
mnt := &Mount{
- ID: atomic.AddUint64(&lastMountID, 1),
+ ID: atomic.AddUint64(&vfs.lastMountID, 1),
+ Flags: opts.Flags,
vfs: vfs,
fs: fs,
root: root,
- flags: opts.Flags,
ns: mntns,
refs: 1,
}
@@ -111,8 +108,17 @@ func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *Mount
return mnt
}
-// A MountNamespace is a collection of Mounts.
-//
+// Options returns a copy of the MountOptions currently applicable to mnt.
+func (mnt *Mount) Options() MountOptions {
+ mnt.vfs.mountMu.Lock()
+ defer mnt.vfs.mountMu.Unlock()
+ return MountOptions{
+ Flags: mnt.Flags,
+ ReadOnly: mnt.readOnly(),
+ }
+}
+
+// A MountNamespace is a collection of Mounts.//
// MountNamespaces are reference-counted. Unless otherwise specified, all
// MountNamespace methods require that a reference is held.
//
@@ -120,6 +126,9 @@ func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *Mount
//
// +stateify savable
type MountNamespace struct {
+ // Owner is the usernamespace that owns this mount namespace.
+ Owner *auth.UserNamespace
+
// root is the MountNamespace's root mount. root is immutable.
root *Mount
@@ -148,7 +157,7 @@ type MountNamespace struct {
func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) {
rft := vfs.getFilesystemType(fsTypeName)
if rft == nil {
- ctx.Warningf("Unknown filesystem: %s", fsTypeName)
+ ctx.Warningf("Unknown filesystem type: %s", fsTypeName)
return nil, syserror.ENODEV
}
fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, *opts)
@@ -156,6 +165,7 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth
return nil, err
}
mntns := &MountNamespace{
+ Owner: creds.UserNamespace,
refs: 1,
mountpoints: make(map[*Dentry]uint32),
}
@@ -175,26 +185,34 @@ func (vfs *VirtualFilesystem) NewDisconnectedMount(fs *Filesystem, root *Dentry,
return newMount(vfs, fs, root, nil /* mntns */, opts), nil
}
-// MountAt creates and mounts a Filesystem configured by the given arguments.
-func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error {
+// MountDisconnected creates a Filesystem configured by the given arguments,
+// then returns a Mount representing it. The new Mount is not associated with
+// any MountNamespace and is not connected to any other Mounts.
+func (vfs *VirtualFilesystem) MountDisconnected(ctx context.Context, creds *auth.Credentials, source string, fsTypeName string, opts *MountOptions) (*Mount, error) {
rft := vfs.getFilesystemType(fsTypeName)
if rft == nil {
- return syserror.ENODEV
+ return nil, syserror.ENODEV
}
if !opts.InternalMount && !rft.opts.AllowUserMount {
- return syserror.ENODEV
+ return nil, syserror.ENODEV
}
fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions)
if err != nil {
- return err
+ return nil, err
}
+ defer root.DecRef()
+ defer fs.DecRef()
+ return vfs.NewDisconnectedMount(fs, root, opts)
+}
+// ConnectMountAt connects mnt at the path represented by target.
+//
+// Preconditions: mnt must be disconnected.
+func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Credentials, mnt *Mount, target *PathOperation) error {
// We can't hold vfs.mountMu while calling FilesystemImpl methods due to
// lock ordering.
vd, err := vfs.GetDentryAt(ctx, creds, target, &GetDentryOptions{})
if err != nil {
- root.DecRef()
- fs.DecRef()
return err
}
vfs.mountMu.Lock()
@@ -204,8 +222,6 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia
vd.dentry.mu.Unlock()
vfs.mountMu.Unlock()
vd.DecRef()
- root.DecRef()
- fs.DecRef()
return syserror.ENOENT
}
// vd might have been mounted over between vfs.GetDentryAt() and
@@ -238,7 +254,6 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia
// point and the mount root are directories, or neither are, and returns
// ENOTDIR if this is not the case.
mntns := vd.mount.ns
- mnt := newMount(vfs, fs, root, mntns, opts)
vfs.mounts.seq.BeginWrite()
vfs.connectLocked(mnt, vd, mntns)
vfs.mounts.seq.EndWrite()
@@ -247,6 +262,19 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia
return nil
}
+// MountAt creates and mounts a Filesystem configured by the given arguments.
+func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error {
+ mnt, err := vfs.MountDisconnected(ctx, creds, source, fsTypeName, opts)
+ if err != nil {
+ return err
+ }
+ defer mnt.DecRef()
+ if err := vfs.ConnectMountAt(ctx, creds, mnt, target); err != nil {
+ return err
+ }
+ return nil
+}
+
// UmountAt removes the Mount at the given path.
func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *UmountOptions) error {
if opts.Flags&^(linux.MNT_FORCE|linux.MNT_DETACH) != 0 {
@@ -254,6 +282,9 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti
}
// MNT_FORCE is currently unimplemented except for the permission check.
+ // Force unmounting specifically requires CAP_SYS_ADMIN in the root user
+ // namespace, and not in the owner user namespace for the target mount. See
+ // fs/namespace.c:SYSCALL_DEFINE2(umount, ...)
if opts.Flags&linux.MNT_FORCE != 0 && creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) {
return syserror.EPERM
}
@@ -369,14 +400,22 @@ func (vfs *VirtualFilesystem) umountRecursiveLocked(mnt *Mount, opts *umountRecu
// references held by vd.
//
// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a
-// writer critical section. d.mu must be locked. mnt.parent() == nil.
+// writer critical section. d.mu must be locked. mnt.parent() == nil, i.e. mnt
+// must not already be connected.
func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns *MountNamespace) {
+ if checkInvariants {
+ if mnt.parent() != nil {
+ panic("VFS.connectLocked called on connected mount")
+ }
+ }
+ mnt.IncRef() // dropped by callers of umountRecursiveLocked
mnt.storeKey(vd)
if vd.mount.children == nil {
vd.mount.children = make(map[*Mount]struct{})
}
vd.mount.children[mnt] = struct{}{}
atomic.AddUint32(&vd.dentry.mounts, 1)
+ mnt.ns = mntns
mntns.mountpoints[vd.dentry]++
vfs.mounts.insertSeqed(mnt)
vfsmpmounts, ok := vfs.mountpoints[vd.dentry]
@@ -394,6 +433,11 @@ func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns
// writer critical section. mnt.parent() != nil.
func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry {
vd := mnt.loadKey()
+ if checkInvariants {
+ if vd.mount != nil {
+ panic("VFS.disconnectLocked called on disconnected mount")
+ }
+ }
mnt.storeKey(VirtualDentry{})
delete(vd.mount.children, mnt)
atomic.AddUint32(&vd.dentry.mounts, math.MaxUint32) // -1
@@ -715,7 +759,10 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi
if mnt.readOnly() {
opts = "ro"
}
- if mnt.flags.NoExec {
+ if mnt.Flags.NoATime {
+ opts = ",noatime"
+ }
+ if mnt.Flags.NoExec {
opts += ",noexec"
}
@@ -800,11 +847,12 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo
if mnt.readOnly() {
opts = "ro"
}
- if mnt.flags.NoExec {
+ if mnt.Flags.NoATime {
+ opts = ",noatime"
+ }
+ if mnt.Flags.NoExec {
opts += ",noexec"
}
- // TODO(gvisor.dev/issue/1193): Add "noatime" if MS_NOATIME is
- // set.
fmt.Fprintf(buf, "%s ", opts)
// (7) Optional fields: zero or more fields of the form "tag[:value]".
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index bc7581698..70f850ca4 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.12
-// +build !go1.15
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index 53d364c5c..dfc8573fd 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -75,6 +75,21 @@ type MknodOptions struct {
type MountFlags struct {
// NoExec is equivalent to MS_NOEXEC.
NoExec bool
+
+ // NoATime is equivalent to MS_NOATIME and indicates that the
+ // filesystem should not update access time in-place.
+ NoATime bool
+
+ // NoDev is equivalent to MS_NODEV and indicates that the
+ // filesystem should not allow access to devices (special files).
+ // TODO(gVisor.dev/issue/3186): respect this flag in non FUSE
+ // filesystems.
+ NoDev bool
+
+ // NoSUID is equivalent to MS_NOSUID and indicates that the
+ // filesystem should not honor set-user-ID and set-group-ID bits or
+ // file capabilities when executing programs.
+ NoSUID bool
}
// MountOptions contains options to VirtualFilesystem.MountAt().
@@ -149,6 +164,12 @@ type SetStatOptions struct {
// == UTIME_OMIT (VFS users must unset the corresponding bit in Stat.Mask
// instead).
Stat linux.Statx
+
+ // NeedWritePerm indicates that write permission on the file is needed for
+ // this operation. This is needed for truncate(2) (note that ftruncate(2)
+ // does not require the same check--instead, it checks that the fd is
+ // writable).
+ NeedWritePerm bool
}
// BoundEndpointOptions contains options to VirtualFilesystem.BoundEndpointAt()
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index f9647f90e..33389c1df 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -94,6 +94,37 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, mode linu
return syserror.EACCES
}
+// MayLink determines whether creating a hard link to a file with the given
+// mode, kuid, and kgid is permitted.
+//
+// This corresponds to Linux's fs/namei.c:may_linkat.
+func MayLink(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+ // Source inode owner can hardlink all they like; otherwise, it must be a
+ // safe source.
+ if CanActAsOwner(creds, kuid) {
+ return nil
+ }
+
+ // Only regular files can be hard linked.
+ if mode.FileType() != linux.S_IFREG {
+ return syserror.EPERM
+ }
+
+ // Setuid files should not get pinned to the filesystem.
+ if mode&linux.S_ISUID != 0 {
+ return syserror.EPERM
+ }
+
+ // Executable setgid files should not get pinned to the filesystem, but we
+ // don't support S_IXGRP anyway.
+
+ // Hardlinking to unreadable or unwritable sources is dangerous.
+ if err := GenericCheckPermissions(creds, MayRead|MayWrite, mode, kuid, kgid); err != nil {
+ return syserror.EPERM
+ }
+ return nil
+}
+
// AccessTypesForOpenFlags returns the access types required to open a file
// with the given OpenOptions.Flags. Note that this is NOT the same thing as
// the set of accesses permitted for the opened file:
@@ -152,7 +183,8 @@ func MayWriteFileWithOpenFlags(flags uint32) bool {
// CheckSetStat checks that creds has permission to change the metadata of a
// file with the given permissions, UID, and GID as specified by stat, subject
// to the rules of Linux's fs/attr.c:setattr_prepare().
-func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+func CheckSetStat(ctx context.Context, creds *auth.Credentials, opts *SetStatOptions, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+ stat := &opts.Stat
if stat.Mask&linux.STATX_SIZE != 0 {
limit, err := CheckLimit(ctx, 0, int64(stat.Size))
if err != nil {
@@ -184,6 +216,11 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Stat
return syserror.EPERM
}
}
+ if opts.NeedWritePerm && !creds.HasCapability(linux.CAP_DAC_OVERRIDE) {
+ if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil {
+ return err
+ }
+ }
if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 {
if !CanActAsOwner(creds, kuid) {
if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) ||
@@ -199,6 +236,20 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Stat
return nil
}
+// CheckDeleteSticky checks whether the sticky bit is set on a directory with
+// the given file mode, and if so, checks whether creds has permission to
+// remove a file owned by childKUID from a directory with the given mode.
+// CheckDeleteSticky is consistent with fs/linux.h:check_sticky().
+func CheckDeleteSticky(creds *auth.Credentials, parentMode linux.FileMode, childKUID auth.KUID) error {
+ if parentMode&linux.ModeSticky == 0 {
+ return nil
+ }
+ if CanActAsOwner(creds, childKUID) {
+ return nil
+ }
+ return syserror.EPERM
+}
+
// CanActAsOwner returns true if creds can act as the owner of a file with the
// given owning UID, consistent with Linux's
// fs/inode.c:inode_owner_or_capable().
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 8d7f8f8af..522e27475 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -24,6 +24,9 @@
// Locks acquired by FilesystemImpls between Prepare{Delete,Rename}Dentry and Commit{Delete,Rename*}Dentry
// VirtualFilesystem.filesystemsMu
// EpollInstance.mu
+// Inotify.mu
+// Watches.mu
+// Inotify.evMu
// VirtualFilesystem.fsTypesMu
//
// Locking Dentry.mu in multiple Dentries requires holding
@@ -82,6 +85,10 @@ type VirtualFilesystem struct {
// mountpoints is analogous to Linux's mountpoint_hashtable.
mountpoints map[*Dentry]map[*Mount]struct{}
+ // lastMountID is the last allocated mount ID. lastMountID is accessed
+ // using atomic memory operations.
+ lastMountID uint64
+
// anonMount is a Mount, not included in mounts or mountpoints,
// representing an anonFilesystem. anonMount is used to back
// VirtualDentries returned by VirtualFilesystem.NewAnonVirtualDentry().
@@ -116,6 +123,9 @@ type VirtualFilesystem struct {
// Init initializes a new VirtualFilesystem with no mounts or FilesystemTypes.
func (vfs *VirtualFilesystem) Init() error {
+ if vfs.mountpoints != nil {
+ panic("VFS already initialized")
+ }
vfs.mountpoints = make(map[*Dentry]map[*Mount]struct{})
vfs.devices = make(map[devTuple]*registeredDevice)
vfs.anonBlockDevMinorNext = 1
@@ -401,7 +411,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
vfs.putResolvingPath(rp)
if opts.FileExec {
- if fd.Mount().flags.NoExec {
+ if fd.Mount().Flags.NoExec {
fd.DecRef()
return nil, syserror.EACCES
}
@@ -418,6 +428,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
}
}
+ fd.Dentry().InotifyWithParent(linux.IN_OPEN, 0, PathEvent)
return fd, nil
}
if !rp.handleError(err) {
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index 101497ed6..748273366 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -77,7 +77,10 @@ var DefaultOpts = Opts{
// trigger it.
const descheduleThreshold = 1 * time.Second
-var stuckTasks = metric.MustCreateNewUint64Metric("/watchdog/stuck_tasks_detected", true /* sync */, "Cumulative count of stuck tasks detected")
+var (
+ stuckStartup = metric.MustCreateNewUint64Metric("/watchdog/stuck_startup_detected", true /* sync */, "Incremented once on startup watchdog timeout")
+ stuckTasks = metric.MustCreateNewUint64Metric("/watchdog/stuck_tasks_detected", true /* sync */, "Cumulative count of stuck tasks detected")
+)
// Amount of time to wait before dumping the stack to the log again when the same task(s) remains stuck.
var stackDumpSameTaskPeriod = time.Minute
@@ -220,6 +223,9 @@ func (w *Watchdog) waitForStart() {
// We are fine.
return
}
+
+ stuckStartup.Increment()
+
var buf bytes.Buffer
buf.WriteString(fmt.Sprintf("Watchdog.Start() not called within %s", w.StartupTimeout))
w.doAction(w.StartupTimeoutAction, false, &buf)
@@ -323,13 +329,13 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo
func (w *Watchdog) reportStuckWatchdog() {
var buf bytes.Buffer
- buf.WriteString("Watchdog goroutine is stuck:")
+ buf.WriteString("Watchdog goroutine is stuck")
w.doAction(w.TaskTimeoutAction, false, &buf)
}
// doAction will take the given action. If the action is LogWarning, the stack
-// is not always dumpped to the log to prevent log flooding. "forceStack"
-// guarantees that the stack will be dumped regarless.
+// is not always dumped to the log to prevent log flooding. "forceStack"
+// guarantees that the stack will be dumped regardless.
func (w *Watchdog) doAction(action Action, forceStack bool, msg *bytes.Buffer) {
switch action {
case LogWarning:
diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD
new file mode 100644
index 000000000..f08599ebd
--- /dev/null
+++ b/pkg/shim/runsc/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "runsc",
+ srcs = [
+ "runsc.go",
+ "utils.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "@com_github_containerd_go_runc//:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ ],
+)
diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go
new file mode 100644
index 000000000..c5cf68efa
--- /dev/null
+++ b/pkg/shim/runsc/runsc.go
@@ -0,0 +1,514 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package runsc
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strconv"
+ "syscall"
+ "time"
+
+ runc "github.com/containerd/go-runc"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+var Monitor runc.ProcessMonitor = runc.Monitor
+
+// DefaultCommand is the default command for Runsc.
+const DefaultCommand = "runsc"
+
+// Runsc is the client to the runsc cli.
+type Runsc struct {
+ Command string
+ PdeathSignal syscall.Signal
+ Setpgid bool
+ Root string
+ Log string
+ LogFormat runc.Format
+ Config map[string]string
+}
+
+// List returns all containers created inside the provided runsc root directory.
+func (r *Runsc) List(context context.Context) ([]*runc.Container, error) {
+ data, err := cmdOutput(r.command(context, "list", "--format=json"), false)
+ if err != nil {
+ return nil, err
+ }
+ var out []*runc.Container
+ if err := json.Unmarshal(data, &out); err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// State returns the state for the container provided by id.
+func (r *Runsc) State(context context.Context, id string) (*runc.Container, error) {
+ data, err := cmdOutput(r.command(context, "state", id), true)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %s", err, data)
+ }
+ var c runc.Container
+ if err := json.Unmarshal(data, &c); err != nil {
+ return nil, err
+ }
+ return &c, nil
+}
+
+type CreateOpts struct {
+ runc.IO
+ ConsoleSocket runc.ConsoleSocket
+
+ // PidFile is a path to where a pid file should be created.
+ PidFile string
+
+ // UserLog is a path to where runsc user log should be generated.
+ UserLog string
+}
+
+func (o *CreateOpts) args() (out []string, err error) {
+ if o.PidFile != "" {
+ abs, err := filepath.Abs(o.PidFile)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, "--pid-file", abs)
+ }
+ if o.ConsoleSocket != nil {
+ out = append(out, "--console-socket", o.ConsoleSocket.Path())
+ }
+ if o.UserLog != "" {
+ out = append(out, "--user-log", o.UserLog)
+ }
+ return out, nil
+}
+
+// Create creates a new container and returns its pid if it was created successfully.
+func (r *Runsc) Create(context context.Context, id, bundle string, opts *CreateOpts) error {
+ args := []string{"create", "--bundle", bundle}
+ if opts != nil {
+ oargs, err := opts.args()
+ if err != nil {
+ return err
+ }
+ args = append(args, oargs...)
+ }
+ cmd := r.command(context, append(args, id)...)
+ if opts != nil && opts.IO != nil {
+ opts.Set(cmd)
+ }
+
+ if cmd.Stdout == nil && cmd.Stderr == nil {
+ data, err := cmdOutput(cmd, true)
+ if err != nil {
+ return fmt.Errorf("%s: %s", err, data)
+ }
+ return nil
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return err
+ }
+ if opts != nil && opts.IO != nil {
+ if c, ok := opts.IO.(runc.StartCloser); ok {
+ if err := c.CloseAfterStart(); err != nil {
+ return err
+ }
+ }
+ }
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+
+ return err
+}
+
+// Start will start an already created container.
+func (r *Runsc) Start(context context.Context, id string, cio runc.IO) error {
+ cmd := r.command(context, "start", id)
+ if cio != nil {
+ cio.Set(cmd)
+ }
+
+ if cmd.Stdout == nil && cmd.Stderr == nil {
+ data, err := cmdOutput(cmd, true)
+ if err != nil {
+ return fmt.Errorf("%s: %s", err, data)
+ }
+ return nil
+ }
+
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return err
+ }
+ if cio != nil {
+ if c, ok := cio.(runc.StartCloser); ok {
+ if err := c.CloseAfterStart(); err != nil {
+ return err
+ }
+ }
+ }
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+
+ return err
+}
+
+type waitResult struct {
+ ID string `json:"id"`
+ ExitStatus int `json:"exitStatus"`
+}
+
+// Wait will wait for a running container, and return its exit status.
+//
+// TODO(random-liu): Add exec process support.
+func (r *Runsc) Wait(context context.Context, id string) (int, error) {
+ data, err := cmdOutput(r.command(context, "wait", id), true)
+ if err != nil {
+ return 0, fmt.Errorf("%s: %s", err, data)
+ }
+ var res waitResult
+ if err := json.Unmarshal(data, &res); err != nil {
+ return 0, err
+ }
+ return res.ExitStatus, nil
+}
+
+type ExecOpts struct {
+ runc.IO
+ PidFile string
+ InternalPidFile string
+ ConsoleSocket runc.ConsoleSocket
+ Detach bool
+}
+
+func (o *ExecOpts) args() (out []string, err error) {
+ if o.ConsoleSocket != nil {
+ out = append(out, "--console-socket", o.ConsoleSocket.Path())
+ }
+ if o.Detach {
+ out = append(out, "--detach")
+ }
+ if o.PidFile != "" {
+ abs, err := filepath.Abs(o.PidFile)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, "--pid-file", abs)
+ }
+ if o.InternalPidFile != "" {
+ abs, err := filepath.Abs(o.InternalPidFile)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, "--internal-pid-file", abs)
+ }
+ return out, nil
+}
+
+// Exec executes an additional process inside the container based on a full OCI
+// Process specification.
+func (r *Runsc) Exec(context context.Context, id string, spec specs.Process, opts *ExecOpts) error {
+ f, err := ioutil.TempFile(os.Getenv("XDG_RUNTIME_DIR"), "runsc-process")
+ if err != nil {
+ return err
+ }
+ defer os.Remove(f.Name())
+ err = json.NewEncoder(f).Encode(spec)
+ f.Close()
+ if err != nil {
+ return err
+ }
+ args := []string{"exec", "--process", f.Name()}
+ if opts != nil {
+ oargs, err := opts.args()
+ if err != nil {
+ return err
+ }
+ args = append(args, oargs...)
+ }
+ cmd := r.command(context, append(args, id)...)
+ if opts != nil && opts.IO != nil {
+ opts.Set(cmd)
+ }
+ if cmd.Stdout == nil && cmd.Stderr == nil {
+ data, err := cmdOutput(cmd, true)
+ if err != nil {
+ return fmt.Errorf("%s: %s", err, data)
+ }
+ return nil
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return err
+ }
+ if opts != nil && opts.IO != nil {
+ if c, ok := opts.IO.(runc.StartCloser); ok {
+ if err := c.CloseAfterStart(); err != nil {
+ return err
+ }
+ }
+ }
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+ return err
+}
+
+// Run runs the create, start, delete lifecycle of the container and returns
+// its exit status after it has exited.
+func (r *Runsc) Run(context context.Context, id, bundle string, opts *CreateOpts) (int, error) {
+ args := []string{"run", "--bundle", bundle}
+ if opts != nil {
+ oargs, err := opts.args()
+ if err != nil {
+ return -1, err
+ }
+ args = append(args, oargs...)
+ }
+ cmd := r.command(context, append(args, id)...)
+ if opts != nil && opts.IO != nil {
+ opts.Set(cmd)
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return -1, err
+ }
+ return Monitor.Wait(cmd, ec)
+}
+
+type DeleteOpts struct {
+ Force bool
+}
+
+func (o *DeleteOpts) args() (out []string) {
+ if o.Force {
+ out = append(out, "--force")
+ }
+ return out
+}
+
+// Delete deletes the container.
+func (r *Runsc) Delete(context context.Context, id string, opts *DeleteOpts) error {
+ args := []string{"delete"}
+ if opts != nil {
+ args = append(args, opts.args()...)
+ }
+ return r.runOrError(r.command(context, append(args, id)...))
+}
+
+// KillOpts specifies options for killing a container and its processes.
+type KillOpts struct {
+ All bool
+ Pid int
+}
+
+func (o *KillOpts) args() (out []string) {
+ if o.All {
+ out = append(out, "--all")
+ }
+ if o.Pid != 0 {
+ out = append(out, "--pid", strconv.Itoa(o.Pid))
+ }
+ return out
+}
+
+// Kill sends the specified signal to the container.
+func (r *Runsc) Kill(context context.Context, id string, sig int, opts *KillOpts) error {
+ args := []string{
+ "kill",
+ }
+ if opts != nil {
+ args = append(args, opts.args()...)
+ }
+ return r.runOrError(r.command(context, append(args, id, strconv.Itoa(sig))...))
+}
+
+// Stats return the stats for a container like cpu, memory, and I/O.
+func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) {
+ cmd := r.command(context, "events", "--stats", id)
+ rd, err := cmd.StdoutPipe()
+ if err != nil {
+ return nil, err
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ rd.Close()
+ Monitor.Wait(cmd, ec)
+ }()
+ var e runc.Event
+ if err := json.NewDecoder(rd).Decode(&e); err != nil {
+ return nil, err
+ }
+ return e.Stats, nil
+}
+
+// Events returns an event stream from runsc for a container with stats and OOM notifications.
+func (r *Runsc) Events(context context.Context, id string, interval time.Duration) (chan *runc.Event, error) {
+ cmd := r.command(context, "events", fmt.Sprintf("--interval=%ds", int(interval.Seconds())), id)
+ rd, err := cmd.StdoutPipe()
+ if err != nil {
+ return nil, err
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ rd.Close()
+ return nil, err
+ }
+ var (
+ dec = json.NewDecoder(rd)
+ c = make(chan *runc.Event, 128)
+ )
+ go func() {
+ defer func() {
+ close(c)
+ rd.Close()
+ Monitor.Wait(cmd, ec)
+ }()
+ for {
+ var e runc.Event
+ if err := dec.Decode(&e); err != nil {
+ if err == io.EOF {
+ return
+ }
+ e = runc.Event{
+ Type: "error",
+ Err: err,
+ }
+ }
+ c <- &e
+ }
+ }()
+ return c, nil
+}
+
+// Ps lists all the processes inside the container returning their pids.
+func (r *Runsc) Ps(context context.Context, id string) ([]int, error) {
+ data, err := cmdOutput(r.command(context, "ps", "--format", "json", id), true)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %s", err, data)
+ }
+ var pids []int
+ if err := json.Unmarshal(data, &pids); err != nil {
+ return nil, err
+ }
+ return pids, nil
+}
+
+// Top lists all the processes inside the container returning the full ps data.
+func (r *Runsc) Top(context context.Context, id string) (*runc.TopResults, error) {
+ data, err := cmdOutput(r.command(context, "ps", "--format", "table", id), true)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %s", err, data)
+ }
+
+ topResults, err := runc.ParsePSOutput(data)
+ if err != nil {
+ return nil, fmt.Errorf("%s: ", err)
+ }
+ return topResults, nil
+}
+
+func (r *Runsc) args() []string {
+ var args []string
+ if r.Root != "" {
+ args = append(args, fmt.Sprintf("--root=%s", r.Root))
+ }
+ if r.Log != "" {
+ args = append(args, fmt.Sprintf("--log=%s", r.Log))
+ }
+ if r.LogFormat != "" {
+ args = append(args, fmt.Sprintf("--log-format=%s", r.LogFormat))
+ }
+ for k, v := range r.Config {
+ args = append(args, fmt.Sprintf("--%s=%s", k, v))
+ }
+ return args
+}
+
+// runOrError will run the provided command.
+//
+// If an error is encountered and neither Stdout or Stderr was set the error
+// will be returned in the format of <error>: <stderr>.
+func (r *Runsc) runOrError(cmd *exec.Cmd) error {
+ if cmd.Stdout != nil || cmd.Stderr != nil {
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return err
+ }
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+ return err
+ }
+ data, err := cmdOutput(cmd, true)
+ if err != nil {
+ return fmt.Errorf("%s: %s", err, data)
+ }
+ return nil
+}
+
+func (r *Runsc) command(context context.Context, args ...string) *exec.Cmd {
+ command := r.Command
+ if command == "" {
+ command = DefaultCommand
+ }
+ cmd := exec.CommandContext(context, command, append(r.args(), args...)...)
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setpgid: r.Setpgid,
+ }
+ if r.PdeathSignal != 0 {
+ cmd.SysProcAttr.Pdeathsig = r.PdeathSignal
+ }
+
+ return cmd
+}
+
+func cmdOutput(cmd *exec.Cmd, combined bool) ([]byte, error) {
+ b := getBuf()
+ defer putBuf(b)
+
+ cmd.Stdout = b
+ if combined {
+ cmd.Stderr = b
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return nil, err
+ }
+
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+
+ return b.Bytes(), err
+}
diff --git a/pkg/shim/runsc/utils.go b/pkg/shim/runsc/utils.go
new file mode 100644
index 000000000..c514b3bc7
--- /dev/null
+++ b/pkg/shim/runsc/utils.go
@@ -0,0 +1,44 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package runsc
+
+import (
+ "bytes"
+ "strings"
+ "sync"
+)
+
+var bytesBufferPool = sync.Pool{
+ New: func() interface{} {
+ return bytes.NewBuffer(nil)
+ },
+}
+
+func getBuf() *bytes.Buffer {
+ return bytesBufferPool.Get().(*bytes.Buffer)
+}
+
+func putBuf(b *bytes.Buffer) {
+ b.Reset()
+ bytesBufferPool.Put(b)
+}
+
+// FormatLogPath parses runsc config, and fill in %ID% in the log path.
+func FormatLogPath(id string, config map[string]string) {
+ if path, ok := config["debug-log"]; ok {
+ config["debug-log"] = strings.Replace(path, "%ID%", id, -1)
+ }
+}
diff --git a/pkg/shim/v1/proc/BUILD b/pkg/shim/v1/proc/BUILD
new file mode 100644
index 000000000..4377306af
--- /dev/null
+++ b/pkg/shim/v1/proc/BUILD
@@ -0,0 +1,36 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "proc",
+ srcs = [
+ "deleted_state.go",
+ "exec.go",
+ "exec_state.go",
+ "init.go",
+ "init_state.go",
+ "io.go",
+ "process.go",
+ "types.go",
+ "utils.go",
+ ],
+ visibility = [
+ "//pkg/shim:__subpackages__",
+ "//shim:__subpackages__",
+ ],
+ deps = [
+ "//pkg/shim/runsc",
+ "@com_github_containerd_console//:go_default_library",
+ "@com_github_containerd_containerd//errdefs:go_default_library",
+ "@com_github_containerd_containerd//log:go_default_library",
+ "@com_github_containerd_containerd//mount:go_default_library",
+ "@com_github_containerd_containerd//pkg/process:go_default_library",
+ "@com_github_containerd_containerd//pkg/stdio:go_default_library",
+ "@com_github_containerd_fifo//:go_default_library",
+ "@com_github_containerd_go_runc//:go_default_library",
+ "@com_github_gogo_protobuf//types:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/shim/v1/proc/deleted_state.go b/pkg/shim/v1/proc/deleted_state.go
new file mode 100644
index 000000000..d9b970c4d
--- /dev/null
+++ b/pkg/shim/v1/proc/deleted_state.go
@@ -0,0 +1,49 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/pkg/process"
+)
+
+type deletedState struct{}
+
+func (*deletedState) Resize(ws console.WinSize) error {
+ return fmt.Errorf("cannot resize a deleted process.ss")
+}
+
+func (*deletedState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a deleted process.ss")
+}
+
+func (*deletedState) Delete(ctx context.Context) error {
+ return fmt.Errorf("cannot delete a deleted process.ss: %w", errdefs.ErrNotFound)
+}
+
+func (*deletedState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return fmt.Errorf("cannot kill a deleted process.ss: %w", errdefs.ErrNotFound)
+}
+
+func (*deletedState) SetExited(status int) {}
+
+func (*deletedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ return nil, fmt.Errorf("cannot exec in a deleted state")
+}
diff --git a/pkg/shim/v1/proc/exec.go b/pkg/shim/v1/proc/exec.go
new file mode 100644
index 000000000..1d1d90488
--- /dev/null
+++ b/pkg/shim/v1/proc/exec.go
@@ -0,0 +1,281 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/pkg/stdio"
+ "github.com/containerd/fifo"
+ runc "github.com/containerd/go-runc"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+)
+
+type execProcess struct {
+ wg sync.WaitGroup
+
+ execState execState
+
+ mu sync.Mutex
+ id string
+ console console.Console
+ io runc.IO
+ status int
+ exited time.Time
+ pid int
+ internalPid int
+ closers []io.Closer
+ stdin io.Closer
+ stdio stdio.Stdio
+ path string
+ spec specs.Process
+
+ parent *Init
+ waitBlock chan struct{}
+}
+
+func (e *execProcess) Wait() {
+ <-e.waitBlock
+}
+
+func (e *execProcess) ID() string {
+ return e.id
+}
+
+func (e *execProcess) Pid() int {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.pid
+}
+
+func (e *execProcess) ExitStatus() int {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.status
+}
+
+func (e *execProcess) ExitedAt() time.Time {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.exited
+}
+
+func (e *execProcess) SetExited(status int) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.execState.SetExited(status)
+}
+
+func (e *execProcess) setExited(status int) {
+ e.status = status
+ e.exited = time.Now()
+ e.parent.Platform.ShutdownConsole(context.Background(), e.console)
+ close(e.waitBlock)
+}
+
+func (e *execProcess) Delete(ctx context.Context) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ return e.execState.Delete(ctx)
+}
+
+func (e *execProcess) delete(ctx context.Context) error {
+ e.wg.Wait()
+ if e.io != nil {
+ for _, c := range e.closers {
+ c.Close()
+ }
+ e.io.Close()
+ }
+ pidfile := filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id))
+ // silently ignore error
+ os.Remove(pidfile)
+ internalPidfile := filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id))
+ // silently ignore error
+ os.Remove(internalPidfile)
+ return nil
+}
+
+func (e *execProcess) Resize(ws console.WinSize) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ return e.execState.Resize(ws)
+}
+
+func (e *execProcess) resize(ws console.WinSize) error {
+ if e.console == nil {
+ return nil
+ }
+ return e.console.Resize(ws)
+}
+
+func (e *execProcess) Kill(ctx context.Context, sig uint32, _ bool) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ return e.execState.Kill(ctx, sig, false)
+}
+
+func (e *execProcess) kill(ctx context.Context, sig uint32, _ bool) error {
+ internalPid := e.internalPid
+ if internalPid != 0 {
+ if err := e.parent.runtime.Kill(ctx, e.parent.id, int(sig), &runsc.KillOpts{
+ Pid: internalPid,
+ }); err != nil {
+ // If this returns error, consider the process has
+ // already stopped.
+ //
+ // TODO: Fix after signal handling is fixed.
+ return fmt.Errorf("%s: %w", err.Error(), errdefs.ErrNotFound)
+ }
+ }
+ return nil
+}
+
+func (e *execProcess) Stdin() io.Closer {
+ return e.stdin
+}
+
+func (e *execProcess) Stdio() stdio.Stdio {
+ return e.stdio
+}
+
+func (e *execProcess) Start(ctx context.Context) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ return e.execState.Start(ctx)
+}
+
+func (e *execProcess) start(ctx context.Context) (err error) {
+ var (
+ socket *runc.Socket
+ pidfile = filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id))
+ internalPidfile = filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id))
+ )
+ if e.stdio.Terminal {
+ if socket, err = runc.NewTempConsoleSocket(); err != nil {
+ return fmt.Errorf("failed to create runc console socket: %w", err)
+ }
+ defer socket.Close()
+ } else if e.stdio.IsNull() {
+ if e.io, err = runc.NewNullIO(); err != nil {
+ return fmt.Errorf("creating new NULL IO: %w", err)
+ }
+ } else {
+ if e.io, err = runc.NewPipeIO(e.parent.IoUID, e.parent.IoGID, withConditionalIO(e.stdio)); err != nil {
+ return fmt.Errorf("failed to create runc io pipes: %w", err)
+ }
+ }
+ opts := &runsc.ExecOpts{
+ PidFile: pidfile,
+ InternalPidFile: internalPidfile,
+ IO: e.io,
+ Detach: true,
+ }
+ if socket != nil {
+ opts.ConsoleSocket = socket
+ }
+ eventCh := e.parent.Monitor.Subscribe()
+ defer func() {
+ // Unsubscribe if an error is returned.
+ if err != nil {
+ e.parent.Monitor.Unsubscribe(eventCh)
+ }
+ }()
+ if err := e.parent.runtime.Exec(ctx, e.parent.id, e.spec, opts); err != nil {
+ close(e.waitBlock)
+ return e.parent.runtimeError(err, "OCI runtime exec failed")
+ }
+ if e.stdio.Stdin != "" {
+ sc, err := fifo.OpenFifo(context.Background(), e.stdio.Stdin, syscall.O_WRONLY|syscall.O_NONBLOCK, 0)
+ if err != nil {
+ return fmt.Errorf("failed to open stdin fifo %s: %w", e.stdio.Stdin, err)
+ }
+ e.closers = append(e.closers, sc)
+ e.stdin = sc
+ }
+ ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
+ defer cancel()
+ if socket != nil {
+ console, err := socket.ReceiveMaster()
+ if err != nil {
+ return fmt.Errorf("failed to retrieve console master: %w", err)
+ }
+ if e.console, err = e.parent.Platform.CopyConsole(ctx, console, e.stdio.Stdin, e.stdio.Stdout, e.stdio.Stderr, &e.wg); err != nil {
+ return fmt.Errorf("failed to start console copy: %w", err)
+ }
+ } else if !e.stdio.IsNull() {
+ if err := copyPipes(ctx, e.io, e.stdio.Stdin, e.stdio.Stdout, e.stdio.Stderr, &e.wg); err != nil {
+ return fmt.Errorf("failed to start io pipe copy: %w", err)
+ }
+ }
+ pid, err := runc.ReadPidFile(opts.PidFile)
+ if err != nil {
+ return fmt.Errorf("failed to retrieve OCI runtime exec pid: %w", err)
+ }
+ e.pid = pid
+ internalPid, err := runc.ReadPidFile(opts.InternalPidFile)
+ if err != nil {
+ return fmt.Errorf("failed to retrieve OCI runtime exec internal pid: %w", err)
+ }
+ e.internalPid = internalPid
+ go func() {
+ defer e.parent.Monitor.Unsubscribe(eventCh)
+ for event := range eventCh {
+ if event.Pid == e.pid {
+ ExitCh <- Exit{
+ Timestamp: event.Timestamp,
+ ID: e.id,
+ Status: event.Status,
+ }
+ break
+ }
+ }
+ }()
+ return nil
+}
+
+func (e *execProcess) Status(ctx context.Context) (string, error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ // if we don't have a pid then the exec process has just been created
+ if e.pid == 0 {
+ return "created", nil
+ }
+ // if we have a pid and it can be signaled, the process is running
+ // TODO(random-liu): Use `runsc kill --pid`.
+ if err := unix.Kill(e.pid, 0); err == nil {
+ return "running", nil
+ }
+ // else if we have a pid but it can nolonger be signaled, it has stopped
+ return "stopped", nil
+}
diff --git a/pkg/shim/v1/proc/exec_state.go b/pkg/shim/v1/proc/exec_state.go
new file mode 100644
index 000000000..4dcda8b44
--- /dev/null
+++ b/pkg/shim/v1/proc/exec_state.go
@@ -0,0 +1,154 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/containerd/console"
+)
+
+type execState interface {
+ Resize(console.WinSize) error
+ Start(context.Context) error
+ Delete(context.Context) error
+ Kill(context.Context, uint32, bool) error
+ SetExited(int)
+}
+
+type execCreatedState struct {
+ p *execProcess
+}
+
+func (s *execCreatedState) transition(name string) error {
+ switch name {
+ case "running":
+ s.p.execState = &execRunningState{p: s.p}
+ case "stopped":
+ s.p.execState = &execStoppedState{p: s.p}
+ case "deleted":
+ s.p.execState = &deletedState{}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *execCreatedState) Resize(ws console.WinSize) error {
+ return s.p.resize(ws)
+}
+
+func (s *execCreatedState) Start(ctx context.Context) error {
+ if err := s.p.start(ctx); err != nil {
+ return err
+ }
+ return s.transition("running")
+}
+
+func (s *execCreatedState) Delete(ctx context.Context) error {
+ if err := s.p.delete(ctx); err != nil {
+ return err
+ }
+ return s.transition("deleted")
+}
+
+func (s *execCreatedState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *execCreatedState) SetExited(status int) {
+ s.p.setExited(status)
+
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+}
+
+type execRunningState struct {
+ p *execProcess
+}
+
+func (s *execRunningState) transition(name string) error {
+ switch name {
+ case "stopped":
+ s.p.execState = &execStoppedState{p: s.p}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *execRunningState) Resize(ws console.WinSize) error {
+ return s.p.resize(ws)
+}
+
+func (s *execRunningState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a running process")
+}
+
+func (s *execRunningState) Delete(ctx context.Context) error {
+ return fmt.Errorf("cannot delete a running process")
+}
+
+func (s *execRunningState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *execRunningState) SetExited(status int) {
+ s.p.setExited(status)
+
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+}
+
+type execStoppedState struct {
+ p *execProcess
+}
+
+func (s *execStoppedState) transition(name string) error {
+ switch name {
+ case "deleted":
+ s.p.execState = &deletedState{}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *execStoppedState) Resize(ws console.WinSize) error {
+ return fmt.Errorf("cannot resize a stopped container")
+}
+
+func (s *execStoppedState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a stopped process")
+}
+
+func (s *execStoppedState) Delete(ctx context.Context) error {
+ if err := s.p.delete(ctx); err != nil {
+ return err
+ }
+ return s.transition("deleted")
+}
+
+func (s *execStoppedState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *execStoppedState) SetExited(status int) {
+ // no op
+}
diff --git a/pkg/shim/v1/proc/init.go b/pkg/shim/v1/proc/init.go
new file mode 100644
index 000000000..dab3123d6
--- /dev/null
+++ b/pkg/shim/v1/proc/init.go
@@ -0,0 +1,460 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "path/filepath"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/log"
+ "github.com/containerd/containerd/mount"
+ "github.com/containerd/containerd/pkg/process"
+ "github.com/containerd/containerd/pkg/stdio"
+ "github.com/containerd/fifo"
+ runc "github.com/containerd/go-runc"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+)
+
+// InitPidFile name of the file that contains the init pid.
+const InitPidFile = "init.pid"
+
+// Init represents an initial process for a container.
+type Init struct {
+ wg sync.WaitGroup
+ initState initState
+
+ // mu is used to ensure that `Start()` and `Exited()` calls return in
+ // the right order when invoked in separate go routines. This is the
+ // case within the shim implementation as it makes use of the reaper
+ // interface.
+ mu sync.Mutex
+
+ waitBlock chan struct{}
+
+ WorkDir string
+
+ id string
+ Bundle string
+ console console.Console
+ Platform stdio.Platform
+ io runc.IO
+ runtime *runsc.Runsc
+ status int
+ exited time.Time
+ pid int
+ closers []io.Closer
+ stdin io.Closer
+ stdio stdio.Stdio
+ Rootfs string
+ IoUID int
+ IoGID int
+ Sandbox bool
+ UserLog string
+ Monitor ProcessMonitor
+}
+
+// NewRunsc returns a new runsc instance for a process.
+func NewRunsc(root, path, namespace, runtime string, config map[string]string) *runsc.Runsc {
+ if root == "" {
+ root = RunscRoot
+ }
+ return &runsc.Runsc{
+ Command: runtime,
+ PdeathSignal: syscall.SIGKILL,
+ Log: filepath.Join(path, "log.json"),
+ LogFormat: runc.JSON,
+ Root: filepath.Join(root, namespace),
+ Config: config,
+ }
+}
+
+// New returns a new init process.
+func New(id string, runtime *runsc.Runsc, stdio stdio.Stdio) *Init {
+ p := &Init{
+ id: id,
+ runtime: runtime,
+ stdio: stdio,
+ status: 0,
+ waitBlock: make(chan struct{}),
+ }
+ p.initState = &createdState{p: p}
+ return p
+}
+
+// Create the process with the provided config.
+func (p *Init) Create(ctx context.Context, r *CreateConfig) (err error) {
+ var socket *runc.Socket
+ if r.Terminal {
+ if socket, err = runc.NewTempConsoleSocket(); err != nil {
+ return fmt.Errorf("failed to create OCI runtime console socket: %w", err)
+ }
+ defer socket.Close()
+ } else if hasNoIO(r) {
+ if p.io, err = runc.NewNullIO(); err != nil {
+ return fmt.Errorf("creating new NULL IO: %w", err)
+ }
+ } else {
+ if p.io, err = runc.NewPipeIO(p.IoUID, p.IoGID, withConditionalIO(p.stdio)); err != nil {
+ return fmt.Errorf("failed to create OCI runtime io pipes: %w", err)
+ }
+ }
+ pidFile := filepath.Join(p.Bundle, InitPidFile)
+ opts := &runsc.CreateOpts{
+ PidFile: pidFile,
+ }
+ if socket != nil {
+ opts.ConsoleSocket = socket
+ }
+ if p.Sandbox {
+ opts.IO = p.io
+ // UserLog is only useful for sandbox.
+ opts.UserLog = p.UserLog
+ }
+ if err := p.runtime.Create(ctx, r.ID, r.Bundle, opts); err != nil {
+ return p.runtimeError(err, "OCI runtime create failed")
+ }
+ if r.Stdin != "" {
+ sc, err := fifo.OpenFifo(context.Background(), r.Stdin, syscall.O_WRONLY|syscall.O_NONBLOCK, 0)
+ if err != nil {
+ return fmt.Errorf("failed to open stdin fifo %s: %w", r.Stdin, err)
+ }
+ p.stdin = sc
+ p.closers = append(p.closers, sc)
+ }
+ ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
+ defer cancel()
+ if socket != nil {
+ console, err := socket.ReceiveMaster()
+ if err != nil {
+ return fmt.Errorf("failed to retrieve console master: %w", err)
+ }
+ console, err = p.Platform.CopyConsole(ctx, console, r.Stdin, r.Stdout, r.Stderr, &p.wg)
+ if err != nil {
+ return fmt.Errorf("failed to start console copy: %w", err)
+ }
+ p.console = console
+ } else if !hasNoIO(r) {
+ if err := copyPipes(ctx, p.io, r.Stdin, r.Stdout, r.Stderr, &p.wg); err != nil {
+ return fmt.Errorf("failed to start io pipe copy: %w", err)
+ }
+ }
+ pid, err := runc.ReadPidFile(pidFile)
+ if err != nil {
+ return fmt.Errorf("failed to retrieve OCI runtime container pid: %w", err)
+ }
+ p.pid = pid
+ return nil
+}
+
+// Wait waits for the process to exit.
+func (p *Init) Wait() {
+ <-p.waitBlock
+}
+
+// ID returns the ID of the process.
+func (p *Init) ID() string {
+ return p.id
+}
+
+// Pid returns the PID of the process.
+func (p *Init) Pid() int {
+ return p.pid
+}
+
+// ExitStatus returns the exit status of the process.
+func (p *Init) ExitStatus() int {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.status
+}
+
+// ExitedAt returns the time when the process exited.
+func (p *Init) ExitedAt() time.Time {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.exited
+}
+
+// Status returns the status of the process.
+func (p *Init) Status(ctx context.Context) (string, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ c, err := p.runtime.State(ctx, p.id)
+ if err != nil {
+ if strings.Contains(err.Error(), "does not exist") {
+ return "stopped", nil
+ }
+ return "", p.runtimeError(err, "OCI runtime state failed")
+ }
+ return p.convertStatus(c.Status), nil
+}
+
+// Start starts the init process.
+func (p *Init) Start(ctx context.Context) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ return p.initState.Start(ctx)
+}
+
+func (p *Init) start(ctx context.Context) error {
+ var cio runc.IO
+ if !p.Sandbox {
+ cio = p.io
+ }
+ if err := p.runtime.Start(ctx, p.id, cio); err != nil {
+ return p.runtimeError(err, "OCI runtime start failed")
+ }
+ go func() {
+ status, err := p.runtime.Wait(context.Background(), p.id)
+ if err != nil {
+ log.G(ctx).WithError(err).Errorf("Failed to wait for container %q", p.id)
+ // TODO(random-liu): Handle runsc kill error.
+ if err := p.killAll(ctx); err != nil {
+ log.G(ctx).WithError(err).Errorf("Failed to kill container %q", p.id)
+ }
+ status = internalErrorCode
+ }
+ ExitCh <- Exit{
+ Timestamp: time.Now(),
+ ID: p.id,
+ Status: status,
+ }
+ }()
+ return nil
+}
+
+// SetExited set the exit stauts of the init process.
+func (p *Init) SetExited(status int) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.initState.SetExited(status)
+}
+
+func (p *Init) setExited(status int) {
+ p.exited = time.Now()
+ p.status = status
+ p.Platform.ShutdownConsole(context.Background(), p.console)
+ close(p.waitBlock)
+}
+
+// Delete deletes the init process.
+func (p *Init) Delete(ctx context.Context) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ return p.initState.Delete(ctx)
+}
+
+func (p *Init) delete(ctx context.Context) error {
+ p.killAll(ctx)
+ p.wg.Wait()
+ err := p.runtime.Delete(ctx, p.id, nil)
+ // ignore errors if a runtime has already deleted the process
+ // but we still hold metadata and pipes
+ //
+ // this is common during a checkpoint, runc will delete the container state
+ // after a checkpoint and the container will no longer exist within runc
+ if err != nil {
+ if strings.Contains(err.Error(), "does not exist") {
+ err = nil
+ } else {
+ err = p.runtimeError(err, "failed to delete task")
+ }
+ }
+ if p.io != nil {
+ for _, c := range p.closers {
+ c.Close()
+ }
+ p.io.Close()
+ }
+ if err2 := mount.UnmountAll(p.Rootfs, 0); err2 != nil {
+ log.G(ctx).WithError(err2).Warn("failed to cleanup rootfs mount")
+ if err == nil {
+ err = fmt.Errorf("failed rootfs umount: %w", err2)
+ }
+ }
+ return err
+}
+
+// Resize resizes the init processes console.
+func (p *Init) Resize(ws console.WinSize) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.console == nil {
+ return nil
+ }
+ return p.console.Resize(ws)
+}
+
+func (p *Init) resize(ws console.WinSize) error {
+ if p.console == nil {
+ return nil
+ }
+ return p.console.Resize(ws)
+}
+
+// Kill kills the init process.
+func (p *Init) Kill(ctx context.Context, signal uint32, all bool) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ return p.initState.Kill(ctx, signal, all)
+}
+
+func (p *Init) kill(context context.Context, signal uint32, all bool) error {
+ var (
+ killErr error
+ backoff = 100 * time.Millisecond
+ )
+ timeout := 1 * time.Second
+ for start := time.Now(); time.Now().Sub(start) < timeout; {
+ c, err := p.runtime.State(context, p.id)
+ if err != nil {
+ if strings.Contains(err.Error(), "does not exist") {
+ return fmt.Errorf("no such process: %w", errdefs.ErrNotFound)
+ }
+ return p.runtimeError(err, "OCI runtime state failed")
+ }
+ // For runsc, signal only works when container is running state.
+ // If the container is not in running state, directly return
+ // "no such process"
+ if p.convertStatus(c.Status) == "stopped" {
+ return fmt.Errorf("no such process: %w", errdefs.ErrNotFound)
+ }
+ killErr = p.runtime.Kill(context, p.id, int(signal), &runsc.KillOpts{
+ All: all,
+ })
+ if killErr == nil {
+ return nil
+ }
+ time.Sleep(backoff)
+ backoff *= 2
+ }
+ return p.runtimeError(killErr, "kill timeout")
+}
+
+// KillAll kills all processes belonging to the init process.
+func (p *Init) KillAll(context context.Context) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.killAll(context)
+}
+
+func (p *Init) killAll(context context.Context) error {
+ p.runtime.Kill(context, p.id, int(syscall.SIGKILL), &runsc.KillOpts{
+ All: true,
+ })
+ // Ignore error handling for `runsc kill --all` for now.
+ // * If it doesn't return error, it is good;
+ // * If it returns error, consider the container has already stopped.
+ // TODO: Fix `runsc kill --all` error handling.
+ return nil
+}
+
+// Stdin returns the stdin of the process.
+func (p *Init) Stdin() io.Closer {
+ return p.stdin
+}
+
+// Runtime returns the OCI runtime configured for the init process.
+func (p *Init) Runtime() *runsc.Runsc {
+ return p.runtime
+}
+
+// Exec returns a new child process.
+func (p *Init) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ return p.initState.Exec(ctx, path, r)
+}
+
+// exec returns a new exec'd process.
+func (p *Init) exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ // process exec request
+ var spec specs.Process
+ if err := json.Unmarshal(r.Spec.Value, &spec); err != nil {
+ return nil, err
+ }
+ spec.Terminal = r.Terminal
+
+ e := &execProcess{
+ id: r.ID,
+ path: path,
+ parent: p,
+ spec: spec,
+ stdio: stdio.Stdio{
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Terminal: r.Terminal,
+ },
+ waitBlock: make(chan struct{}),
+ }
+ e.execState = &execCreatedState{p: e}
+ return e, nil
+}
+
+// Stdio returns the stdio of the process.
+func (p *Init) Stdio() stdio.Stdio {
+ return p.stdio
+}
+
+func (p *Init) runtimeError(rErr error, msg string) error {
+ if rErr == nil {
+ return nil
+ }
+
+ rMsg, err := getLastRuntimeError(p.runtime)
+ switch {
+ case err != nil:
+ return fmt.Errorf("%s: %w (unable to retrieve OCI runtime error: %v)", msg, rErr, err)
+ case rMsg == "":
+ return fmt.Errorf("%s: %w", msg, rErr)
+ default:
+ return fmt.Errorf("%s: %s", msg, rMsg)
+ }
+}
+
+func (p *Init) convertStatus(status string) string {
+ if status == "created" && !p.Sandbox && p.status == internalErrorCode {
+ // Treat start failure state for non-root container as stopped.
+ return "stopped"
+ }
+ return status
+}
+
+func withConditionalIO(c stdio.Stdio) runc.IOOpt {
+ return func(o *runc.IOOption) {
+ o.OpenStdin = c.Stdin != ""
+ o.OpenStdout = c.Stdout != ""
+ o.OpenStderr = c.Stderr != ""
+ }
+}
diff --git a/pkg/shim/v1/proc/init_state.go b/pkg/shim/v1/proc/init_state.go
new file mode 100644
index 000000000..9233ecc85
--- /dev/null
+++ b/pkg/shim/v1/proc/init_state.go
@@ -0,0 +1,182 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/pkg/process"
+)
+
+type initState interface {
+ Resize(console.WinSize) error
+ Start(context.Context) error
+ Delete(context.Context) error
+ Exec(context.Context, string, *ExecConfig) (process.Process, error)
+ Kill(context.Context, uint32, bool) error
+ SetExited(int)
+}
+
+type createdState struct {
+ p *Init
+}
+
+func (s *createdState) transition(name string) error {
+ switch name {
+ case "running":
+ s.p.initState = &runningState{p: s.p}
+ case "stopped":
+ s.p.initState = &stoppedState{p: s.p}
+ case "deleted":
+ s.p.initState = &deletedState{}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *createdState) Resize(ws console.WinSize) error {
+ return s.p.resize(ws)
+}
+
+func (s *createdState) Start(ctx context.Context) error {
+ if err := s.p.start(ctx); err != nil {
+ // Containerd doesn't allow deleting container in created state.
+ // However, for gvisor, a non-root container in created state can
+ // only go to running state. If the container can't be started,
+ // it can only stay in created state, and never be deleted.
+ // To work around that, we treat non-root container in start failure
+ // state as stopped.
+ if !s.p.Sandbox {
+ s.p.io.Close()
+ s.p.setExited(internalErrorCode)
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+ }
+ return err
+ }
+ return s.transition("running")
+}
+
+func (s *createdState) Delete(ctx context.Context) error {
+ if err := s.p.delete(ctx); err != nil {
+ return err
+ }
+ return s.transition("deleted")
+}
+
+func (s *createdState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *createdState) SetExited(status int) {
+ s.p.setExited(status)
+
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+}
+
+func (s *createdState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ return s.p.exec(ctx, path, r)
+}
+
+type runningState struct {
+ p *Init
+}
+
+func (s *runningState) transition(name string) error {
+ switch name {
+ case "stopped":
+ s.p.initState = &stoppedState{p: s.p}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *runningState) Resize(ws console.WinSize) error {
+ return s.p.resize(ws)
+}
+
+func (s *runningState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a running process.ss")
+}
+
+func (s *runningState) Delete(ctx context.Context) error {
+ return fmt.Errorf("cannot delete a running process.ss")
+}
+
+func (s *runningState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *runningState) SetExited(status int) {
+ s.p.setExited(status)
+
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+}
+
+func (s *runningState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ return s.p.exec(ctx, path, r)
+}
+
+type stoppedState struct {
+ p *Init
+}
+
+func (s *stoppedState) transition(name string) error {
+ switch name {
+ case "deleted":
+ s.p.initState = &deletedState{}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *stoppedState) Resize(ws console.WinSize) error {
+ return fmt.Errorf("cannot resize a stopped container")
+}
+
+func (s *stoppedState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a stopped process.ss")
+}
+
+func (s *stoppedState) Delete(ctx context.Context) error {
+ if err := s.p.delete(ctx); err != nil {
+ return err
+ }
+ return s.transition("deleted")
+}
+
+func (s *stoppedState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return errdefs.ToGRPCf(errdefs.ErrNotFound, "process.ss %s not found", s.p.id)
+}
+
+func (s *stoppedState) SetExited(status int) {
+ // no op
+}
+
+func (s *stoppedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ return nil, fmt.Errorf("cannot exec in a stopped state")
+}
diff --git a/pkg/shim/v1/proc/io.go b/pkg/shim/v1/proc/io.go
new file mode 100644
index 000000000..34d825fb7
--- /dev/null
+++ b/pkg/shim/v1/proc/io.go
@@ -0,0 +1,162 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "github.com/containerd/containerd/log"
+ "github.com/containerd/fifo"
+ runc "github.com/containerd/go-runc"
+)
+
+// TODO(random-liu): This file can be a util.
+
+var bufPool = sync.Pool{
+ New: func() interface{} {
+ buffer := make([]byte, 32<<10)
+ return &buffer
+ },
+}
+
+func copyPipes(ctx context.Context, rio runc.IO, stdin, stdout, stderr string, wg *sync.WaitGroup) error {
+ var sameFile *countingWriteCloser
+ for _, i := range []struct {
+ name string
+ dest func(wc io.WriteCloser, rc io.Closer)
+ }{
+ {
+ name: stdout,
+ dest: func(wc io.WriteCloser, rc io.Closer) {
+ wg.Add(1)
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ if _, err := io.CopyBuffer(wc, rio.Stdout(), *p); err != nil {
+ log.G(ctx).Warn("error copying stdout")
+ }
+ wg.Done()
+ wc.Close()
+ if rc != nil {
+ rc.Close()
+ }
+ }()
+ },
+ }, {
+ name: stderr,
+ dest: func(wc io.WriteCloser, rc io.Closer) {
+ wg.Add(1)
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ if _, err := io.CopyBuffer(wc, rio.Stderr(), *p); err != nil {
+ log.G(ctx).Warn("error copying stderr")
+ }
+ wg.Done()
+ wc.Close()
+ if rc != nil {
+ rc.Close()
+ }
+ }()
+ },
+ },
+ } {
+ ok, err := isFifo(i.name)
+ if err != nil {
+ return err
+ }
+ var (
+ fw io.WriteCloser
+ fr io.Closer
+ )
+ if ok {
+ if fw, err = fifo.OpenFifo(ctx, i.name, syscall.O_WRONLY, 0); err != nil {
+ return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err)
+ }
+ if fr, err = fifo.OpenFifo(ctx, i.name, syscall.O_RDONLY, 0); err != nil {
+ return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err)
+ }
+ } else {
+ if sameFile != nil {
+ sameFile.count++
+ i.dest(sameFile, nil)
+ continue
+ }
+ if fw, err = os.OpenFile(i.name, syscall.O_WRONLY|syscall.O_APPEND, 0); err != nil {
+ return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err)
+ }
+ if stdout == stderr {
+ sameFile = &countingWriteCloser{
+ WriteCloser: fw,
+ count: 1,
+ }
+ }
+ }
+ i.dest(fw, fr)
+ }
+ if stdin == "" {
+ return nil
+ }
+ f, err := fifo.OpenFifo(context.Background(), stdin, syscall.O_RDONLY|syscall.O_NONBLOCK, 0)
+ if err != nil {
+ return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", stdin, err)
+ }
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+
+ io.CopyBuffer(rio.Stdin(), f, *p)
+ rio.Stdin().Close()
+ f.Close()
+ }()
+ return nil
+}
+
+// countingWriteCloser masks io.Closer() until close has been invoked a certain number of times.
+type countingWriteCloser struct {
+ io.WriteCloser
+ count int64
+}
+
+func (c *countingWriteCloser) Close() error {
+ if atomic.AddInt64(&c.count, -1) > 0 {
+ return nil
+ }
+ return c.WriteCloser.Close()
+}
+
+// isFifo checks if a file is a fifo.
+//
+// If the file does not exist then it returns false.
+func isFifo(path string) (bool, error) {
+ stat, err := os.Stat(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return false, nil
+ }
+ return false, err
+ }
+ if stat.Mode()&os.ModeNamedPipe == os.ModeNamedPipe {
+ return true, nil
+ }
+ return false, nil
+}
diff --git a/pkg/shim/v1/proc/process.go b/pkg/shim/v1/proc/process.go
new file mode 100644
index 000000000..d462c3eef
--- /dev/null
+++ b/pkg/shim/v1/proc/process.go
@@ -0,0 +1,37 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+)
+
+// RunscRoot is the path to the root runsc state directory.
+const RunscRoot = "/run/containerd/runsc"
+
+func stateName(v interface{}) string {
+ switch v.(type) {
+ case *runningState, *execRunningState:
+ return "running"
+ case *createdState, *execCreatedState:
+ return "created"
+ case *deletedState:
+ return "deleted"
+ case *stoppedState:
+ return "stopped"
+ }
+ panic(fmt.Errorf("invalid state %v", v))
+}
diff --git a/pkg/shim/v1/proc/types.go b/pkg/shim/v1/proc/types.go
new file mode 100644
index 000000000..2b0df4663
--- /dev/null
+++ b/pkg/shim/v1/proc/types.go
@@ -0,0 +1,69 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "time"
+
+ runc "github.com/containerd/go-runc"
+ "github.com/gogo/protobuf/types"
+)
+
+// Mount holds filesystem mount configuration.
+type Mount struct {
+ Type string
+ Source string
+ Target string
+ Options []string
+}
+
+// CreateConfig hold task creation configuration.
+type CreateConfig struct {
+ ID string
+ Bundle string
+ Runtime string
+ Rootfs []Mount
+ Terminal bool
+ Stdin string
+ Stdout string
+ Stderr string
+ Options *types.Any
+}
+
+// ExecConfig holds exec creation configuration.
+type ExecConfig struct {
+ ID string
+ Terminal bool
+ Stdin string
+ Stdout string
+ Stderr string
+ Spec *types.Any
+}
+
+// Exit is the type of exit events.
+type Exit struct {
+ Timestamp time.Time
+ ID string
+ Status int
+}
+
+// ProcessMonitor monitors process exit changes.
+type ProcessMonitor interface {
+ // Subscribe to process exit changes
+ Subscribe() chan runc.Exit
+ // Unsubscribe to process exit changes
+ Unsubscribe(c chan runc.Exit)
+}
diff --git a/pkg/shim/v1/proc/utils.go b/pkg/shim/v1/proc/utils.go
new file mode 100644
index 000000000..716de2f59
--- /dev/null
+++ b/pkg/shim/v1/proc/utils.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "encoding/json"
+ "io"
+ "os"
+ "strings"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+)
+
+const (
+ internalErrorCode = 128
+ bufferSize = 32
+)
+
+// ExitCh is the exit events channel for containers and exec processes
+// inside the sandbox.
+var ExitCh = make(chan Exit, bufferSize)
+
+// TODO(mlaventure): move to runc package?
+func getLastRuntimeError(r *runsc.Runsc) (string, error) {
+ if r.Log == "" {
+ return "", nil
+ }
+
+ f, err := os.OpenFile(r.Log, os.O_RDONLY, 0400)
+ if err != nil {
+ return "", err
+ }
+
+ var (
+ errMsg string
+ log struct {
+ Level string
+ Msg string
+ Time time.Time
+ }
+ )
+
+ dec := json.NewDecoder(f)
+ for err = nil; err == nil; {
+ if err = dec.Decode(&log); err != nil && err != io.EOF {
+ return "", err
+ }
+ if log.Level == "error" {
+ errMsg = strings.TrimSpace(log.Msg)
+ }
+ }
+
+ return errMsg, nil
+}
+
+func copyFile(to, from string) error {
+ ff, err := os.Open(from)
+ if err != nil {
+ return err
+ }
+ defer ff.Close()
+ tt, err := os.Create(to)
+ if err != nil {
+ return err
+ }
+ defer tt.Close()
+
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ _, err = io.CopyBuffer(tt, ff, *p)
+ return err
+}
+
+func hasNoIO(r *CreateConfig) bool {
+ return r.Stdin == "" && r.Stdout == "" && r.Stderr == ""
+}
diff --git a/pkg/shim/v1/shim/BUILD b/pkg/shim/v1/shim/BUILD
new file mode 100644
index 000000000..05c595bc9
--- /dev/null
+++ b/pkg/shim/v1/shim/BUILD
@@ -0,0 +1,40 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "shim",
+ srcs = [
+ "api.go",
+ "platform.go",
+ "service.go",
+ ],
+ visibility = [
+ "//pkg/shim:__subpackages__",
+ "//shim:__subpackages__",
+ ],
+ deps = [
+ "//pkg/shim/runsc",
+ "//pkg/shim/v1/proc",
+ "//pkg/shim/v1/utils",
+ "@com_github_containerd_console//:go_default_library",
+ "@com_github_containerd_containerd//api/events:go_default_library",
+ "@com_github_containerd_containerd//api/types/task:go_default_library",
+ "@com_github_containerd_containerd//errdefs:go_default_library",
+ "@com_github_containerd_containerd//events:go_default_library",
+ "@com_github_containerd_containerd//log:go_default_library",
+ "@com_github_containerd_containerd//mount:go_default_library",
+ "@com_github_containerd_containerd//namespaces:go_default_library",
+ "@com_github_containerd_containerd//pkg/process:go_default_library",
+ "@com_github_containerd_containerd//pkg/stdio:go_default_library",
+ "@com_github_containerd_containerd//runtime:go_default_library",
+ "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library",
+ "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library",
+ "@com_github_containerd_containerd//sys/reaper:go_default_library",
+ "@com_github_containerd_fifo//:go_default_library",
+ "@com_github_containerd_typeurl//:go_default_library",
+ "@com_github_gogo_protobuf//types:go_default_library",
+ "@org_golang_google_grpc//codes:go_default_library",
+ "@org_golang_google_grpc//status:go_default_library",
+ ],
+)
diff --git a/pkg/shim/v1/shim/api.go b/pkg/shim/v1/shim/api.go
new file mode 100644
index 000000000..5dd8ff172
--- /dev/null
+++ b/pkg/shim/v1/shim/api.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shim
+
+import (
+ "github.com/containerd/containerd/api/events"
+)
+
+type TaskCreate = events.TaskCreate
+type TaskStart = events.TaskStart
+type TaskOOM = events.TaskOOM
+type TaskExit = events.TaskExit
+type TaskDelete = events.TaskDelete
+type TaskExecAdded = events.TaskExecAdded
+type TaskExecStarted = events.TaskExecStarted
diff --git a/pkg/shim/v1/shim/platform.go b/pkg/shim/v1/shim/platform.go
new file mode 100644
index 000000000..f590f80ef
--- /dev/null
+++ b/pkg/shim/v1/shim/platform.go
@@ -0,0 +1,106 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shim
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "sync"
+ "syscall"
+
+ "github.com/containerd/console"
+ "github.com/containerd/fifo"
+)
+
+type linuxPlatform struct {
+ epoller *console.Epoller
+}
+
+func (p *linuxPlatform) CopyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) (console.Console, error) {
+ if p.epoller == nil {
+ return nil, fmt.Errorf("uninitialized epoller")
+ }
+
+ epollConsole, err := p.epoller.Add(console)
+ if err != nil {
+ return nil, err
+ }
+
+ if stdin != "" {
+ in, err := fifo.OpenFifo(ctx, stdin, syscall.O_RDONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ io.CopyBuffer(epollConsole, in, *p)
+ }()
+ }
+
+ outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ outr, err := fifo.OpenFifo(ctx, stdout, syscall.O_RDONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ wg.Add(1)
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ io.CopyBuffer(outw, epollConsole, *p)
+ epollConsole.Close()
+ outr.Close()
+ outw.Close()
+ wg.Done()
+ }()
+ return epollConsole, nil
+}
+
+func (p *linuxPlatform) ShutdownConsole(ctx context.Context, cons console.Console) error {
+ if p.epoller == nil {
+ return fmt.Errorf("uninitialized epoller")
+ }
+ epollConsole, ok := cons.(*console.EpollConsole)
+ if !ok {
+ return fmt.Errorf("expected EpollConsole, got %#v", cons)
+ }
+ return epollConsole.Shutdown(p.epoller.CloseConsole)
+}
+
+func (p *linuxPlatform) Close() error {
+ return p.epoller.Close()
+}
+
+// initialize a single epoll fd to manage our consoles. `initPlatform` should
+// only be called once.
+func (s *Service) initPlatform() error {
+ if s.platform != nil {
+ return nil
+ }
+ epoller, err := console.NewEpoller()
+ if err != nil {
+ return fmt.Errorf("failed to initialize epoller: %w", err)
+ }
+ s.platform = &linuxPlatform{
+ epoller: epoller,
+ }
+ go epoller.Wait()
+ return nil
+}
diff --git a/pkg/shim/v1/shim/service.go b/pkg/shim/v1/shim/service.go
new file mode 100644
index 000000000..84a810cb2
--- /dev/null
+++ b/pkg/shim/v1/shim/service.go
@@ -0,0 +1,573 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shim
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "sync"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/api/types/task"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/events"
+ "github.com/containerd/containerd/log"
+ "github.com/containerd/containerd/mount"
+ "github.com/containerd/containerd/namespaces"
+ "github.com/containerd/containerd/pkg/process"
+ "github.com/containerd/containerd/pkg/stdio"
+ "github.com/containerd/containerd/runtime"
+ "github.com/containerd/containerd/runtime/linux/runctypes"
+ shim "github.com/containerd/containerd/runtime/v1/shim/v1"
+ "github.com/containerd/containerd/sys/reaper"
+ "github.com/containerd/typeurl"
+ "github.com/gogo/protobuf/types"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+ "gvisor.dev/gvisor/pkg/shim/v1/proc"
+ "gvisor.dev/gvisor/pkg/shim/v1/utils"
+)
+
+var (
+ empty = &types.Empty{}
+ bufPool = sync.Pool{
+ New: func() interface{} {
+ buffer := make([]byte, 32<<10)
+ return &buffer
+ },
+ }
+)
+
+// Config contains shim specific configuration.
+type Config struct {
+ Path string
+ Namespace string
+ WorkDir string
+ RuntimeRoot string
+ RunscConfig map[string]string
+}
+
+// NewService returns a new shim service that can be used via GRPC.
+func NewService(config Config, publisher events.Publisher) (*Service, error) {
+ if config.Namespace == "" {
+ return nil, fmt.Errorf("shim namespace cannot be empty")
+ }
+ ctx := namespaces.WithNamespace(context.Background(), config.Namespace)
+ s := &Service{
+ config: config,
+ context: ctx,
+ processes: make(map[string]process.Process),
+ events: make(chan interface{}, 128),
+ ec: proc.ExitCh,
+ }
+ go s.processExits()
+ if err := s.initPlatform(); err != nil {
+ return nil, fmt.Errorf("failed to initialized platform behavior: %w", err)
+ }
+ go s.forward(publisher)
+ return s, nil
+}
+
+// Service is the shim implementation of a remote shim over GRPC.
+type Service struct {
+ mu sync.Mutex
+
+ config Config
+ context context.Context
+ processes map[string]process.Process
+ events chan interface{}
+ platform stdio.Platform
+ ec chan proc.Exit
+
+ // Filled by Create()
+ id string
+ bundle string
+}
+
+// Create creates a new initial process and container with the underlying OCI runtime.
+func (s *Service) Create(ctx context.Context, r *shim.CreateTaskRequest) (_ *shim.CreateTaskResponse, err error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var mounts []proc.Mount
+ for _, m := range r.Rootfs {
+ mounts = append(mounts, proc.Mount{
+ Type: m.Type,
+ Source: m.Source,
+ Target: m.Target,
+ Options: m.Options,
+ })
+ }
+
+ rootfs := filepath.Join(r.Bundle, "rootfs")
+ if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) {
+ return nil, err
+ }
+
+ config := &proc.CreateConfig{
+ ID: r.ID,
+ Bundle: r.Bundle,
+ Runtime: r.Runtime,
+ Rootfs: mounts,
+ Terminal: r.Terminal,
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Options: r.Options,
+ }
+ defer func() {
+ if err != nil {
+ if err2 := mount.UnmountAll(rootfs, 0); err2 != nil {
+ log.G(ctx).WithError(err2).Warn("Failed to cleanup rootfs mount")
+ }
+ }
+ }()
+ for _, rm := range mounts {
+ m := &mount.Mount{
+ Type: rm.Type,
+ Source: rm.Source,
+ Options: rm.Options,
+ }
+ if err := m.Mount(rootfs); err != nil {
+ return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err)
+ }
+ }
+ process, err := newInit(
+ ctx,
+ s.config.Path,
+ s.config.WorkDir,
+ s.config.RuntimeRoot,
+ s.config.Namespace,
+ s.config.RunscConfig,
+ s.platform,
+ config,
+ )
+ if err := process.Create(ctx, config); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ // Save the main task id and bundle to the shim for additional
+ // requests.
+ s.id = r.ID
+ s.bundle = r.Bundle
+ pid := process.Pid()
+ s.processes[r.ID] = process
+ return &shim.CreateTaskResponse{
+ Pid: uint32(pid),
+ }, nil
+}
+
+// Start starts a process.
+func (s *Service) Start(ctx context.Context, r *shim.StartRequest) (*shim.StartResponse, error) {
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Start(ctx); err != nil {
+ return nil, err
+ }
+ return &shim.StartResponse{
+ ID: p.ID(),
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// Delete deletes the initial process and container.
+func (s *Service) Delete(ctx context.Context, r *types.Empty) (*shim.DeleteResponse, error) {
+ p, err := s.getInitProcess()
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Delete(ctx); err != nil {
+ return nil, err
+ }
+ s.mu.Lock()
+ delete(s.processes, s.id)
+ s.mu.Unlock()
+ s.platform.Close()
+ return &shim.DeleteResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// DeleteProcess deletes an exec'd process.
+func (s *Service) DeleteProcess(ctx context.Context, r *shim.DeleteProcessRequest) (*shim.DeleteResponse, error) {
+ if r.ID == s.id {
+ return nil, status.Errorf(codes.InvalidArgument, "cannot delete init process with DeleteProcess")
+ }
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Delete(ctx); err != nil {
+ return nil, err
+ }
+ s.mu.Lock()
+ delete(s.processes, r.ID)
+ s.mu.Unlock()
+ return &shim.DeleteResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// Exec spawns an additional process inside the container.
+func (s *Service) Exec(ctx context.Context, r *shim.ExecProcessRequest) (*types.Empty, error) {
+ s.mu.Lock()
+
+ if p := s.processes[r.ID]; p != nil {
+ s.mu.Unlock()
+ return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ID)
+ }
+
+ p := s.processes[s.id]
+ s.mu.Unlock()
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+
+ process, err := p.(*proc.Init).Exec(ctx, s.config.Path, &proc.ExecConfig{
+ ID: r.ID,
+ Terminal: r.Terminal,
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Spec: r.Spec,
+ })
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ s.mu.Lock()
+ s.processes[r.ID] = process
+ s.mu.Unlock()
+ return empty, nil
+}
+
+// ResizePty resises the terminal of a process.
+func (s *Service) ResizePty(ctx context.Context, r *shim.ResizePtyRequest) (*types.Empty, error) {
+ if r.ID == "" {
+ return nil, errdefs.ToGRPCf(errdefs.ErrInvalidArgument, "id not provided")
+ }
+ ws := console.WinSize{
+ Width: uint16(r.Width),
+ Height: uint16(r.Height),
+ }
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Resize(ws); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+}
+
+// State returns runtime state information for a process.
+func (s *Service) State(ctx context.Context, r *shim.StateRequest) (*shim.StateResponse, error) {
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ st, err := p.Status(ctx)
+ if err != nil {
+ return nil, err
+ }
+ status := task.StatusUnknown
+ switch st {
+ case "created":
+ status = task.StatusCreated
+ case "running":
+ status = task.StatusRunning
+ case "stopped":
+ status = task.StatusStopped
+ }
+ sio := p.Stdio()
+ return &shim.StateResponse{
+ ID: p.ID(),
+ Bundle: s.bundle,
+ Pid: uint32(p.Pid()),
+ Status: status,
+ Stdin: sio.Stdin,
+ Stdout: sio.Stdout,
+ Stderr: sio.Stderr,
+ Terminal: sio.Terminal,
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ }, nil
+}
+
+// Pause pauses the container.
+func (s *Service) Pause(ctx context.Context, r *types.Empty) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Resume resumes the container.
+func (s *Service) Resume(ctx context.Context, r *types.Empty) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Kill kills a process with the provided signal.
+func (s *Service) Kill(ctx context.Context, r *shim.KillRequest) (*types.Empty, error) {
+ if r.ID == "" {
+ p, err := s.getInitProcess()
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Kill(ctx, r.Signal, r.All); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+ }
+
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Kill(ctx, r.Signal, r.All); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+}
+
+// ListPids returns all pids inside the container.
+func (s *Service) ListPids(ctx context.Context, r *shim.ListPidsRequest) (*shim.ListPidsResponse, error) {
+ pids, err := s.getContainerPids(ctx, r.ID)
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ var processes []*task.ProcessInfo
+ for _, pid := range pids {
+ pInfo := task.ProcessInfo{
+ Pid: pid,
+ }
+ for _, p := range s.processes {
+ if p.Pid() == int(pid) {
+ d := &runctypes.ProcessDetails{
+ ExecID: p.ID(),
+ }
+ a, err := typeurl.MarshalAny(d)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal process %d info: %w", pid, err)
+ }
+ pInfo.Info = a
+ break
+ }
+ }
+ processes = append(processes, &pInfo)
+ }
+ return &shim.ListPidsResponse{
+ Processes: processes,
+ }, nil
+}
+
+// CloseIO closes the I/O context of a process.
+func (s *Service) CloseIO(ctx context.Context, r *shim.CloseIORequest) (*types.Empty, error) {
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if stdin := p.Stdin(); stdin != nil {
+ if err := stdin.Close(); err != nil {
+ return nil, fmt.Errorf("close stdin: %w", err)
+ }
+ }
+ return empty, nil
+}
+
+// Checkpoint checkpoints the container.
+func (s *Service) Checkpoint(ctx context.Context, r *shim.CheckpointTaskRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// ShimInfo returns shim information such as the shim's pid.
+func (s *Service) ShimInfo(ctx context.Context, r *types.Empty) (*shim.ShimInfoResponse, error) {
+ return &shim.ShimInfoResponse{
+ ShimPid: uint32(os.Getpid()),
+ }, nil
+}
+
+// Update updates a running container.
+func (s *Service) Update(ctx context.Context, r *shim.UpdateTaskRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Wait waits for a process to exit.
+func (s *Service) Wait(ctx context.Context, r *shim.WaitRequest) (*shim.WaitResponse, error) {
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ p.Wait()
+
+ return &shim.WaitResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ }, nil
+}
+
+func (s *Service) processExits() {
+ for e := range s.ec {
+ s.checkProcesses(e)
+ }
+}
+
+func (s *Service) allProcesses() []process.Process {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ res := make([]process.Process, 0, len(s.processes))
+ for _, p := range s.processes {
+ res = append(res, p)
+ }
+ return res
+}
+
+func (s *Service) checkProcesses(e proc.Exit) {
+ for _, p := range s.allProcesses() {
+ if p.ID() == e.ID {
+ if ip, ok := p.(*proc.Init); ok {
+ // Ensure all children are killed.
+ if err := ip.KillAll(s.context); err != nil {
+ log.G(s.context).WithError(err).WithField("id", ip.ID()).
+ Error("failed to kill init's children")
+ }
+ }
+ p.SetExited(e.Status)
+ s.events <- &TaskExit{
+ ContainerID: s.id,
+ ID: p.ID(),
+ Pid: uint32(p.Pid()),
+ ExitStatus: uint32(e.Status),
+ ExitedAt: p.ExitedAt(),
+ }
+ return
+ }
+ }
+}
+
+func (s *Service) getContainerPids(ctx context.Context, id string) ([]uint32, error) {
+ p, err := s.getInitProcess()
+ if err != nil {
+ return nil, err
+ }
+
+ ps, err := p.(*proc.Init).Runtime().Ps(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ pids := make([]uint32, 0, len(ps))
+ for _, pid := range ps {
+ pids = append(pids, uint32(pid))
+ }
+ return pids, nil
+}
+
+func (s *Service) forward(publisher events.Publisher) {
+ for e := range s.events {
+ if err := publisher.Publish(s.context, getTopic(s.context, e), e); err != nil {
+ log.G(s.context).WithError(err).Error("post event")
+ }
+ }
+}
+
+// getInitProcess returns the init process.
+func (s *Service) getInitProcess() (process.Process, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ p := s.processes[s.id]
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ return p, nil
+}
+
+// getExecProcess returns the given exec process.
+func (s *Service) getExecProcess(id string) (process.Process, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ p := s.processes[id]
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process %s does not exist", id)
+ }
+ return p, nil
+}
+
+func getTopic(ctx context.Context, e interface{}) string {
+ switch e.(type) {
+ case *TaskCreate:
+ return runtime.TaskCreateEventTopic
+ case *TaskStart:
+ return runtime.TaskStartEventTopic
+ case *TaskOOM:
+ return runtime.TaskOOMEventTopic
+ case *TaskExit:
+ return runtime.TaskExitEventTopic
+ case *TaskDelete:
+ return runtime.TaskDeleteEventTopic
+ case *TaskExecAdded:
+ return runtime.TaskExecAddedEventTopic
+ case *TaskExecStarted:
+ return runtime.TaskExecStartedEventTopic
+ default:
+ log.L.Printf("no topic for type %#v", e)
+ }
+ return runtime.TaskUnknownTopic
+}
+
+func newInit(ctx context.Context, path, workDir, runtimeRoot, namespace string, config map[string]string, platform stdio.Platform, r *proc.CreateConfig) (*proc.Init, error) {
+ var options runctypes.CreateOptions
+ if r.Options != nil {
+ v, err := typeurl.UnmarshalAny(r.Options)
+ if err != nil {
+ return nil, err
+ }
+ options = *v.(*runctypes.CreateOptions)
+ }
+
+ spec, err := utils.ReadSpec(r.Bundle)
+ if err != nil {
+ return nil, fmt.Errorf("read oci spec: %w", err)
+ }
+ if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil {
+ return nil, fmt.Errorf("update volume annotations: %w", err)
+ }
+
+ runsc.FormatLogPath(r.ID, config)
+ rootfs := filepath.Join(path, "rootfs")
+ runtime := proc.NewRunsc(runtimeRoot, path, namespace, r.Runtime, config)
+ p := proc.New(r.ID, runtime, stdio.Stdio{
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Terminal: r.Terminal,
+ })
+ p.Bundle = r.Bundle
+ p.Platform = platform
+ p.Rootfs = rootfs
+ p.WorkDir = workDir
+ p.IoUID = int(options.IoUid)
+ p.IoGID = int(options.IoGid)
+ p.Sandbox = utils.IsSandbox(spec)
+ p.UserLog = utils.UserLogPath(spec)
+ p.Monitor = reaper.Default
+ return p, nil
+}
diff --git a/pkg/shim/v1/utils/BUILD b/pkg/shim/v1/utils/BUILD
new file mode 100644
index 000000000..54a0aabb7
--- /dev/null
+++ b/pkg/shim/v1/utils/BUILD
@@ -0,0 +1,27 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "utils",
+ srcs = [
+ "annotations.go",
+ "utils.go",
+ "volumes.go",
+ ],
+ visibility = [
+ "//pkg/shim:__subpackages__",
+ "//shim:__subpackages__",
+ ],
+ deps = [
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ ],
+)
+
+go_test(
+ name = "utils_test",
+ size = "small",
+ srcs = ["volumes_test.go"],
+ library = ":utils",
+ deps = ["@com_github_opencontainers_runtime_spec//specs-go:go_default_library"],
+)
diff --git a/pkg/shim/v1/utils/annotations.go b/pkg/shim/v1/utils/annotations.go
new file mode 100644
index 000000000..1e9d3f365
--- /dev/null
+++ b/pkg/shim/v1/utils/annotations.go
@@ -0,0 +1,25 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+// Annotations from the CRI annotations package.
+//
+// These are vendor due to import conflicts.
+const (
+ sandboxLogDirAnnotation = "io.kubernetes.cri.sandbox-log-directory"
+ containerTypeAnnotation = "io.kubernetes.cri.container-type"
+ containerTypeSandbox = "sandbox"
+ containerTypeContainer = "container"
+)
diff --git a/pkg/shim/v1/utils/utils.go b/pkg/shim/v1/utils/utils.go
new file mode 100644
index 000000000..07e346654
--- /dev/null
+++ b/pkg/shim/v1/utils/utils.go
@@ -0,0 +1,56 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import (
+ "encoding/json"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+// ReadSpec reads OCI spec from the bundle directory.
+func ReadSpec(bundle string) (*specs.Spec, error) {
+ f, err := os.Open(filepath.Join(bundle, "config.json"))
+ if err != nil {
+ return nil, err
+ }
+ b, err := ioutil.ReadAll(f)
+ if err != nil {
+ return nil, err
+ }
+ var spec specs.Spec
+ if err := json.Unmarshal(b, &spec); err != nil {
+ return nil, err
+ }
+ return &spec, nil
+}
+
+// IsSandbox checks whether a container is a sandbox container.
+func IsSandbox(spec *specs.Spec) bool {
+ t, ok := spec.Annotations[containerTypeAnnotation]
+ return !ok || t == containerTypeSandbox
+}
+
+// UserLogPath gets user log path from OCI annotation.
+func UserLogPath(spec *specs.Spec) string {
+ sandboxLogDir := spec.Annotations[sandboxLogDirAnnotation]
+ if sandboxLogDir == "" {
+ return ""
+ }
+ return filepath.Join(sandboxLogDir, "gvisor.log")
+}
diff --git a/pkg/shim/v1/utils/volumes.go b/pkg/shim/v1/utils/volumes.go
new file mode 100644
index 000000000..52a428179
--- /dev/null
+++ b/pkg/shim/v1/utils/volumes.go
@@ -0,0 +1,155 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "path/filepath"
+ "strings"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+const volumeKeyPrefix = "dev.gvisor.spec.mount."
+
+var kubeletPodsDir = "/var/lib/kubelet/pods"
+
+// volumeName gets volume name from volume annotation key, example:
+// dev.gvisor.spec.mount.NAME.share
+func volumeName(k string) string {
+ return strings.SplitN(strings.TrimPrefix(k, volumeKeyPrefix), ".", 2)[0]
+}
+
+// volumeFieldName gets volume field name from volume annotation key, example:
+// `type` is the field of dev.gvisor.spec.mount.NAME.type
+func volumeFieldName(k string) string {
+ parts := strings.Split(strings.TrimPrefix(k, volumeKeyPrefix), ".")
+ return parts[len(parts)-1]
+}
+
+// podUID gets pod UID from the pod log path.
+func podUID(s *specs.Spec) (string, error) {
+ sandboxLogDir := s.Annotations[sandboxLogDirAnnotation]
+ if sandboxLogDir == "" {
+ return "", fmt.Errorf("no sandbox log path annotation")
+ }
+ fields := strings.Split(filepath.Base(sandboxLogDir), "_")
+ switch len(fields) {
+ case 1: // This is the old CRI logging path.
+ return fields[0], nil
+ case 3: // This is the new CRI logging path.
+ return fields[2], nil
+ }
+ return "", fmt.Errorf("unexpected sandbox log path %q", sandboxLogDir)
+}
+
+// isVolumeKey checks whether an annotation key is for volume.
+func isVolumeKey(k string) bool {
+ return strings.HasPrefix(k, volumeKeyPrefix)
+}
+
+// volumeSourceKey constructs the annotation key for volume source.
+func volumeSourceKey(volume string) string {
+ return volumeKeyPrefix + volume + ".source"
+}
+
+// volumePath searches the volume path in the kubelet pod directory.
+func volumePath(volume, uid string) (string, error) {
+ // TODO: Support subpath when gvisor supports pod volume bind mount.
+ volumeSearchPath := fmt.Sprintf("%s/%s/volumes/*/%s", kubeletPodsDir, uid, volume)
+ dirs, err := filepath.Glob(volumeSearchPath)
+ if err != nil {
+ return "", err
+ }
+ if len(dirs) != 1 {
+ return "", fmt.Errorf("unexpected matched volume list %v", dirs)
+ }
+ return dirs[0], nil
+}
+
+// isVolumePath checks whether a string is the volume path.
+func isVolumePath(volume, path string) (bool, error) {
+ // TODO: Support subpath when gvisor supports pod volume bind mount.
+ volumeSearchPath := fmt.Sprintf("%s/*/volumes/*/%s", kubeletPodsDir, volume)
+ return filepath.Match(volumeSearchPath, path)
+}
+
+// UpdateVolumeAnnotations add necessary OCI annotations for gvisor
+// volume optimization.
+func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error {
+ var (
+ uid string
+ err error
+ )
+ if IsSandbox(s) {
+ uid, err = podUID(s)
+ if err != nil {
+ // Skip if we can't get pod UID, because this doesn't work
+ // for containerd 1.1.
+ return nil
+ }
+ }
+ var updated bool
+ for k, v := range s.Annotations {
+ if !isVolumeKey(k) {
+ continue
+ }
+ if volumeFieldName(k) != "type" {
+ continue
+ }
+ volume := volumeName(k)
+ if uid != "" {
+ // This is a sandbox.
+ path, err := volumePath(volume, uid)
+ if err != nil {
+ return fmt.Errorf("get volume path for %q: %w", volume, err)
+ }
+ s.Annotations[volumeSourceKey(volume)] = path
+ updated = true
+ } else {
+ // This is a container.
+ for i := range s.Mounts {
+ // An error is returned for sandbox if source
+ // annotation is not successfully applied, so
+ // it is guaranteed that the source annotation
+ // for sandbox has already been successfully
+ // applied at this point.
+ //
+ // The volume name is unique inside a pod, so
+ // matching without podUID is fine here.
+ //
+ // TODO: Pass podUID down to shim for containers to do
+ // more accurate matching.
+ if yes, _ := isVolumePath(volume, s.Mounts[i].Source); yes {
+ // gVisor requires the container mount type to match
+ // sandbox mount type.
+ s.Mounts[i].Type = v
+ updated = true
+ }
+ }
+ }
+ }
+ if !updated {
+ return nil
+ }
+ // Update bundle.
+ b, err := json.Marshal(s)
+ if err != nil {
+ return err
+ }
+ return ioutil.WriteFile(filepath.Join(bundle, "config.json"), b, 0666)
+}
diff --git a/pkg/shim/v1/utils/volumes_test.go b/pkg/shim/v1/utils/volumes_test.go
new file mode 100644
index 000000000..3e02c6151
--- /dev/null
+++ b/pkg/shim/v1/utils/volumes_test.go
@@ -0,0 +1,308 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "reflect"
+ "testing"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+func TestUpdateVolumeAnnotations(t *testing.T) {
+ dir, err := ioutil.TempDir("", "test-update-volume-annotations")
+ if err != nil {
+ t.Fatalf("create tempdir: %v", err)
+ }
+ defer os.RemoveAll(dir)
+ kubeletPodsDir = dir
+
+ const (
+ testPodUID = "testuid"
+ testVolumeName = "testvolume"
+ testLogDirPath = "/var/log/pods/testns_testname_" + testPodUID
+ testLegacyLogDirPath = "/var/log/pods/" + testPodUID
+ )
+ testVolumePath := fmt.Sprintf("%s/%s/volumes/kubernetes.io~empty-dir/%s", dir, testPodUID, testVolumeName)
+
+ if err := os.MkdirAll(testVolumePath, 0755); err != nil {
+ t.Fatalf("Create test volume: %v", err)
+ }
+
+ for _, test := range []struct {
+ desc string
+ spec *specs.Spec
+ expected *specs.Spec
+ expectErr bool
+ expectUpdate bool
+ }{
+ {
+ desc: "volume annotations for sandbox",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath,
+ },
+ },
+ expectUpdate: true,
+ },
+ {
+ desc: "volume annotations for sandbox with legacy log path",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLegacyLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLegacyLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath,
+ },
+ },
+ expectUpdate: true,
+ },
+ {
+ desc: "tmpfs: volume annotations for container",
+ spec: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: testVolumePath,
+ Options: []string{"ro"},
+ },
+ {
+ Destination: "/random",
+ Type: "bind",
+ Source: "/random",
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "tmpfs",
+ Source: testVolumePath,
+ Options: []string{"ro"},
+ },
+ {
+ Destination: "/random",
+ Type: "bind",
+ Source: "/random",
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expectUpdate: true,
+ },
+ {
+ desc: "bind: volume annotations for container",
+ spec: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: testVolumePath,
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "container",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: testVolumePath,
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "container",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expectUpdate: true,
+ },
+ {
+ desc: "should not return error without pod log directory",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ },
+ {
+ desc: "should return error if volume path does not exist",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount.notexist.share": "pod",
+ "dev.gvisor.spec.mount.notexist.type": "tmpfs",
+ "dev.gvisor.spec.mount.notexist.options": "ro",
+ },
+ },
+ expectErr: true,
+ },
+ {
+ desc: "no volume annotations for sandbox",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ },
+ },
+ expected: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ },
+ },
+ },
+ {
+ desc: "no volume annotations for container",
+ spec: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: "/test",
+ Options: []string{"ro"},
+ },
+ {
+ Destination: "/random",
+ Type: "bind",
+ Source: "/random",
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ },
+ },
+ expected: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: "/test",
+ Options: []string{"ro"},
+ },
+ {
+ Destination: "/random",
+ Type: "bind",
+ Source: "/random",
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ },
+ },
+ },
+ } {
+ t.Run(test.desc, func(t *testing.T) {
+ bundle, err := ioutil.TempDir(dir, "test-bundle")
+ if err != nil {
+ t.Fatalf("Create test bundle: %v", err)
+ }
+ err = UpdateVolumeAnnotations(bundle, test.spec)
+ if test.expectErr {
+ if err == nil {
+ t.Fatal("Expected error, but got nil")
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+ if !reflect.DeepEqual(test.expected, test.spec) {
+ t.Fatalf("Expected %+v, got %+v", test.expected, test.spec)
+ }
+ if test.expectUpdate {
+ b, err := ioutil.ReadFile(filepath.Join(bundle, "config.json"))
+ if err != nil {
+ t.Fatalf("Read spec from bundle: %v", err)
+ }
+ var spec specs.Spec
+ if err := json.Unmarshal(b, &spec); err != nil {
+ t.Fatalf("Unmarshal spec: %v", err)
+ }
+ if !reflect.DeepEqual(test.expected, &spec) {
+ t.Fatalf("Expected %+v, got %+v", test.expected, &spec)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/shim/v2/BUILD b/pkg/shim/v2/BUILD
new file mode 100644
index 000000000..7e0a114a0
--- /dev/null
+++ b/pkg/shim/v2/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "v2",
+ srcs = [
+ "api.go",
+ "epoll.go",
+ "service.go",
+ "service_linux.go",
+ ],
+ visibility = ["//shim:__subpackages__"],
+ deps = [
+ "//pkg/shim/runsc",
+ "//pkg/shim/v1/proc",
+ "//pkg/shim/v1/utils",
+ "//pkg/shim/v2/options",
+ "//pkg/shim/v2/runtimeoptions",
+ "//runsc/specutils",
+ "@com_github_burntsushi_toml//:go_default_library",
+ "@com_github_containerd_cgroups//:go_default_library",
+ "@com_github_containerd_console//:go_default_library",
+ "@com_github_containerd_containerd//api/events:go_default_library",
+ "@com_github_containerd_containerd//api/types/task:go_default_library",
+ "@com_github_containerd_containerd//errdefs:go_default_library",
+ "@com_github_containerd_containerd//events:go_default_library",
+ "@com_github_containerd_containerd//log:go_default_library",
+ "@com_github_containerd_containerd//mount:go_default_library",
+ "@com_github_containerd_containerd//namespaces:go_default_library",
+ "@com_github_containerd_containerd//pkg/process:go_default_library",
+ "@com_github_containerd_containerd//pkg/stdio:go_default_library",
+ "@com_github_containerd_containerd//runtime:go_default_library",
+ "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library",
+ "@com_github_containerd_containerd//runtime/v2/shim:go_default_library",
+ "@com_github_containerd_containerd//runtime/v2/task:go_default_library",
+ "@com_github_containerd_containerd//sys/reaper:go_default_library",
+ "@com_github_containerd_fifo//:go_default_library",
+ "@com_github_containerd_typeurl//:go_default_library",
+ "@com_github_gogo_protobuf//types:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/shim/v2/api.go b/pkg/shim/v2/api.go
new file mode 100644
index 000000000..dbe5c59f6
--- /dev/null
+++ b/pkg/shim/v2/api.go
@@ -0,0 +1,22 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package v2
+
+import (
+ "github.com/containerd/containerd/api/events"
+)
+
+type TaskOOM = events.TaskOOM
diff --git a/pkg/shim/v2/epoll.go b/pkg/shim/v2/epoll.go
new file mode 100644
index 000000000..41232cca8
--- /dev/null
+++ b/pkg/shim/v2/epoll.go
@@ -0,0 +1,129 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package v2
+
+import (
+ "context"
+ "fmt"
+ "sync"
+
+ "github.com/containerd/cgroups"
+ "github.com/containerd/containerd/events"
+ "github.com/containerd/containerd/runtime"
+ "golang.org/x/sys/unix"
+)
+
+func newOOMEpoller(publisher events.Publisher) (*epoller, error) {
+ fd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC)
+ if err != nil {
+ return nil, err
+ }
+ return &epoller{
+ fd: fd,
+ publisher: publisher,
+ set: make(map[uintptr]*item),
+ }, nil
+}
+
+type epoller struct {
+ mu sync.Mutex
+
+ fd int
+ publisher events.Publisher
+ set map[uintptr]*item
+}
+
+type item struct {
+ id string
+ cg cgroups.Cgroup
+}
+
+func (e *epoller) Close() error {
+ return unix.Close(e.fd)
+}
+
+func (e *epoller) run(ctx context.Context) {
+ var events [128]unix.EpollEvent
+ for {
+ select {
+ case <-ctx.Done():
+ e.Close()
+ return
+ default:
+ n, err := unix.EpollWait(e.fd, events[:], -1)
+ if err != nil {
+ if err == unix.EINTR || err == unix.EAGAIN {
+ continue
+ }
+ // Should not happen.
+ panic(fmt.Errorf("cgroups: epoll wait: %w", err))
+ }
+ for i := 0; i < n; i++ {
+ e.process(ctx, uintptr(events[i].Fd))
+ }
+ }
+ }
+}
+
+func (e *epoller) add(id string, cg cgroups.Cgroup) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ fd, err := cg.OOMEventFD()
+ if err != nil {
+ return err
+ }
+ e.set[fd] = &item{
+ id: id,
+ cg: cg,
+ }
+ event := unix.EpollEvent{
+ Fd: int32(fd),
+ Events: unix.EPOLLHUP | unix.EPOLLIN | unix.EPOLLERR,
+ }
+ return unix.EpollCtl(e.fd, unix.EPOLL_CTL_ADD, int(fd), &event)
+}
+
+func (e *epoller) process(ctx context.Context, fd uintptr) {
+ flush(fd)
+ e.mu.Lock()
+ i, ok := e.set[fd]
+ if !ok {
+ e.mu.Unlock()
+ return
+ }
+ e.mu.Unlock()
+ if i.cg.State() == cgroups.Deleted {
+ e.mu.Lock()
+ delete(e.set, fd)
+ e.mu.Unlock()
+ unix.Close(int(fd))
+ return
+ }
+ if err := e.publisher.Publish(ctx, runtime.TaskOOMEventTopic, &TaskOOM{
+ ContainerID: i.id,
+ }); err != nil {
+ // Should not happen.
+ panic(fmt.Errorf("publish OOM event: %w", err))
+ }
+}
+
+func flush(fd uintptr) error {
+ var buf [8]byte
+ _, err := unix.Read(int(fd), buf[:])
+ return err
+}
diff --git a/pkg/shim/v2/options/BUILD b/pkg/shim/v2/options/BUILD
new file mode 100644
index 000000000..ca212e874
--- /dev/null
+++ b/pkg/shim/v2/options/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "options",
+ srcs = [
+ "options.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/shim/v2/options/options.go b/pkg/shim/v2/options/options.go
new file mode 100644
index 000000000..de09f2f79
--- /dev/null
+++ b/pkg/shim/v2/options/options.go
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package options
+
+const OptionType = "io.containerd.runsc.v1.options"
+
+// Options is runtime options for io.containerd.runsc.v1.
+type Options struct {
+ // ShimCgroup is the cgroup the shim should be in.
+ ShimCgroup string `toml:"shim_cgroup"`
+ // IoUid is the I/O's pipes uid.
+ IoUid uint32 `toml:"io_uid"`
+ // IoUid is the I/O's pipes gid.
+ IoGid uint32 `toml:"io_gid"`
+ // BinaryName is the binary name of the runsc binary.
+ BinaryName string `toml:"binary_name"`
+ // Root is the runsc root directory.
+ Root string `toml:"root"`
+ // RunscConfig is a key/value map of all runsc flags.
+ RunscConfig map[string]string `toml:"runsc_config"`
+}
diff --git a/pkg/shim/v2/runtimeoptions/BUILD b/pkg/shim/v2/runtimeoptions/BUILD
new file mode 100644
index 000000000..01716034c
--- /dev/null
+++ b/pkg/shim/v2/runtimeoptions/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library", "proto_library")
+
+package(licenses = ["notice"])
+
+proto_library(
+ name = "api",
+ srcs = [
+ "runtimeoptions.proto",
+ ],
+)
+
+go_library(
+ name = "runtimeoptions",
+ srcs = ["runtimeoptions.go"],
+ visibility = ["//pkg/shim/v2:__pkg__"],
+ deps = [
+ "//pkg/shim/v2/runtimeoptions:api_go_proto",
+ "@com_github_gogo_protobuf//proto:go_default_library",
+ ],
+)
diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.go b/pkg/shim/v2/runtimeoptions/runtimeoptions.go
new file mode 100644
index 000000000..1c1a0c5d1
--- /dev/null
+++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.go
@@ -0,0 +1,27 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package runtimeoptions
+
+import (
+ proto "github.com/gogo/protobuf/proto"
+ pb "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions/api_go_proto"
+)
+
+type Options = pb.Options
+
+func init() {
+ proto.RegisterType((*Options)(nil), "cri.runtimeoptions.v1.Options")
+}
diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.proto b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto
new file mode 100644
index 000000000..edb19020a
--- /dev/null
+++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package runtimeoptions;
+
+// This is a version of the runtimeoptions CRI API that is vendored.
+//
+// Imported the full CRI package is a nightmare.
+message Options {
+ string type_url = 1;
+ string config_path = 2;
+}
diff --git a/pkg/shim/v2/service.go b/pkg/shim/v2/service.go
new file mode 100644
index 000000000..1534152fc
--- /dev/null
+++ b/pkg/shim/v2/service.go
@@ -0,0 +1,824 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package v2
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/BurntSushi/toml"
+ "github.com/containerd/cgroups"
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/api/events"
+ "github.com/containerd/containerd/api/types/task"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/log"
+ "github.com/containerd/containerd/mount"
+ "github.com/containerd/containerd/namespaces"
+ "github.com/containerd/containerd/pkg/process"
+ "github.com/containerd/containerd/pkg/stdio"
+ "github.com/containerd/containerd/runtime"
+ "github.com/containerd/containerd/runtime/linux/runctypes"
+ "github.com/containerd/containerd/runtime/v2/shim"
+ taskAPI "github.com/containerd/containerd/runtime/v2/task"
+ "github.com/containerd/containerd/sys/reaper"
+ "github.com/containerd/typeurl"
+ "github.com/gogo/protobuf/types"
+ "golang.org/x/sys/unix"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+ "gvisor.dev/gvisor/pkg/shim/v1/proc"
+ "gvisor.dev/gvisor/pkg/shim/v1/utils"
+ "gvisor.dev/gvisor/pkg/shim/v2/options"
+ "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+var (
+ empty = &types.Empty{}
+ bufPool = sync.Pool{
+ New: func() interface{} {
+ buffer := make([]byte, 32<<10)
+ return &buffer
+ },
+ }
+)
+
+var _ = (taskAPI.TaskService)(&service{})
+
+// configFile is the default config file name. For containerd 1.2,
+// we assume that a config.toml should exist in the runtime root.
+const configFile = "config.toml"
+
+// New returns a new shim service that can be used via GRPC.
+func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) {
+ ep, err := newOOMEpoller(publisher)
+ if err != nil {
+ return nil, err
+ }
+ go ep.run(ctx)
+ s := &service{
+ id: id,
+ context: ctx,
+ processes: make(map[string]process.Process),
+ events: make(chan interface{}, 128),
+ ec: proc.ExitCh,
+ oomPoller: ep,
+ cancel: cancel,
+ }
+ go s.processExits()
+ runsc.Monitor = reaper.Default
+ if err := s.initPlatform(); err != nil {
+ cancel()
+ return nil, fmt.Errorf("failed to initialized platform behavior: %w", err)
+ }
+ go s.forward(publisher)
+ return s, nil
+}
+
+// service is the shim implementation of a remote shim over GRPC.
+type service struct {
+ mu sync.Mutex
+
+ context context.Context
+ task process.Process
+ processes map[string]process.Process
+ events chan interface{}
+ platform stdio.Platform
+ opts options.Options
+ ec chan proc.Exit
+ oomPoller *epoller
+
+ id string
+ bundle string
+ cancel func()
+}
+
+func newCommand(ctx context.Context, containerdBinary, containerdAddress string) (*exec.Cmd, error) {
+ ns, err := namespaces.NamespaceRequired(ctx)
+ if err != nil {
+ return nil, err
+ }
+ self, err := os.Executable()
+ if err != nil {
+ return nil, err
+ }
+ cwd, err := os.Getwd()
+ if err != nil {
+ return nil, err
+ }
+ args := []string{
+ "-namespace", ns,
+ "-address", containerdAddress,
+ "-publish-binary", containerdBinary,
+ }
+ cmd := exec.Command(self, args...)
+ cmd.Dir = cwd
+ cmd.Env = append(os.Environ(), "GOMAXPROCS=2")
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setpgid: true,
+ }
+ return cmd, nil
+}
+
+func (s *service) StartShim(ctx context.Context, id, containerdBinary, containerdAddress, containerdTTRPCAddress string) (string, error) {
+ cmd, err := newCommand(ctx, containerdBinary, containerdAddress)
+ if err != nil {
+ return "", err
+ }
+ address, err := shim.SocketAddress(ctx, id)
+ if err != nil {
+ return "", err
+ }
+ socket, err := shim.NewSocket(address)
+ if err != nil {
+ return "", err
+ }
+ defer socket.Close()
+ f, err := socket.File()
+ if err != nil {
+ return "", err
+ }
+ defer f.Close()
+
+ cmd.ExtraFiles = append(cmd.ExtraFiles, f)
+
+ if err := cmd.Start(); err != nil {
+ return "", err
+ }
+ defer func() {
+ if err != nil {
+ cmd.Process.Kill()
+ }
+ }()
+ // make sure to wait after start
+ go cmd.Wait()
+ if err := shim.WritePidFile("shim.pid", cmd.Process.Pid); err != nil {
+ return "", err
+ }
+ if err := shim.WriteAddress("address", address); err != nil {
+ return "", err
+ }
+ if err := shim.SetScore(cmd.Process.Pid); err != nil {
+ return "", fmt.Errorf("failed to set OOM Score on shim: %w", err)
+ }
+ return address, nil
+}
+
+func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) {
+ path, err := os.Getwd()
+ if err != nil {
+ return nil, err
+ }
+ ns, err := namespaces.NamespaceRequired(ctx)
+ if err != nil {
+ return nil, err
+ }
+ runtime, err := s.readRuntime(path)
+ if err != nil {
+ return nil, err
+ }
+ r := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil)
+ if err := r.Delete(ctx, s.id, &runsc.DeleteOpts{
+ Force: true,
+ }); err != nil {
+ log.L.Printf("failed to remove runc container: %v", err)
+ }
+ if err := mount.UnmountAll(filepath.Join(path, "rootfs"), 0); err != nil {
+ log.L.Printf("failed to cleanup rootfs mount: %v", err)
+ }
+ return &taskAPI.DeleteResponse{
+ ExitedAt: time.Now(),
+ ExitStatus: 128 + uint32(unix.SIGKILL),
+ }, nil
+}
+
+func (s *service) readRuntime(path string) (string, error) {
+ data, err := ioutil.ReadFile(filepath.Join(path, "runtime"))
+ if err != nil {
+ return "", err
+ }
+ return string(data), nil
+}
+
+func (s *service) writeRuntime(path, runtime string) error {
+ return ioutil.WriteFile(filepath.Join(path, "runtime"), []byte(runtime), 0600)
+}
+
+// Create creates a new initial process and container with the underlying OCI
+// runtime.
+func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ *taskAPI.CreateTaskResponse, err error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ ns, err := namespaces.NamespaceRequired(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("create namespace: %w", err)
+ }
+
+ // Read from root for now.
+ var opts options.Options
+ if r.Options != nil {
+ v, err := typeurl.UnmarshalAny(r.Options)
+ if err != nil {
+ return nil, err
+ }
+ var path string
+ switch o := v.(type) {
+ case *runctypes.CreateOptions: // containerd 1.2.x
+ opts.IoUid = o.IoUid
+ opts.IoGid = o.IoGid
+ opts.ShimCgroup = o.ShimCgroup
+ case *runctypes.RuncOptions: // containerd 1.2.x
+ root := proc.RunscRoot
+ if o.RuntimeRoot != "" {
+ root = o.RuntimeRoot
+ }
+
+ opts.BinaryName = o.Runtime
+
+ path = filepath.Join(root, configFile)
+ if _, err := os.Stat(path); err != nil {
+ if !os.IsNotExist(err) {
+ return nil, fmt.Errorf("stat config file %q: %w", path, err)
+ }
+ // A config file in runtime root is not required.
+ path = ""
+ }
+ case *runtimeoptions.Options: // containerd 1.3.x+
+ if o.ConfigPath == "" {
+ break
+ }
+ if o.TypeUrl != options.OptionType {
+ return nil, fmt.Errorf("unsupported option type %q", o.TypeUrl)
+ }
+ path = o.ConfigPath
+ default:
+ return nil, fmt.Errorf("unsupported option type %q", r.Options.TypeUrl)
+ }
+ if path != "" {
+ if _, err = toml.DecodeFile(path, &opts); err != nil {
+ return nil, fmt.Errorf("decode config file %q: %w", path, err)
+ }
+ }
+ }
+
+ var mounts []proc.Mount
+ for _, m := range r.Rootfs {
+ mounts = append(mounts, proc.Mount{
+ Type: m.Type,
+ Source: m.Source,
+ Target: m.Target,
+ Options: m.Options,
+ })
+ }
+
+ rootfs := filepath.Join(r.Bundle, "rootfs")
+ if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) {
+ return nil, err
+ }
+
+ config := &proc.CreateConfig{
+ ID: r.ID,
+ Bundle: r.Bundle,
+ Runtime: opts.BinaryName,
+ Rootfs: mounts,
+ Terminal: r.Terminal,
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Options: r.Options,
+ }
+ if err := s.writeRuntime(r.Bundle, opts.BinaryName); err != nil {
+ return nil, err
+ }
+ defer func() {
+ if err != nil {
+ if err := mount.UnmountAll(rootfs, 0); err != nil {
+ log.L.Printf("failed to cleanup rootfs mount: %v", err)
+ }
+ }
+ }()
+ for _, rm := range mounts {
+ m := &mount.Mount{
+ Type: rm.Type,
+ Source: rm.Source,
+ Options: rm.Options,
+ }
+ if err := m.Mount(rootfs); err != nil {
+ return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err)
+ }
+ }
+ process, err := newInit(
+ ctx,
+ r.Bundle,
+ filepath.Join(r.Bundle, "work"),
+ ns,
+ s.platform,
+ config,
+ &opts,
+ rootfs,
+ )
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ if err := process.Create(ctx, config); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ // Save the main task id and bundle to the shim for additional
+ // requests.
+ s.id = r.ID
+ s.bundle = r.Bundle
+
+ // Set up OOM notification on the sandbox's cgroup. This is done on
+ // sandbox create since the sandbox process will be created here.
+ pid := process.Pid()
+ if pid > 0 {
+ cg, err := cgroups.Load(cgroups.V1, cgroups.PidPath(pid))
+ if err != nil {
+ return nil, fmt.Errorf("loading cgroup for %d: %w", pid, err)
+ }
+ if err := s.oomPoller.add(s.id, cg); err != nil {
+ return nil, fmt.Errorf("add cg to OOM monitor: %w", err)
+ }
+ }
+ s.task = process
+ s.opts = opts
+ return &taskAPI.CreateTaskResponse{
+ Pid: uint32(process.Pid()),
+ }, nil
+
+}
+
+// Start starts a process.
+func (s *service) Start(ctx context.Context, r *taskAPI.StartRequest) (*taskAPI.StartResponse, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Start(ctx); err != nil {
+ return nil, err
+ }
+ // TODO: Set the cgroup and oom notifications on restore.
+ // https://github.com/google/gvisor-containerd-shim/issues/58
+ return &taskAPI.StartResponse{
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// Delete deletes the initial process and container.
+func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAPI.DeleteResponse, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ if err := p.Delete(ctx); err != nil {
+ return nil, err
+ }
+ isTask := r.ExecID == ""
+ if !isTask {
+ s.mu.Lock()
+ delete(s.processes, r.ExecID)
+ s.mu.Unlock()
+ }
+ if isTask && s.platform != nil {
+ s.platform.Close()
+ }
+ return &taskAPI.DeleteResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// Exec spawns an additional process inside the container.
+func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*types.Empty, error) {
+ s.mu.Lock()
+ p := s.processes[r.ExecID]
+ s.mu.Unlock()
+ if p != nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ExecID)
+ }
+ p = s.task
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ process, err := p.(*proc.Init).Exec(ctx, s.bundle, &proc.ExecConfig{
+ ID: r.ExecID,
+ Terminal: r.Terminal,
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Spec: r.Spec,
+ })
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ s.mu.Lock()
+ s.processes[r.ExecID] = process
+ s.mu.Unlock()
+ return empty, nil
+}
+
+// ResizePty resizes the terminal of a process.
+func (s *service) ResizePty(ctx context.Context, r *taskAPI.ResizePtyRequest) (*types.Empty, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ ws := console.WinSize{
+ Width: uint16(r.Width),
+ Height: uint16(r.Height),
+ }
+ if err := p.Resize(ws); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+}
+
+// State returns runtime state information for a process.
+func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI.StateResponse, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ st, err := p.Status(ctx)
+ if err != nil {
+ return nil, err
+ }
+ status := task.StatusUnknown
+ switch st {
+ case "created":
+ status = task.StatusCreated
+ case "running":
+ status = task.StatusRunning
+ case "stopped":
+ status = task.StatusStopped
+ }
+ sio := p.Stdio()
+ return &taskAPI.StateResponse{
+ ID: p.ID(),
+ Bundle: s.bundle,
+ Pid: uint32(p.Pid()),
+ Status: status,
+ Stdin: sio.Stdin,
+ Stdout: sio.Stdout,
+ Stderr: sio.Stderr,
+ Terminal: sio.Terminal,
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ }, nil
+}
+
+// Pause the container.
+func (s *service) Pause(ctx context.Context, r *taskAPI.PauseRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Resume the container.
+func (s *service) Resume(ctx context.Context, r *taskAPI.ResumeRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Kill a process with the provided signal.
+func (s *service) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empty, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ if err := p.Kill(ctx, r.Signal, r.All); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+}
+
+// Pids returns all pids inside the container.
+func (s *service) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskAPI.PidsResponse, error) {
+ pids, err := s.getContainerPids(ctx, r.ID)
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ var processes []*task.ProcessInfo
+ for _, pid := range pids {
+ pInfo := task.ProcessInfo{
+ Pid: pid,
+ }
+ for _, p := range s.processes {
+ if p.Pid() == int(pid) {
+ d := &runctypes.ProcessDetails{
+ ExecID: p.ID(),
+ }
+ a, err := typeurl.MarshalAny(d)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal process %d info: %w", pid, err)
+ }
+ pInfo.Info = a
+ break
+ }
+ }
+ processes = append(processes, &pInfo)
+ }
+ return &taskAPI.PidsResponse{
+ Processes: processes,
+ }, nil
+}
+
+// CloseIO closes the I/O context of a process.
+func (s *service) CloseIO(ctx context.Context, r *taskAPI.CloseIORequest) (*types.Empty, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if stdin := p.Stdin(); stdin != nil {
+ if err := stdin.Close(); err != nil {
+ return nil, fmt.Errorf("close stdin: %w", err)
+ }
+ }
+ return empty, nil
+}
+
+// Checkpoint checkpoints the container.
+func (s *service) Checkpoint(ctx context.Context, r *taskAPI.CheckpointTaskRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Connect returns shim information such as the shim's pid.
+func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*taskAPI.ConnectResponse, error) {
+ var pid int
+ if s.task != nil {
+ pid = s.task.Pid()
+ }
+ return &taskAPI.ConnectResponse{
+ ShimPid: uint32(os.Getpid()),
+ TaskPid: uint32(pid),
+ }, nil
+}
+
+func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*types.Empty, error) {
+ s.cancel()
+ os.Exit(0)
+ return empty, nil
+}
+
+func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.StatsResponse, error) {
+ path, err := os.Getwd()
+ if err != nil {
+ return nil, err
+ }
+ ns, err := namespaces.NamespaceRequired(ctx)
+ if err != nil {
+ return nil, err
+ }
+ runtime, err := s.readRuntime(path)
+ if err != nil {
+ return nil, err
+ }
+ rs := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil)
+ stats, err := rs.Stats(ctx, s.id)
+ if err != nil {
+ return nil, err
+ }
+
+ // gvisor currently (as of 2020-03-03) only returns the total memory
+ // usage and current PID value[0]. However, we copy the common fields here
+ // so that future updates will propagate correct information. We're
+ // using the cgroups.Metrics structure so we're returning the same type
+ // as runc.
+ //
+ // [0]: https://github.com/google/gvisor/blob/277a0d5a1fbe8272d4729c01ee4c6e374d047ebc/runsc/boot/events.go#L61-L81
+ data, err := typeurl.MarshalAny(&cgroups.Metrics{
+ CPU: &cgroups.CPUStat{
+ Usage: &cgroups.CPUUsage{
+ Total: stats.Cpu.Usage.Total,
+ Kernel: stats.Cpu.Usage.Kernel,
+ User: stats.Cpu.Usage.User,
+ PerCPU: stats.Cpu.Usage.Percpu,
+ },
+ Throttling: &cgroups.Throttle{
+ Periods: stats.Cpu.Throttling.Periods,
+ ThrottledPeriods: stats.Cpu.Throttling.ThrottledPeriods,
+ ThrottledTime: stats.Cpu.Throttling.ThrottledTime,
+ },
+ },
+ Memory: &cgroups.MemoryStat{
+ Cache: stats.Memory.Cache,
+ Usage: &cgroups.MemoryEntry{
+ Limit: stats.Memory.Usage.Limit,
+ Usage: stats.Memory.Usage.Usage,
+ Max: stats.Memory.Usage.Max,
+ Failcnt: stats.Memory.Usage.Failcnt,
+ },
+ Swap: &cgroups.MemoryEntry{
+ Limit: stats.Memory.Swap.Limit,
+ Usage: stats.Memory.Swap.Usage,
+ Max: stats.Memory.Swap.Max,
+ Failcnt: stats.Memory.Swap.Failcnt,
+ },
+ Kernel: &cgroups.MemoryEntry{
+ Limit: stats.Memory.Kernel.Limit,
+ Usage: stats.Memory.Kernel.Usage,
+ Max: stats.Memory.Kernel.Max,
+ Failcnt: stats.Memory.Kernel.Failcnt,
+ },
+ KernelTCP: &cgroups.MemoryEntry{
+ Limit: stats.Memory.KernelTCP.Limit,
+ Usage: stats.Memory.KernelTCP.Usage,
+ Max: stats.Memory.KernelTCP.Max,
+ Failcnt: stats.Memory.KernelTCP.Failcnt,
+ },
+ },
+ Pids: &cgroups.PidsStat{
+ Current: stats.Pids.Current,
+ Limit: stats.Pids.Limit,
+ },
+ })
+ if err != nil {
+ return nil, err
+ }
+ return &taskAPI.StatsResponse{
+ Stats: data,
+ }, nil
+}
+
+// Update updates a running container.
+func (s *service) Update(ctx context.Context, r *taskAPI.UpdateTaskRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Wait waits for a process to exit.
+func (s *service) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.WaitResponse, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ p.Wait()
+
+ return &taskAPI.WaitResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ }, nil
+}
+
+func (s *service) processExits() {
+ for e := range s.ec {
+ s.checkProcesses(e)
+ }
+}
+
+func (s *service) checkProcesses(e proc.Exit) {
+ // TODO(random-liu): Add `shouldKillAll` logic if container pid
+ // namespace is supported.
+ for _, p := range s.allProcesses() {
+ if p.ID() == e.ID {
+ if ip, ok := p.(*proc.Init); ok {
+ // Ensure all children are killed.
+ if err := ip.KillAll(s.context); err != nil {
+ log.G(s.context).WithError(err).WithField("id", ip.ID()).
+ Error("failed to kill init's children")
+ }
+ }
+ p.SetExited(e.Status)
+ s.events <- &events.TaskExit{
+ ContainerID: s.id,
+ ID: p.ID(),
+ Pid: uint32(p.Pid()),
+ ExitStatus: uint32(e.Status),
+ ExitedAt: p.ExitedAt(),
+ }
+ return
+ }
+ }
+}
+
+func (s *service) allProcesses() (o []process.Process) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for _, p := range s.processes {
+ o = append(o, p)
+ }
+ if s.task != nil {
+ o = append(o, s.task)
+ }
+ return o
+}
+
+func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, error) {
+ s.mu.Lock()
+ p := s.task
+ s.mu.Unlock()
+ if p == nil {
+ return nil, fmt.Errorf("container must be created: %w", errdefs.ErrFailedPrecondition)
+ }
+ ps, err := p.(*proc.Init).Runtime().Ps(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ pids := make([]uint32, 0, len(ps))
+ for _, pid := range ps {
+ pids = append(pids, uint32(pid))
+ }
+ return pids, nil
+}
+
+func (s *service) forward(publisher shim.Publisher) {
+ for e := range s.events {
+ ctx, cancel := context.WithTimeout(s.context, 5*time.Second)
+ err := publisher.Publish(ctx, getTopic(e), e)
+ cancel()
+ if err != nil {
+ // Should not happen.
+ panic(fmt.Errorf("post event: %w", err))
+ }
+ }
+}
+
+func (s *service) getProcess(execID string) (process.Process, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if execID == "" {
+ return s.task, nil
+ }
+ p := s.processes[execID]
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process does not exist %s", execID)
+ }
+ return p, nil
+}
+
+func getTopic(e interface{}) string {
+ switch e.(type) {
+ case *events.TaskCreate:
+ return runtime.TaskCreateEventTopic
+ case *events.TaskStart:
+ return runtime.TaskStartEventTopic
+ case *events.TaskOOM:
+ return runtime.TaskOOMEventTopic
+ case *events.TaskExit:
+ return runtime.TaskExitEventTopic
+ case *events.TaskDelete:
+ return runtime.TaskDeleteEventTopic
+ case *events.TaskExecAdded:
+ return runtime.TaskExecAddedEventTopic
+ case *events.TaskExecStarted:
+ return runtime.TaskExecStartedEventTopic
+ default:
+ log.L.Printf("no topic for type %#v", e)
+ }
+ return runtime.TaskUnknownTopic
+}
+
+func newInit(ctx context.Context, path, workDir, namespace string, platform stdio.Platform, r *proc.CreateConfig, options *options.Options, rootfs string) (*proc.Init, error) {
+ spec, err := utils.ReadSpec(r.Bundle)
+ if err != nil {
+ return nil, fmt.Errorf("read oci spec: %w", err)
+ }
+ if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil {
+ return nil, fmt.Errorf("update volume annotations: %w", err)
+ }
+ runsc.FormatLogPath(r.ID, options.RunscConfig)
+ runtime := proc.NewRunsc(options.Root, path, namespace, options.BinaryName, options.RunscConfig)
+ p := proc.New(r.ID, runtime, stdio.Stdio{
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Terminal: r.Terminal,
+ })
+ p.Bundle = r.Bundle
+ p.Platform = platform
+ p.Rootfs = rootfs
+ p.WorkDir = workDir
+ p.IoUID = int(options.IoUid)
+ p.IoGID = int(options.IoGid)
+ p.Sandbox = specutils.SpecContainerType(spec) == specutils.ContainerTypeSandbox
+ p.UserLog = utils.UserLogPath(spec)
+ p.Monitor = reaper.Default
+ return p, nil
+}
diff --git a/pkg/shim/v2/service_linux.go b/pkg/shim/v2/service_linux.go
new file mode 100644
index 000000000..1800ab90b
--- /dev/null
+++ b/pkg/shim/v2/service_linux.go
@@ -0,0 +1,108 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package v2
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "sync"
+ "syscall"
+
+ "github.com/containerd/console"
+ "github.com/containerd/fifo"
+)
+
+type linuxPlatform struct {
+ epoller *console.Epoller
+}
+
+func (p *linuxPlatform) CopyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) (console.Console, error) {
+ if p.epoller == nil {
+ return nil, fmt.Errorf("uninitialized epoller")
+ }
+
+ epollConsole, err := p.epoller.Add(console)
+ if err != nil {
+ return nil, err
+ }
+
+ if stdin != "" {
+ in, err := fifo.OpenFifo(context.Background(), stdin, syscall.O_RDONLY|syscall.O_NONBLOCK, 0)
+ if err != nil {
+ return nil, err
+ }
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ io.CopyBuffer(epollConsole, in, *p)
+ }()
+ }
+
+ outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ outr, err := fifo.OpenFifo(ctx, stdout, syscall.O_RDONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ wg.Add(1)
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ io.CopyBuffer(outw, epollConsole, *p)
+ epollConsole.Close()
+ outr.Close()
+ outw.Close()
+ wg.Done()
+ }()
+ return epollConsole, nil
+}
+
+func (p *linuxPlatform) ShutdownConsole(ctx context.Context, cons console.Console) error {
+ if p.epoller == nil {
+ return fmt.Errorf("uninitialized epoller")
+ }
+ epollConsole, ok := cons.(*console.EpollConsole)
+ if !ok {
+ return fmt.Errorf("expected EpollConsole, got %#v", cons)
+ }
+ return epollConsole.Shutdown(p.epoller.CloseConsole)
+}
+
+func (p *linuxPlatform) Close() error {
+ return p.epoller.Close()
+}
+
+// initialize a single epoll fd to manage our consoles. `initPlatform` should
+// only be called once.
+func (s *service) initPlatform() error {
+ if s.platform != nil {
+ return nil
+ }
+ epoller, err := console.NewEpoller()
+ if err != nil {
+ return fmt.Errorf("failed to initialize epoller: %w", err)
+ }
+ s.platform = &linuxPlatform{
+ epoller: epoller,
+ }
+ go epoller.Wait()
+ return nil
+}
diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD
index e131455f7..ae0fe1522 100644
--- a/pkg/sleep/BUILD
+++ b/pkg/sleep/BUILD
@@ -12,6 +12,7 @@ go_library(
"sleep_unsafe.go",
],
visibility = ["//:sandbox"],
+ deps = ["//pkg/sync"],
)
go_test(
diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go
index af47e2ba1..1dd11707d 100644
--- a/pkg/sleep/sleep_test.go
+++ b/pkg/sleep/sleep_test.go
@@ -379,10 +379,7 @@ func TestRace(t *testing.T) {
// TestRaceInOrder tests that multiple wakers can continuously send wake requests to
// the sleeper and that the wakers are retrieved in the order asserted.
func TestRaceInOrder(t *testing.T) {
- const wakers = 100
- const wakeRequests = 10000
-
- w := make([]Waker, wakers)
+ w := make([]Waker, 10000)
s := Sleeper{}
// Associate each waker and start goroutines that will assert them.
@@ -390,19 +387,16 @@ func TestRaceInOrder(t *testing.T) {
s.AddWaker(&w[i], i)
}
go func() {
- n := 0
- for n < wakeRequests {
- wk := w[n%len(w)]
- wk.Assert()
- n++
+ for i := range w {
+ w[i].Assert()
}
}()
// Wait for all wake up notifications from all wakers.
- for i := 0; i < wakeRequests; i++ {
- v, _ := s.Fetch(true)
- if got, want := v, i%wakers; got != want {
- t.Fatalf("got %d want %d", got, want)
+ for want := range w {
+ got, _ := s.Fetch(true)
+ if got != want {
+ t.Fatalf("got %d want %d", got, want)
}
}
}
diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go
index 65bfcf778..118805492 100644
--- a/pkg/sleep/sleep_unsafe.go
+++ b/pkg/sleep/sleep_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.11
-// +build !go1.15
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
@@ -75,6 +75,8 @@ package sleep
import (
"sync/atomic"
"unsafe"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
const (
@@ -323,7 +325,12 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) {
//
// This struct is thread-safe, that is, its methods can be called concurrently
// by multiple goroutines.
+//
+// Note, it is not safe to copy a Waker as its fields are modified by value
+// (the pointer fields are individually modified with atomic operations).
type Waker struct {
+ _ sync.NoCopy
+
// s is the sleeper that this waker can wake up. Only one sleeper at a
// time is allowed. This field can have three classes of values:
// nil -- the waker is not asserted: it either is not associated with
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
index 2b1350135..089b3bbef 100644
--- a/pkg/state/BUILD
+++ b/pkg/state/BUILD
@@ -1,9 +1,47 @@
-load("//tools:defs.bzl", "go_library", "go_test", "proto_library")
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
go_template_instance(
+ name = "pending_list",
+ out = "pending_list.go",
+ package = "state",
+ prefix = "pending",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectEncodeState",
+ "ElementMapper": "pendingMapper",
+ "Linker": "*pendingEntry",
+ },
+)
+
+go_template_instance(
+ name = "deferred_list",
+ out = "deferred_list.go",
+ package = "state",
+ prefix = "deferred",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectEncodeState",
+ "ElementMapper": "deferredMapper",
+ "Linker": "*deferredEntry",
+ },
+)
+
+go_template_instance(
+ name = "complete_list",
+ out = "complete_list.go",
+ package = "state",
+ prefix = "complete",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectDecodeState",
+ "Linker": "*objectDecodeState",
+ },
+)
+
+go_template_instance(
name = "addr_range",
out = "addr_range.go",
package = "state",
@@ -29,7 +67,7 @@ go_template_instance(
types = {
"Key": "uintptr",
"Range": "addrRange",
- "Value": "reflect.Value",
+ "Value": "*objectEncodeState",
"Functions": "addrSetFunctions",
},
)
@@ -39,32 +77,24 @@ go_library(
srcs = [
"addr_range.go",
"addr_set.go",
+ "complete_list.go",
"decode.go",
+ "decode_unsafe.go",
+ "deferred_list.go",
"encode.go",
"encode_unsafe.go",
- "map.go",
- "printer.go",
+ "pending_list.go",
"state.go",
+ "state_norace.go",
+ "state_race.go",
"stats.go",
+ "types.go",
],
marshal = False,
stateify = False,
visibility = ["//:sandbox"],
deps = [
- ":object_go_proto",
- "@com_github_golang_protobuf//proto:go_default_library",
+ "//pkg/log",
+ "//pkg/state/wire",
],
)
-
-proto_library(
- name = "object",
- srcs = ["object.proto"],
- visibility = ["//:sandbox"],
-)
-
-go_test(
- name = "state_test",
- timeout = "long",
- srcs = ["state_test.go"],
- library = ":state",
-)
diff --git a/pkg/state/README.md b/pkg/state/README.md
new file mode 100644
index 000000000..1aa401193
--- /dev/null
+++ b/pkg/state/README.md
@@ -0,0 +1,158 @@
+# State Encoding and Decoding
+
+The state package implements the encoding and decoding of data structures for
+`go_stateify`. This package is designed for use cases other than the standard
+encoding packages, e.g. `gob` and `json`. Principally:
+
+* This package operates on complex object graphs and accurately serializes and
+ restores all relationships. That is, you can have things like: intrusive
+ pointers, cycles, and pointer chains of arbitrary depths. These are not
+ handled appropriately by existing encoders. This is not an implementation
+ flaw: the formats themselves are not capable of representing these graphs,
+ as they can only generate directed trees.
+
+* This package allows installing order-dependent load callbacks and then
+ resolves that graph at load time, with cycle detection. Similarly, there is
+ no analogous feature possible in the standard encoders.
+
+* This package handles the resolution of interfaces, based on a registered
+ type name. For interface objects type information is saved in the serialized
+ format. This is generally true for `gob` as well, but it works differently.
+
+Here's an overview of how encoding and decoding works.
+
+## Encoding
+
+Encoding produces a `statefile`, which contains a list of chunks of the form
+`(header, payload)`. The payload can either be some raw data, or a series of
+encoded wire objects representing some object graph. All encoded objects are
+defined in the `wire` subpackage.
+
+Encoding of an object graph begins with `encodeState.Save`.
+
+### 1. Memory Map & Encoding
+
+To discover relationships between potentially interdependent data structures
+(for example, a struct may contain pointers to members of other data
+structures), the encoder first walks the object graph and constructs a memory
+map of the objects in the input graph. As this walk progresses, objects are
+queued in the `pending` list and items are placed on the `deferred` list as they
+are discovered. No single object will be encoded multiple times, but the
+discovered relationships between objects may change as more parts of the overall
+object graph are discovered.
+
+The encoder starts at the root object and recursively visits all reachable
+objects, recording the address ranges containing the underlying data for each
+object. This is stored as a segment set (`addrSet`), mapping address ranges to
+the of the object occupying the range; see `encodeState.values`. Note that there
+is special handling for zero-sized types and map objects during this process.
+
+Additionally, the encoder assigns each object a unique identifier which is used
+to indicate relationships between objects in the statefile; see `objectID` in
+`encode.go`.
+
+### 2. Type Serialization
+
+The enoder will subsequently serialize all information about discovered types,
+including field names. These are used during decoding to reconcile these types
+with other internally registered types.
+
+### 3. Object Serialization
+
+With a full address map, and all objects correctly encoded, all object encodings
+are serialized. The assigned `objectID`s aren't explicitly encoded in the
+statefile. The order of object messages in the stream determine their IDs.
+
+### Example
+
+Given the following data structure definitions:
+
+```go
+type system struct {
+ o *outer
+ i *inner
+}
+
+type outer struct {
+ a int64
+ cn *container
+}
+
+type container struct {
+ n uint64
+ elem *inner
+}
+
+type inner struct {
+ c container
+ x, y uint64
+}
+```
+
+Initialized like this:
+
+```go
+o := outer{
+ a: 10,
+ cn: nil,
+}
+i := inner{
+ x: 20,
+ y: 30,
+ c: container{},
+}
+s := system{
+ o: &o,
+ i: &i,
+}
+
+o.cn = &i.c
+o.cn.elem = &i
+
+```
+
+Encoding will produce an object stream like this:
+
+```
+g0r1 = struct{
+ i: g0r3,
+ o: g0r2,
+}
+g0r2 = struct{
+ a: 10,
+ cn: g0r3.c,
+}
+g0r3 = struct{
+ c: struct{
+ elem: g0r3,
+ n: 0u,
+ },
+ x: 20u,
+ y: 30u,
+}
+```
+
+Note how `g0r3.c` is correctly encoded as the underlying `container` object for
+`inner.c`, and how the pointer from `outer.cn` points to it, despite `system.i`
+being discovered after the pointer to it in `system.o.cn`. Also note that
+decoding isn't strictly reliant on the order of encoded object stream, as long
+as the relationship between objects are correctly encoded.
+
+## Decoding
+
+Decoding reads the statefile and reconstructs the object graph. Decoding begins
+in `decodeState.Load`. Decoding is performed in a single pass over the object
+stream in the statefile, and a subsequent pass over all deserialized objects is
+done to fire off all loading callbacks in the correctly defined order. Note that
+introducing cycles is possible here, but these are detected and an error will be
+returned.
+
+Decoding is relatively straight forward. For most primitive values, the decoder
+constructs an appropriate object and fills it with the values encoded in the
+statefile. Pointers need special handling, as they must point to a value
+allocated elsewhere. When values are constructed, the decoder indexes them by
+their `objectID`s in `decodeState.objectsByID`. The target of pointers are
+resolved by searching for the target in this index by their `objectID`; see
+`decodeState.register`. For pointers to values inside another value (fields in a
+pointer, elements of an array), the decoder uses the accessor path to walk to
+the appropriate location; see `walkChild`.
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
index 590c241a3..c9971cdf6 100644
--- a/pkg/state/decode.go
+++ b/pkg/state/decode.go
@@ -17,28 +17,49 @@ package state
import (
"bytes"
"context"
- "encoding/binary"
- "errors"
"fmt"
- "io"
+ "math"
"reflect"
- "sort"
- "github.com/golang/protobuf/proto"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
-// objectState represents an object that may be in the process of being
+// internalCallback is a interface called on object completion.
+//
+// There are two implementations: objectDecodeState & userCallback.
+type internalCallback interface {
+ // source returns the dependent object. May be nil.
+ source() *objectDecodeState
+
+ // callbackRun executes the callback.
+ callbackRun()
+}
+
+// userCallback is an implementation of internalCallback.
+type userCallback func()
+
+// source implements internalCallback.source.
+func (userCallback) source() *objectDecodeState {
+ return nil
+}
+
+// callbackRun implements internalCallback.callbackRun.
+func (uc userCallback) callbackRun() {
+ uc()
+}
+
+// objectDecodeState represents an object that may be in the process of being
// decoded. Specifically, it represents either a decoded object, or an an
// interest in a future object that will be decoded. When that interest is
// registered (via register), the storage for the object will be created, but
// it will not be decoded until the object is encountered in the stream.
-type objectState struct {
+type objectDecodeState struct {
// id is the id for this object.
- //
- // If this field is zero, then this is an anonymous (unregistered,
- // non-reference primitive) object. This is immutable.
- id uint64
+ id objectID
+
+ // typ is the id for this typeID. This may be zero if this is not a
+ // type-registered structure.
+ typ typeID
// obj is the object. This may or may not be valid yet, depending on
// whether complete returns true. However, regardless of whether the
@@ -57,69 +78,52 @@ type objectState struct {
// blockedBy is the number of dependencies this object has.
blockedBy int
- // blocking is a list of the objects blocked by this one.
- blocking []*objectState
+ // callbacksInline is inline storage for callbacks.
+ callbacksInline [2]internalCallback
// callbacks is a set of callbacks to execute on load.
- callbacks []func()
-
- // path is the decoding path to the object.
- path recoverable
-}
-
-// complete indicates the object is complete.
-func (os *objectState) complete() bool {
- return os.blockedBy == 0 && len(os.callbacks) == 0
-}
-
-// checkComplete checks for completion. If the object is complete, pending
-// callbacks will be executed and checkComplete will be called on downstream
-// objects (those depending on this one).
-func (os *objectState) checkComplete(stats *Stats) {
- if os.blockedBy > 0 {
- return
- }
- stats.Start(os.obj)
+ callbacks []internalCallback
- // Fire all callbacks.
- for _, fn := range os.callbacks {
- fn()
- }
- os.callbacks = nil
-
- // Clear all blocked objects.
- for _, other := range os.blocking {
- other.blockedBy--
- other.checkComplete(stats)
- }
- os.blocking = nil
- stats.Done()
+ completeEntry
}
-// waitFor queues a dependency on the given object.
-func (os *objectState) waitFor(other *objectState, callback func()) {
- os.blockedBy++
- other.blocking = append(other.blocking, os)
- if callback != nil {
- other.callbacks = append(other.callbacks, callback)
+// addCallback adds a callback to the objectDecodeState.
+func (ods *objectDecodeState) addCallback(ic internalCallback) {
+ if ods.callbacks == nil {
+ ods.callbacks = ods.callbacksInline[:0]
}
+ ods.callbacks = append(ods.callbacks, ic)
}
// findCycleFor returns when the given object is found in the blocking set.
-func (os *objectState) findCycleFor(target *objectState) []*objectState {
- for _, other := range os.blocking {
- if other == target {
- return []*objectState{target}
+func (ods *objectDecodeState) findCycleFor(target *objectDecodeState) []*objectDecodeState {
+ for _, ic := range ods.callbacks {
+ other := ic.source()
+ if other != nil && other == target {
+ return []*objectDecodeState{target}
} else if childList := other.findCycleFor(target); childList != nil {
return append(childList, other)
}
}
- return nil
+
+ // This should not occur.
+ Failf("no deadlock found?")
+ panic("unreachable")
}
// findCycle finds a dependency cycle.
-func (os *objectState) findCycle() []*objectState {
- return append(os.findCycleFor(os), os)
+func (ods *objectDecodeState) findCycle() []*objectDecodeState {
+ return append(ods.findCycleFor(ods), ods)
+}
+
+// source implements internalCallback.source.
+func (ods *objectDecodeState) source() *objectDecodeState {
+ return ods
+}
+
+// callbackRun implements internalCallback.callbackRun.
+func (ods *objectDecodeState) callbackRun() {
+ ods.blockedBy--
}
// decodeState is a graph of objects in the process of being decoded.
@@ -137,30 +141,66 @@ type decodeState struct {
// ctx is the decode context.
ctx context.Context
+ // r is the input stream.
+ r wire.Reader
+
+ // types is the type database.
+ types typeDecodeDatabase
+
// objectByID is the set of objects in progress.
- objectsByID map[uint64]*objectState
+ objectsByID []*objectDecodeState
// deferred are objects that have been read, by no interest has been
// registered yet. These will be decoded once interest in registered.
- deferred map[uint64]*pb.Object
+ deferred map[objectID]wire.Object
- // outstanding is the number of outstanding objects.
- outstanding uint32
+ // pending is the set of objects that are not yet complete.
+ pending completeList
- // r is the input stream.
- r io.Reader
-
- // stats is the passed stats object.
- stats *Stats
-
- // recoverable is the panic recover facility.
- recoverable
+ // stats tracks time data.
+ stats Stats
}
// lookup looks up an object in decodeState or returns nil if no such object
// has been previously registered.
-func (ds *decodeState) lookup(id uint64) *objectState {
- return ds.objectsByID[id]
+func (ds *decodeState) lookup(id objectID) *objectDecodeState {
+ if len(ds.objectsByID) < int(id) {
+ return nil
+ }
+ return ds.objectsByID[id-1]
+}
+
+// checkComplete checks for completion.
+func (ds *decodeState) checkComplete(ods *objectDecodeState) bool {
+ // Still blocked?
+ if ods.blockedBy > 0 {
+ return false
+ }
+
+ // Track stats if relevant.
+ if ods.callbacks != nil && ods.typ != 0 {
+ ds.stats.start(ods.typ)
+ defer ds.stats.done()
+ }
+
+ // Fire all callbacks.
+ for _, ic := range ods.callbacks {
+ ic.callbackRun()
+ }
+
+ // Mark completed.
+ cbs := ods.callbacks
+ ods.callbacks = nil
+ ds.pending.Remove(ods)
+
+ // Recursively check others.
+ for _, ic := range cbs {
+ if other := ic.source(); other != nil && other.blockedBy == 0 {
+ ds.checkComplete(other)
+ }
+ }
+
+ return true // All set.
}
// wait registers a dependency on an object.
@@ -168,11 +208,8 @@ func (ds *decodeState) lookup(id uint64) *objectState {
// As a special case, we always allow _useable_ references back to the first
// decoding object because it may have fields that are already decoded. We also
// allow trivial self reference, since they can be handled internally.
-func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) {
+func (ds *decodeState) wait(waiter *objectDecodeState, id objectID, callback func()) {
switch id {
- case 0:
- // Nil pointer; nothing to wait for.
- fallthrough
case waiter.id:
// Trivial self reference.
fallthrough
@@ -184,107 +221,188 @@ func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) {
return
}
+ // Mark as blocked.
+ waiter.blockedBy++
+
// No nil can be returned here.
- waiter.waitFor(ds.lookup(id), callback)
+ other := ds.lookup(id)
+ if callback != nil {
+ // Add the additional user callback.
+ other.addCallback(userCallback(callback))
+ }
+
+ // Mark waiter as unblocked.
+ other.addCallback(waiter)
}
// waitObject notes a blocking relationship.
-func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) {
- if rv, ok := p.Value.(*pb.Object_RefValue); ok {
+func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, callback func()) {
+ if rv, ok := encoded.(*wire.Ref); ok && rv.Root != 0 {
// Refs can encode pointers and maps.
- ds.wait(os, rv.RefValue, callback)
- } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok {
+ ds.wait(ods, objectID(rv.Root), callback)
+ } else if sv, ok := encoded.(*wire.Slice); ok && sv.Ref.Root != 0 {
// See decodeObject; we need to wait for the array (if non-nil).
- ds.wait(os, sv.SliceValue.RefValue, callback)
- } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok {
+ ds.wait(ods, objectID(sv.Ref.Root), callback)
+ } else if iv, ok := encoded.(*wire.Interface); ok {
// It's an interface (wait recurisvely).
- ds.waitObject(os, iv.InterfaceValue.Value, callback)
+ ds.waitObject(ods, iv.Value, callback)
} else if callback != nil {
// Nothing to wait for: execute the callback immediately.
callback()
}
}
+// walkChild returns a child object from obj, given an accessor path. This is
+// the decode-side equivalent to traverse in encode.go.
+//
+// For the purposes of this function, a child object is either a field within a
+// struct or an array element, with one such indirection per element in
+// path. The returned value may be an unexported field, so it may not be
+// directly assignable. See unsafePointerTo.
+func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value {
+ // See wire.Ref.Dots. The path here is specified in reverse order.
+ for i := len(path) - 1; i >= 0; i-- {
+ switch pc := path[i].(type) {
+ case *wire.FieldName: // Must be a pointer.
+ if obj.Kind() != reflect.Struct {
+ Failf("next component in child path is a field name, but the current object is not a struct. Path: %v, current obj: %#v", path, obj)
+ }
+ obj = obj.FieldByName(string(*pc))
+ case wire.Index: // Embedded.
+ if obj.Kind() != reflect.Array {
+ Failf("next component in child path is an array index, but the current object is not an array. Path: %v, current obj: %#v", path, obj)
+ }
+ obj = obj.Index(int(pc))
+ default:
+ panic("unreachable: switch should be exhaustive")
+ }
+ }
+ return obj
+}
+
// register registers a decode with a type.
//
// This type is only used to instantiate a new object if it has not been
-// registered previously.
-func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState {
- os, ok := ds.objectsByID[id]
- if ok {
- return os
+// registered previously. This depends on the type provided if none is
+// available in the object itself.
+func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value {
+ // Grow the objectsByID slice.
+ id := objectID(r.Root)
+ if len(ds.objectsByID) < int(id) {
+ ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...)
+ }
+
+ // Does this object already exist?
+ ods := ds.objectsByID[id-1]
+ if ods != nil {
+ return walkChild(r.Dots, ods.obj)
+ }
+
+ // Create the object.
+ if len(r.Dots) != 0 {
+ typ = ds.findType(r.Type)
}
+ v := reflect.New(typ)
+ ods = &objectDecodeState{
+ id: id,
+ obj: v.Elem(),
+ }
+ ds.objectsByID[id-1] = ods
+ ds.pending.PushBack(ods)
- // Record in the object index.
- if typ.Kind() == reflect.Map {
- os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()}
- } else {
- os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()}
+ // Process any deferred objects & callbacks.
+ if encoded, ok := ds.deferred[id]; ok {
+ delete(ds.deferred, id)
+ ds.decodeObject(ods, ods.obj, encoded)
}
- ds.objectsByID[id] = os
- if o, ok := ds.deferred[id]; ok {
- // There is a deferred object.
- delete(ds.deferred, id) // Free memory.
- ds.decodeObject(os, os.obj, o, "", nil)
- } else {
- // There is no deferred object.
- ds.outstanding++
+ return walkChild(r.Dots, ods.obj)
+}
+
+// objectDecoder is for decoding structs.
+type objectDecoder struct {
+ // ds is decodeState.
+ ds *decodeState
+
+ // ods is current object being decoded.
+ ods *objectDecodeState
+
+ // reconciledTypeEntry is the reconciled type information.
+ rte *reconciledTypeEntry
+
+ // encoded is the encoded object state.
+ encoded *wire.Struct
+}
+
+// load is helper for the public methods on Source.
+func (od *objectDecoder) load(slot int, objPtr reflect.Value, wait bool, fn func()) {
+ // Note that we have reconciled the type and may remap the fields here
+ // to match what's expected by the decoder. The "slot" parameter here
+ // is in terms of the local type, where the fields in the encoded
+ // object are in terms of the wire object's type, which might be in a
+ // different order (but will have the same fields).
+ v := *od.encoded.Field(od.rte.FieldOrder[slot])
+ od.ds.decodeObject(od.ods, objPtr.Elem(), v)
+ if wait {
+ // Mark this individual object a blocker.
+ od.ds.waitObject(od.ods, v, fn)
}
+}
- return os
+// aterLoad implements Source.AfterLoad.
+func (od *objectDecoder) afterLoad(fn func()) {
+ // Queue the local callback; this will execute when all of the above
+ // data dependencies have been cleared.
+ od.ods.addCallback(userCallback(fn))
}
// decodeStruct decodes a struct value.
-func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) {
- // Set the fields.
- m := Map{newInternalMap(nil, ds, os)}
- defer internalMapPool.Put(m.internalMap)
- for _, field := range s.Fields {
- m.data = append(m.data, entry{
- name: field.Name,
- object: field.Value,
- })
- }
-
- // Sort the fields for efficient searching.
- //
- // Technically, these should already appear in sorted order in the
- // state ordering, so this cost is effectively a single scan to ensure
- // that the order is correct.
- if len(m.data) > 1 {
- sort.Slice(m.data, func(i, j int) bool {
- return m.data[i].name < m.data[j].name
- })
- }
-
- // Invoke the load; this will recursively decode other objects.
- fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
- if ok {
- // Invoke the loader.
- fns.invokeLoad(obj.Addr(), m)
- } else if obj.NumField() == 0 {
- // Allow anonymous empty structs.
- return
- } else {
+func (ds *decodeState) decodeStruct(ods *objectDecodeState, obj reflect.Value, encoded *wire.Struct) {
+ if encoded.TypeID == 0 {
+ // Allow anonymous empty structs, but only if the encoded
+ // object also has no fields.
+ if encoded.Fields() == 0 && obj.NumField() == 0 {
+ return
+ }
+
// Propagate an error.
- panic(fmt.Errorf("unregistered type %s", obj.Type()))
+ Failf("empty struct on wire %#v has field mismatch with type %q", encoded, obj.Type().Name())
+ }
+
+ // Lookup the object type.
+ rte := ds.types.Lookup(typeID(encoded.TypeID), obj.Type())
+ ods.typ = typeID(encoded.TypeID)
+
+ // Invoke the loader.
+ od := objectDecoder{
+ ds: ds,
+ ods: ods,
+ rte: rte,
+ encoded: encoded,
+ }
+ ds.stats.start(ods.typ)
+ defer ds.stats.done()
+ if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
+ // Note: may be a registered empty struct which does not
+ // implement the saver/loader interfaces.
+ sl.StateLoad(Source{internal: od})
}
}
// decodeMap decodes a map value.
-func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) {
+func (ds *decodeState) decodeMap(ods *objectDecodeState, obj reflect.Value, encoded *wire.Map) {
if obj.IsNil() {
+ // See pointerTo.
obj.Set(reflect.MakeMap(obj.Type()))
}
- for i := 0; i < len(m.Keys); i++ {
+ for i := 0; i < len(encoded.Keys); i++ {
// Decode the objects.
kv := reflect.New(obj.Type().Key()).Elem()
vv := reflect.New(obj.Type().Elem()).Elem()
- ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i)
- ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface())
- ds.waitObject(os, m.Keys[i], nil)
- ds.waitObject(os, m.Values[i], nil)
+ ds.decodeObject(ods, kv, encoded.Keys[i])
+ ds.decodeObject(ods, vv, encoded.Values[i])
+ ds.waitObject(ods, encoded.Keys[i], nil)
+ ds.waitObject(ods, encoded.Values[i], nil)
// Set in the map.
obj.SetMapIndex(kv, vv)
@@ -292,271 +410,294 @@ func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map)
}
// decodeArray decodes an array value.
-func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) {
- if len(a.Contents) != obj.Len() {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents)))
+func (ds *decodeState) decodeArray(ods *objectDecodeState, obj reflect.Value, encoded *wire.Array) {
+ if len(encoded.Contents) != obj.Len() {
+ Failf("mismatching array length expect=%d, actual=%d", obj.Len(), len(encoded.Contents))
}
// Decode the contents into the array.
- for i := 0; i < len(a.Contents); i++ {
- ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i)
- ds.waitObject(os, a.Contents[i], nil)
+ for i := 0; i < len(encoded.Contents); i++ {
+ ds.decodeObject(ods, obj.Index(i), encoded.Contents[i])
+ ds.waitObject(ods, encoded.Contents[i], nil)
}
}
-// decodeInterface decodes an interface value.
-func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) {
- // Is this a nil value?
- if i.Type == "" {
- return // Just leave obj alone.
+// findType finds the type for the given wire.TypeSpecs.
+func (ds *decodeState) findType(t wire.TypeSpec) reflect.Type {
+ switch x := t.(type) {
+ case wire.TypeID:
+ typ := ds.types.LookupType(typeID(x))
+ rte := ds.types.Lookup(typeID(x), typ)
+ return rte.LocalType
+ case *wire.TypeSpecPointer:
+ return reflect.PtrTo(ds.findType(x.Type))
+ case *wire.TypeSpecArray:
+ return reflect.ArrayOf(int(x.Count), ds.findType(x.Type))
+ case *wire.TypeSpecSlice:
+ return reflect.SliceOf(ds.findType(x.Type))
+ case *wire.TypeSpecMap:
+ return reflect.MapOf(ds.findType(x.Key), ds.findType(x.Value))
+ default:
+ // Should not happen.
+ Failf("unknown type %#v", t)
}
+ panic("unreachable")
+}
- // Get the dispatchable type. This may not be used if the given
- // reference has already been resolved, but if not we need to know the
- // type to create.
- t, ok := registeredTypes.lookupType(i.Type)
- if !ok {
- panic(fmt.Errorf("no valid type for %q", i.Type))
+// decodeInterface decodes an interface value.
+func (ds *decodeState) decodeInterface(ods *objectDecodeState, obj reflect.Value, encoded *wire.Interface) {
+ if _, ok := encoded.Type.(wire.TypeSpecNil); ok {
+ // Special case; the nil object. Just decode directly, which
+ // will read nil from the wire (if encoded correctly).
+ ds.decodeObject(ods, obj, encoded.Value)
+ return
}
- if obj.Kind() != reflect.Map {
- // Set the obj to be the given typed value; this actually sets
- // obj to be a non-zero value -- namely, it inserts type
- // information. There's no need to do this for maps.
- obj.Set(reflect.Zero(t))
+ // We now need to resolve the actual type.
+ typ := ds.findType(encoded.Type)
+
+ // We need to imbue type information here, then we can proceed to
+ // decode normally. In order to avoid issues with setting value-types,
+ // we create a new non-interface version of this object. We will then
+ // set the interface object to be equal to whatever we decode.
+ origObj := obj
+ obj = reflect.New(typ).Elem()
+ defer origObj.Set(obj)
+
+ // With the object now having sufficient type information to actually
+ // have Set called on it, we can proceed to decode the value.
+ ds.decodeObject(ods, obj, encoded.Value)
+}
+
+// isFloatEq determines if x and y represent the same value.
+func isFloatEq(x float64, y float64) bool {
+ switch {
+ case math.IsNaN(x):
+ return math.IsNaN(y)
+ case math.IsInf(x, 1):
+ return math.IsInf(y, 1)
+ case math.IsInf(x, -1):
+ return math.IsInf(y, -1)
+ default:
+ return x == y
}
+}
- // Decode the dereferenced element; there is no need to wait here, as
- // the interface object shares the current object state.
- ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type)
+// isComplexEq determines if x and y represent the same value.
+func isComplexEq(x complex128, y complex128) bool {
+ return isFloatEq(real(x), real(y)) && isFloatEq(imag(x), imag(y))
}
// decodeObject decodes a object value.
-func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) {
- ds.push(false, format, param)
- ds.stats.Add(obj)
- ds.stats.Start(obj)
-
- switch x := object.GetValue().(type) {
- case *pb.Object_BoolValue:
- obj.SetBool(x.BoolValue)
- case *pb.Object_StringValue:
- obj.SetString(string(x.StringValue))
- case *pb.Object_Int64Value:
- obj.SetInt(x.Int64Value)
- if obj.Int() != x.Int64Value {
- panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type()))
+func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, encoded wire.Object) {
+ switch x := encoded.(type) {
+ case wire.Nil: // Fast path: first.
+ // We leave obj alone here. That's because if obj represents an
+ // interface, it may have been imbued with type information in
+ // decodeInterface, and we don't want to destroy that.
+ case *wire.Ref:
+ // Nil pointers may be encoded in a "forceValue" context. For
+ // those we just leave it alone as the value will already be
+ // correct (nil).
+ if id := objectID(x.Root); id == 0 {
+ return
}
- case *pb.Object_Uint64Value:
- obj.SetUint(x.Uint64Value)
- if obj.Uint() != x.Uint64Value {
- panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type()))
- }
- case *pb.Object_DoubleValue:
- obj.SetFloat(x.DoubleValue)
- if obj.Float() != x.DoubleValue {
- panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type()))
- }
- case *pb.Object_RefValue:
- // Resolve the pointer itself, even though the object may not
- // be decoded yet. You need to use wait() in order to ensure
- // that is the case. See wait above, and Map.Barrier.
- if id := x.RefValue; id != 0 {
- // Decoding the interface should have imparted type
- // information, so from this point it's safe to resolve
- // and use this dynamic information for actually
- // creating the object in register.
- //
- // (For non-interfaces this is a no-op).
- dyntyp := reflect.TypeOf(obj.Interface())
- if dyntyp.Kind() == reflect.Map {
- // Remove the map object count here to avoid
- // double counting, as this object will be
- // counted again when it gets processed later.
- // We do not add a reference count as the
- // reference is artificial.
- ds.stats.Remove(obj)
- obj.Set(ds.register(id, dyntyp).obj)
- } else if dyntyp.Kind() == reflect.Ptr {
- ds.push(true /* dereference */, "", nil)
- obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
- ds.pop()
- } else {
- obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
+
+ // Note that if this is a map type, we go through a level of
+ // indirection to allow for map aliasing.
+ if obj.Kind() == reflect.Map {
+ v := ds.register(x, obj.Type())
+ if v.IsNil() {
+ // Note that we don't want to clobber the map
+ // if has already been decoded by decodeMap. We
+ // just make it so that we have a consistent
+ // reference when that eventually does happen.
+ v.Set(reflect.MakeMap(v.Type()))
}
- } else {
- // We leave obj alone here. That's because if obj
- // represents an interface, it may have been embued
- // with type information in decodeInterface, and we
- // don't want to destroy that information.
+ obj.Set(v)
+ return
}
- case *pb.Object_SliceValue:
- // It's okay to slice the array here, since the contents will
- // still be provided later on. These semantics are a bit
- // strange but they are handled in the Map.Barrier properly.
- //
- // The special semantics of zero ref apply here too.
- if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 {
- v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem())
- obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity)))
+
+ // Normal assignment: authoritative only if no dots.
+ v := ds.register(x, obj.Type().Elem())
+ if v.IsValid() {
+ obj.Set(unsafePointerTo(v))
}
- case *pb.Object_ArrayValue:
- ds.decodeArray(os, obj, x.ArrayValue)
- case *pb.Object_StructValue:
- ds.decodeStruct(os, obj, x.StructValue)
- case *pb.Object_MapValue:
- ds.decodeMap(os, obj, x.MapValue)
- case *pb.Object_InterfaceValue:
- ds.decodeInterface(os, obj, x.InterfaceValue)
- case *pb.Object_ByteArrayValue:
- copyArray(obj, reflect.ValueOf(x.ByteArrayValue))
- case *pb.Object_Uint16ArrayValue:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := x.Uint16ArrayValue.Values
- t := obj.Slice(0, obj.Len()).Interface().([]uint16)
- if len(t) != len(s) {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
+ case wire.Bool:
+ obj.SetBool(bool(x))
+ case wire.Int:
+ obj.SetInt(int64(x))
+ if obj.Int() != int64(x) {
+ Failf("signed integer truncated from %v to %v", int64(x), obj.Int())
}
- for i := range s {
- t[i] = uint16(s[i])
+ case wire.Uint:
+ obj.SetUint(uint64(x))
+ if obj.Uint() != uint64(x) {
+ Failf("unsigned integer truncated from %v to %v", uint64(x), obj.Uint())
}
- case *pb.Object_Uint32ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values))
- case *pb.Object_Uint64ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values))
- case *pb.Object_UintptrArrayValue:
- copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0))))
- case *pb.Object_Int8ArrayValue:
- copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0))))
- case *pb.Object_Int16ArrayValue:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := x.Int16ArrayValue.Values
- t := obj.Slice(0, obj.Len()).Interface().([]int16)
- if len(t) != len(s) {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
+ case wire.Float32:
+ obj.SetFloat(float64(x))
+ case wire.Float64:
+ obj.SetFloat(float64(x))
+ if !isFloatEq(obj.Float(), float64(x)) {
+ Failf("floating point number truncated from %v to %v", float64(x), obj.Float())
}
- for i := range s {
- t[i] = int16(s[i])
+ case *wire.Complex64:
+ obj.SetComplex(complex128(*x))
+ case *wire.Complex128:
+ obj.SetComplex(complex128(*x))
+ if !isComplexEq(obj.Complex(), complex128(*x)) {
+ Failf("complex number truncated from %v to %v", complex128(*x), obj.Complex())
}
- case *pb.Object_Int32ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values))
- case *pb.Object_Int64ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values))
- case *pb.Object_BoolArrayValue:
- copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values))
- case *pb.Object_Float64ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values))
- case *pb.Object_Float32ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values))
+ case *wire.String:
+ obj.SetString(string(*x))
+ case *wire.Slice:
+ // See *wire.Ref above; same applies.
+ if id := objectID(x.Ref.Root); id == 0 {
+ return
+ }
+ // Note that it's fine to slice the array here and assume that
+ // contents will still be filled in later on.
+ typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type.
+ v := ds.register(&x.Ref, typ)
+ obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity)))
+ case *wire.Array:
+ ds.decodeArray(ods, obj, x)
+ case *wire.Struct:
+ ds.decodeStruct(ods, obj, x)
+ case *wire.Map:
+ ds.decodeMap(ods, obj, x)
+ case *wire.Interface:
+ ds.decodeInterface(ods, obj, x)
default:
// Shoud not happen, not propagated as an error.
- panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type()))
- }
-
- ds.stats.Done()
- ds.pop()
-}
-
-func copyArray(dest reflect.Value, src reflect.Value) {
- if dest.Len() != src.Len() {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len()))
+ Failf("unknown object %#v for %q", encoded, obj.Type().Name())
}
- reflect.Copy(dest, castSlice(src, dest.Type().Elem()))
}
-// Deserialize deserializes the object state.
+// Load deserializes the object graph rooted at obj.
//
// This function may panic and should be run in safely().
-func (ds *decodeState) Deserialize(obj reflect.Value) {
- ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()}
- ds.outstanding = 1 // The root object.
+func (ds *decodeState) Load(obj reflect.Value) {
+ ds.stats.init()
+ defer ds.stats.fini(func(id typeID) string {
+ return ds.types.LookupName(id)
+ })
+
+ // Create the root object.
+ ds.objectsByID = append(ds.objectsByID, &objectDecodeState{
+ id: 1,
+ obj: obj,
+ })
+
+ // Read the number of objects.
+ lastID, object, err := ReadHeader(ds.r)
+ if err != nil {
+ Failf("header error: %w", err)
+ }
+ if !object {
+ Failf("object missing")
+ }
+
+ // Decode all objects.
+ var (
+ encoded wire.Object
+ ods *objectDecodeState
+ id = objectID(1)
+ tid = typeID(1)
+ )
+ if err := safely(func() {
+ // Decode all objects in the stream.
+ //
+ // Note that the structure of this decoding loop should match
+ // the raw decoding loop in printer.go.
+ for id <= objectID(lastID) {
+ // Unmarshal the object.
+ encoded = wire.Load(ds.r)
+
+ // Is this a type object? Handle inline.
+ if wt, ok := encoded.(*wire.Type); ok {
+ ds.types.Register(wt)
+ tid++
+ encoded = nil
+ continue
+ }
- // Decode all objects in the stream.
- //
- // See above, we never process objects while we have no outstanding
- // interests (other than the very first object).
- for id := uint64(1); ds.outstanding > 0; id++ {
- os := ds.lookup(id)
- ds.stats.Start(os.obj)
-
- o, err := ds.readObject()
- if err != nil {
- panic(err)
- }
+ // Actually resolve the object.
+ ods = ds.lookup(id)
+ if ods != nil {
+ // Decode the object.
+ ds.decodeObject(ods, ods.obj, encoded)
+ } else {
+ // If an object hasn't had interest registered
+ // previously or isn't yet valid, we deferred
+ // decoding until interest is registered.
+ ds.deferred[id] = encoded
+ }
- if os != nil {
- // Decode the object.
- ds.from = &os.path
- ds.decodeObject(os, os.obj, o, "", nil)
- ds.outstanding--
+ // For error handling.
+ ods = nil
+ encoded = nil
+ id++
+ }
+ }); err != nil {
+ // Include as much information as we can, taking into account
+ // the possible state transitions above.
+ if ods != nil {
+ Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err)
+ } else if encoded != nil {
+ Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err)
} else {
- // If an object hasn't had interest registered
- // previously, we deferred decoding until interest is
- // registered.
- ds.deferred[id] = o
+ Failf("general decoding error: %w", err)
}
-
- ds.stats.Done()
- }
-
- // Check the zero-length header at the end.
- length, object, err := ReadHeader(ds.r)
- if err != nil {
- panic(err)
- }
- if length != 0 {
- panic(fmt.Sprintf("expected zero-length terminal, got %d", length))
- }
- if object {
- panic("expected non-object terminal")
}
// Check if we have any deferred objects.
- if count := len(ds.deferred); count > 0 {
- // Shoud not happen, not propagated as an error.
- panic(fmt.Sprintf("still have %d deferred objects", count))
- }
-
- // Scan and fire all callbacks.
- for _, os := range ds.objectsByID {
- os.checkComplete(ds.stats)
+ for id, encoded := range ds.deferred {
+ // Shoud never happen, the graph was bogus.
+ Failf("still have deferred objects: one is ID %d, %#v", id, encoded)
}
- // Check if we have any remaining dependency cycles.
- for _, os := range ds.objectsByID {
- if !os.complete() {
- // This must be the result of a dependency cycle.
- cycle := os.findCycle()
- var buf bytes.Buffer
- buf.WriteString("dependency cycle: {")
- for i, cycleOS := range cycle {
- if i > 0 {
- buf.WriteString(" => ")
+ // Scan and fire all callbacks. We iterate over the list of incomplete
+ // objects until all have been finished. We stop iterating if no
+ // objects become complete (there is a dependency cycle).
+ //
+ // Note that we iterate backwards here, because there will be a strong
+ // tendendcy for blocking relationships to go from earlier objects to
+ // later (deeper) objects in the graph. This will reduce the number of
+ // iterations required to finish all objects.
+ if err := safely(func() {
+ for ds.pending.Back() != nil {
+ thisCycle := false
+ for ods = ds.pending.Back(); ods != nil; {
+ if ds.checkComplete(ods) {
+ thisCycle = true
+ break
}
- buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type()))
+ ods = ods.Prev()
+ }
+ if !thisCycle {
+ break
}
- buf.WriteString("}")
- // Panic as an error; propagate to the caller.
- panic(errors.New(string(buf.Bytes())))
}
- }
-}
-
-type byteReader struct {
- io.Reader
-}
-
-// ReadByte implements io.ByteReader.
-func (br byteReader) ReadByte() (byte, error) {
- var b [1]byte
- n, err := br.Reader.Read(b[:])
- if n > 0 {
- return b[0], nil
- } else if err != nil {
- return 0, err
- } else {
- return 0, io.ErrUnexpectedEOF
+ }); err != nil {
+ Failf("error executing callbacks for %#v: %w", ods.obj.Interface(), err)
+ }
+
+ // Check if we have any remaining dependency cycles. If there are any
+ // objects left in the pending list, then it must be due to a cycle.
+ if ods := ds.pending.Front(); ods != nil {
+ // This must be the result of a dependency cycle.
+ cycle := ods.findCycle()
+ var buf bytes.Buffer
+ buf.WriteString("dependency cycle: {")
+ for i, cycleOS := range cycle {
+ if i > 0 {
+ buf.WriteString(" => ")
+ }
+ fmt.Fprintf(&buf, "%q", cycleOS.obj.Type())
+ }
+ buf.WriteString("}")
+ Failf("incomplete graph: %s", string(buf.Bytes()))
}
}
@@ -565,45 +706,20 @@ func (br byteReader) ReadByte() (byte, error) {
// Each object written to the statefile is prefixed with a header. See
// WriteHeader for more information; these functions are exported to allow
// non-state writes to the file to play nice with debugging tools.
-func ReadHeader(r io.Reader) (length uint64, object bool, err error) {
+func ReadHeader(r wire.Reader) (length uint64, object bool, err error) {
// Read the header.
- length, err = binary.ReadUvarint(byteReader{r})
+ err = safely(func() {
+ length = wire.LoadUint(r)
+ })
if err != nil {
- return
+ // On the header, pass raw I/O errors.
+ if sErr, ok := err.(*ErrState); ok {
+ return 0, false, sErr.Unwrap()
+ }
}
// Decode whether the object is valid.
- object = length&0x1 != 0
- length = length >> 1
+ object = length&objectFlag != 0
+ length &^= objectFlag
return
}
-
-// readObject reads an object from the stream.
-func (ds *decodeState) readObject() (*pb.Object, error) {
- // Read the header.
- length, object, err := ReadHeader(ds.r)
- if err != nil {
- return nil, err
- }
- if !object {
- return nil, fmt.Errorf("invalid object header")
- }
-
- // Read the object.
- buf := make([]byte, length)
- for done := 0; done < len(buf); {
- n, err := ds.r.Read(buf[done:])
- done += n
- if n == 0 && err != nil {
- return nil, err
- }
- }
-
- // Unmarshal.
- obj := new(pb.Object)
- if err := proto.Unmarshal(buf, obj); err != nil {
- return nil, err
- }
-
- return obj, nil
-}
diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go
new file mode 100644
index 000000000..d048f61a1
--- /dev/null
+++ b/pkg/state/decode_unsafe.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on
+// values representing unexported fields. This bypasses visibility, but not
+// type safety.
+func unsafePointerTo(obj reflect.Value) reflect.Value {
+ return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr()))
+}
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
index c5118d3a9..92fcad4e9 100644
--- a/pkg/state/encode.go
+++ b/pkg/state/encode.go
@@ -15,437 +15,797 @@
package state
import (
- "container/list"
"context"
- "encoding/binary"
- "fmt"
- "io"
"reflect"
- "sort"
- "github.com/golang/protobuf/proto"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
-// queuedObject is an object queued for encoding.
-type queuedObject struct {
- id uint64
- obj reflect.Value
- path recoverable
+// objectEncodeState the type and identity of an object occupying a memory
+// address range. This is the value type for addrSet, and the intrusive entry
+// for the pending and deferred lists.
+type objectEncodeState struct {
+ // id is the assigned ID for this object.
+ id objectID
+
+ // obj is the object value. Note that this may be replaced if we
+ // encounter an object that contains this object. When this happens (in
+ // resolve), we will update existing references approprately, below,
+ // and defer a re-encoding of the object.
+ obj reflect.Value
+
+ // encoded is the encoded value of this object. Note that this may not
+ // be up to date if this object is still in the deferred list.
+ encoded wire.Object
+
+ // how indicates whether this object should be encoded as a value. This
+ // is used only for deferred encoding.
+ how encodeStrategy
+
+ // refs are the list of reference objects used by other objects
+ // referring to this object. When the object is updated, these
+ // references may be updated directly and automatically.
+ refs []*wire.Ref
+
+ pendingEntry
+ deferredEntry
}
// encodeState is state used for encoding.
//
-// The encoding process is a breadth-first traversal of the object graph. The
-// inherent races and dependencies are much simpler than the decode case.
+// The encoding process constructs a representation of the in-memory graph of
+// objects before a single object is serialized. This is done to ensure that
+// all references can be fully disambiguated. See resolve for more details.
type encodeState struct {
// ctx is the encode context.
ctx context.Context
- // lastID is the last object ID.
- //
- // See idsByObject for context. Because of the special zero encoding
- // used for reference values, the first ID must be 1.
- lastID uint64
+ // w is the output stream.
+ w wire.Writer
- // idsByObject is a set of objects, indexed via:
- //
- // reflect.ValueOf(x).UnsafeAddr
- //
- // This provides IDs for objects.
- idsByObject map[uintptr]uint64
+ // types is the type database.
+ types typeEncodeDatabase
+
+ // lastID is the last allocated object ID.
+ lastID objectID
- // values stores values that span the addresses.
+ // values tracks the address ranges occupied by objects, along with the
+ // types of these objects. This is used to locate pointer targets,
+ // including pointers to fields within another type.
//
- // addrSet is a a generated type which efficiently stores ranges of
- // addresses. When encoding pointers, these ranges are filled in and
- // used to check for overlapping or conflicting pointers. This would
- // indicate a pointer to an field, or a non-type safe value, neither of
- // which are currently decodable.
+ // Multiple objects may overlap in memory iff the larger object fully
+ // contains the smaller one, and the type of the smaller object matches
+ // a field or array element's type at the appropriate offset. An
+ // arbitrary number of objects may be nested in this manner.
//
- // See the usage of values below for more context.
+ // Note that this does not track zero-sized objects, those are tracked
+ // by zeroValues below.
values addrSet
- // w is the output stream.
- w io.Writer
+ // zeroValues tracks zero-sized objects.
+ zeroValues map[reflect.Type]*objectEncodeState
- // pending is the list of objects to be serialized.
- //
- // This is a set of queuedObjects.
- pending list.List
+ // deferred is the list of objects to be encoded.
+ deferred deferredList
- // done is the a list of finished objects.
- //
- // This is kept to prevent garbage collection and address reuse.
- done list.List
+ // pendingTypes is the list of types to be serialized. Serialization
+ // will occur when all objects have been encoded, but before pending is
+ // serialized.
+ pendingTypes []wire.Type
- // stats is the passed stats object.
- stats *Stats
+ // pending is the list of objects to be serialized. Serialization does
+ // not actually occur until the full object graph is computed.
+ pending pendingList
- // recoverable is the panic recover facility.
- recoverable
+ // stats tracks time data.
+ stats Stats
}
-// register looks up an ID, registering if necessary.
+// isSameSizeParent returns true if child is a field value or element within
+// parent. Only a struct or array can have a child value.
+//
+// isSameSizeParent deals with objects like this:
+//
+// struct child {
+// // fields..
+// }
//
-// If the object was not previously registered, it is enqueued to be serialized.
-// See the documentation for idsByObject for more information.
-func (es *encodeState) register(obj reflect.Value) uint64 {
- // It is not legal to call register for any non-pointer objects (see
- // below), so we panic with a recoverable error if this is a mismatch.
- if obj.Kind() != reflect.Ptr && obj.Kind() != reflect.Map {
- panic(fmt.Errorf("non-pointer %#v registered", obj.Interface()))
+// struct parent {
+// c child
+// }
+//
+// var p parent
+// record(&p.c)
+//
+// Here, &p and &p.c occupy the exact same address range.
+//
+// Or like this:
+//
+// struct child {
+// // fields
+// }
+//
+// var arr [1]parent
+// record(&arr[0])
+//
+// Similarly, &arr[0] and &arr[0].c have the exact same address range.
+//
+// Precondition: parent and child must occupy the same memory.
+func isSameSizeParent(parent reflect.Value, childType reflect.Type) bool {
+ switch parent.Kind() {
+ case reflect.Struct:
+ for i := 0; i < parent.NumField(); i++ {
+ field := parent.Field(i)
+ if field.Type() == childType {
+ return true
+ }
+ // Recurse through any intermediate types.
+ if isSameSizeParent(field, childType) {
+ return true
+ }
+ // Does it make sense to keep going if the first field
+ // doesn't match? Yes, because there might be an
+ // arbitrary number of zero-sized fields before we get
+ // a match, and childType itself can be zero-sized.
+ }
+ return false
+ case reflect.Array:
+ // The only case where an array with more than one elements can
+ // return true is if childType is zero-sized. In such cases,
+ // it's ambiguous which element contains the match since a
+ // zero-sized child object fully fits in any of the zero-sized
+ // elements in an array... However since all elements are of
+ // the same type, we only need to check one element.
+ //
+ // For non-zero-sized childTypes, parent.Len() must be 1, but a
+ // combination of the precondition and an implicit comparison
+ // between the array element size and childType ensures this.
+ return parent.Len() > 0 && isSameSizeParent(parent.Index(0), childType)
+ default:
+ return false
}
+}
- addr := obj.Pointer()
- if obj.Kind() == reflect.Ptr && obj.Elem().Type().Size() == 0 {
- // For zero-sized objects, we always provide a unique ID.
- // That's because the runtime internally multiplexes pointers
- // to the same address. We can't be certain what the intent is
- // with pointers to zero-sized objects, so we just give them
- // all unique identities.
- } else if id, ok := es.idsByObject[addr]; ok {
- // Already registered.
- return id
- }
-
- // Ensure that the first ID given out is one. See note on lastID. The
- // ID zero is used to indicate nil values.
+// nextID returns the next valid ID.
+func (es *encodeState) nextID() objectID {
es.lastID++
- id := es.lastID
- es.idsByObject[addr] = id
- if obj.Kind() == reflect.Ptr {
- // Dereference and treat as a pointer.
- es.pending.PushBack(queuedObject{id: id, obj: obj.Elem(), path: es.recoverable.copy()})
-
- // Register this object at all addresses.
- typ := obj.Elem().Type()
- if size := typ.Size(); size > 0 {
- r := addrRange{addr, addr + size}
- if !es.values.IsEmptyRange(r) {
- old := es.values.LowerBoundSegment(addr).Value().Interface().(recoverable)
- panic(fmt.Errorf("overlapping objects: [new object] %#v [existing object path] %s", obj.Interface(), old.path()))
+ return objectID(es.lastID)
+}
+
+// dummyAddr points to the dummy zero-sized address.
+var dummyAddr = reflect.ValueOf(new(struct{})).Pointer()
+
+// resolve records the address range occupied by an object.
+func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
+ addr := obj.Pointer()
+
+ // Is this a map pointer? Just record the single address. It is not
+ // possible to take any pointers into the map internals.
+ if obj.Kind() == reflect.Map {
+ if addr == 0 {
+ // Just leave the nil reference alone. This is fine, we
+ // may need to encode as a reference in this way. We
+ // return nil for our objectEncodeState so that anyone
+ // depending on this value knows there's nothing there.
+ return
+ }
+ if seg, _ := es.values.Find(addr); seg.Ok() {
+ // Ensure the map types match.
+ existing := seg.Value()
+ if existing.obj.Type() != obj.Type() {
+ Failf("overlapping map objects at 0x%x: [new object] %#v [existing object type] %s", addr, obj, existing.obj)
}
- es.values.Add(r, reflect.ValueOf(es.recoverable.copy()))
+
+ // No sense recording refs, maps may not be replaced by
+ // covering objects, they are maximal.
+ ref.Root = wire.Uint(existing.id)
+ return
}
+
+ // Record the map.
+ oes := &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ how: encodeMapAsValue,
+ }
+ es.values.Add(addrRange{addr, addr + 1}, oes)
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+
+ // See above: no ref recording.
+ ref.Root = wire.Uint(oes.id)
+ return
+ }
+
+ // If not a map, then the object must be a pointer.
+ if obj.Kind() != reflect.Ptr {
+ Failf("attempt to record non-map and non-pointer object %#v", obj)
+ }
+
+ obj = obj.Elem() // Value from here.
+
+ // Is this a zero-sized type?
+ typ := obj.Type()
+ size := typ.Size()
+ if size == 0 {
+ if addr == dummyAddr {
+ // Zero-sized objects point to a dummy byte within the
+ // runtime. There's no sense recording this in the
+ // address map. We add this to the dedicated
+ // zeroValues.
+ //
+ // Note that zero-sized objects must be *true*
+ // zero-sized objects. They cannot be part of some
+ // larger object. In that case, they are assigned a
+ // 1-byte address at the end of the object.
+ oes, ok := es.zeroValues[typ]
+ if !ok {
+ oes = &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ }
+ es.zeroValues[typ] = oes
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+ }
+
+ // There's also no sense tracking back references. We
+ // know that this is a true zero-sized object, and not
+ // part of a larger container, so it will not change.
+ ref.Root = wire.Uint(oes.id)
+ return
+ }
+ size = 1 // See above.
+ }
+
+ // Calculate the container.
+ end := addr + size
+ r := addrRange{addr, end}
+ if seg, _ := es.values.Find(addr); seg.Ok() {
+ existing := seg.Value()
+ switch {
+ case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type():
+ // The object is a perfect match. Happy path. Avoid the
+ // traversal and just return directly. We don't need to
+ // encode the type information or any dots here.
+ ref.Root = wire.Uint(existing.id)
+ existing.refs = append(existing.refs, ref)
+ return
+
+ case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end):
+ // The previously registered object is larger than
+ // this, no need to update. But we expect some
+ // traversal below.
+
+ case seg.Start() == addr && seg.End() == end:
+ if !isSameSizeParent(obj, existing.obj.Type()) {
+ break // Needs traversal.
+ }
+ fallthrough // Needs update.
+
+ case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end):
+ // Update the object and redo the encoding.
+ old := existing.obj
+ existing.obj = obj
+ es.deferred.Remove(existing)
+ es.deferred.PushBack(existing)
+
+ // The previously registered object is superseded by
+ // this new object. We are guaranteed to not have any
+ // mergeable neighbours in this segment set.
+ if !raceEnabled {
+ seg.SetRangeUnchecked(r)
+ } else {
+ // Add extra paranoid. This will be statically
+ // removed at compile time unless a race build.
+ es.values.Remove(seg)
+ es.values.Add(r, existing)
+ seg = es.values.LowerBoundSegment(addr)
+ }
+
+ // Compute the traversal required & update references.
+ dots := traverse(obj.Type(), old.Type(), addr, seg.Start())
+ wt := es.findType(obj.Type())
+ for _, ref := range existing.refs {
+ ref.Dots = append(ref.Dots, dots...)
+ ref.Type = wt
+ }
+ default:
+ // There is a non-sensical overlap.
+ Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj)
+ }
+
+ // Compute the new reference, record and return it.
+ ref.Root = wire.Uint(existing.id)
+ ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr)
+ ref.Type = es.findType(obj.Type())
+ existing.refs = append(existing.refs, ref)
+ return
+ }
+
+ // The only remaining case is a pointer value that doesn't overlap with
+ // any registered addresses. Create a new entry for it, and start
+ // tracking the first reference we just created.
+ oes := &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ }
+ if !raceEnabled {
+ es.values.AddWithoutMerging(r, oes)
} else {
- // Push back the map itself; when maps are encoded from the
- // top-level, forceMap will be equal to true.
- es.pending.PushBack(queuedObject{id: id, obj: obj, path: es.recoverable.copy()})
+ // Merges should never happen. This is just enabled extra
+ // sanity checks because the Merge function below will panic.
+ es.values.Add(r, oes)
+ }
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+ ref.Root = wire.Uint(oes.id)
+ oes.refs = append(oes.refs, ref)
+}
+
+// traverse searches for a target object within a root object, where the target
+// object is a struct field or array element within root, with potentially
+// multiple intervening types. traverse returns the set of field or element
+// traversals required to reach the target.
+//
+// Note that for efficiency, traverse returns the dots in the reverse order.
+// That is, the first traversal required will be the last element of the list.
+//
+// Precondition: The target object must lie completely within the range defined
+// by [rootAddr, rootAddr + sizeof(rootType)].
+func traverse(rootType, targetType reflect.Type, rootAddr, targetAddr uintptr) []wire.Dot {
+ // Recursion base case: the types actually match.
+ if targetType == rootType && targetAddr == rootAddr {
+ return nil
}
- return id
+ switch rootType.Kind() {
+ case reflect.Struct:
+ offset := targetAddr - rootAddr
+ for i := rootType.NumField(); i > 0; i-- {
+ field := rootType.Field(i - 1)
+ // The first field from the end with an offset that is
+ // smaller than or equal to our address offset is where
+ // the target is located. Traverse from there.
+ if field.Offset <= offset {
+ dots := traverse(field.Type, targetType, rootAddr+field.Offset, targetAddr)
+ fieldName := wire.FieldName(field.Name)
+ return append(dots, &fieldName)
+ }
+ }
+ // Should never happen; the target should be reachable.
+ Failf("no field in root type %v contains target type %v", rootType, targetType)
+
+ case reflect.Array:
+ // Since arrays have homogenous types, all elements have the
+ // same size and we can compute where the target lives. This
+ // does not matter for the purpose of typing, but matters for
+ // the purpose of computing the address of the given index.
+ elemSize := int(rootType.Elem().Size())
+ n := int(targetAddr-rootAddr) / elemSize // Relies on integer division rounding down.
+ if rootType.Len() < n {
+ Failf("traversal target of type %v @%x is beyond the end of the array type %v @%x with %v elements",
+ targetType, targetAddr, rootType, rootAddr, rootType.Len())
+ }
+ dots := traverse(rootType.Elem(), targetType, rootAddr+uintptr(n*elemSize), targetAddr)
+ return append(dots, wire.Index(n))
+
+ default:
+ // For any other type, there's no possibility of aliasing so if
+ // the types didn't match earlier then we have an addresss
+ // collision which shouldn't be possible at this point.
+ Failf("traverse failed for root type %v and target type %v", rootType, targetType)
+ }
+ panic("unreachable")
}
// encodeMap encodes a map.
-func (es *encodeState) encodeMap(obj reflect.Value) *pb.Map {
- var (
- keys []*pb.Object
- values []*pb.Object
- )
+func (es *encodeState) encodeMap(obj reflect.Value, dest *wire.Object) {
+ if obj.IsNil() {
+ // Because there is a difference between a nil map and an empty
+ // map, we need to not decode in the case of a truly nil map.
+ *dest = wire.Nil{}
+ return
+ }
+ l := obj.Len()
+ m := &wire.Map{
+ Keys: make([]wire.Object, l),
+ Values: make([]wire.Object, l),
+ }
+ *dest = m
for i, k := range obj.MapKeys() {
v := obj.MapIndex(k)
- kp := es.encodeObject(k, false, ".(key %d)", i)
- vp := es.encodeObject(v, false, "[%#v]", k.Interface())
- keys = append(keys, kp)
- values = append(values, vp)
+ // Map keys must be encoded using the full value because the
+ // type will be omitted after the first key.
+ es.encodeObject(k, encodeAsValue, &m.Keys[i])
+ es.encodeObject(v, encodeAsValue, &m.Values[i])
}
- return &pb.Map{Keys: keys, Values: values}
+}
+
+// objectEncoder is for encoding structs.
+type objectEncoder struct {
+ // es is encodeState.
+ es *encodeState
+
+ // encoded is the encoded struct.
+ encoded *wire.Struct
+}
+
+// save is called by the public methods on Sink.
+func (oe *objectEncoder) save(slot int, obj reflect.Value) {
+ fieldValue := oe.encoded.Field(slot)
+ oe.es.encodeObject(obj, encodeDefault, fieldValue)
}
// encodeStruct encodes a composite object.
-func (es *encodeState) encodeStruct(obj reflect.Value) *pb.Struct {
- // Invoke the save.
- m := Map{newInternalMap(es, nil, nil)}
- defer internalMapPool.Put(m.internalMap)
+func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
+ // Ensure that the obj is addressable. There are two cases when it is
+ // not. First, is when this is dispatched via SaveValue. Second, when
+ // this is a map key as a struct. Either way, we need to make a copy to
+ // obtain an addressable value.
if !obj.CanAddr() {
- // Force it to a * type of the above; this involves a copy.
localObj := reflect.New(obj.Type())
localObj.Elem().Set(obj)
obj = localObj.Elem()
}
- fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
- if ok {
- // Invoke the provided saver.
- fns.invokeSave(obj.Addr(), m)
- } else if obj.NumField() == 0 {
- // Allow unregistered anonymous, empty structs.
- return &pb.Struct{}
- } else {
- // Propagate an error.
- panic(fmt.Errorf("unregistered type %T", obj.Interface()))
- }
-
- // Sort the underlying slice, and check for duplicates. This is done
- // once instead of on each add, because performing this sort once is
- // far more efficient.
- if len(m.data) > 1 {
- sort.Slice(m.data, func(i, j int) bool {
- return m.data[i].name < m.data[j].name
- })
- for i := range m.data {
- if i > 0 && m.data[i-1].name == m.data[i].name {
- panic(fmt.Errorf("duplicate name %s", m.data[i].name))
- }
+
+ // Prepare the value.
+ s := &wire.Struct{}
+ *dest = s
+
+ // Look the type up in the database.
+ te, ok := es.types.Lookup(obj.Type())
+ if te == nil {
+ if obj.NumField() == 0 {
+ // Allow unregistered anonymous, empty structs. This
+ // will just return success without ever invoking the
+ // passed function. This uses the immutable EmptyStruct
+ // variable to prevent an allocation in this case.
+ //
+ // Note that this mechanism does *not* work for
+ // interfaces in general. So you can't dispatch
+ // non-registered empty structs via interfaces because
+ // then they can't be restored.
+ s.Alloc(0)
+ return
}
+ // We need a SaverLoader for struct types.
+ Failf("struct %T does not implement SaverLoader", obj.Interface())
}
-
- // Encode the resulting fields.
- fields := make([]*pb.Field, 0, len(m.data))
- for _, e := range m.data {
- fields = append(fields, &pb.Field{
- Name: e.name,
- Value: e.object,
- })
+ if !ok {
+ // Queue the type to be serialized.
+ es.pendingTypes = append(es.pendingTypes, te.Type)
}
- // Return the encoded object.
- return &pb.Struct{Fields: fields}
+ // Invoke the provided saver.
+ s.TypeID = wire.TypeID(te.ID)
+ s.Alloc(len(te.Fields))
+ oe := objectEncoder{
+ es: es,
+ encoded: s,
+ }
+ es.stats.start(te.ID)
+ defer es.stats.done()
+ if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
+ // Note: may be a registered empty struct which does not
+ // implement the saver/loader interfaces.
+ sl.StateSave(Sink{internal: oe})
+ }
}
// encodeArray encodes an array.
-func (es *encodeState) encodeArray(obj reflect.Value) *pb.Array {
- var (
- contents []*pb.Object
- )
- for i := 0; i < obj.Len(); i++ {
- entry := es.encodeObject(obj.Index(i), false, "[%d]", i)
- contents = append(contents, entry)
- }
- return &pb.Array{Contents: contents}
+func (es *encodeState) encodeArray(obj reflect.Value, dest *wire.Object) {
+ l := obj.Len()
+ a := &wire.Array{
+ Contents: make([]wire.Object, l),
+ }
+ *dest = a
+ for i := 0; i < l; i++ {
+ // We need to encode the full value because arrays are encoded
+ // using the type information from only the first element.
+ es.encodeObject(obj.Index(i), encodeAsValue, &a.Contents[i])
+ }
+}
+
+// findType recursively finds type information.
+func (es *encodeState) findType(typ reflect.Type) wire.TypeSpec {
+ // First: check if this is a proper type. It's possible for pointers,
+ // slices, arrays, maps, etc to all have some different type.
+ te, ok := es.types.Lookup(typ)
+ if te != nil {
+ if !ok {
+ // See encodeStruct.
+ es.pendingTypes = append(es.pendingTypes, te.Type)
+ }
+ return wire.TypeID(te.ID)
+ }
+
+ switch typ.Kind() {
+ case reflect.Ptr:
+ return &wire.TypeSpecPointer{
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Slice:
+ return &wire.TypeSpecSlice{
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Array:
+ return &wire.TypeSpecArray{
+ Count: wire.Uint(typ.Len()),
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Map:
+ return &wire.TypeSpecMap{
+ Key: es.findType(typ.Key()),
+ Value: es.findType(typ.Elem()),
+ }
+ default:
+ // After potentially chasing many pointers, the
+ // ultimate type of the object is not known.
+ Failf("type %q is not known", typ)
+ }
+ panic("unreachable")
}
// encodeInterface encodes an interface.
-//
-// Precondition: the value is not nil.
-func (es *encodeState) encodeInterface(obj reflect.Value) *pb.Interface {
- // Check for the nil interface.
- obj = reflect.ValueOf(obj.Interface())
+func (es *encodeState) encodeInterface(obj reflect.Value, dest *wire.Object) {
+ // Dereference the object.
+ obj = obj.Elem()
if !obj.IsValid() {
- return &pb.Interface{
- Type: "", // left alone in decode.
- Value: &pb.Object{Value: &pb.Object_RefValue{0}},
+ // Special case: the nil object.
+ *dest = &wire.Interface{
+ Type: wire.TypeSpecNil{},
+ Value: wire.Nil{},
}
+ return
}
- // We have an interface value here. How do we save that? We
- // resolve the underlying type and save it as a dispatchable.
- typName, ok := registeredTypes.lookupName(obj.Type())
- if !ok {
- panic(fmt.Errorf("type %s is not registered", obj.Type()))
+
+ // Encode underlying object.
+ i := &wire.Interface{
+ Type: es.findType(obj.Type()),
}
+ *dest = i
+ es.encodeObject(obj, encodeAsValue, &i.Value)
+}
- // Encode the object again.
- return &pb.Interface{
- Type: typName,
- Value: es.encodeObject(obj, false, ".(%s)", typName),
+// isPrimitive returns true if this is a primitive object, or a composite
+// object composed entirely of primitives.
+func isPrimitiveZero(typ reflect.Type) bool {
+ switch typ.Kind() {
+ case reflect.Ptr:
+ // Pointers are always treated as primitive types because we
+ // won't encode directly from here. Returning true here won't
+ // prevent the object from being encoded correctly.
+ return true
+ case reflect.Bool:
+ return true
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return true
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ return true
+ case reflect.Float32, reflect.Float64:
+ return true
+ case reflect.Complex64, reflect.Complex128:
+ return true
+ case reflect.String:
+ return true
+ case reflect.Slice:
+ // The slice itself a primitive, but not necessarily the array
+ // that points to. This is similar to a pointer.
+ return true
+ case reflect.Array:
+ // We cannot treat an array as a primitive, because it may be
+ // composed of structures or other things with side-effects.
+ return isPrimitiveZero(typ.Elem())
+ case reflect.Interface:
+ // Since we now that this type is the zero type, the interface
+ // value must be zero. Therefore this is primitive.
+ return true
+ case reflect.Struct:
+ return false
+ case reflect.Map:
+ // The isPrimitiveZero function is called only on zero-types to
+ // see if it's safe to serialize. Since a zero map has no
+ // elements, it is safe to treat as a primitive.
+ return true
+ default:
+ Failf("unknown type %q", typ.Name())
}
+ panic("unreachable")
}
-// encodeObject encodes an object.
-//
-// If mapAsValue is true, then a map will be encoded directly.
-func (es *encodeState) encodeObject(obj reflect.Value, mapAsValue bool, format string, param interface{}) (object *pb.Object) {
- es.push(false, format, param)
- es.stats.Add(obj)
- es.stats.Start(obj)
+// encodeStrategy is the strategy used for encodeObject.
+type encodeStrategy int
+const (
+ // encodeDefault means types are encoded normally as references.
+ encodeDefault encodeStrategy = iota
+
+ // encodeAsValue means that types will never take short-circuited and
+ // will always be encoded as a normal value.
+ encodeAsValue
+
+ // encodeMapAsValue means that even maps will be fully encoded.
+ encodeMapAsValue
+)
+
+// encodeObject encodes an object.
+func (es *encodeState) encodeObject(obj reflect.Value, how encodeStrategy, dest *wire.Object) {
+ if how == encodeDefault && isPrimitiveZero(obj.Type()) && obj.IsZero() {
+ *dest = wire.Nil{}
+ return
+ }
switch obj.Kind() {
+ case reflect.Ptr: // Fast path: first.
+ r := new(wire.Ref)
+ *dest = r
+ if obj.IsNil() {
+ // May be in an array or elsewhere such that a value is
+ // required. So we encode as a reference to the zero
+ // object, which does not exist. Note that this has to
+ // be handled correctly in the decode path as well.
+ return
+ }
+ es.resolve(obj, r)
case reflect.Bool:
- object = &pb.Object{Value: &pb.Object_BoolValue{obj.Bool()}}
+ *dest = wire.Bool(obj.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- object = &pb.Object{Value: &pb.Object_Int64Value{obj.Int()}}
+ *dest = wire.Int(obj.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- object = &pb.Object{Value: &pb.Object_Uint64Value{obj.Uint()}}
- case reflect.Float32, reflect.Float64:
- object = &pb.Object{Value: &pb.Object_DoubleValue{obj.Float()}}
+ *dest = wire.Uint(obj.Uint())
+ case reflect.Float32:
+ *dest = wire.Float32(obj.Float())
+ case reflect.Float64:
+ *dest = wire.Float64(obj.Float())
+ case reflect.Complex64:
+ c := wire.Complex64(obj.Complex())
+ *dest = &c // Needs alloc.
+ case reflect.Complex128:
+ c := wire.Complex128(obj.Complex())
+ *dest = &c // Needs alloc.
+ case reflect.String:
+ s := wire.String(obj.String())
+ *dest = &s // Needs alloc.
case reflect.Array:
- switch obj.Type().Elem().Kind() {
- case reflect.Uint8:
- object = &pb.Object{Value: &pb.Object_ByteArrayValue{pbSlice(obj).Interface().([]byte)}}
- case reflect.Uint16:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := pbSlice(obj).Interface().([]uint16)
- t := make([]uint32, len(s))
- for i := range s {
- t[i] = uint32(s[i])
- }
- object = &pb.Object{Value: &pb.Object_Uint16ArrayValue{&pb.Uint16S{Values: t}}}
- case reflect.Uint32:
- object = &pb.Object{Value: &pb.Object_Uint32ArrayValue{&pb.Uint32S{Values: pbSlice(obj).Interface().([]uint32)}}}
- case reflect.Uint64:
- object = &pb.Object{Value: &pb.Object_Uint64ArrayValue{&pb.Uint64S{Values: pbSlice(obj).Interface().([]uint64)}}}
- case reflect.Uintptr:
- object = &pb.Object{Value: &pb.Object_UintptrArrayValue{&pb.Uintptrs{Values: pbSlice(obj).Interface().([]uint64)}}}
- case reflect.Int8:
- object = &pb.Object{Value: &pb.Object_Int8ArrayValue{&pb.Int8S{Values: pbSlice(obj).Interface().([]byte)}}}
- case reflect.Int16:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := pbSlice(obj).Interface().([]int16)
- t := make([]int32, len(s))
- for i := range s {
- t[i] = int32(s[i])
- }
- object = &pb.Object{Value: &pb.Object_Int16ArrayValue{&pb.Int16S{Values: t}}}
- case reflect.Int32:
- object = &pb.Object{Value: &pb.Object_Int32ArrayValue{&pb.Int32S{Values: pbSlice(obj).Interface().([]int32)}}}
- case reflect.Int64:
- object = &pb.Object{Value: &pb.Object_Int64ArrayValue{&pb.Int64S{Values: pbSlice(obj).Interface().([]int64)}}}
- case reflect.Bool:
- object = &pb.Object{Value: &pb.Object_BoolArrayValue{&pb.Bools{Values: pbSlice(obj).Interface().([]bool)}}}
- case reflect.Float32:
- object = &pb.Object{Value: &pb.Object_Float32ArrayValue{&pb.Float32S{Values: pbSlice(obj).Interface().([]float32)}}}
- case reflect.Float64:
- object = &pb.Object{Value: &pb.Object_Float64ArrayValue{&pb.Float64S{Values: pbSlice(obj).Interface().([]float64)}}}
- default:
- object = &pb.Object{Value: &pb.Object_ArrayValue{es.encodeArray(obj)}}
- }
+ es.encodeArray(obj, dest)
case reflect.Slice:
- if obj.IsNil() || obj.Cap() == 0 {
- // Handled specially in decode; store as nil value.
- object = &pb.Object{Value: &pb.Object_RefValue{0}}
- } else {
- // Serialize a slice as the array plus length and capacity.
- object = &pb.Object{Value: &pb.Object_SliceValue{&pb.Slice{
- Capacity: uint32(obj.Cap()),
- Length: uint32(obj.Len()),
- RefValue: es.register(arrayFromSlice(obj)),
- }}}
+ s := &wire.Slice{
+ Capacity: wire.Uint(obj.Cap()),
+ Length: wire.Uint(obj.Len()),
}
- case reflect.String:
- object = &pb.Object{Value: &pb.Object_StringValue{[]byte(obj.String())}}
- case reflect.Ptr:
+ *dest = s
+ // Note that we do need to provide a wire.Slice type here as
+ // how is not encodeDefault. If this were the case, then it
+ // would have been caught by the IsZero check above and we
+ // would have just used wire.Nil{}.
if obj.IsNil() {
- // Handled specially in decode; store as a nil value.
- object = &pb.Object{Value: &pb.Object_RefValue{0}}
- } else {
- es.push(true /* dereference */, "", nil)
- object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}}
- es.pop()
+ return
}
+ // Slices need pointer resolution.
+ es.resolve(arrayFromSlice(obj), &s.Ref)
case reflect.Interface:
- // We don't check for IsNil here, as we want to encode type
- // information. The case of the empty interface (no type, no
- // value) is handled by encodeInteface.
- object = &pb.Object{Value: &pb.Object_InterfaceValue{es.encodeInterface(obj)}}
+ es.encodeInterface(obj, dest)
case reflect.Struct:
- object = &pb.Object{Value: &pb.Object_StructValue{es.encodeStruct(obj)}}
+ es.encodeStruct(obj, dest)
case reflect.Map:
- if obj.IsNil() {
- // Handled specially in decode; store as a nil value.
- object = &pb.Object{Value: &pb.Object_RefValue{0}}
- } else if mapAsValue {
- // Encode the map directly.
- object = &pb.Object{Value: &pb.Object_MapValue{es.encodeMap(obj)}}
- } else {
- // Encode a reference to the map.
- //
- // Remove the map object count here to avoid double
- // counting, as this object will be counted again when
- // it gets processed later. We do not add a reference
- // count as the reference is artificial.
- es.stats.Remove(obj)
- object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}}
+ if how == encodeMapAsValue {
+ es.encodeMap(obj, dest)
+ return
}
+ r := new(wire.Ref)
+ *dest = r
+ es.resolve(obj, r)
default:
- panic(fmt.Errorf("unknown primitive %#v", obj.Interface()))
+ Failf("unknown object %#v", obj.Interface())
+ panic("unreachable")
}
-
- es.stats.Done()
- es.pop()
- return
}
-// Serialize serializes the object state.
-//
-// This function may panic and should be run in safely().
-func (es *encodeState) Serialize(obj reflect.Value) {
- es.register(obj.Addr())
-
- // Pop off the list until we're done.
- for es.pending.Len() > 0 {
- e := es.pending.Front()
-
- // Extract the queued object.
- qo := e.Value.(queuedObject)
- es.stats.Start(qo.obj)
+// Save serializes the object graph rooted at obj.
+func (es *encodeState) Save(obj reflect.Value) {
+ es.stats.init()
+ defer es.stats.fini(func(id typeID) string {
+ return es.pendingTypes[id-1].Name
+ })
+
+ // Resolve the first object, which should queue a pile of additional
+ // objects on the pending list. All queued objects should be fully
+ // resolved, and we should be able to serialize after this call.
+ var root wire.Ref
+ es.resolve(obj.Addr(), &root)
+
+ // Encode the graph.
+ var oes *objectEncodeState
+ if err := safely(func() {
+ for oes = es.deferred.Front(); oes != nil; oes = es.deferred.Front() {
+ // Remove and encode the object. Note that as a result
+ // of this encoding, the object may be enqueued on the
+ // deferred list yet again. That's expected, and why it
+ // is removed first.
+ es.deferred.Remove(oes)
+ es.encodeObject(oes.obj, oes.how, &oes.encoded)
+ }
+ }); err != nil {
+ // Include the object in the error message.
+ Failf("encoding error at object %#v: %w", oes.obj.Interface(), err)
+ }
- es.pending.Remove(e)
+ // Check that items are pending.
+ if es.pending.Front() == nil {
+ Failf("pending is empty?")
+ }
- es.from = &qo.path
- o := es.encodeObject(qo.obj, true, "", nil)
+ // Write the header with the number of objects. Note that there is no
+ // way that es.lastID could conflict with objectID, which would
+ // indicate that an impossibly large encoding.
+ if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil {
+ Failf("error writing header: %w", err)
+ }
- // Emit to our output stream.
- if err := es.writeObject(qo.id, o); err != nil {
- panic(err)
+ // Serialize all pending types and pending objects. Note that we don't
+ // bother removing from this list as we walk it because that just
+ // wastes time. It will not change after this point.
+ var id objectID
+ if err := safely(func() {
+ for _, wt := range es.pendingTypes {
+ // Encode the type.
+ wire.Save(es.w, &wt)
}
+ for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() {
+ id++ // First object is 1.
+ if oes.id != id {
+ Failf("expected id %d, got %d", id, oes.id)
+ }
- // Mark as done.
- es.done.PushBack(e)
- es.stats.Done()
+ // Marshall the object.
+ wire.Save(es.w, oes.encoded)
+ }
+ }); err != nil {
+ // Include the object and the error.
+ Failf("error serializing object %#v: %w", oes.encoded, err)
}
- // Write a zero-length terminal at the end; this is a sanity check
- // applied at decode time as well (see decode.go).
- if err := WriteHeader(es.w, 0, false); err != nil {
- panic(err)
+ // Check what we wrote.
+ if id != es.lastID {
+ Failf("expected %d objects, wrote %d", es.lastID, id)
}
}
+// objectFlag indicates that the length is a # of objects, rather than a raw
+// byte length. When this is set on a length header in the stream, it may be
+// decoded appropriately.
+const objectFlag uint64 = 1 << 63
+
// WriteHeader writes a header.
//
// Each object written to the statefile should be prefixed with a header. In
// order to generate statefiles that play nicely with debugging tools, raw
// writes should be prefixed with a header with object set to false and the
// appropriate length. This will allow tools to skip these regions.
-func WriteHeader(w io.Writer, length uint64, object bool) error {
- // The lowest-order bit encodes whether this is a valid object. This is
- // a purely internal convention, but allows the object flag to be
- // returned from ReadHeader.
- length = length << 1
+func WriteHeader(w wire.Writer, length uint64, object bool) error {
+ // Sanity check the length.
+ if length&objectFlag != 0 {
+ Failf("impossibly huge length: %d", length)
+ }
if object {
- length |= 0x1
+ length |= objectFlag
}
// Write a header.
- var hdr [32]byte
- encodedLen := binary.PutUvarint(hdr[:], length)
- for done := 0; done < encodedLen; {
- n, err := w.Write(hdr[done:encodedLen])
- done += n
- if n == 0 && err != nil {
- return err
- }
- }
-
- return nil
+ return safely(func() {
+ wire.SaveUint(w, length)
+ })
}
-// writeObject writes an object to the stream.
-func (es *encodeState) writeObject(id uint64, obj *pb.Object) error {
- // Marshal the proto.
- buf, err := proto.Marshal(obj)
- if err != nil {
- return err
- }
+// pendingMapper is for the pending list.
+type pendingMapper struct{}
- // Write the object header.
- if err := WriteHeader(es.w, uint64(len(buf)), true); err != nil {
- return err
- }
+func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry }
- // Write the object.
- for done := 0; done < len(buf); {
- n, err := es.w.Write(buf[done:])
- done += n
- if n == 0 && err != nil {
- return err
- }
- }
+// deferredMapper is for the deferred list.
+type deferredMapper struct{}
- return nil
-}
+func (deferredMapper) linkerFor(oes *objectEncodeState) *deferredEntry { return &oes.deferredEntry }
// addrSetFunctions is used by addrSet.
type addrSetFunctions struct{}
@@ -458,13 +818,24 @@ func (addrSetFunctions) MaxKey() uintptr {
return ^uintptr(0)
}
-func (addrSetFunctions) ClearValue(val *reflect.Value) {
+func (addrSetFunctions) ClearValue(val **objectEncodeState) {
+ *val = nil
}
-func (addrSetFunctions) Merge(_ addrRange, val1 reflect.Value, _ addrRange, val2 reflect.Value) (reflect.Value, bool) {
- return val1, val1 == val2
+func (addrSetFunctions) Merge(r1 addrRange, val1 *objectEncodeState, r2 addrRange, val2 *objectEncodeState) (*objectEncodeState, bool) {
+ if val1.obj == val2.obj {
+ // This, should never happen. It would indicate that the same
+ // object exists in two non-contiguous address ranges. Note
+ // that this assertion can only be triggered if the race
+ // detector is enabled.
+ Failf("unexpected merge in addrSet @ %v and %v: %#v and %#v", r1, r2, val1.obj, val2.obj)
+ }
+ // Reject the merge.
+ return val1, false
}
-func (addrSetFunctions) Split(_ addrRange, val reflect.Value, _ uintptr) (reflect.Value, reflect.Value) {
- return val, val
+func (addrSetFunctions) Split(r addrRange, val *objectEncodeState, _ uintptr) (*objectEncodeState, *objectEncodeState) {
+ // A split should never happen: we don't remove ranges.
+ Failf("unexpected split in addrSet @ %v: %#v", r, val.obj)
+ panic("unreachable")
}
diff --git a/pkg/state/encode_unsafe.go b/pkg/state/encode_unsafe.go
index 457e6dbb7..e0dad83b4 100644
--- a/pkg/state/encode_unsafe.go
+++ b/pkg/state/encode_unsafe.go
@@ -31,51 +31,3 @@ func arrayFromSlice(obj reflect.Value) reflect.Value {
reflect.ArrayOf(obj.Cap(), obj.Type().Elem()),
unsafe.Pointer(obj.Pointer()))
}
-
-// pbSlice returns a protobuf-supported slice of the array and erase the
-// original element type (which could be a defined type or non-supported type).
-func pbSlice(obj reflect.Value) reflect.Value {
- var typ reflect.Type
- switch obj.Type().Elem().Kind() {
- case reflect.Uint8:
- typ = reflect.TypeOf(byte(0))
- case reflect.Uint16:
- typ = reflect.TypeOf(uint16(0))
- case reflect.Uint32:
- typ = reflect.TypeOf(uint32(0))
- case reflect.Uint64:
- typ = reflect.TypeOf(uint64(0))
- case reflect.Uintptr:
- typ = reflect.TypeOf(uint64(0))
- case reflect.Int8:
- typ = reflect.TypeOf(byte(0))
- case reflect.Int16:
- typ = reflect.TypeOf(int16(0))
- case reflect.Int32:
- typ = reflect.TypeOf(int32(0))
- case reflect.Int64:
- typ = reflect.TypeOf(int64(0))
- case reflect.Bool:
- typ = reflect.TypeOf(bool(false))
- case reflect.Float32:
- typ = reflect.TypeOf(float32(0))
- case reflect.Float64:
- typ = reflect.TypeOf(float64(0))
- default:
- panic("slice element is not of basic value type")
- }
- return reflect.NewAt(
- reflect.ArrayOf(obj.Len(), typ),
- unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()),
- ).Elem().Slice(0, obj.Len())
-}
-
-func castSlice(obj reflect.Value, elemTyp reflect.Type) reflect.Value {
- if obj.Type().Elem().Size() != elemTyp.Size() {
- panic("cannot cast slice into other element type of different size")
- }
- return reflect.NewAt(
- reflect.ArrayOf(obj.Len(), elemTyp),
- unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()),
- ).Elem()
-}
diff --git a/pkg/state/map.go b/pkg/state/map.go
deleted file mode 100644
index 4f3ebb0da..000000000
--- a/pkg/state/map.go
+++ /dev/null
@@ -1,232 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package state
-
-import (
- "context"
- "fmt"
- "reflect"
- "sort"
- "sync"
-
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
-)
-
-// entry is a single map entry.
-type entry struct {
- name string
- object *pb.Object
-}
-
-// internalMap is the internal Map state.
-//
-// These are recycled via a pool to avoid churn.
-type internalMap struct {
- // es is encodeState.
- es *encodeState
-
- // ds is decodeState.
- ds *decodeState
-
- // os is current object being decoded.
- //
- // This will always be nil during encode.
- os *objectState
-
- // data stores the encoded values.
- data []entry
-}
-
-var internalMapPool = sync.Pool{
- New: func() interface{} {
- return new(internalMap)
- },
-}
-
-// newInternalMap returns a cached map.
-func newInternalMap(es *encodeState, ds *decodeState, os *objectState) *internalMap {
- m := internalMapPool.Get().(*internalMap)
- m.es = es
- m.ds = ds
- m.os = os
- if m.data != nil {
- m.data = m.data[:0]
- }
- return m
-}
-
-// Map is a generic state container.
-//
-// This is the object passed to Save and Load in order to store their state.
-//
-// Detailed documentation is available in individual methods.
-type Map struct {
- *internalMap
-}
-
-// Save adds the given object to the map.
-//
-// You should pass always pointers to the object you are saving. For example:
-//
-// type X struct {
-// A int
-// B *int
-// }
-//
-// func (x *X) Save(m Map) {
-// m.Save("A", &x.A)
-// m.Save("B", &x.B)
-// }
-//
-// func (x *X) Load(m Map) {
-// m.Load("A", &x.A)
-// m.Load("B", &x.B)
-// }
-func (m Map) Save(name string, objPtr interface{}) {
- m.save(name, reflect.ValueOf(objPtr).Elem(), ".%s")
-}
-
-// SaveValue adds the given object value to the map.
-//
-// This should be used for values where pointers are not available, or casts
-// are required during Save/Load.
-//
-// For example, if we want to cast external package type P.Foo to int64:
-//
-// type X struct {
-// A P.Foo
-// }
-//
-// func (x *X) Save(m Map) {
-// m.SaveValue("A", int64(x.A))
-// }
-//
-// func (x *X) Load(m Map) {
-// m.LoadValue("A", new(int64), func(x interface{}) {
-// x.A = P.Foo(x.(int64))
-// })
-// }
-func (m Map) SaveValue(name string, obj interface{}) {
- m.save(name, reflect.ValueOf(obj), ".(value %s)")
-}
-
-// save is helper for the above. It takes the name of value to save the field
-// to, the field object (obj), and a format string that specifies how the
-// field's saving logic is dispatched from the struct (normal, value, etc.). The
-// format string should expect one string parameter, which is the name of the
-// field.
-func (m Map) save(name string, obj reflect.Value, format string) {
- if m.es == nil {
- // Not currently encoding.
- m.Failf("no encode state for %q", name)
- }
-
- // Attempt the encode.
- //
- // These are sorted at the end, after all objects are added and will be
- // sorted and checked for duplicates (see encodeStruct).
- m.data = append(m.data, entry{
- name: name,
- object: m.es.encodeObject(obj, false, format, name),
- })
-}
-
-// Load loads the given object from the map.
-//
-// See Save for an example.
-func (m Map) Load(name string, objPtr interface{}) {
- m.load(name, reflect.ValueOf(objPtr), false, nil, ".%s")
-}
-
-// LoadWait loads the given objects from the map, and marks it as requiring all
-// AfterLoad executions to complete prior to running this object's AfterLoad.
-//
-// See Save for an example.
-func (m Map) LoadWait(name string, objPtr interface{}) {
- m.load(name, reflect.ValueOf(objPtr), true, nil, ".(wait %s)")
-}
-
-// LoadValue loads the given object value from the map.
-//
-// See SaveValue for an example.
-func (m Map) LoadValue(name string, objPtr interface{}, fn func(interface{})) {
- o := reflect.ValueOf(objPtr)
- m.load(name, o, true, func() { fn(o.Elem().Interface()) }, ".(value %s)")
-}
-
-// load is helper for the above. It takes the name of value to load the field
-// from, the target field pointer (objPtr), whether load completion of the
-// struct depends on the field's load completion (wait), the load completion
-// logic (fn), and a format string that specifies how the field's loading logic
-// is dispatched from the struct (normal, wait, value, etc.). The format string
-// should expect one string parameter, which is the name of the field.
-func (m Map) load(name string, objPtr reflect.Value, wait bool, fn func(), format string) {
- if m.ds == nil {
- // Not currently decoding.
- m.Failf("no decode state for %q", name)
- }
-
- // Find the object.
- //
- // These are sorted up front (and should appear in the state file
- // sorted as well), so we can do a binary search here to ensure that
- // large structs don't behave badly.
- i := sort.Search(len(m.data), func(i int) bool {
- return m.data[i].name >= name
- })
- if i >= len(m.data) || m.data[i].name != name {
- // There is no data for this name?
- m.Failf("no data found for %q", name)
- }
-
- // Perform the decode.
- m.ds.decodeObject(m.os, objPtr.Elem(), m.data[i].object, format, name)
- if wait {
- // Mark this individual object a blocker.
- m.ds.waitObject(m.os, m.data[i].object, fn)
- }
-}
-
-// Failf fails the save or restore with the provided message. Processing will
-// stop after calling Failf, as the state package uses a panic & recover
-// mechanism for state errors. You should defer any cleanup required.
-func (m Map) Failf(format string, args ...interface{}) {
- panic(fmt.Errorf(format, args...))
-}
-
-// AfterLoad schedules a function execution when all objects have been allocated
-// and their automated loading and customized load logic have been executed. fn
-// will not be executed until all of current object's dependencies' AfterLoad()
-// logic, if exist, have been executed.
-func (m Map) AfterLoad(fn func()) {
- if m.ds == nil {
- // Not currently decoding.
- m.Failf("not decoding")
- }
-
- // Queue the local callback; this will execute when all of the above
- // data dependencies have been cleared.
- m.os.callbacks = append(m.os.callbacks, fn)
-}
-
-// Context returns the current context object.
-func (m Map) Context() context.Context {
- if m.es != nil {
- return m.es.ctx
- } else if m.ds != nil {
- return m.ds.ctx
- }
- return context.Background() // No context.
-}
diff --git a/pkg/state/object.proto b/pkg/state/object.proto
deleted file mode 100644
index 5ebcfb151..000000000
--- a/pkg/state/object.proto
+++ /dev/null
@@ -1,140 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto3";
-
-package gvisor.state.statefile;
-
-// Slice is a slice value.
-message Slice {
- uint32 length = 1;
- uint32 capacity = 2;
- uint64 ref_value = 3;
-}
-
-// Array is an array value.
-message Array {
- repeated Object contents = 1;
-}
-
-// Map is a map value.
-message Map {
- repeated Object keys = 1;
- repeated Object values = 2;
-}
-
-// Interface is an interface value.
-message Interface {
- string type = 1;
- Object value = 2;
-}
-
-// Struct is a basic composite value.
-message Struct {
- repeated Field fields = 1;
-}
-
-// Field encodes a single field.
-message Field {
- string name = 1;
- Object value = 2;
-}
-
-// Uint16s encodes an uint16 array. To be used inside oneof structure.
-message Uint16s {
- // There is no 16-bit type in protobuf so we use variable length 32-bit here.
- repeated uint32 values = 1;
-}
-
-// Uint32s encodes an uint32 array. To be used inside oneof structure.
-message Uint32s {
- repeated fixed32 values = 1;
-}
-
-// Uint64s encodes an uint64 array. To be used inside oneof structure.
-message Uint64s {
- repeated fixed64 values = 1;
-}
-
-// Uintptrs encodes an uintptr array. To be used inside oneof structure.
-message Uintptrs {
- repeated fixed64 values = 1;
-}
-
-// Int8s encodes an int8 array. To be used inside oneof structure.
-message Int8s {
- bytes values = 1;
-}
-
-// Int16s encodes an int16 array. To be used inside oneof structure.
-message Int16s {
- // There is no 16-bit type in protobuf so we use variable length 32-bit here.
- repeated int32 values = 1;
-}
-
-// Int32s encodes an int32 array. To be used inside oneof structure.
-message Int32s {
- repeated sfixed32 values = 1;
-}
-
-// Int64s encodes an int64 array. To be used inside oneof structure.
-message Int64s {
- repeated sfixed64 values = 1;
-}
-
-// Bools encodes a boolean array. To be used inside oneof structure.
-message Bools {
- repeated bool values = 1;
-}
-
-// Float64s encodes a float64 array. To be used inside oneof structure.
-message Float64s {
- repeated double values = 1;
-}
-
-// Float32s encodes a float32 array. To be used inside oneof structure.
-message Float32s {
- repeated float values = 1;
-}
-
-// Object are primitive encodings.
-//
-// Note that ref_value references an Object.id, below.
-message Object {
- oneof value {
- bool bool_value = 1;
- bytes string_value = 2;
- int64 int64_value = 3;
- uint64 uint64_value = 4;
- double double_value = 5;
- uint64 ref_value = 6;
- Slice slice_value = 7;
- Array array_value = 8;
- Interface interface_value = 9;
- Struct struct_value = 10;
- Map map_value = 11;
- bytes byte_array_value = 12;
- Uint16s uint16_array_value = 13;
- Uint32s uint32_array_value = 14;
- Uint64s uint64_array_value = 15;
- Uintptrs uintptr_array_value = 16;
- Int8s int8_array_value = 17;
- Int16s int16_array_value = 18;
- Int32s int32_array_value = 19;
- Int64s int64_array_value = 20;
- Bools bool_array_value = 21;
- Float64s float64_array_value = 22;
- Float32s float32_array_value = 23;
- }
-}
diff --git a/pkg/state/pretty/BUILD b/pkg/state/pretty/BUILD
new file mode 100644
index 000000000..d053802f7
--- /dev/null
+++ b/pkg/state/pretty/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pretty",
+ srcs = ["pretty.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/state",
+ "//pkg/state/wire",
+ ],
+)
diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go
new file mode 100644
index 000000000..cf37aaa49
--- /dev/null
+++ b/pkg/state/pretty/pretty.go
@@ -0,0 +1,273 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pretty is a pretty-printer for state streams.
+package pretty
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "reflect"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+func formatRef(x *wire.Ref, graph uint64, html bool) string {
+ baseRef := fmt.Sprintf("g%dr%d", graph, x.Root)
+ fullRef := baseRef
+ if len(x.Dots) > 0 {
+ // See wire.Ref; Type valid if Dots non-zero.
+ typ, _ := formatType(x.Type, graph, html)
+ var buf strings.Builder
+ buf.WriteString("(*")
+ buf.WriteString(typ)
+ buf.WriteString(")(")
+ buf.WriteString(baseRef)
+ for _, component := range x.Dots {
+ switch v := component.(type) {
+ case *wire.FieldName:
+ buf.WriteString(".")
+ buf.WriteString(string(*v))
+ case wire.Index:
+ buf.WriteString(fmt.Sprintf("[%d]", v))
+ default:
+ panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component)))
+ }
+ }
+ buf.WriteString(")")
+ fullRef = buf.String()
+ }
+ if html {
+ return fmt.Sprintf("<a href=\"#%s\">%s</a>", baseRef, fullRef)
+ }
+ return fullRef
+}
+
+func formatType(t wire.TypeSpec, graph uint64, html bool) (string, bool) {
+ switch x := t.(type) {
+ case wire.TypeID:
+ base := fmt.Sprintf("g%dt%d", graph, x)
+ if html {
+ return fmt.Sprintf("<a href=\"#%s\">%s</a>", base, base), true
+ }
+ return fmt.Sprintf("%s", base), true
+ case wire.TypeSpecNil:
+ return "", false // Only nil type.
+ case *wire.TypeSpecPointer:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("(*%s)", element), true
+ case *wire.TypeSpecArray:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("[%d](%s)", x.Count, element), true
+ case *wire.TypeSpecSlice:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("([]%s)", element), true
+ case *wire.TypeSpecMap:
+ key, _ := formatType(x.Key, graph, html)
+ value, _ := formatType(x.Value, graph, html)
+ return fmt.Sprintf("(map[%s]%s)", key, value), true
+ default:
+ panic(fmt.Sprintf("unreachable: unknown type %T", t))
+ }
+}
+
+// format formats a single object, for pretty-printing. It also returns whether
+// the value is a non-zero value.
+func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bool) {
+ switch x := encoded.(type) {
+ case wire.Nil:
+ return "nil", false
+ case *wire.String:
+ return fmt.Sprintf("%q", *x), *x != ""
+ case *wire.Complex64:
+ return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0
+ case *wire.Complex128:
+ return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0
+ case *wire.Ref:
+ return formatRef(x, graph, html), x.Root != 0
+ case *wire.Type:
+ tabs := "\n" + strings.Repeat("\t", depth)
+ items := make([]string, 0, len(x.Fields)+2)
+ items = append(items, fmt.Sprintf("type %s {", x.Name))
+ for i := 0; i < len(x.Fields); i++ {
+ items = append(items, fmt.Sprintf("\t%d: %s,", i, x.Fields[i]))
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), true // No zero value.
+ case *wire.Slice:
+ return fmt.Sprintf("%s{len:%d,cap:%d}", formatRef(&x.Ref, graph, html), x.Length, x.Capacity), x.Capacity != 0
+ case *wire.Array:
+ if len(x.Contents) == 0 {
+ return "[]", false
+ }
+ items := make([]string, 0, len(x.Contents)+2)
+ zeros := make([]string, 0) // used to eliminate zero entries.
+ items = append(items, "[")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ for i := 0; i < len(x.Contents); i++ {
+ item, ok := format(graph, depth+1, x.Contents[i], html)
+ if !ok {
+ zeros = append(zeros, fmt.Sprintf("\t%s,", item))
+ continue
+ }
+ if len(zeros) > 0 {
+ items = append(items, zeros...)
+ zeros = nil
+ }
+ items = append(items, fmt.Sprintf("\t%s,", item))
+ }
+ if len(zeros) > 0 {
+ items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros)))
+ }
+ items = append(items, "]")
+ return strings.Join(items, tabs), len(zeros) < len(x.Contents)
+ case *wire.Struct:
+ typ, _ := formatType(x.TypeID, graph, html)
+ if x.Fields() == 0 {
+ return fmt.Sprintf("struct[%s]{}", typ), false
+ }
+ items := make([]string, 0, 2)
+ items = append(items, fmt.Sprintf("struct[%s]{", typ))
+ tabs := "\n" + strings.Repeat("\t", depth)
+ allZero := true
+ for i := 0; i < x.Fields(); i++ {
+ element, ok := format(graph, depth+1, *x.Field(i), html)
+ allZero = allZero && !ok
+ items = append(items, fmt.Sprintf("\t%d: %s,", i, element))
+ i++
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), !allZero
+ case *wire.Map:
+ if len(x.Keys) == 0 {
+ return "map{}", false
+ }
+ items := make([]string, 0, len(x.Keys)+2)
+ items = append(items, "map{")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ for i := 0; i < len(x.Keys); i++ {
+ key, _ := format(graph, depth+1, x.Keys[i], html)
+ value, _ := format(graph, depth+1, x.Values[i], html)
+ items = append(items, fmt.Sprintf("\t%s: %s,", key, value))
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), true
+ case *wire.Interface:
+ typ, typOk := formatType(x.Type, graph, html)
+ element, elementOk := format(graph, depth+1, x.Value, html)
+ return fmt.Sprintf("interface[%s]{%s}", typ, element), typOk || elementOk
+ default:
+ // Must be a primitive; use reflection.
+ return fmt.Sprintf("%v", encoded), true
+ }
+}
+
+// printStream is the basic print implementation.
+func printStream(w io.Writer, r wire.Reader, html bool) (err error) {
+ // current graph ID.
+ var graph uint64
+
+ if html {
+ fmt.Fprintf(w, "<pre>")
+ defer fmt.Fprintf(w, "</pre>")
+ }
+
+ defer func() {
+ if r := recover(); r != nil {
+ if rErr, ok := r.(error); ok {
+ err = rErr // Override return.
+ return
+ }
+ panic(r) // Propagate.
+ }
+ }()
+
+ for {
+ // Find the first object to begin generation.
+ length, object, err := state.ReadHeader(r)
+ if err == io.EOF {
+ // Nothing else to do.
+ break
+ } else if err != nil {
+ return err
+ }
+ if !object {
+ graph++ // Increment the graph.
+ if length > 0 {
+ fmt.Fprintf(w, "(%d bytes non-object data)\n", length)
+ io.Copy(ioutil.Discard, &io.LimitedReader{
+ R: r,
+ N: int64(length),
+ })
+ }
+ continue
+ }
+
+ // Read & unmarshal the object.
+ //
+ // Note that this loop must match the general structure of the
+ // loop in decode.go. But we don't register type information,
+ // etc. and just print the raw structures.
+ var (
+ oid uint64 = 1
+ tid uint64 = 1
+ )
+ for oid <= length {
+ // Unmarshal the object.
+ encoded := wire.Load(r)
+
+ // Is this a type?
+ if _, ok := encoded.(*wire.Type); ok {
+ str, _ := format(graph, 0, encoded, html)
+ tag := fmt.Sprintf("g%dt%d", graph, tid)
+ if html {
+ // See below.
+ tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
+ }
+ if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
+ return err
+ }
+ tid++
+ continue
+ }
+
+ // Format the node.
+ str, _ := format(graph, 0, encoded, html)
+ tag := fmt.Sprintf("g%dr%d", graph, oid)
+ if html {
+ // Create a little tag with an anchor next to it for linking.
+ tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
+ }
+ if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
+ return err
+ }
+ oid++
+ }
+ }
+
+ return nil
+}
+
+// PrintText reads the stream from r and prints text to w.
+func PrintText(w io.Writer, r wire.Reader) error {
+ return printStream(w, r, false /* html */)
+}
+
+// PrintHTML reads the stream from r and prints html to w.
+func PrintHTML(w io.Writer, r wire.Reader) error {
+ return printStream(w, r, true /* html */)
+}
diff --git a/pkg/state/printer.go b/pkg/state/printer.go
deleted file mode 100644
index 3ce18242f..000000000
--- a/pkg/state/printer.go
+++ /dev/null
@@ -1,251 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package state
-
-import (
- "fmt"
- "io"
- "io/ioutil"
- "reflect"
- "strings"
-
- "github.com/golang/protobuf/proto"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
-)
-
-// format formats a single object, for pretty-printing. It also returns whether
-// the value is a non-zero value.
-func format(graph uint64, depth int, object *pb.Object, html bool) (string, bool) {
- switch x := object.GetValue().(type) {
- case *pb.Object_BoolValue:
- return fmt.Sprintf("%t", x.BoolValue), x.BoolValue != false
- case *pb.Object_StringValue:
- return fmt.Sprintf("\"%s\"", string(x.StringValue)), len(x.StringValue) != 0
- case *pb.Object_Int64Value:
- return fmt.Sprintf("%d", x.Int64Value), x.Int64Value != 0
- case *pb.Object_Uint64Value:
- return fmt.Sprintf("%du", x.Uint64Value), x.Uint64Value != 0
- case *pb.Object_DoubleValue:
- return fmt.Sprintf("%f", x.DoubleValue), x.DoubleValue != 0.0
- case *pb.Object_RefValue:
- if x.RefValue == 0 {
- return "nil", false
- }
- ref := fmt.Sprintf("g%dr%d", graph, x.RefValue)
- if html {
- ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref)
- }
- return ref, true
- case *pb.Object_SliceValue:
- if x.SliceValue.RefValue == 0 {
- return "nil", false
- }
- ref := fmt.Sprintf("g%dr%d", graph, x.SliceValue.RefValue)
- if html {
- ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref)
- }
- return fmt.Sprintf("%s[:%d:%d]", ref, x.SliceValue.Length, x.SliceValue.Capacity), true
- case *pb.Object_ArrayValue:
- if len(x.ArrayValue.Contents) == 0 {
- return "[]", false
- }
- items := make([]string, 0, len(x.ArrayValue.Contents)+2)
- zeros := make([]string, 0) // used to eliminate zero entries.
- items = append(items, "[")
- tabs := "\n" + strings.Repeat("\t", depth)
- for i := 0; i < len(x.ArrayValue.Contents); i++ {
- item, ok := format(graph, depth+1, x.ArrayValue.Contents[i], html)
- if ok {
- if len(zeros) > 0 {
- items = append(items, zeros...)
- zeros = nil
- }
- items = append(items, fmt.Sprintf("\t%s,", item))
- } else {
- zeros = append(zeros, fmt.Sprintf("\t%s,", item))
- }
- }
- if len(zeros) > 0 {
- items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros)))
- }
- items = append(items, "]")
- return strings.Join(items, tabs), len(zeros) < len(x.ArrayValue.Contents)
- case *pb.Object_StructValue:
- if len(x.StructValue.Fields) == 0 {
- return "struct{}", false
- }
- items := make([]string, 0, len(x.StructValue.Fields)+2)
- items = append(items, "struct{")
- tabs := "\n" + strings.Repeat("\t", depth)
- allZero := true
- for _, field := range x.StructValue.Fields {
- element, ok := format(graph, depth+1, field.Value, html)
- allZero = allZero && !ok
- items = append(items, fmt.Sprintf("\t%s: %s,", field.Name, element))
- }
- items = append(items, "}")
- return strings.Join(items, tabs), !allZero
- case *pb.Object_MapValue:
- if len(x.MapValue.Keys) == 0 {
- return "map{}", false
- }
- items := make([]string, 0, len(x.MapValue.Keys)+2)
- items = append(items, "map{")
- tabs := "\n" + strings.Repeat("\t", depth)
- for i := 0; i < len(x.MapValue.Keys); i++ {
- key, _ := format(graph, depth+1, x.MapValue.Keys[i], html)
- value, _ := format(graph, depth+1, x.MapValue.Values[i], html)
- items = append(items, fmt.Sprintf("\t%s: %s,", key, value))
- }
- items = append(items, "}")
- return strings.Join(items, tabs), true
- case *pb.Object_InterfaceValue:
- if x.InterfaceValue.Type == "" {
- return "interface(nil){}", false
- }
- element, _ := format(graph, depth+1, x.InterfaceValue.Value, html)
- return fmt.Sprintf("interface(\"%s\"){%s}", x.InterfaceValue.Type, element), true
- case *pb.Object_ByteArrayValue:
- return printArray(reflect.ValueOf(x.ByteArrayValue))
- case *pb.Object_Uint16ArrayValue:
- return printArray(reflect.ValueOf(x.Uint16ArrayValue.Values))
- case *pb.Object_Uint32ArrayValue:
- return printArray(reflect.ValueOf(x.Uint32ArrayValue.Values))
- case *pb.Object_Uint64ArrayValue:
- return printArray(reflect.ValueOf(x.Uint64ArrayValue.Values))
- case *pb.Object_UintptrArrayValue:
- return printArray(castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0))))
- case *pb.Object_Int8ArrayValue:
- return printArray(castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0))))
- case *pb.Object_Int16ArrayValue:
- return printArray(reflect.ValueOf(x.Int16ArrayValue.Values))
- case *pb.Object_Int32ArrayValue:
- return printArray(reflect.ValueOf(x.Int32ArrayValue.Values))
- case *pb.Object_Int64ArrayValue:
- return printArray(reflect.ValueOf(x.Int64ArrayValue.Values))
- case *pb.Object_BoolArrayValue:
- return printArray(reflect.ValueOf(x.BoolArrayValue.Values))
- case *pb.Object_Float64ArrayValue:
- return printArray(reflect.ValueOf(x.Float64ArrayValue.Values))
- case *pb.Object_Float32ArrayValue:
- return printArray(reflect.ValueOf(x.Float32ArrayValue.Values))
- }
-
- // Should not happen, but tolerate.
- return fmt.Sprintf("(unknown proto type: %T)", object.GetValue()), true
-}
-
-// PrettyPrint reads the state stream from r, and pretty prints to w.
-func PrettyPrint(w io.Writer, r io.Reader, html bool) error {
- var (
- // current graph ID.
- graph uint64
-
- // current object ID.
- id uint64
- )
-
- if html {
- fmt.Fprintf(w, "<pre>")
- defer fmt.Fprintf(w, "</pre>")
- }
-
- for {
- // Find the first object to begin generation.
- length, object, err := ReadHeader(r)
- if err == io.EOF {
- // Nothing else to do.
- break
- } else if err != nil {
- return err
- }
- if !object {
- // Increment the graph number & reset the ID.
- graph++
- id = 0
- if length > 0 {
- fmt.Fprintf(w, "(%d bytes non-object data)\n", length)
- io.Copy(ioutil.Discard, &io.LimitedReader{
- R: r,
- N: int64(length),
- })
- }
- continue
- }
-
- // Read & unmarshal the object.
- buf := make([]byte, length)
- for done := 0; done < len(buf); {
- n, err := r.Read(buf[done:])
- done += n
- if n == 0 && err != nil {
- return err
- }
- }
- obj := new(pb.Object)
- if err := proto.Unmarshal(buf, obj); err != nil {
- return err
- }
-
- id++ // First object must be one.
- str, _ := format(graph, 0, obj, html)
- tag := fmt.Sprintf("g%dr%d", graph, id)
- if html {
- tag = fmt.Sprintf("<a name=%s>%s</a>", tag, tag)
- }
- if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func printArray(s reflect.Value) (string, bool) {
- zero := reflect.Zero(s.Type().Elem()).Interface()
- z := "0"
- switch s.Type().Elem().Kind() {
- case reflect.Bool:
- z = "false"
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- case reflect.Float32, reflect.Float64:
- default:
- return fmt.Sprintf("unexpected non-primitive type array: %#v", s.Interface()), true
- }
-
- zeros := 0
- items := make([]string, 0, s.Len())
- for i := 0; i <= s.Len(); i++ {
- if i < s.Len() && reflect.DeepEqual(s.Index(i).Interface(), zero) {
- zeros++
- continue
- }
- if zeros > 0 {
- if zeros <= 4 {
- for ; zeros > 0; zeros-- {
- items = append(items, z)
- }
- } else {
- items = append(items, fmt.Sprintf("(%d %ss)", zeros, z))
- zeros = 0
- }
- }
- if i < s.Len() {
- items = append(items, fmt.Sprintf("%v", s.Index(i).Interface()))
- }
- }
- return "[" + strings.Join(items, ",") + "]", zeros < s.Len()
-}
diff --git a/pkg/state/state.go b/pkg/state/state.go
index 03ae2dbb0..acb629969 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -31,210 +31,226 @@
// Uint64 default
// Float32 default
// Float64 default
-// Complex64 custom
-// Complex128 custom
+// Complex64 default
+// Complex128 default
// Array default
// Chan custom
// Func custom
-// Interface custom
-// Map default (*)
+// Interface default
+// Map default
// Ptr default
// Slice default
// String default
-// Struct custom
+// Struct custom (*) Unless zero-sized.
// UnsafePointer custom
//
-// (*) Maps are treated as value types by this package, even if they are
-// pointers internally. If you want to save two independent references
-// to the same map value, you must explicitly use a pointer to a map.
+// See README.md for an overview of how encoding and decoding works.
package state
import (
"context"
"fmt"
- "io"
"reflect"
"runtime"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
+// objectID is a unique identifier assigned to each object to be serialized.
+// Each instance of an object is considered separately, i.e. if there are two
+// objects of the same type in the object graph being serialized, they'll be
+// assigned unique objectIDs.
+type objectID uint32
+
+// typeID is the identifier for a type. Types are serialized and tracked
+// alongside objects in order to avoid the overhead of encoding field names in
+// all objects.
+type typeID uint32
+
// ErrState is returned when an error is encountered during encode/decode.
type ErrState struct {
// err is the underlying error.
err error
- // path is the visit path from root to the current object.
- path string
-
// trace is the stack trace.
trace string
}
// Error returns a sensible description of the state error.
func (e *ErrState) Error() string {
- return fmt.Sprintf("%v:\nstate path: %s\n%s", e.err, e.path, e.trace)
+ return fmt.Sprintf("%v:\n%s", e.err, e.trace)
}
-// UnwrapErrState returns the underlying error in ErrState.
-//
-// If err is not *ErrState, err is returned directly.
-func UnwrapErrState(err error) error {
- if e, ok := err.(*ErrState); ok {
- return e.err
- }
- return err
+// Unwrap implements standard unwrapping.
+func (e *ErrState) Unwrap() error {
+ return e.err
}
// Save saves the given object state.
-func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error {
+func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) {
// Create the encoding state.
- es := &encodeState{
- ctx: ctx,
- idsByObject: make(map[uintptr]uint64),
- w: w,
- stats: stats,
+ es := encodeState{
+ ctx: ctx,
+ w: w,
+ types: makeTypeEncodeDatabase(),
+ zeroValues: make(map[reflect.Type]*objectEncodeState),
}
// Perform the encoding.
- return es.safely(func() {
- es.Serialize(reflect.ValueOf(rootPtr).Elem())
+ err := safely(func() {
+ es.Save(reflect.ValueOf(rootPtr).Elem())
})
+ return es.stats, err
}
// Load loads a checkpoint.
-func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error {
+func Load(ctx context.Context, r wire.Reader, rootPtr interface{}) (Stats, error) {
// Create the decoding state.
- ds := &decodeState{
- ctx: ctx,
- objectsByID: make(map[uint64]*objectState),
- deferred: make(map[uint64]*pb.Object),
- r: r,
- stats: stats,
+ ds := decodeState{
+ ctx: ctx,
+ r: r,
+ types: makeTypeDecodeDatabase(),
+ deferred: make(map[objectID]wire.Object),
}
// Attempt our decode.
- return ds.safely(func() {
- ds.Deserialize(reflect.ValueOf(rootPtr).Elem())
+ err := safely(func() {
+ ds.Load(reflect.ValueOf(rootPtr).Elem())
})
+ return ds.stats, err
}
-// Fns are the state dispatch functions.
-type Fns struct {
- // Save is a function like Save(concreteType, Map).
- Save interface{}
-
- // Load is a function like Load(concreteType, Map).
- Load interface{}
+// Sink is used for Type.StateSave.
+type Sink struct {
+ internal objectEncoder
}
-// Save executes the save function.
-func (fns *Fns) invokeSave(obj reflect.Value, m Map) {
- reflect.ValueOf(fns.Save).Call([]reflect.Value{obj, reflect.ValueOf(m)})
+// Save adds the given object to the map.
+//
+// You should pass always pointers to the object you are saving. For example:
+//
+// type X struct {
+// A int
+// B *int
+// }
+//
+// func (x *X) StateTypeInfo(m Sink) state.TypeInfo {
+// return state.TypeInfo{
+// Name: "pkg.X",
+// Fields: []string{
+// "A",
+// "B",
+// },
+// }
+// }
+//
+// func (x *X) StateSave(m Sink) {
+// m.Save(0, &x.A) // Field is A.
+// m.Save(1, &x.B) // Field is B.
+// }
+//
+// func (x *X) StateLoad(m Source) {
+// m.Load(0, &x.A) // Field is A.
+// m.Load(1, &x.B) // Field is B.
+// }
+func (s Sink) Save(slot int, objPtr interface{}) {
+ s.internal.save(slot, reflect.ValueOf(objPtr).Elem())
}
-// Load executes the load function.
-func (fns *Fns) invokeLoad(obj reflect.Value, m Map) {
- reflect.ValueOf(fns.Load).Call([]reflect.Value{obj, reflect.ValueOf(m)})
+// SaveValue adds the given object value to the map.
+//
+// This should be used for values where pointers are not available, or casts
+// are required during Save/Load.
+//
+// For example, if we want to cast external package type P.Foo to int64:
+//
+// func (x *X) StateSave(m Sink) {
+// m.SaveValue(0, "A", int64(x.A))
+// }
+//
+// func (x *X) StateLoad(m Source) {
+// m.LoadValue(0, new(int64), func(x interface{}) {
+// x.A = P.Foo(x.(int64))
+// })
+// }
+func (s Sink) SaveValue(slot int, obj interface{}) {
+ s.internal.save(slot, reflect.ValueOf(obj))
}
-// validateStateFn ensures types are correct.
-func validateStateFn(fn interface{}, typ reflect.Type) bool {
- fnTyp := reflect.TypeOf(fn)
- if fnTyp.Kind() != reflect.Func {
- return false
- }
- if fnTyp.NumIn() != 2 {
- return false
- }
- if fnTyp.NumOut() != 0 {
- return false
- }
- if fnTyp.In(0) != typ {
- return false
- }
- if fnTyp.In(1) != reflect.TypeOf(Map{}) {
- return false
- }
- return true
+// Context returns the context object provided at save time.
+func (s Sink) Context() context.Context {
+ return s.internal.es.ctx
}
-// Validate validates all state functions.
-func (fns *Fns) Validate(typ reflect.Type) bool {
- return validateStateFn(fns.Save, typ) && validateStateFn(fns.Load, typ)
+// Type is an interface that must be implemented by Struct objects. This allows
+// these objects to be serialized while minimizing runtime reflection required.
+//
+// All these methods can be automatically generated by the go_statify tool.
+type Type interface {
+ // StateTypeName returns the type's name.
+ //
+ // This is used for matching type information during encoding and
+ // decoding, as well as dynamic interface dispatch. This should be
+ // globally unique.
+ StateTypeName() string
+
+ // StateFields returns information about the type.
+ //
+ // Fields is the set of fields for the object. Calls to Sink.Save and
+ // Source.Load must be made in-order with respect to these fields.
+ //
+ // This will be called at most once per serialization.
+ StateFields() []string
}
-type typeDatabase struct {
- // nameToType is a forward lookup table.
- nameToType map[string]reflect.Type
-
- // typeToName is the reverse lookup table.
- typeToName map[reflect.Type]string
+// SaverLoader must be implemented by struct types.
+type SaverLoader interface {
+ // StateSave saves the state of the object to the given Map.
+ StateSave(Sink)
- // typeToFns is the function lookup table.
- typeToFns map[reflect.Type]Fns
+ // StateLoad loads the state of the object.
+ StateLoad(Source)
}
-// registeredTypes is a database used for SaveInterface and LoadInterface.
-var registeredTypes = typeDatabase{
- nameToType: make(map[string]reflect.Type),
- typeToName: make(map[reflect.Type]string),
- typeToFns: make(map[reflect.Type]Fns),
+// Source is used for Type.StateLoad.
+type Source struct {
+ internal objectDecoder
}
-// register registers a type under the given name. This will generally be
-// called via init() methods, and therefore uses panic to propagate errors.
-func (t *typeDatabase) register(name string, typ reflect.Type, fns Fns) {
- // We can't allow name collisions.
- if ot, ok := t.nameToType[name]; ok {
- panic(fmt.Sprintf("type %q can't use name %q, already in use by type %q", typ.Name(), name, ot.Name()))
- }
-
- // Or multiple registrations.
- if on, ok := t.typeToName[typ]; ok {
- panic(fmt.Sprintf("type %q can't be registered as %q, already registered as %q", typ.Name(), name, on))
- }
-
- t.nameToType[name] = typ
- t.typeToName[typ] = name
- t.typeToFns[typ] = fns
+// Load loads the given object passed as a pointer..
+//
+// See Sink.Save for an example.
+func (s Source) Load(slot int, objPtr interface{}) {
+ s.internal.load(slot, reflect.ValueOf(objPtr), false, nil)
}
-// lookupType finds a type given a name.
-func (t *typeDatabase) lookupType(name string) (reflect.Type, bool) {
- typ, ok := t.nameToType[name]
- return typ, ok
+// LoadWait loads the given objects from the map, and marks it as requiring all
+// AfterLoad executions to complete prior to running this object's AfterLoad.
+//
+// See Sink.Save for an example.
+func (s Source) LoadWait(slot int, objPtr interface{}) {
+ s.internal.load(slot, reflect.ValueOf(objPtr), true, nil)
}
-// lookupName finds a name given a type.
-func (t *typeDatabase) lookupName(typ reflect.Type) (string, bool) {
- name, ok := t.typeToName[typ]
- return name, ok
+// LoadValue loads the given object value from the map.
+//
+// See Sink.SaveValue for an example.
+func (s Source) LoadValue(slot int, objPtr interface{}, fn func(interface{})) {
+ o := reflect.ValueOf(objPtr)
+ s.internal.load(slot, o, true, func() { fn(o.Elem().Interface()) })
}
-// lookupFns finds functions given a type.
-func (t *typeDatabase) lookupFns(typ reflect.Type) (Fns, bool) {
- fns, ok := t.typeToFns[typ]
- return fns, ok
+// AfterLoad schedules a function execution when all objects have been
+// allocated and their automated loading and customized load logic have been
+// executed. fn will not be executed until all of current object's
+// dependencies' AfterLoad() logic, if exist, have been executed.
+func (s Source) AfterLoad(fn func()) {
+ s.internal.afterLoad(fn)
}
-// Register must be called for any interface implementation types that
-// implements Loader.
-//
-// Register should be called either immediately after startup or via init()
-// methods. Double registration of either names or types will result in a panic.
-//
-// No synchronization is provided; this should only be called in init.
-//
-// Example usage:
-//
-// state.Register("Foo", (*Foo)(nil), state.Fns{
-// Save: (*Foo).Save,
-// Load: (*Foo).Load,
-// })
-//
-func Register(name string, instance interface{}, fns Fns) {
- registeredTypes.register(name, reflect.TypeOf(instance), fns)
+// Context returns the context object provided at load time.
+func (s Source) Context() context.Context {
+ return s.internal.ds.ctx
}
// IsZeroValue checks if the given value is the zero value.
@@ -244,72 +260,14 @@ func IsZeroValue(val interface{}) bool {
return val == nil || reflect.ValueOf(val).Elem().IsZero()
}
-// step captures one encoding / decoding step. On each step, there is up to one
-// choice made, which is captured by non-nil param. We intentionally do not
-// eagerly create the final path string, as that will only be needed upon panic.
-type step struct {
- // dereference indicate if the current object is obtained by
- // dereferencing a pointer.
- dereference bool
-
- // format is the formatting string that takes param below, if
- // non-nil. For example, in array indexing case, we have "[%d]".
- format string
-
- // param stores the choice made at the current encoding / decoding step.
- // For eaxmple, in array indexing case, param stores the index. When no
- // choice is made, e.g. dereference, param should be nil.
- param interface{}
-}
-
-// recoverable is the state encoding / decoding panic recovery facility. It is
-// also used to store encoding / decoding steps as well as the reference to the
-// original queued object from which the current object is dispatched. The
-// complete encoding / decoding path is synthesised from the steps in all queued
-// objects leading to the current object.
-type recoverable struct {
- from *recoverable
- steps []step
+// Failf is a wrapper around panic that should be used to generate errors that
+// can be caught during saving and loading.
+func Failf(fmtStr string, v ...interface{}) {
+ panic(fmt.Errorf(fmtStr, v...))
}
-// push enters a new context level.
-func (sr *recoverable) push(dereference bool, format string, param interface{}) {
- sr.steps = append(sr.steps, step{dereference, format, param})
-}
-
-// pop exits the current context level.
-func (sr *recoverable) pop() {
- if len(sr.steps) <= 1 {
- return
- }
- sr.steps = sr.steps[:len(sr.steps)-1]
-}
-
-// path returns the complete encoding / decoding path from root. This is only
-// called upon panic.
-func (sr *recoverable) path() string {
- if sr.from == nil {
- return "root"
- }
- p := sr.from.path()
- for _, s := range sr.steps {
- if s.dereference {
- p = fmt.Sprintf("*(%s)", p)
- }
- if s.param == nil {
- p += s.format
- } else {
- p += fmt.Sprintf(s.format, s.param)
- }
- }
- return p
-}
-
-func (sr *recoverable) copy() recoverable {
- return recoverable{from: sr.from, steps: append([]step(nil), sr.steps...)}
-}
-
-// safely executes the given function, catching a panic and unpacking as an error.
+// safely executes the given function, catching a panic and unpacking as an
+// error.
//
// The error flow through the state package uses panic and recover. There are
// two important reasons for this:
@@ -323,9 +281,15 @@ func (sr *recoverable) copy() recoverable {
// method doesn't add a lot of value. If there are specific error conditions
// that you'd like to handle, you should add appropriate functionality to
// objects themselves prior to calling Save() and Load().
-func (sr *recoverable) safely(fn func()) (err error) {
+func safely(fn func()) (err error) {
defer func() {
if r := recover(); r != nil {
+ if es, ok := r.(*ErrState); ok {
+ err = es // Propagate.
+ return
+ }
+
+ // Build a new state error.
es := new(ErrState)
if e, ok := r.(error); ok {
es.err = e
@@ -333,8 +297,6 @@ func (sr *recoverable) safely(fn func()) (err error) {
es.err = fmt.Errorf("%v", r)
}
- es.path = sr.path()
-
// Make a stack. We don't know how big it will be ahead
// of time, but want to make sure we get the whole
// thing. So we just do a stupid brute force approach.
diff --git a/pkg/state/state_norace.go b/pkg/state/state_norace.go
new file mode 100644
index 000000000..4281aed6d
--- /dev/null
+++ b/pkg/state/state_norace.go
@@ -0,0 +1,19 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !race
+
+package state
+
+var raceEnabled = false
diff --git a/pkg/state/state_race.go b/pkg/state/state_race.go
new file mode 100644
index 000000000..8232981ce
--- /dev/null
+++ b/pkg/state/state_race.go
@@ -0,0 +1,19 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build race
+
+package state
+
+var raceEnabled = true
diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go
deleted file mode 100644
index d7221e9e8..000000000
--- a/pkg/state/state_test.go
+++ /dev/null
@@ -1,721 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package state
-
-import (
- "bytes"
- "context"
- "io/ioutil"
- "math"
- "reflect"
- "testing"
-)
-
-// TestCase is used to define a single success/failure testcase of
-// serialization of a set of objects.
-type TestCase struct {
- // Name is the name of the test case.
- Name string
-
- // Objects is the list of values to serialize.
- Objects []interface{}
-
- // Fail is whether the test case is supposed to fail or not.
- Fail bool
-}
-
-// runTest runs all testcases.
-func runTest(t *testing.T, tests []TestCase) {
- for _, test := range tests {
- t.Logf("TEST %s:", test.Name)
- for i, root := range test.Objects {
- t.Logf(" case#%d: %#v", i, root)
-
- // Save the passed object.
- saveBuffer := &bytes.Buffer{}
- saveObjectPtr := reflect.New(reflect.TypeOf(root))
- saveObjectPtr.Elem().Set(reflect.ValueOf(root))
- if err := Save(context.Background(), saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail {
- t.Errorf(" FAIL: Save failed unexpectedly: %v", err)
- continue
- } else if err != nil {
- t.Logf(" PASS: Save failed as expected: %v", err)
- continue
- }
-
- // Load a new copy of the object.
- loadObjectPtr := reflect.New(reflect.TypeOf(root))
- if err := Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail {
- t.Errorf(" FAIL: Load failed unexpectedly: %v", err)
- continue
- } else if err != nil {
- t.Logf(" PASS: Load failed as expected: %v", err)
- continue
- }
-
- // Compare the values.
- loadedValue := loadObjectPtr.Elem().Interface()
- if eq := reflect.DeepEqual(root, loadedValue); !eq && !test.Fail {
- t.Errorf(" FAIL: Objects differs; got %#v", loadedValue)
- continue
- } else if !eq {
- t.Logf(" PASS: Object different as expected.")
- continue
- }
-
- // Everything went okay. Is that good?
- if test.Fail {
- t.Errorf(" FAIL: Unexpected success.")
- } else {
- t.Logf(" PASS: Success.")
- }
- }
- }
-}
-
-// dumbStruct is a struct which does not implement the loader/saver interface.
-// We expect that serialization of this struct will fail.
-type dumbStruct struct {
- A int
- B int
-}
-
-// smartStruct is a struct which does implement the loader/saver interface.
-// We expect that serialization of this struct will succeed.
-type smartStruct struct {
- A int
- B int
-}
-
-func (s *smartStruct) save(m Map) {
- m.Save("A", &s.A)
- m.Save("B", &s.B)
-}
-
-func (s *smartStruct) load(m Map) {
- m.Load("A", &s.A)
- m.Load("B", &s.B)
-}
-
-// valueLoadStruct uses a value load.
-type valueLoadStruct struct {
- v int
-}
-
-func (v *valueLoadStruct) save(m Map) {
- m.SaveValue("v", v.v)
-}
-
-func (v *valueLoadStruct) load(m Map) {
- m.LoadValue("v", new(int), func(value interface{}) {
- v.v = value.(int)
- })
-}
-
-// afterLoadStruct has an AfterLoad function.
-type afterLoadStruct struct {
- v int
-}
-
-func (a *afterLoadStruct) save(m Map) {
-}
-
-func (a *afterLoadStruct) load(m Map) {
- m.AfterLoad(func() {
- a.v++
- })
-}
-
-// genericContainer is a generic dispatcher.
-type genericContainer struct {
- v interface{}
-}
-
-func (g *genericContainer) save(m Map) {
- m.Save("v", &g.v)
-}
-
-func (g *genericContainer) load(m Map) {
- m.Load("v", &g.v)
-}
-
-// sliceContainer is a generic slice.
-type sliceContainer struct {
- v []interface{}
-}
-
-func (s *sliceContainer) save(m Map) {
- m.Save("v", &s.v)
-}
-
-func (s *sliceContainer) load(m Map) {
- m.Load("v", &s.v)
-}
-
-// mapContainer is a generic map.
-type mapContainer struct {
- v map[int]interface{}
-}
-
-func (mc *mapContainer) save(m Map) {
- m.Save("v", &mc.v)
-}
-
-func (mc *mapContainer) load(m Map) {
- // Some of the test cases below assume legacy behavior wherein maps
- // will automatically inherit dependencies.
- m.LoadWait("v", &mc.v)
-}
-
-// dumbMap is a map which does not implement the loader/saver interface.
-// Serialization of this map will default to the standard encode/decode logic.
-type dumbMap map[string]int
-
-// pointerStruct contains various pointers, shared and non-shared, and pointers
-// to pointers. We expect that serialization will respect the structure.
-type pointerStruct struct {
- A *int
- B *int
- C *int
- D *int
-
- AA **int
- BB **int
-}
-
-func (p *pointerStruct) save(m Map) {
- m.Save("A", &p.A)
- m.Save("B", &p.B)
- m.Save("C", &p.C)
- m.Save("D", &p.D)
- m.Save("AA", &p.AA)
- m.Save("BB", &p.BB)
-}
-
-func (p *pointerStruct) load(m Map) {
- m.Load("A", &p.A)
- m.Load("B", &p.B)
- m.Load("C", &p.C)
- m.Load("D", &p.D)
- m.Load("AA", &p.AA)
- m.Load("BB", &p.BB)
-}
-
-// testInterface is a trivial interface example.
-type testInterface interface {
- Foo()
-}
-
-// testImpl is a trivial implementation of testInterface.
-type testImpl struct {
-}
-
-// Foo satisfies testInterface.
-func (t *testImpl) Foo() {
-}
-
-// testImpl is trivially serializable.
-func (t *testImpl) save(m Map) {
-}
-
-// testImpl is trivially serializable.
-func (t *testImpl) load(m Map) {
-}
-
-// testI demonstrates interface dispatching.
-type testI struct {
- I testInterface
-}
-
-func (t *testI) save(m Map) {
- m.Save("I", &t.I)
-}
-
-func (t *testI) load(m Map) {
- m.Load("I", &t.I)
-}
-
-// cycleStruct is used to implement basic cycles.
-type cycleStruct struct {
- c *cycleStruct
-}
-
-func (c *cycleStruct) save(m Map) {
- m.Save("c", &c.c)
-}
-
-func (c *cycleStruct) load(m Map) {
- m.Load("c", &c.c)
-}
-
-// badCycleStruct actually has deadlocking dependencies.
-//
-// This should pass if b.b = {nil|b} and fail otherwise.
-type badCycleStruct struct {
- b *badCycleStruct
-}
-
-func (b *badCycleStruct) save(m Map) {
- m.Save("b", &b.b)
-}
-
-func (b *badCycleStruct) load(m Map) {
- m.LoadWait("b", &b.b)
- m.AfterLoad(func() {
- // This is not executable, since AfterLoad requires that the
- // object and all dependencies are complete. This should cause
- // a deadlock error during load.
- })
-}
-
-// emptyStructPointer points to an empty struct.
-type emptyStructPointer struct {
- nothing *struct{}
-}
-
-func (e *emptyStructPointer) save(m Map) {
- m.Save("nothing", &e.nothing)
-}
-
-func (e *emptyStructPointer) load(m Map) {
- m.Load("nothing", &e.nothing)
-}
-
-// truncateInteger truncates an integer.
-type truncateInteger struct {
- v int64
- v2 int32
-}
-
-func (t *truncateInteger) save(m Map) {
- t.v2 = int32(t.v)
- m.Save("v", &t.v)
-}
-
-func (t *truncateInteger) load(m Map) {
- m.Load("v", &t.v2)
- t.v = int64(t.v2)
-}
-
-// truncateUnsignedInteger truncates an unsigned integer.
-type truncateUnsignedInteger struct {
- v uint64
- v2 uint32
-}
-
-func (t *truncateUnsignedInteger) save(m Map) {
- t.v2 = uint32(t.v)
- m.Save("v", &t.v)
-}
-
-func (t *truncateUnsignedInteger) load(m Map) {
- m.Load("v", &t.v2)
- t.v = uint64(t.v2)
-}
-
-// truncateFloat truncates a floating point number.
-type truncateFloat struct {
- v float64
- v2 float32
-}
-
-func (t *truncateFloat) save(m Map) {
- t.v2 = float32(t.v)
- m.Save("v", &t.v)
-}
-
-func (t *truncateFloat) load(m Map) {
- m.Load("v", &t.v2)
- t.v = float64(t.v2)
-}
-
-func TestTypes(t *testing.T) {
- // x and y are basic integers, while xp points to x.
- x := 1
- y := 2
- xp := &x
-
- // cs is a single object cycle.
- cs := cycleStruct{nil}
- cs.c = &cs
-
- // cs1 and cs2 are in a two object cycle.
- cs1 := cycleStruct{nil}
- cs2 := cycleStruct{nil}
- cs1.c = &cs2
- cs2.c = &cs1
-
- // bs is a single object cycle.
- bs := badCycleStruct{nil}
- bs.b = &bs
-
- // bs2 and bs2 are in a deadlocking cycle.
- bs1 := badCycleStruct{nil}
- bs2 := badCycleStruct{nil}
- bs1.b = &bs2
- bs2.b = &bs1
-
- // regular nils.
- var (
- nilmap dumbMap
- nilslice []byte
- )
-
- // embed points to embedded fields.
- embed1 := pointerStruct{}
- embed1.AA = &embed1.A
- embed2 := pointerStruct{}
- embed2.BB = &embed2.B
-
- // es1 contains two structs pointing to the same empty struct.
- es := emptyStructPointer{new(struct{})}
- es1 := []emptyStructPointer{es, es}
-
- tests := []TestCase{
- {
- Name: "bool",
- Objects: []interface{}{
- true,
- false,
- },
- },
- {
- Name: "integers",
- Objects: []interface{}{
- int(0),
- int(1),
- int(-1),
- int8(0),
- int8(1),
- int8(-1),
- int16(0),
- int16(1),
- int16(-1),
- int32(0),
- int32(1),
- int32(-1),
- int64(0),
- int64(1),
- int64(-1),
- },
- },
- {
- Name: "unsigned integers",
- Objects: []interface{}{
- uint(0),
- uint(1),
- uint8(0),
- uint8(1),
- uint16(0),
- uint16(1),
- uint32(1),
- uint64(0),
- uint64(1),
- },
- },
- {
- Name: "strings",
- Objects: []interface{}{
- "",
- "foo",
- "bar",
- "\xa0",
- },
- },
- {
- Name: "slices",
- Objects: []interface{}{
- []int{-1, 0, 1},
- []*int{&x, &x, &x},
- []int{1, 2, 3}[0:1],
- []int{1, 2, 3}[1:2],
- make([]byte, 32),
- make([]byte, 32)[:16],
- make([]byte, 32)[:16:20],
- nilslice,
- },
- },
- {
- Name: "arrays",
- Objects: []interface{}{
- &[1048576]bool{false, true, false, true},
- &[1048576]uint8{0, 1, 2, 3},
- &[1048576]byte{0, 1, 2, 3},
- &[1048576]uint16{0, 1, 2, 3},
- &[1048576]uint{0, 1, 2, 3},
- &[1048576]uint32{0, 1, 2, 3},
- &[1048576]uint64{0, 1, 2, 3},
- &[1048576]uintptr{0, 1, 2, 3},
- &[1048576]int8{0, -1, -2, -3},
- &[1048576]int16{0, -1, -2, -3},
- &[1048576]int32{0, -1, -2, -3},
- &[1048576]int64{0, -1, -2, -3},
- &[1048576]float32{0, 1.1, 2.2, 3.3},
- &[1048576]float64{0, 1.1, 2.2, 3.3},
- },
- },
- {
- Name: "pointers",
- Objects: []interface{}{
- &pointerStruct{A: &x, B: &x, C: &y, D: &y, AA: &xp, BB: &xp},
- &pointerStruct{},
- },
- },
- {
- Name: "empty struct",
- Objects: []interface{}{
- struct{}{},
- },
- },
- {
- Name: "unenlightened structs",
- Objects: []interface{}{
- &dumbStruct{A: 1, B: 2},
- },
- Fail: true,
- },
- {
- Name: "enlightened structs",
- Objects: []interface{}{
- &smartStruct{A: 1, B: 2},
- },
- },
- {
- Name: "load-hooks",
- Objects: []interface{}{
- &afterLoadStruct{v: 1},
- &valueLoadStruct{v: 1},
- &genericContainer{v: &afterLoadStruct{v: 1}},
- &genericContainer{v: &valueLoadStruct{v: 1}},
- &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}},
- &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}},
- &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}},
- &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}},
- },
- },
- {
- Name: "maps",
- Objects: []interface{}{
- dumbMap{"a": -1, "b": 0, "c": 1},
- map[smartStruct]int{{}: 0, {A: 1}: 1},
- nilmap,
- &mapContainer{v: map[int]interface{}{0: &smartStruct{A: 1}}},
- },
- },
- {
- Name: "interfaces",
- Objects: []interface{}{
- &testI{&testImpl{}},
- &testI{nil},
- &testI{(*testImpl)(nil)},
- },
- },
- {
- Name: "unregistered-interfaces",
- Objects: []interface{}{
- &genericContainer{v: afterLoadStruct{v: 1}},
- &genericContainer{v: valueLoadStruct{v: 1}},
- &sliceContainer{v: []interface{}{afterLoadStruct{v: 1}}},
- &sliceContainer{v: []interface{}{valueLoadStruct{v: 1}}},
- &mapContainer{v: map[int]interface{}{0: afterLoadStruct{v: 1}}},
- &mapContainer{v: map[int]interface{}{0: valueLoadStruct{v: 1}}},
- },
- Fail: true,
- },
- {
- Name: "cycles",
- Objects: []interface{}{
- &cs,
- &cs1,
- &cycleStruct{&cs1},
- &cycleStruct{&cs},
- &badCycleStruct{nil},
- &bs,
- },
- },
- {
- Name: "deadlock",
- Objects: []interface{}{
- &bs1,
- },
- Fail: true,
- },
- {
- Name: "embed",
- Objects: []interface{}{
- &embed1,
- &embed2,
- },
- Fail: true,
- },
- {
- Name: "empty structs",
- Objects: []interface{}{
- new(struct{}),
- es,
- es1,
- },
- },
- {
- Name: "truncated okay",
- Objects: []interface{}{
- &truncateInteger{v: 1},
- &truncateUnsignedInteger{v: 1},
- &truncateFloat{v: 1.0},
- },
- },
- {
- Name: "truncated bad",
- Objects: []interface{}{
- &truncateInteger{v: math.MaxInt32 + 1},
- &truncateUnsignedInteger{v: math.MaxUint32 + 1},
- &truncateFloat{v: math.MaxFloat32 * 2},
- },
- Fail: true,
- },
- }
-
- runTest(t, tests)
-}
-
-// benchStruct is used for benchmarking.
-type benchStruct struct {
- b *benchStruct
-
- // Dummy data is included to ensure that these objects are large.
- // This is to detect possible regression when registering objects.
- _ [4096]byte
-}
-
-func (b *benchStruct) save(m Map) {
- m.Save("b", &b.b)
-}
-
-func (b *benchStruct) load(m Map) {
- m.LoadWait("b", &b.b)
- m.AfterLoad(b.afterLoad)
-}
-
-func (b *benchStruct) afterLoad() {
- // Do nothing, just force scheduling.
-}
-
-// buildObject builds a benchmark object.
-func buildObject(n int) (b *benchStruct) {
- for i := 0; i < n; i++ {
- b = &benchStruct{b: b}
- }
- return
-}
-
-func BenchmarkEncoding(b *testing.B) {
- b.StopTimer()
- bs := buildObject(b.N)
- var stats Stats
- b.StartTimer()
- if err := Save(context.Background(), ioutil.Discard, bs, &stats); err != nil {
- b.Errorf("save failed: %v", err)
- }
- b.StopTimer()
- if b.N > 1000 {
- b.Logf("breakdown (n=%d): %s", b.N, &stats)
- }
-}
-
-func BenchmarkDecoding(b *testing.B) {
- b.StopTimer()
- bs := buildObject(b.N)
- var newBS benchStruct
- buf := &bytes.Buffer{}
- if err := Save(context.Background(), buf, bs, nil); err != nil {
- b.Errorf("save failed: %v", err)
- }
- var stats Stats
- b.StartTimer()
- if err := Load(context.Background(), buf, &newBS, &stats); err != nil {
- b.Errorf("load failed: %v", err)
- }
- b.StopTimer()
- if b.N > 1000 {
- b.Logf("breakdown (n=%d): %s", b.N, &stats)
- }
-}
-
-func init() {
- Register("stateTest.smartStruct", (*smartStruct)(nil), Fns{
- Save: (*smartStruct).save,
- Load: (*smartStruct).load,
- })
- Register("stateTest.afterLoadStruct", (*afterLoadStruct)(nil), Fns{
- Save: (*afterLoadStruct).save,
- Load: (*afterLoadStruct).load,
- })
- Register("stateTest.valueLoadStruct", (*valueLoadStruct)(nil), Fns{
- Save: (*valueLoadStruct).save,
- Load: (*valueLoadStruct).load,
- })
- Register("stateTest.genericContainer", (*genericContainer)(nil), Fns{
- Save: (*genericContainer).save,
- Load: (*genericContainer).load,
- })
- Register("stateTest.sliceContainer", (*sliceContainer)(nil), Fns{
- Save: (*sliceContainer).save,
- Load: (*sliceContainer).load,
- })
- Register("stateTest.mapContainer", (*mapContainer)(nil), Fns{
- Save: (*mapContainer).save,
- Load: (*mapContainer).load,
- })
- Register("stateTest.pointerStruct", (*pointerStruct)(nil), Fns{
- Save: (*pointerStruct).save,
- Load: (*pointerStruct).load,
- })
- Register("stateTest.testImpl", (*testImpl)(nil), Fns{
- Save: (*testImpl).save,
- Load: (*testImpl).load,
- })
- Register("stateTest.testI", (*testI)(nil), Fns{
- Save: (*testI).save,
- Load: (*testI).load,
- })
- Register("stateTest.cycleStruct", (*cycleStruct)(nil), Fns{
- Save: (*cycleStruct).save,
- Load: (*cycleStruct).load,
- })
- Register("stateTest.badCycleStruct", (*badCycleStruct)(nil), Fns{
- Save: (*badCycleStruct).save,
- Load: (*badCycleStruct).load,
- })
- Register("stateTest.emptyStructPointer", (*emptyStructPointer)(nil), Fns{
- Save: (*emptyStructPointer).save,
- Load: (*emptyStructPointer).load,
- })
- Register("stateTest.truncateInteger", (*truncateInteger)(nil), Fns{
- Save: (*truncateInteger).save,
- Load: (*truncateInteger).load,
- })
- Register("stateTest.truncateUnsignedInteger", (*truncateUnsignedInteger)(nil), Fns{
- Save: (*truncateUnsignedInteger).save,
- Load: (*truncateUnsignedInteger).load,
- })
- Register("stateTest.truncateFloat", (*truncateFloat)(nil), Fns{
- Save: (*truncateFloat).save,
- Load: (*truncateFloat).load,
- })
- Register("stateTest.benchStruct", (*benchStruct)(nil), Fns{
- Save: (*benchStruct).save,
- Load: (*benchStruct).load,
- })
-}
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD
index e7581c09b..d6c89c7e9 100644
--- a/pkg/state/statefile/BUILD
+++ b/pkg/state/statefile/BUILD
@@ -9,6 +9,7 @@ go_library(
deps = [
"//pkg/binary",
"//pkg/compressio",
+ "//pkg/state/wire",
],
)
diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go
index c0f4c4954..bdfb800fb 100644
--- a/pkg/state/statefile/statefile.go
+++ b/pkg/state/statefile/statefile.go
@@ -57,6 +57,7 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/compressio"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
// keySize is the AES-256 key length.
@@ -83,10 +84,16 @@ var ErrInvalidMetadataLength = fmt.Errorf("metadata length invalid, maximum size
// ErrMetadataInvalid is returned if passed metadata is invalid.
var ErrMetadataInvalid = fmt.Errorf("metadata invalid, can't start with _")
+// WriteCloser is an io.Closer and wire.Writer.
+type WriteCloser interface {
+ wire.Writer
+ io.Closer
+}
+
// NewWriter returns a state data writer for a statefile.
//
// Note that the returned WriteCloser must be closed.
-func NewWriter(w io.Writer, key []byte, metadata map[string]string) (io.WriteCloser, error) {
+func NewWriter(w io.Writer, key []byte, metadata map[string]string) (WriteCloser, error) {
if metadata == nil {
metadata = make(map[string]string)
}
@@ -215,7 +222,7 @@ func metadata(r io.Reader, h hash.Hash) (map[string]string, error) {
}
// NewReader returns a reader for a statefile.
-func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) {
+func NewReader(r io.Reader, key []byte) (wire.Reader, map[string]string, error) {
// Read the metadata with the hash.
h := hmac.New(sha256.New, key)
metadata, err := metadata(r, h)
@@ -224,9 +231,9 @@ func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) {
}
// Wrap in compression.
- rc, err := compressio.NewReader(r, key)
+ cr, err := compressio.NewReader(r, key)
if err != nil {
return nil, nil, err
}
- return rc, metadata, nil
+ return cr, metadata, nil
}
diff --git a/pkg/state/stats.go b/pkg/state/stats.go
index eb51cda47..eaec664a1 100644
--- a/pkg/state/stats.go
+++ b/pkg/state/stats.go
@@ -17,7 +17,6 @@ package state
import (
"bytes"
"fmt"
- "reflect"
"sort"
"time"
)
@@ -35,92 +34,81 @@ type statEntry struct {
// All exported receivers accept nil.
type Stats struct {
// byType contains a breakdown of time spent by type.
- byType map[reflect.Type]*statEntry
+ //
+ // This is indexed *directly* by typeID, including zero.
+ byType []statEntry
// stack contains objects in progress.
- stack []reflect.Type
+ stack []typeID
+
+ // names contains type names.
+ //
+ // This is also indexed *directly* by typeID, including zero, which we
+ // hard-code as "state.default". This is only resolved by calling fini
+ // on the stats object.
+ names []string
// last is the last start time.
last time.Time
}
-// sample adds the samples to the given object.
-func (s *Stats) sample(typ reflect.Type) {
- now := time.Now()
- s.byType[typ].total += now.Sub(s.last)
- s.last = now
+// init initializes statistics.
+func (s *Stats) init() {
+ s.last = time.Now()
+ s.stack = append(s.stack, 0)
}
-// Add adds a sample count.
-func (s *Stats) Add(obj reflect.Value) {
- if s == nil {
- return
- }
- if s.byType == nil {
- s.byType = make(map[reflect.Type]*statEntry)
- }
- typ := obj.Type()
- entry, ok := s.byType[typ]
- if !ok {
- entry = new(statEntry)
- s.byType[typ] = entry
+// fini finalizes statistics.
+func (s *Stats) fini(resolve func(id typeID) string) {
+ s.done()
+
+ // Resolve all type names.
+ s.names = make([]string, len(s.byType))
+ s.names[0] = "state.default" // See above.
+ for id := typeID(1); int(id) < len(s.names); id++ {
+ s.names[id] = resolve(id)
}
- entry.count++
}
-// Remove removes a sample count. It should only be called after a previous
-// Add().
-func (s *Stats) Remove(obj reflect.Value) {
- if s == nil {
- return
+// sample adds the samples to the given object.
+func (s *Stats) sample(id typeID) {
+ now := time.Now()
+ if len(s.byType) <= int(id) {
+ // Allocate all the missing entries in one fell swoop.
+ s.byType = append(s.byType, make([]statEntry, 1+int(id)-len(s.byType))...)
}
- typ := obj.Type()
- entry := s.byType[typ]
- entry.count--
+ s.byType[id].total += now.Sub(s.last)
+ s.last = now
}
-// Start starts a sample.
-func (s *Stats) Start(obj reflect.Value) {
- if s == nil {
- return
- }
- if len(s.stack) > 0 {
- last := s.stack[len(s.stack)-1]
- s.sample(last)
- } else {
- // First time sample.
- s.last = time.Now()
- }
- s.stack = append(s.stack, obj.Type())
+// start starts a sample.
+func (s *Stats) start(id typeID) {
+ last := s.stack[len(s.stack)-1]
+ s.sample(last)
+ s.stack = append(s.stack, id)
}
-// Done finishes the current sample.
-func (s *Stats) Done() {
- if s == nil {
- return
- }
+// done finishes the current sample.
+func (s *Stats) done() {
last := s.stack[len(s.stack)-1]
s.sample(last)
+ s.byType[last].count++
s.stack = s.stack[:len(s.stack)-1]
}
type sliceEntry struct {
- typ reflect.Type
+ name string
entry *statEntry
}
// String returns a table representation of the stats.
func (s *Stats) String() string {
- if s == nil || len(s.byType) == 0 {
- return "(no data)"
- }
-
// Build a list of stat entries.
ss := make([]sliceEntry, 0, len(s.byType))
- for typ, entry := range s.byType {
+ for id := 0; id < len(s.names); id++ {
ss = append(ss, sliceEntry{
- typ: typ,
- entry: entry,
+ name: s.names[id],
+ entry: &s.byType[id],
})
}
@@ -136,17 +124,22 @@ func (s *Stats) String() string {
total time.Duration
)
buf.WriteString("\n")
- buf.WriteString(fmt.Sprintf("%12s | %8s | %8s | %s\n", "total", "count", "per", "type"))
- buf.WriteString("-------------+----------+----------+-------------\n")
+ buf.WriteString(fmt.Sprintf("% 16s | % 8s | % 16s | %s\n", "total", "count", "per", "type"))
+ buf.WriteString("-----------------+----------+------------------+----------------\n")
for _, se := range ss {
+ if se.entry.count == 0 {
+ // Since we store all types linearly, we are not
+ // guaranteed that any entry actually has time.
+ continue
+ }
count += se.entry.count
total += se.entry.total
per := se.entry.total / time.Duration(se.entry.count)
- buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | %s\n",
- se.entry.total, se.entry.count, per, se.typ.String()))
+ buf.WriteString(fmt.Sprintf("% 16s | %8d | % 16s | %s\n",
+ se.entry.total, se.entry.count, per, se.name))
}
- buf.WriteString("-------------+----------+----------+-------------\n")
- buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | [all]",
+ buf.WriteString("-----------------+----------+------------------+----------------\n")
+ buf.WriteString(fmt.Sprintf("% 16s | % 8d | % 16s | [all]",
total, count, total/time.Duration(count)))
return string(buf.Bytes())
}
diff --git a/pkg/state/tests/BUILD b/pkg/state/tests/BUILD
new file mode 100644
index 000000000..9297cafbe
--- /dev/null
+++ b/pkg/state/tests/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tests",
+ srcs = [
+ "array.go",
+ "bench.go",
+ "integer.go",
+ "load.go",
+ "map.go",
+ "register.go",
+ "struct.go",
+ "tests.go",
+ ],
+ deps = [
+ "//pkg/state",
+ "//pkg/state/pretty",
+ ],
+)
+
+go_test(
+ name = "tests_test",
+ size = "small",
+ srcs = [
+ "array_test.go",
+ "bench_test.go",
+ "bool_test.go",
+ "float_test.go",
+ "integer_test.go",
+ "load_test.go",
+ "map_test.go",
+ "register_test.go",
+ "string_test.go",
+ "struct_test.go",
+ ],
+ library = ":tests",
+ deps = [
+ "//pkg/state",
+ "//pkg/state/wire",
+ ],
+)
diff --git a/pkg/state/tests/array.go b/pkg/state/tests/array.go
new file mode 100644
index 000000000..0972a80e7
--- /dev/null
+++ b/pkg/state/tests/array.go
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type arrayContainer struct {
+ v [1]interface{}
+}
+
+// +stateify savable
+type arrayPtrContainer struct {
+ v *[1]interface{}
+}
+
+// +stateify savable
+type sliceContainer struct {
+ v []interface{}
+}
+
+// +stateify savable
+type slicePtrContainer struct {
+ v *[]interface{}
+}
diff --git a/pkg/state/tests/array_test.go b/pkg/state/tests/array_test.go
new file mode 100644
index 000000000..a347b2947
--- /dev/null
+++ b/pkg/state/tests/array_test.go
@@ -0,0 +1,134 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "reflect"
+ "testing"
+)
+
+var allArrayPrimitives = []interface{}{
+ [1]bool{},
+ [1]bool{true},
+ [2]bool{false, true},
+ [1]int{},
+ [1]int{1},
+ [2]int{0, 1},
+ [1]int8{},
+ [1]int8{1},
+ [2]int8{0, 1},
+ [1]int16{},
+ [1]int16{1},
+ [2]int16{0, 1},
+ [1]int32{},
+ [1]int32{1},
+ [2]int32{0, 1},
+ [1]int64{},
+ [1]int64{1},
+ [2]int64{0, 1},
+ [1]uint{},
+ [1]uint{1},
+ [2]uint{0, 1},
+ [1]uintptr{},
+ [1]uintptr{1},
+ [2]uintptr{0, 1},
+ [1]uint8{},
+ [1]uint8{1},
+ [2]uint8{0, 1},
+ [1]uint16{},
+ [1]uint16{1},
+ [2]uint16{0, 1},
+ [1]uint32{},
+ [1]uint32{1},
+ [2]uint32{0, 1},
+ [1]uint64{},
+ [1]uint64{1},
+ [2]uint64{0, 1},
+ [1]string{},
+ [1]string{""},
+ [1]string{nonEmptyString},
+ [2]string{"", nonEmptyString},
+}
+
+func TestArrayPrimitives(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allArrayPrimitives))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allArrayPrimitives)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allArrayPrimitives)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allArrayPrimitives))))
+}
+
+func TestSlices(t *testing.T) {
+ var allSlices = flatten(
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ return v.Slice(0, v.Len()).Interface(), true
+ }),
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ if v.Len() == 0 {
+ // Return the pure "nil" value for the slice.
+ return reflect.New(v.Slice(0, 0).Type()).Elem().Interface(), true
+ }
+ return v.Slice(1, v.Len()).Interface(), true
+ }),
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ if v.Len() == 0 {
+ // Return the zero-valued slice.
+ return reflect.MakeSlice(v.Slice(0, 0).Type(), 0, 0).Interface(), true
+ }
+ return v.Slice(0, v.Len()-1).Interface(), true
+ }),
+ )
+ runTestCases(t, false, "plain", allSlices)
+ runTestCases(t, false, "pointers", pointersTo(allSlices))
+ runTestCases(t, false, "interfaces", interfacesTo(allSlices))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(allSlices)))
+}
+
+func TestArrayContainers(t *testing.T) {
+ var (
+ emptyArray [1]interface{}
+ fullArray [1]interface{}
+ )
+ fullArray[0] = &emptyArray
+ runTestCases(t, false, "", []interface{}{
+ arrayContainer{v: emptyArray},
+ arrayContainer{v: fullArray},
+ arrayPtrContainer{v: nil},
+ arrayPtrContainer{v: &emptyArray},
+ arrayPtrContainer{v: &fullArray},
+ })
+}
+
+func TestSliceContainers(t *testing.T) {
+ var (
+ nilSlice []interface{}
+ emptySlice = make([]interface{}, 0)
+ fullSlice = []interface{}{nil}
+ )
+ runTestCases(t, false, "", []interface{}{
+ sliceContainer{v: nilSlice},
+ sliceContainer{v: emptySlice},
+ sliceContainer{v: fullSlice},
+ slicePtrContainer{v: nil},
+ slicePtrContainer{v: &nilSlice},
+ slicePtrContainer{v: &emptySlice},
+ slicePtrContainer{v: &fullSlice},
+ })
+}
diff --git a/pkg/state/tests/bench.go b/pkg/state/tests/bench.go
new file mode 100644
index 000000000..40869cdfb
--- /dev/null
+++ b/pkg/state/tests/bench.go
@@ -0,0 +1,24 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type benchStruct struct {
+ B *benchStruct // Must be exported for gob.
+}
+
+func (b *benchStruct) afterLoad() {
+ // Do nothing, just force scheduling.
+}
diff --git a/pkg/state/tests/bench_test.go b/pkg/state/tests/bench_test.go
new file mode 100644
index 000000000..7e102c907
--- /dev/null
+++ b/pkg/state/tests/bench_test.go
@@ -0,0 +1,153 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "bytes"
+ "context"
+ "encoding/gob"
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// buildPtrObject builds a benchmark object.
+func buildPtrObject(n int) interface{} {
+ b := new(benchStruct)
+ for i := 0; i < n; i++ {
+ b = &benchStruct{B: b}
+ }
+ return b
+}
+
+// buildMapObject builds a benchmark object.
+func buildMapObject(n int) interface{} {
+ b := new(benchStruct)
+ m := make(map[int]*benchStruct)
+ for i := 0; i < n; i++ {
+ m[i] = b
+ }
+ return &m
+}
+
+// buildSliceObject builds a benchmark object.
+func buildSliceObject(n int) interface{} {
+ b := new(benchStruct)
+ s := make([]*benchStruct, 0, n)
+ for i := 0; i < n; i++ {
+ s = append(s, b)
+ }
+ return &s
+}
+
+var allObjects = map[string]struct {
+ New func(int) interface{}
+}{
+ "ptr": {
+ New: buildPtrObject,
+ },
+ "map": {
+ New: buildMapObject,
+ },
+ "slice": {
+ New: buildSliceObject,
+ },
+}
+
+func buildObjects(n int, fn func(int) interface{}) (iters int, v interface{}) {
+ // maxSize is the maximum size of an individual object below. For an N
+ // larger than this, we start to return multiple objects.
+ const maxSize = 1024
+ if n <= maxSize {
+ return 1, fn(n)
+ }
+ iters = (n + maxSize - 1) / maxSize
+ return iters, fn(maxSize)
+}
+
+// gobSave is a version of save using gob (no stats available).
+func gobSave(_ context.Context, w wire.Writer, v interface{}) (_ state.Stats, err error) {
+ enc := gob.NewEncoder(w)
+ err = enc.Encode(v)
+ return
+}
+
+// gobLoad is a version of load using gob (no stats available).
+func gobLoad(_ context.Context, r wire.Reader, v interface{}) (_ state.Stats, err error) {
+ dec := gob.NewDecoder(r)
+ err = dec.Decode(v)
+ return
+}
+
+var allAlgos = map[string]struct {
+ Save func(context.Context, wire.Writer, interface{}) (state.Stats, error)
+ Load func(context.Context, wire.Reader, interface{}) (state.Stats, error)
+ MaxPtr int
+}{
+ "state": {
+ Save: state.Save,
+ Load: state.Load,
+ },
+ "gob": {
+ Save: gobSave,
+ Load: gobLoad,
+ },
+}
+
+func BenchmarkEncoding(b *testing.B) {
+ for objName, objInfo := range allObjects {
+ for algoName, algoInfo := range allAlgos {
+ b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) {
+ b.StopTimer()
+ n, v := buildObjects(b.N, objInfo.New)
+ b.ReportAllocs()
+ b.StartTimer()
+ for i := 0; i < n; i++ {
+ if _, err := algoInfo.Save(context.Background(), discard{}, v); err != nil {
+ b.Errorf("save failed: %v", err)
+ }
+ }
+ b.StopTimer()
+ })
+ }
+ }
+}
+
+func BenchmarkDecoding(b *testing.B) {
+ for objName, objInfo := range allObjects {
+ for algoName, algoInfo := range allAlgos {
+ b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) {
+ b.StopTimer()
+ n, v := buildObjects(b.N, objInfo.New)
+ buf := new(bytes.Buffer)
+ if _, err := algoInfo.Save(context.Background(), buf, v); err != nil {
+ b.Errorf("save failed: %v", err)
+ }
+ b.ReportAllocs()
+ b.StartTimer()
+ var r bytes.Reader
+ for i := 0; i < n; i++ {
+ r.Reset(buf.Bytes())
+ if _, err := algoInfo.Load(context.Background(), &r, v); err != nil {
+ b.Errorf("load failed: %v", err)
+ }
+ }
+ b.StopTimer()
+ })
+ }
+ }
+}
diff --git a/pkg/state/tests/bool_test.go b/pkg/state/tests/bool_test.go
new file mode 100644
index 000000000..e17cfacf9
--- /dev/null
+++ b/pkg/state/tests/bool_test.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+)
+
+var allBools = []bool{
+ true,
+ false,
+}
+
+func TestBool(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allBools))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allBools)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allBools)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allBools))))
+}
diff --git a/pkg/state/tests/float_test.go b/pkg/state/tests/float_test.go
new file mode 100644
index 000000000..3e89edd9c
--- /dev/null
+++ b/pkg/state/tests/float_test.go
@@ -0,0 +1,118 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "math"
+ "testing"
+)
+
+var safeFloat32s = []float32{
+ float32(0.0),
+ float32(1.0),
+ float32(-1.0),
+ float32(math.Inf(1)),
+ float32(math.Inf(-1)),
+}
+
+var allFloat32s = append(safeFloat32s, float32(math.NaN()))
+
+var safeFloat64s = []float64{
+ float64(0.0),
+ float64(1.0),
+ float64(-1.0),
+ math.Inf(1),
+ math.Inf(-1),
+}
+
+var allFloat64s = append(safeFloat64s, math.NaN())
+
+func TestFloat(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(
+ allFloat32s,
+ allFloat64s,
+ ))
+ // See checkEqual for why NaNs are missing.
+ runTestCases(t, false, "pointers", pointersTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ )))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ )))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ ))))
+}
+
+const onlyDouble float64 = 1.0000000000000002
+
+func TestFloatTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingFloat32{save: onlyDouble},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingFloat32{save: 1.0},
+ })
+}
+
+var safeComplex64s = combine(safeFloat32s, safeFloat32s, func(i, j interface{}) interface{} {
+ return complex(i.(float32), j.(float32))
+})
+
+var allComplex64s = combine(allFloat32s, allFloat32s, func(i, j interface{}) interface{} {
+ return complex(i.(float32), j.(float32))
+})
+
+var safeComplex128s = combine(safeFloat64s, safeFloat64s, func(i, j interface{}) interface{} {
+ return complex(i.(float64), j.(float64))
+})
+
+var allComplex128s = combine(allFloat64s, allFloat64s, func(i, j interface{}) interface{} {
+ return complex(i.(float64), j.(float64))
+})
+
+func TestComplex(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(
+ allComplex64s,
+ allComplex128s,
+ ))
+ // See TestFloat; same issue.
+ runTestCases(t, false, "pointers", pointersTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ )))
+ runTestCases(t, false, "interfacse", interfacesTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ )))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ ))))
+}
+
+func TestComplexTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingComplex64{save: complex(onlyDouble, onlyDouble)},
+ truncatingComplex64{save: complex(1.0, onlyDouble)},
+ truncatingComplex64{save: complex(onlyDouble, 1.0)},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingComplex64{save: complex(1.0, 1.0)},
+ })
+}
diff --git a/pkg/state/tests/integer.go b/pkg/state/tests/integer.go
new file mode 100644
index 000000000..ca403eed1
--- /dev/null
+++ b/pkg/state/tests/integer.go
@@ -0,0 +1,163 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+// +stateify type
+type truncatingUint8 struct {
+ save uint64
+ load uint8 `state:"nosave"`
+}
+
+func (t *truncatingUint8) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint8) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint8)(nil)
+
+// +stateify type
+type truncatingUint16 struct {
+ save uint64
+ load uint16 `state:"nosave"`
+}
+
+func (t *truncatingUint16) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint16) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint16)(nil)
+
+// +stateify type
+type truncatingUint32 struct {
+ save uint64
+ load uint32 `state:"nosave"`
+}
+
+func (t *truncatingUint32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint32)(nil)
+
+// +stateify type
+type truncatingInt8 struct {
+ save int64
+ load int8 `state:"nosave"`
+}
+
+func (t *truncatingInt8) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt8) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt8)(nil)
+
+// +stateify type
+type truncatingInt16 struct {
+ save int64
+ load int16 `state:"nosave"`
+}
+
+func (t *truncatingInt16) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt16) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt16)(nil)
+
+// +stateify type
+type truncatingInt32 struct {
+ save int64
+ load int32 `state:"nosave"`
+}
+
+func (t *truncatingInt32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt32)(nil)
+
+// +stateify type
+type truncatingFloat32 struct {
+ save float64
+ load float32 `state:"nosave"`
+}
+
+func (t *truncatingFloat32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingFloat32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = float64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingFloat32)(nil)
+
+// +stateify type
+type truncatingComplex64 struct {
+ save complex128
+ load complex64 `state:"nosave"`
+}
+
+func (t *truncatingComplex64) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingComplex64) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = complex128(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingComplex64)(nil)
diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go
new file mode 100644
index 000000000..d3931c952
--- /dev/null
+++ b/pkg/state/tests/integer_test.go
@@ -0,0 +1,94 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "math"
+ "testing"
+)
+
+var (
+ allIntTs = []int{-1, 0, 1}
+ allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8}
+ allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16}
+ allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32}
+ allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64}
+ allUintTs = []uint{0, 1}
+ allUintptrs = []uintptr{0, 1, ^uintptr(0)}
+ allUint8s = []uint8{0, 1, math.MaxUint8}
+ allUint16s = []uint16{0, 1, math.MaxUint16}
+ allUint32s = []uint32{0, 1, math.MaxUint32}
+ allUint64s = []uint64{0, 1, math.MaxUint64}
+)
+
+var allInts = flatten(
+ allIntTs,
+ allInt8s,
+ allInt16s,
+ allInt32s,
+ allInt64s,
+)
+
+var allUints = flatten(
+ allUintTs,
+ allUintptrs,
+ allUint8s,
+ allUint16s,
+ allUint32s,
+ allUint64s,
+)
+
+func TestInt(t *testing.T) {
+ runTestCases(t, false, "plain", allInts)
+ runTestCases(t, false, "pointers", pointersTo(allInts))
+ runTestCases(t, false, "interfaces", interfacesTo(allInts))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allInts)))
+}
+
+func TestIntTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingInt8{save: math.MinInt8 - 1},
+ truncatingInt16{save: math.MinInt16 - 1},
+ truncatingInt32{save: math.MinInt32 - 1},
+ truncatingInt8{save: math.MaxInt8 + 1},
+ truncatingInt16{save: math.MaxInt16 + 1},
+ truncatingInt32{save: math.MaxInt32 + 1},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingInt8{save: 1},
+ truncatingInt16{save: 1},
+ truncatingInt32{save: 1},
+ })
+}
+
+func TestUint(t *testing.T) {
+ runTestCases(t, false, "plain", allUints)
+ runTestCases(t, false, "pointers", pointersTo(allUints))
+ runTestCases(t, false, "interfaces", interfacesTo(allUints))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allUints)))
+}
+
+func TestUintTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingUint8{save: math.MaxUint8 + 1},
+ truncatingUint16{save: math.MaxUint16 + 1},
+ truncatingUint32{save: math.MaxUint32 + 1},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingUint8{save: 1},
+ truncatingUint16{save: 1},
+ truncatingUint32{save: 1},
+ })
+}
diff --git a/pkg/state/tests/load.go b/pkg/state/tests/load.go
new file mode 100644
index 000000000..a8350c0f3
--- /dev/null
+++ b/pkg/state/tests/load.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type genericContainer struct {
+ v interface{}
+}
+
+// +stateify savable
+type afterLoadStruct struct {
+ v int `state:"nosave"`
+}
+
+func (a *afterLoadStruct) afterLoad() {
+ a.v++
+}
+
+// +stateify savable
+type valueLoadStruct struct {
+ v int `state:".(int64)"`
+}
+
+func (v *valueLoadStruct) saveV() int64 {
+ return int64(v.v) // Save as int64.
+}
+
+func (v *valueLoadStruct) loadV(value int64) {
+ v.v = int(value) // Load as int.
+}
+
+// +stateify savable
+type cycleStruct struct {
+ c *cycleStruct
+}
+
+// +stateify savable
+type badCycleStruct struct {
+ b *badCycleStruct `state:"wait"`
+}
+
+func (b *badCycleStruct) afterLoad() {
+ if b.b != b {
+ // This is not executable, since AfterLoad requires that the
+ // object and all dependencies are complete. This should cause
+ // a deadlock error during load.
+ panic("badCycleStruct.afterLoad called")
+ }
+}
diff --git a/pkg/state/tests/load_test.go b/pkg/state/tests/load_test.go
new file mode 100644
index 000000000..1e9794296
--- /dev/null
+++ b/pkg/state/tests/load_test.go
@@ -0,0 +1,70 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+)
+
+func TestLoadHooks(t *testing.T) {
+ runTestCases(t, false, "load-hooks", []interface{}{
+ &afterLoadStruct{v: 1},
+ &valueLoadStruct{v: 1},
+ &genericContainer{v: &afterLoadStruct{v: 1}},
+ &genericContainer{v: &valueLoadStruct{v: 1}},
+ &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}},
+ &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}},
+ &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}},
+ &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}},
+ })
+}
+
+func TestCycles(t *testing.T) {
+ // cs is a single object cycle.
+ cs := cycleStruct{nil}
+ cs.c = &cs
+
+ // cs1 and cs2 are in a two object cycle.
+ cs1 := cycleStruct{nil}
+ cs2 := cycleStruct{nil}
+ cs1.c = &cs2
+ cs2.c = &cs1
+
+ runTestCases(t, false, "cycles", []interface{}{
+ cs,
+ cs1,
+ })
+}
+
+func TestDeadlock(t *testing.T) {
+ // bs is a single object cycle. This does not cause deadlock because an
+ // object cannot wait for itself.
+ bs := badCycleStruct{nil}
+ bs.b = &bs
+
+ runTestCases(t, false, "self", []interface{}{
+ &bs,
+ })
+
+ // bs2 and bs2 are in a deadlocking cycle.
+ bs1 := badCycleStruct{nil}
+ bs2 := badCycleStruct{nil}
+ bs1.b = &bs2
+ bs2.b = &bs1
+
+ runTestCases(t, true, "deadlock", []interface{}{
+ &bs1,
+ })
+}
diff --git a/pkg/state/tests/map.go b/pkg/state/tests/map.go
new file mode 100644
index 000000000..db4e548f1
--- /dev/null
+++ b/pkg/state/tests/map.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type mapContainer struct {
+ v map[int]interface{}
+}
+
+// +stateify savable
+type mapPtrContainer struct {
+ v *map[int]interface{}
+}
+
+// +stateify savable
+type registeredMapStruct struct{}
diff --git a/pkg/state/tests/map_test.go b/pkg/state/tests/map_test.go
new file mode 100644
index 000000000..92bf0fc01
--- /dev/null
+++ b/pkg/state/tests/map_test.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "reflect"
+ "testing"
+)
+
+var allMapPrimitives = []interface{}{
+ bool(true),
+ int(1),
+ int8(1),
+ int16(1),
+ int32(1),
+ int64(1),
+ uint(1),
+ uintptr(1),
+ uint8(1),
+ uint16(1),
+ uint32(1),
+ uint64(1),
+ string(""),
+ registeredMapStruct{},
+}
+
+var allMapKeys = flatten(allMapPrimitives, pointersTo(allMapPrimitives))
+
+var allMapValues = flatten(allMapPrimitives, pointersTo(allMapPrimitives), interfacesTo(allMapPrimitives))
+
+var emptyMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} {
+ m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2)))
+ return m.Interface()
+})
+
+var fullMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} {
+ m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2)))
+ m.SetMapIndex(reflect.Zero(reflect.TypeOf(v1)), reflect.Zero(reflect.TypeOf(v2)))
+ return m.Interface()
+})
+
+func TestMapAliasing(t *testing.T) {
+ v := make(map[int]int)
+ ptrToV := &v
+ aliases := []map[int]int{v, v}
+ runTestCases(t, false, "", []interface{}{ptrToV, aliases})
+}
+
+func TestMapsEmpty(t *testing.T) {
+ runTestCases(t, false, "plain", emptyMaps)
+ runTestCases(t, false, "pointers", pointersTo(emptyMaps))
+ runTestCases(t, false, "interfaces", interfacesTo(emptyMaps))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(emptyMaps)))
+}
+
+func TestMapsFull(t *testing.T) {
+ runTestCases(t, false, "plain", fullMaps)
+ runTestCases(t, false, "pointers", pointersTo(fullMaps))
+ runTestCases(t, false, "interfaces", interfacesTo(fullMaps))
+ runTestCases(t, false, "interfacesToPointer", interfacesTo(pointersTo(fullMaps)))
+}
+
+func TestMapContainers(t *testing.T) {
+ var (
+ nilMap map[int]interface{}
+ emptyMap = make(map[int]interface{})
+ fullMap = map[int]interface{}{0: nil}
+ )
+ runTestCases(t, false, "", []interface{}{
+ mapContainer{v: nilMap},
+ mapContainer{v: emptyMap},
+ mapContainer{v: fullMap},
+ mapPtrContainer{v: nil},
+ mapPtrContainer{v: &nilMap},
+ mapPtrContainer{v: &emptyMap},
+ mapPtrContainer{v: &fullMap},
+ })
+}
diff --git a/pkg/state/tests/register.go b/pkg/state/tests/register.go
new file mode 100644
index 000000000..074d86315
--- /dev/null
+++ b/pkg/state/tests/register.go
@@ -0,0 +1,21 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type alreadyRegisteredStruct struct{}
+
+// +stateify savable
+type alreadyRegisteredOther int
diff --git a/pkg/state/tests/register_test.go b/pkg/state/tests/register_test.go
new file mode 100644
index 000000000..c829753cc
--- /dev/null
+++ b/pkg/state/tests/register_test.go
@@ -0,0 +1,167 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+// faker calls itself whatever is in the name field.
+type faker struct {
+ Name string
+ Fields []string
+}
+
+func (f *faker) StateTypeName() string {
+ return f.Name
+}
+
+func (f *faker) StateFields() []string {
+ return f.Fields
+}
+
+// fakerWithSaverLoader has all it needs.
+type fakerWithSaverLoader struct {
+ faker
+}
+
+func (f *fakerWithSaverLoader) StateSave(m state.Sink) {}
+
+func (f *fakerWithSaverLoader) StateLoad(m state.Source) {}
+
+// fakerOther calls itself .. uh, itself?
+type fakerOther string
+
+func (f *fakerOther) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOther) StateFields() []string {
+ return nil
+}
+
+func newFakerOther(name string) *fakerOther {
+ f := fakerOther(name)
+ return &f
+}
+
+// fakerOtherBadFields returns non-nil fields.
+type fakerOtherBadFields string
+
+func (f *fakerOtherBadFields) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOtherBadFields) StateFields() []string {
+ return []string{string(*f)}
+}
+
+func newFakerOtherBadFields(name string) *fakerOtherBadFields {
+ f := fakerOtherBadFields(name)
+ return &f
+}
+
+// fakerOtherSaverLoader implements SaverLoader methods.
+type fakerOtherSaverLoader string
+
+func (f *fakerOtherSaverLoader) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOtherSaverLoader) StateFields() []string {
+ return nil
+}
+
+func (f *fakerOtherSaverLoader) StateSave(m state.Sink) {}
+
+func (f *fakerOtherSaverLoader) StateLoad(m state.Source) {}
+
+func newFakerOtherSaverLoader(name string) *fakerOtherSaverLoader {
+ f := fakerOtherSaverLoader(name)
+ return &f
+}
+
+func TestRegisterPrimitives(t *testing.T) {
+ for _, typeName := range []string{
+ "int",
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+ "uint",
+ "uintptr",
+ "uint8",
+ "uint16",
+ "uint32",
+ "uint64",
+ "float32",
+ "float64",
+ "complex64",
+ "complex128",
+ "string",
+ } {
+ t.Run("struct/"+typeName, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering type %q did not panic", typeName)
+ }
+ }()
+ state.Register(&faker{
+ Name: typeName,
+ })
+ })
+ t.Run("other/"+typeName, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering type %q did not panic", typeName)
+ }
+ }()
+ state.Register(newFakerOther(typeName))
+ })
+ }
+}
+
+func TestRegisterBad(t *testing.T) {
+ const (
+ goodName = "foo"
+ firstField = "a"
+ secondField = "b"
+ )
+ for name, object := range map[string]state.Type{
+ "non-struct-with-fields": newFakerOtherBadFields(goodName),
+ "non-struct-with-saverloader": newFakerOtherSaverLoader(goodName),
+ "struct-without-saverloader": &faker{Name: goodName},
+ "non-struct-duplicate-with-struct": newFakerOther((new(alreadyRegisteredStruct)).StateTypeName()),
+ "non-struct-duplicate-with-non-struct": newFakerOther((new(alreadyRegisteredOther)).StateTypeName()),
+ "struct-duplicate-with-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredStruct)).StateTypeName()}},
+ "struct-duplicate-with-non-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredOther)).StateTypeName()}},
+ "struct-with-empty-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{""}}},
+ "struct-with-empty-field-and-non-empty": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, ""}}},
+ "struct-with-duplicate-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, firstField}}},
+ "struct-with-duplicate-field-and-non-dup": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, secondField, firstField}}},
+ } {
+ t.Run(name, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering object %#v did not panic", object)
+ }
+ }()
+ state.Register(object)
+ })
+
+ }
+}
diff --git a/pkg/state/tests/string_test.go b/pkg/state/tests/string_test.go
new file mode 100644
index 000000000..44f5a562c
--- /dev/null
+++ b/pkg/state/tests/string_test.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+)
+
+const nonEmptyString = "hello world"
+
+var allStrings = []string{
+ "",
+ nonEmptyString,
+ "\\0",
+}
+
+func TestString(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allStrings))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allStrings)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allStrings)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allStrings))))
+}
diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go
new file mode 100644
index 000000000..bd2c2b399
--- /dev/null
+++ b/pkg/state/tests/struct.go
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+type unregisteredEmptyStruct struct{}
+
+// typeOnlyEmptyStruct just implements the state.Type interface.
+type typeOnlyEmptyStruct struct{}
+
+func (*typeOnlyEmptyStruct) StateTypeName() string { return "registeredEmptyStruct" }
+
+func (*typeOnlyEmptyStruct) StateFields() []string { return nil }
+
+// +stateify savable
+type savableEmptyStruct struct{}
+
+// +stateify savable
+type emptyStructPointer struct {
+ nothing *struct{}
+}
+
+// +stateify savable
+type outerSame struct {
+ inner inner
+}
+
+// +stateify savable
+type outerFieldFirst struct {
+ inner inner
+ v int64
+}
+
+// +stateify savable
+type outerFieldSecond struct {
+ v int64
+ inner inner
+}
+
+// +stateify savable
+type outerArray struct {
+ inner [2]inner
+}
+
+// +stateify savable
+type inner struct {
+ v int64
+}
+
+// +stateify savable
+type system struct {
+ v1 interface{}
+ v2 interface{}
+}
diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go
new file mode 100644
index 000000000..de9d17aa7
--- /dev/null
+++ b/pkg/state/tests/struct_test.go
@@ -0,0 +1,89 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func TestEmptyStruct(t *testing.T) {
+ runTestCases(t, false, "plain", []interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ savableEmptyStruct{},
+ })
+ runTestCases(t, false, "pointers", pointersTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ savableEmptyStruct{},
+ }))
+ runTestCases(t, false, "interfaces-pass", interfacesTo([]interface{}{
+ // Only registered types can be dispatched via interfaces. All
+ // other types should fail, even if it is the empty struct.
+ savableEmptyStruct{},
+ }))
+ runTestCases(t, true, "interfaces-fail", interfacesTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ }))
+ runTestCases(t, false, "interfacesToPointers-pass", interfacesTo(pointersTo([]interface{}{
+ savableEmptyStruct{},
+ })))
+ runTestCases(t, true, "interfacesToPointers-fail", interfacesTo(pointersTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ })))
+
+ // Ensuring empty struct aliasing works.
+ es := emptyStructPointer{new(struct{})}
+ runTestCases(t, false, "empty-struct-pointers", []interface{}{
+ emptyStructPointer{},
+ es,
+ []emptyStructPointer{es, es}, // Same pointer.
+ })
+}
+
+func TestRegisterTypeOnlyStruct(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Register did not panic")
+ }
+ }()
+ state.Register((*typeOnlyEmptyStruct)(nil))
+}
+
+func TestEmbeddedPointers(t *testing.T) {
+ var (
+ ofs outerSame
+ of1 outerFieldFirst
+ of2 outerFieldSecond
+ oa outerArray
+ )
+
+ runTestCases(t, false, "embedded-pointers", []interface{}{
+ system{&ofs, &ofs.inner},
+ system{&ofs.inner, &ofs},
+ system{&of1, &of1.inner},
+ system{&of1.inner, &of1},
+ system{&of2, &of2.inner},
+ system{&of2.inner, &of2},
+ system{&oa, &oa.inner[0]},
+ system{&oa, &oa.inner[1]},
+ system{&oa.inner[0], &oa},
+ system{&oa.inner[1], &oa},
+ })
+}
diff --git a/pkg/state/tests/tests.go b/pkg/state/tests/tests.go
new file mode 100644
index 000000000..435a0e9db
--- /dev/null
+++ b/pkg/state/tests/tests.go
@@ -0,0 +1,215 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tests tests the state packages.
+package tests
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "math"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/pretty"
+)
+
+// discard is an implementation of wire.Writer.
+type discard struct{}
+
+// Write implements wire.Writer.Write.
+func (discard) Write(p []byte) (int, error) { return len(p), nil }
+
+// WriteByte implements wire.Writer.WriteByte.
+func (discard) WriteByte(byte) error { return nil }
+
+// checkEqual checks if two objects are equal.
+//
+// N.B. This only handles one level of dereferences for NaN. Otherwise we
+// would need to fork the entire implementation of reflect.DeepEqual.
+func checkEqual(root, loadedValue interface{}) bool {
+ if reflect.DeepEqual(root, loadedValue) {
+ return true
+ }
+
+ // NaN is not equal to itself. We handle the case of raw floating point
+ // primitives here, but don't handle this case nested.
+ rf32, ok1 := root.(float32)
+ lf32, ok2 := loadedValue.(float32)
+ if ok1 && ok2 && math.IsNaN(float64(rf32)) && math.IsNaN(float64(lf32)) {
+ return true
+ }
+ rf64, ok1 := root.(float64)
+ lf64, ok2 := loadedValue.(float64)
+ if ok1 && ok2 && math.IsNaN(rf64) && math.IsNaN(lf64) {
+ return true
+ }
+
+ // Same real for complex numbers.
+ rc64, ok1 := root.(complex64)
+ lc64, ok2 := root.(complex64)
+ if ok1 && ok2 {
+ return checkEqual(real(rc64), real(lc64)) && checkEqual(imag(rc64), imag(lc64))
+ }
+ rc128, ok1 := root.(complex128)
+ lc128, ok2 := root.(complex128)
+ if ok1 && ok2 {
+ return checkEqual(real(rc128), real(lc128)) && checkEqual(imag(rc128), imag(lc128))
+ }
+
+ return false
+}
+
+// runTestCases runs a test for each object in objects.
+func runTestCases(t *testing.T, shouldFail bool, prefix string, objects []interface{}) {
+ t.Helper()
+ for i, root := range objects {
+ t.Run(fmt.Sprintf("%s%d", prefix, i), func(t *testing.T) {
+ t.Logf("Original object:\n%#v", root)
+
+ // Save the passed object.
+ saveBuffer := &bytes.Buffer{}
+ saveObjectPtr := reflect.New(reflect.TypeOf(root))
+ saveObjectPtr.Elem().Set(reflect.ValueOf(root))
+ saveStats, err := state.Save(context.Background(), saveBuffer, saveObjectPtr.Interface())
+ if err != nil {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Save failed unexpectedly: %v", err)
+ }
+
+ // Dump the serialized proto to aid with debugging.
+ var ppBuf bytes.Buffer
+ t.Logf("Raw state:\n%v", saveBuffer.Bytes())
+ if err := pretty.PrintText(&ppBuf, bytes.NewReader(saveBuffer.Bytes())); err != nil {
+ // We don't count this as a test failure if we
+ // have shouldFail set, but we will count as a
+ // failure if we were not expecting to fail.
+ if !shouldFail {
+ t.Errorf("PrettyPrint(html=false) failed unexpected: %v", err)
+ }
+ }
+ if err := pretty.PrintHTML(discard{}, bytes.NewReader(saveBuffer.Bytes())); err != nil {
+ // See above.
+ if !shouldFail {
+ t.Errorf("PrettyPrint(html=true) failed unexpected: %v", err)
+ }
+ }
+ t.Logf("Encoded state:\n%s", ppBuf.String())
+ t.Logf("Save stats:\n%s", saveStats.String())
+
+ // Load a new copy of the object.
+ loadObjectPtr := reflect.New(reflect.TypeOf(root))
+ loadStats, err := state.Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface())
+ if err != nil {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Load failed unexpectedly: %v", err)
+ }
+
+ // Compare the values.
+ loadedValue := loadObjectPtr.Elem().Interface()
+ if !checkEqual(root, loadedValue) {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Objects differ:\n\toriginal: %#v\n\tloaded: %#v\n", root, loadedValue)
+ }
+
+ // Everything went okay. Is that good?
+ if shouldFail {
+ t.Fatalf("This test was expected to fail, but didn't.")
+ }
+ t.Logf("Load stats:\n%s", loadStats.String())
+
+ // Truncate half the bytes in the byte stream,
+ // and ensure that we can't restore. Then
+ // truncate only the final byte and ensure that
+ // we can't restore.
+ l := saveBuffer.Len()
+ halfReader := bytes.NewReader(saveBuffer.Bytes()[:l/2])
+ if _, err := state.Load(context.Background(), halfReader, loadObjectPtr.Interface()); err == nil {
+ t.Errorf("Load with half bytes succeeded unexpectedly.")
+ }
+ missingByteReader := bytes.NewReader(saveBuffer.Bytes()[:l-1])
+ if _, err := state.Load(context.Background(), missingByteReader, loadObjectPtr.Interface()); err == nil {
+ t.Errorf("Load with missing byte succeeded unexpectedly.")
+ }
+ })
+ }
+}
+
+// convert converts the slice to an []interface{}.
+func convert(v interface{}) (r []interface{}) {
+ s := reflect.ValueOf(v) // Must be slice.
+ for i := 0; i < s.Len(); i++ {
+ r = append(r, s.Index(i).Interface())
+ }
+ return r
+}
+
+// flatten flattens multiple slices.
+func flatten(vs ...interface{}) (r []interface{}) {
+ for _, v := range vs {
+ r = append(r, convert(v)...)
+ }
+ return r
+}
+
+// filter maps from one slice to another.
+func filter(vs interface{}, fn func(interface{}) (interface{}, bool)) (r []interface{}) {
+ s := reflect.ValueOf(vs)
+ for i := 0; i < s.Len(); i++ {
+ v, ok := fn(s.Index(i).Interface())
+ if ok {
+ r = append(r, v)
+ }
+ }
+ return r
+}
+
+// combine combines objects in two slices as specified.
+func combine(v1, v2 interface{}, fn func(_, _ interface{}) interface{}) (r []interface{}) {
+ s1 := reflect.ValueOf(v1)
+ s2 := reflect.ValueOf(v2)
+ for i := 0; i < s1.Len(); i++ {
+ for j := 0; j < s2.Len(); j++ {
+ // Combine using the given function.
+ r = append(r, fn(s1.Index(i).Interface(), s2.Index(j).Interface()))
+ }
+ }
+ return r
+}
+
+// pointersTo is a filter function that returns pointers.
+func pointersTo(vs interface{}) []interface{} {
+ return filter(vs, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o))
+ v.Elem().Set(reflect.ValueOf(o))
+ return v.Interface(), true
+ })
+}
+
+// interfacesTo is a filter function that returns interface objects.
+func interfacesTo(vs interface{}) []interface{} {
+ return filter(vs, func(o interface{}) (interface{}, bool) {
+ var v [1]interface{}
+ v[0] = o
+ return v, true
+ })
+}
diff --git a/pkg/state/types.go b/pkg/state/types.go
new file mode 100644
index 000000000..215ef80f8
--- /dev/null
+++ b/pkg/state/types.go
@@ -0,0 +1,361 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "reflect"
+ "sort"
+
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// assertValidType asserts that the type is valid.
+func assertValidType(name string, fields []string) {
+ if name == "" {
+ Failf("type has empty name")
+ }
+ fieldsCopy := make([]string, len(fields))
+ for i := 0; i < len(fields); i++ {
+ if fields[i] == "" {
+ Failf("field has empty name for type %q", name)
+ }
+ fieldsCopy[i] = fields[i]
+ }
+ sort.Slice(fieldsCopy, func(i, j int) bool {
+ return fieldsCopy[i] < fieldsCopy[j]
+ })
+ for i := range fieldsCopy {
+ if i > 0 && fieldsCopy[i-1] == fieldsCopy[i] {
+ Failf("duplicate field %q for type %s", fieldsCopy[i], name)
+ }
+ }
+}
+
+// typeEntry is an entry in the typeDatabase.
+type typeEntry struct {
+ ID typeID
+ wire.Type
+}
+
+// reconciledTypeEntry is a reconciled entry in the typeDatabase.
+type reconciledTypeEntry struct {
+ wire.Type
+ LocalType reflect.Type
+ FieldOrder []int
+}
+
+// typeEncodeDatabase is an internal TypeInfo database for encoding.
+type typeEncodeDatabase struct {
+ // byType maps by type to the typeEntry.
+ byType map[reflect.Type]*typeEntry
+
+ // lastID is the last used ID.
+ lastID typeID
+}
+
+// makeTypeEncodeDatabase makes a typeDatabase.
+func makeTypeEncodeDatabase() typeEncodeDatabase {
+ return typeEncodeDatabase{
+ byType: make(map[reflect.Type]*typeEntry),
+ }
+}
+
+// typeDecodeDatabase is an internal TypeInfo database for decoding.
+type typeDecodeDatabase struct {
+ // byID maps by ID to type.
+ byID []*reconciledTypeEntry
+
+ // pending are entries that are pending validation by Lookup. These
+ // will be reconciled with actual objects. Note that these will also be
+ // used to lookup types by name, since they may not be reconciled and
+ // there's little value to deleting from this map.
+ pending []*wire.Type
+}
+
+// makeTypeDecodeDatabase makes a typeDatabase.
+func makeTypeDecodeDatabase() typeDecodeDatabase {
+ return typeDecodeDatabase{}
+}
+
+// lookupNameFields extracts the name and fields from an object.
+func lookupNameFields(typ reflect.Type) (string, []string, bool) {
+ v := reflect.Zero(reflect.PtrTo(typ)).Interface()
+ t, ok := v.(Type)
+ if !ok {
+ // Is this a primitive?
+ if typ.Kind() == reflect.Interface {
+ return interfaceType, nil, true
+ }
+ name := typ.Name()
+ if _, ok := primitiveTypeDatabase[name]; !ok {
+ // This is not a known type, and not a primitive. The
+ // encoder may proceed for anonymous empty structs, or
+ // it may deference the type pointer and try again.
+ return "", nil, false
+ }
+ return name, nil, true
+ }
+ // Extract the name from the object.
+ name := t.StateTypeName()
+ fields := t.StateFields()
+ assertValidType(name, fields)
+ return name, fields, true
+}
+
+// Lookup looks up or registers the given object.
+//
+// The bool indicates whether this is an existing entry: false means the entry
+// did not exist, and true means the entry did exist. If this bool is false and
+// the returned typeEntry are nil, then the obj did not implement the Type
+// interface.
+func (tdb *typeEncodeDatabase) Lookup(typ reflect.Type) (*typeEntry, bool) {
+ te, ok := tdb.byType[typ]
+ if !ok {
+ // Lookup the type information.
+ name, fields, ok := lookupNameFields(typ)
+ if !ok {
+ // Empty structs may still be encoded, so let the
+ // caller decide what to do from here.
+ return nil, false
+ }
+
+ // Register the new type.
+ tdb.lastID++
+ te = &typeEntry{
+ ID: tdb.lastID,
+ Type: wire.Type{
+ Name: name,
+ Fields: fields,
+ },
+ }
+
+ // All done.
+ tdb.byType[typ] = te
+ return te, false
+ }
+ return te, true
+}
+
+// Register adds a typeID entry.
+func (tbd *typeDecodeDatabase) Register(typ *wire.Type) {
+ assertValidType(typ.Name, typ.Fields)
+ tbd.pending = append(tbd.pending, typ)
+}
+
+// LookupName looks up the type name by ID.
+func (tbd *typeDecodeDatabase) LookupName(id typeID) string {
+ if len(tbd.pending) < int(id) {
+ // This is likely an encoder error?
+ Failf("type ID %d not available", id)
+ }
+ return tbd.pending[id-1].Name
+}
+
+// LookupType looks up the type by ID.
+func (tbd *typeDecodeDatabase) LookupType(id typeID) reflect.Type {
+ name := tbd.LookupName(id)
+ typ, ok := globalTypeDatabase[name]
+ if !ok {
+ // If not available, see if it's primitive.
+ typ, ok = primitiveTypeDatabase[name]
+ if !ok && name == interfaceType {
+ // Matches the built-in interface type.
+ var i interface{}
+ return reflect.TypeOf(&i).Elem()
+ }
+ if !ok {
+ // The type is perhaps not registered?
+ Failf("type name %q is not available", name)
+ }
+ return typ // Primitive type.
+ }
+ return typ // Registered type.
+}
+
+// singleFieldOrder defines the field order for a single field.
+var singleFieldOrder = []int{0}
+
+// Lookup looks up or registers the given object.
+//
+// First, the typeID is searched to see if this has already been appropriately
+// reconciled. If no, then a reconcilation will take place that may result in a
+// field ordering. If a nil reconciledTypeEntry is returned from this method,
+// then the object does not support the Type interface.
+//
+// This method never returns nil.
+func (tbd *typeDecodeDatabase) Lookup(id typeID, typ reflect.Type) *reconciledTypeEntry {
+ if len(tbd.byID) > int(id) && tbd.byID[id-1] != nil {
+ // Already reconciled.
+ return tbd.byID[id-1]
+ }
+ // The ID has not been reconciled yet. That's fine. We need to make
+ // sure it aligns with the current provided object.
+ if len(tbd.pending) < int(id) {
+ // This id was never registered. Probably an encoder error?
+ Failf("typeDatabase does not contain id %d", id)
+ }
+ // Extract the pending info.
+ pending := tbd.pending[id-1]
+ // Grow the byID list.
+ if len(tbd.byID) < int(id) {
+ tbd.byID = append(tbd.byID, make([]*reconciledTypeEntry, int(id)-len(tbd.byID))...)
+ }
+ // Reconcile the type.
+ name, fields, ok := lookupNameFields(typ)
+ if !ok {
+ // Empty structs are decoded only when the type is nil. Since
+ // this isn't the case, we fail here.
+ Failf("unsupported type %q during decode; can't reconcile", pending.Name)
+ }
+ if name != pending.Name {
+ // Are these the same type? Print a helpful message as this may
+ // actually happen in practice if types change.
+ Failf("typeDatabase contains conflicting definitions for id %d: %s->%v (current) and %s->%v (existing)",
+ id, name, fields, pending.Name, pending.Fields)
+ }
+ rte := &reconciledTypeEntry{
+ Type: wire.Type{
+ Name: name,
+ Fields: fields,
+ },
+ LocalType: typ,
+ }
+ // If there are zero or one fields, then we skip allocating the field
+ // slice. There is special handling for decoding in this case. If the
+ // field name does not match, it will be caught in the general purpose
+ // code below.
+ if len(fields) != len(pending.Fields) {
+ Failf("type %q contains different fields: %v (decode) and %v (encode)",
+ name, fields, pending.Fields)
+ }
+ if len(fields) == 0 {
+ tbd.byID[id-1] = rte // Save.
+ return rte
+ }
+ if len(fields) == 1 && fields[0] == pending.Fields[0] {
+ tbd.byID[id-1] = rte // Save.
+ rte.FieldOrder = singleFieldOrder
+ return rte
+ }
+ // For each field in the current object's information, match it to a
+ // field in the destination object. We know from the assertion above
+ // and the insertion on insertion to pending that neither field
+ // contains any duplicates.
+ fieldOrder := make([]int, len(fields))
+ for i, name := range fields {
+ fieldOrder[i] = -1 // Sentinel.
+ // Is it an exact match?
+ if pending.Fields[i] == name {
+ fieldOrder[i] = i
+ continue
+ }
+ // Find the matching field.
+ for j, otherName := range pending.Fields {
+ if name == otherName {
+ fieldOrder[i] = j
+ break
+ }
+ }
+ if fieldOrder[i] == -1 {
+ // The type name matches but we are lacking some common fields.
+ Failf("type %q has mismatched fields: %v (decode) and %v (encode)",
+ name, fields, pending.Fields)
+ }
+ }
+ // The type has been reeconciled.
+ rte.FieldOrder = fieldOrder
+ tbd.byID[id-1] = rte
+ return rte
+}
+
+// interfaceType defines all interfaces.
+const interfaceType = "interface"
+
+// primitiveTypeDatabase is a set of fixed types.
+var primitiveTypeDatabase = func() map[string]reflect.Type {
+ r := make(map[string]reflect.Type)
+ for _, t := range []reflect.Type{
+ reflect.TypeOf(false),
+ reflect.TypeOf(int(0)),
+ reflect.TypeOf(int8(0)),
+ reflect.TypeOf(int16(0)),
+ reflect.TypeOf(int32(0)),
+ reflect.TypeOf(int64(0)),
+ reflect.TypeOf(uint(0)),
+ reflect.TypeOf(uintptr(0)),
+ reflect.TypeOf(uint8(0)),
+ reflect.TypeOf(uint16(0)),
+ reflect.TypeOf(uint32(0)),
+ reflect.TypeOf(uint64(0)),
+ reflect.TypeOf(""),
+ reflect.TypeOf(float32(0.0)),
+ reflect.TypeOf(float64(0.0)),
+ reflect.TypeOf(complex64(0.0)),
+ reflect.TypeOf(complex128(0.0)),
+ } {
+ r[t.Name()] = t
+ }
+ return r
+}()
+
+// globalTypeDatabase is used for dispatching interfaces on decode.
+var globalTypeDatabase = map[string]reflect.Type{}
+
+// Register registers a type.
+//
+// This must be called on init and only done once.
+func Register(t Type) {
+ name := t.StateTypeName()
+ fields := t.StateFields()
+ assertValidType(name, fields)
+ // Register must always be called on pointers.
+ typ := reflect.TypeOf(t)
+ if typ.Kind() != reflect.Ptr {
+ Failf("Register must be called on pointers")
+ }
+ typ = typ.Elem()
+ if typ.Kind() == reflect.Struct {
+ // All registered structs must implement SaverLoader. We allow
+ // the registration is non-struct types with just the Type
+ // interface, but we need to call StateSave/StateLoad methods
+ // on aggregate types.
+ if _, ok := t.(SaverLoader); !ok {
+ Failf("struct %T does not implement SaverLoader", t)
+ }
+ } else {
+ // Non-structs must not have any fields. We don't support
+ // calling StateSave/StateLoad methods on any non-struct types.
+ // If custom behavior is required, these types should be
+ // wrapped in a structure of some kind.
+ if len(fields) != 0 {
+ Failf("non-struct %T has non-zero fields %v", t, fields)
+ }
+ // We don't allow non-structs to implement StateSave/StateLoad
+ // methods, because they won't be called and it's confusing.
+ if _, ok := t.(SaverLoader); ok {
+ Failf("non-struct %T implements SaverLoader", t)
+ }
+ }
+ if _, ok := primitiveTypeDatabase[name]; ok {
+ Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t)
+ }
+ if _, ok := globalTypeDatabase[name]; ok {
+ Failf("conflicting globalTypeDatabase entries for %T: name conflict", t)
+ }
+ if name == interfaceType {
+ Failf("conflicting name for %T: matches interfaceType", t)
+ }
+ globalTypeDatabase[name] = typ
+}
diff --git a/pkg/state/wire/BUILD b/pkg/state/wire/BUILD
new file mode 100644
index 000000000..311b93dcb
--- /dev/null
+++ b/pkg/state/wire/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "wire",
+ srcs = ["wire.go"],
+ marshal = False,
+ stateify = False,
+ visibility = ["//:sandbox"],
+ deps = ["//pkg/gohacks"],
+)
diff --git a/pkg/state/wire/wire.go b/pkg/state/wire/wire.go
new file mode 100644
index 000000000..93dee6740
--- /dev/null
+++ b/pkg/state/wire/wire.go
@@ -0,0 +1,970 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package wire contains a few basic types that can be composed to serialize
+// graph information for the state package. This package defines the wire
+// protocol.
+//
+// Note that these types are careful about how they implement the relevant
+// interfaces (either value receiver or pointer receiver), so that native-sized
+// types, such as integers and simple pointers, can fit inside the interface
+// object.
+//
+// This package also uses panic as control flow, so called should be careful to
+// wrap calls in appropriate handlers.
+//
+// Testing for this package is driven by the state test package.
+package wire
+
+import (
+ "fmt"
+ "io"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
+)
+
+// Reader is the required reader interface.
+type Reader interface {
+ io.Reader
+ ReadByte() (byte, error)
+}
+
+// Writer is the required writer interface.
+type Writer interface {
+ io.Writer
+ WriteByte(byte) error
+}
+
+// readFull is a utility. The equivalent is not needed for Write, but the API
+// contract dictates that it must always complete all bytes given or return an
+// error.
+func readFull(r io.Reader, p []byte) {
+ for done := 0; done < len(p); {
+ n, err := r.Read(p[done:])
+ done += n
+ if n == 0 && err != nil {
+ panic(err)
+ }
+ }
+}
+
+// Object is a generic object.
+type Object interface {
+ // save saves the given object.
+ //
+ // Panic is used for error control flow.
+ save(Writer)
+
+ // load loads a new object of the given type.
+ //
+ // Panic is used for error control flow.
+ load(Reader) Object
+}
+
+// Bool is a boolean.
+type Bool bool
+
+// loadBool loads an object of type Bool.
+func loadBool(r Reader) Bool {
+ b := loadUint(r)
+ return Bool(b == 1)
+}
+
+// save implements Object.save.
+func (b Bool) save(w Writer) {
+ var v Uint
+ if b {
+ v = 1
+ } else {
+ v = 0
+ }
+ v.save(w)
+}
+
+// load implements Object.load.
+func (Bool) load(r Reader) Object { return loadBool(r) }
+
+// Int is a signed integer.
+//
+// This uses varint encoding.
+type Int int64
+
+// loadInt loads an object of type Int.
+func loadInt(r Reader) Int {
+ u := loadUint(r)
+ x := Int(u >> 1)
+ if u&1 != 0 {
+ x = ^x
+ }
+ return x
+}
+
+// save implements Object.save.
+func (i Int) save(w Writer) {
+ u := Uint(i) << 1
+ if i < 0 {
+ u = ^u
+ }
+ u.save(w)
+}
+
+// load implements Object.load.
+func (Int) load(r Reader) Object { return loadInt(r) }
+
+// Uint is an unsigned integer.
+type Uint uint64
+
+// loadUint loads an object of type Uint.
+func loadUint(r Reader) Uint {
+ var (
+ u Uint
+ s uint
+ )
+ for i := 0; i <= 9; i++ {
+ b, err := r.ReadByte()
+ if err != nil {
+ panic(err)
+ }
+ if b < 0x80 {
+ if i == 9 && b > 1 {
+ panic("overflow")
+ }
+ u |= Uint(b) << s
+ return u
+ }
+ u |= Uint(b&0x7f) << s
+ s += 7
+ }
+ panic("unreachable")
+}
+
+// save implements Object.save.
+func (u Uint) save(w Writer) {
+ for u >= 0x80 {
+ if err := w.WriteByte(byte(u) | 0x80); err != nil {
+ panic(err)
+ }
+ u >>= 7
+ }
+ if err := w.WriteByte(byte(u)); err != nil {
+ panic(err)
+ }
+}
+
+// load implements Object.load.
+func (Uint) load(r Reader) Object { return loadUint(r) }
+
+// Float32 is a 32-bit floating point number.
+type Float32 float32
+
+// loadFloat32 loads an object of type Float32.
+func loadFloat32(r Reader) Float32 {
+ n := loadUint(r)
+ return Float32(math.Float32frombits(uint32(n)))
+}
+
+// save implements Object.save.
+func (f Float32) save(w Writer) {
+ n := Uint(math.Float32bits(float32(f)))
+ n.save(w)
+}
+
+// load implements Object.load.
+func (Float32) load(r Reader) Object { return loadFloat32(r) }
+
+// Float64 is a 64-bit floating point number.
+type Float64 float64
+
+// loadFloat64 loads an object of type Float64.
+func loadFloat64(r Reader) Float64 {
+ n := loadUint(r)
+ return Float64(math.Float64frombits(uint64(n)))
+}
+
+// save implements Object.save.
+func (f Float64) save(w Writer) {
+ n := Uint(math.Float64bits(float64(f)))
+ n.save(w)
+}
+
+// load implements Object.load.
+func (Float64) load(r Reader) Object { return loadFloat64(r) }
+
+// Complex64 is a 64-bit complex number.
+type Complex64 complex128
+
+// loadComplex64 loads an object of type Complex64.
+func loadComplex64(r Reader) Complex64 {
+ re := loadFloat32(r)
+ im := loadFloat32(r)
+ return Complex64(complex(float32(re), float32(im)))
+}
+
+// save implements Object.save.
+func (c *Complex64) save(w Writer) {
+ re := Float32(real(*c))
+ im := Float32(imag(*c))
+ re.save(w)
+ im.save(w)
+}
+
+// load implements Object.load.
+func (*Complex64) load(r Reader) Object {
+ c := loadComplex64(r)
+ return &c
+}
+
+// Complex128 is a 128-bit complex number.
+type Complex128 complex128
+
+// loadComplex128 loads an object of type Complex128.
+func loadComplex128(r Reader) Complex128 {
+ re := loadFloat64(r)
+ im := loadFloat64(r)
+ return Complex128(complex(float64(re), float64(im)))
+}
+
+// save implements Object.save.
+func (c *Complex128) save(w Writer) {
+ re := Float64(real(*c))
+ im := Float64(imag(*c))
+ re.save(w)
+ im.save(w)
+}
+
+// load implements Object.load.
+func (*Complex128) load(r Reader) Object {
+ c := loadComplex128(r)
+ return &c
+}
+
+// String is a string.
+type String string
+
+// loadString loads an object of type String.
+func loadString(r Reader) String {
+ l := loadUint(r)
+ p := make([]byte, l)
+ readFull(r, p)
+ return String(gohacks.StringFromImmutableBytes(p))
+}
+
+// save implements Object.save.
+func (s *String) save(w Writer) {
+ l := Uint(len(*s))
+ l.save(w)
+ p := gohacks.ImmutableBytesFromString(string(*s))
+ _, err := w.Write(p) // Must write all bytes.
+ if err != nil {
+ panic(err)
+ }
+}
+
+// load implements Object.load.
+func (*String) load(r Reader) Object {
+ s := loadString(r)
+ return &s
+}
+
+// Dot is a kind of reference: one of Index and FieldName.
+type Dot interface {
+ isDot()
+}
+
+// Index is a reference resolution.
+type Index uint32
+
+func (Index) isDot() {}
+
+// FieldName is a reference resolution.
+type FieldName string
+
+func (*FieldName) isDot() {}
+
+// Ref is a reference to an object.
+type Ref struct {
+ // Root is the root object.
+ Root Uint
+
+ // Dots is the set of traversals required from the Root object above.
+ // Note that this will be stored in reverse order for efficiency.
+ Dots []Dot
+
+ // Type is the base type for the root object. This is non-nil iff Dots
+ // is non-zero length (that is, this is a complex reference). This is
+ // not *strictly* necessary, but can be used to simplify decoding.
+ Type TypeSpec
+}
+
+// loadRef loads an object of type Ref (abstract).
+func loadRef(r Reader) Ref {
+ ref := Ref{
+ Root: loadUint(r),
+ }
+ l := loadUint(r)
+ ref.Dots = make([]Dot, l)
+ for i := 0; i < int(l); i++ {
+ // Disambiguate between an Index (non-negative) and a field
+ // name (negative). This does some space and avoids a dedicate
+ // loadDot function. See Ref.save for the other side.
+ d := loadInt(r)
+ if d >= 0 {
+ ref.Dots[i] = Index(d)
+ continue
+ }
+ p := make([]byte, -d)
+ readFull(r, p)
+ fieldName := FieldName(gohacks.StringFromImmutableBytes(p))
+ ref.Dots[i] = &fieldName
+ }
+ if l != 0 {
+ // Only if dots is non-zero.
+ ref.Type = loadTypeSpec(r)
+ }
+ return ref
+}
+
+// save implements Object.save.
+func (r *Ref) save(w Writer) {
+ r.Root.save(w)
+ l := Uint(len(r.Dots))
+ l.save(w)
+ for _, d := range r.Dots {
+ // See LoadRef. We use non-negative numbers to encode Index
+ // objects and negative numbers to encode field lengths.
+ switch x := d.(type) {
+ case Index:
+ i := Int(x)
+ i.save(w)
+ case *FieldName:
+ d := Int(-len(*x))
+ d.save(w)
+ p := gohacks.ImmutableBytesFromString(string(*x))
+ if _, err := w.Write(p); err != nil {
+ panic(err)
+ }
+ default:
+ panic("unknown dot implementation")
+ }
+ }
+ if l != 0 {
+ // See above.
+ saveTypeSpec(w, r.Type)
+ }
+}
+
+// load implements Object.load.
+func (*Ref) load(r Reader) Object {
+ ref := loadRef(r)
+ return &ref
+}
+
+// Nil is a primitive zero value of any type.
+type Nil struct{}
+
+// loadNil loads an object of type Nil.
+func loadNil(r Reader) Nil {
+ return Nil{}
+}
+
+// save implements Object.save.
+func (Nil) save(w Writer) {}
+
+// load implements Object.load.
+func (Nil) load(r Reader) Object { return loadNil(r) }
+
+// Slice is a slice value.
+type Slice struct {
+ Length Uint
+ Capacity Uint
+ Ref Ref
+}
+
+// loadSlice loads an object of type Slice.
+func loadSlice(r Reader) Slice {
+ return Slice{
+ Length: loadUint(r),
+ Capacity: loadUint(r),
+ Ref: loadRef(r),
+ }
+}
+
+// save implements Object.save.
+func (s *Slice) save(w Writer) {
+ s.Length.save(w)
+ s.Capacity.save(w)
+ s.Ref.save(w)
+}
+
+// load implements Object.load.
+func (*Slice) load(r Reader) Object {
+ s := loadSlice(r)
+ return &s
+}
+
+// Array is an array value.
+type Array struct {
+ Contents []Object
+}
+
+// loadArray loads an object of type Array.
+func loadArray(r Reader) Array {
+ l := loadUint(r)
+ if l == 0 {
+ // Note that there isn't a single object available to encode
+ // the type of, so we need this additional branch.
+ return Array{}
+ }
+ // All the objects here have the same type, so use dynamic dispatch
+ // only once. All other objects will automatically take the same type
+ // as the first object.
+ contents := make([]Object, l)
+ v := Load(r)
+ contents[0] = v
+ for i := 1; i < int(l); i++ {
+ contents[i] = v.load(r)
+ }
+ return Array{
+ Contents: contents,
+ }
+}
+
+// save implements Object.save.
+func (a *Array) save(w Writer) {
+ l := Uint(len(a.Contents))
+ l.save(w)
+ if l == 0 {
+ // See LoadArray.
+ return
+ }
+ // See above.
+ Save(w, a.Contents[0])
+ for i := 1; i < int(l); i++ {
+ a.Contents[i].save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Array) load(r Reader) Object {
+ a := loadArray(r)
+ return &a
+}
+
+// Map is a map value.
+type Map struct {
+ Keys []Object
+ Values []Object
+}
+
+// loadMap loads an object of type Map.
+func loadMap(r Reader) Map {
+ l := loadUint(r)
+ if l == 0 {
+ // See LoadArray.
+ return Map{}
+ }
+ // See type dispatch notes in Array.
+ keys := make([]Object, l)
+ values := make([]Object, l)
+ k := Load(r)
+ v := Load(r)
+ keys[0] = k
+ values[0] = v
+ for i := 1; i < int(l); i++ {
+ keys[i] = k.load(r)
+ values[i] = v.load(r)
+ }
+ return Map{
+ Keys: keys,
+ Values: values,
+ }
+}
+
+// save implements Object.save.
+func (m *Map) save(w Writer) {
+ l := Uint(len(m.Keys))
+ if int(l) != len(m.Values) {
+ panic(fmt.Sprintf("mismatched keys (%d) Aand values (%d)", len(m.Keys), len(m.Values)))
+ }
+ l.save(w)
+ if l == 0 {
+ // See LoadArray.
+ return
+ }
+ // See above.
+ Save(w, m.Keys[0])
+ Save(w, m.Values[0])
+ for i := 1; i < int(l); i++ {
+ m.Keys[i].save(w)
+ m.Values[i].save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Map) load(r Reader) Object {
+ m := loadMap(r)
+ return &m
+}
+
+// TypeSpec is a type dereference.
+type TypeSpec interface {
+ isTypeSpec()
+}
+
+// TypeID is a concrete type ID.
+type TypeID Uint
+
+func (TypeID) isTypeSpec() {}
+
+// TypeSpecPointer is a pointer type.
+type TypeSpecPointer struct {
+ Type TypeSpec
+}
+
+func (*TypeSpecPointer) isTypeSpec() {}
+
+// TypeSpecArray is an array type.
+type TypeSpecArray struct {
+ Count Uint
+ Type TypeSpec
+}
+
+func (*TypeSpecArray) isTypeSpec() {}
+
+// TypeSpecSlice is a slice type.
+type TypeSpecSlice struct {
+ Type TypeSpec
+}
+
+func (*TypeSpecSlice) isTypeSpec() {}
+
+// TypeSpecMap is a map type.
+type TypeSpecMap struct {
+ Key TypeSpec
+ Value TypeSpec
+}
+
+func (*TypeSpecMap) isTypeSpec() {}
+
+// TypeSpecNil is an empty type.
+type TypeSpecNil struct{}
+
+func (TypeSpecNil) isTypeSpec() {}
+
+// TypeSpec types.
+//
+// These use a distinct encoding on the wire, as they are used only in the
+// interface object. They are decoded through the dedicated loadTypeSpec and
+// saveTypeSpec functions.
+const (
+ typeSpecTypeID Uint = iota
+ typeSpecPointer
+ typeSpecArray
+ typeSpecSlice
+ typeSpecMap
+ typeSpecNil
+)
+
+// loadTypeSpec loads TypeSpec values.
+func loadTypeSpec(r Reader) TypeSpec {
+ switch hdr := loadUint(r); hdr {
+ case typeSpecTypeID:
+ return TypeID(loadUint(r))
+ case typeSpecPointer:
+ return &TypeSpecPointer{
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecArray:
+ return &TypeSpecArray{
+ Count: loadUint(r),
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecSlice:
+ return &TypeSpecSlice{
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecMap:
+ return &TypeSpecMap{
+ Key: loadTypeSpec(r),
+ Value: loadTypeSpec(r),
+ }
+ case typeSpecNil:
+ return TypeSpecNil{}
+ default:
+ // This is not a valid stream?
+ panic(fmt.Errorf("unknown header: %d", hdr))
+ }
+}
+
+// saveTypeSpec saves TypeSpec values.
+func saveTypeSpec(w Writer, t TypeSpec) {
+ switch x := t.(type) {
+ case TypeID:
+ typeSpecTypeID.save(w)
+ Uint(x).save(w)
+ case *TypeSpecPointer:
+ typeSpecPointer.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecArray:
+ typeSpecArray.save(w)
+ x.Count.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecSlice:
+ typeSpecSlice.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecMap:
+ typeSpecMap.save(w)
+ saveTypeSpec(w, x.Key)
+ saveTypeSpec(w, x.Value)
+ case TypeSpecNil:
+ typeSpecNil.save(w)
+ default:
+ // This should not happen?
+ panic(fmt.Errorf("unknown type %T", t))
+ }
+}
+
+// Interface is an interface value.
+type Interface struct {
+ Type TypeSpec
+ Value Object
+}
+
+// loadInterface loads an object of type Interface.
+func loadInterface(r Reader) Interface {
+ return Interface{
+ Type: loadTypeSpec(r),
+ Value: Load(r),
+ }
+}
+
+// save implements Object.save.
+func (i *Interface) save(w Writer) {
+ saveTypeSpec(w, i.Type)
+ Save(w, i.Value)
+}
+
+// load implements Object.load.
+func (*Interface) load(r Reader) Object {
+ i := loadInterface(r)
+ return &i
+}
+
+// Type is type information.
+type Type struct {
+ Name string
+ Fields []string
+}
+
+// loadType loads an object of type Type.
+func loadType(r Reader) Type {
+ name := string(loadString(r))
+ l := loadUint(r)
+ fields := make([]string, l)
+ for i := 0; i < int(l); i++ {
+ fields[i] = string(loadString(r))
+ }
+ return Type{
+ Name: name,
+ Fields: fields,
+ }
+}
+
+// save implements Object.save.
+func (t *Type) save(w Writer) {
+ s := String(t.Name)
+ s.save(w)
+ l := Uint(len(t.Fields))
+ l.save(w)
+ for i := 0; i < int(l); i++ {
+ s := String(t.Fields[i])
+ s.save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Type) load(r Reader) Object {
+ t := loadType(r)
+ return &t
+}
+
+// multipleObjects is a special type for serializing multiple objects.
+type multipleObjects []Object
+
+// loadMultipleObjects loads a series of objects.
+func loadMultipleObjects(r Reader) multipleObjects {
+ l := loadUint(r)
+ m := make(multipleObjects, l)
+ for i := 0; i < int(l); i++ {
+ m[i] = Load(r)
+ }
+ return m
+}
+
+// save implements Object.save.
+func (m *multipleObjects) save(w Writer) {
+ l := Uint(len(*m))
+ l.save(w)
+ for i := 0; i < int(l); i++ {
+ Save(w, (*m)[i])
+ }
+}
+
+// load implements Object.load.
+func (*multipleObjects) load(r Reader) Object {
+ m := loadMultipleObjects(r)
+ return &m
+}
+
+// noObjects represents no objects.
+type noObjects struct{}
+
+// loadNoObjects loads a sentinel.
+func loadNoObjects(r Reader) noObjects { return noObjects{} }
+
+// save implements Object.save.
+func (noObjects) save(w Writer) {}
+
+// load implements Object.load.
+func (noObjects) load(r Reader) Object { return loadNoObjects(r) }
+
+// Struct is a basic composite value.
+type Struct struct {
+ TypeID TypeID
+ fields Object // Optionally noObjects or *multipleObjects.
+}
+
+// Field returns a pointer to the given field slot.
+//
+// This must be called after Alloc.
+func (s *Struct) Field(i int) *Object {
+ if fields, ok := s.fields.(*multipleObjects); ok {
+ return &((*fields)[i])
+ }
+ if _, ok := s.fields.(noObjects); ok {
+ // Alloc may be optionally called; can't call twice.
+ panic("Field called inappropriately, wrong Alloc?")
+ }
+ return &s.fields
+}
+
+// Alloc allocates the given number of fields.
+//
+// This must be called before Add and Save.
+//
+// Precondition: slots must be positive.
+func (s *Struct) Alloc(slots int) {
+ switch {
+ case slots == 0:
+ s.fields = noObjects{}
+ case slots == 1:
+ // Leave it alone.
+ case slots > 1:
+ fields := make(multipleObjects, slots)
+ s.fields = &fields
+ default:
+ // Violates precondition.
+ panic(fmt.Sprintf("Alloc called with negative slots %d?", slots))
+ }
+}
+
+// Fields returns the number of fields.
+func (s *Struct) Fields() int {
+ switch x := s.fields.(type) {
+ case *multipleObjects:
+ return len(*x)
+ case noObjects:
+ return 0
+ default:
+ return 1
+ }
+}
+
+// loadStruct loads an object of type Struct.
+func loadStruct(r Reader) Struct {
+ return Struct{
+ TypeID: TypeID(loadUint(r)),
+ fields: Load(r),
+ }
+}
+
+// save implements Object.save.
+//
+// Precondition: Alloc must have been called, and the fields all filled in
+// appropriately. See Alloc and Add for more details.
+func (s *Struct) save(w Writer) {
+ Uint(s.TypeID).save(w)
+ Save(w, s.fields)
+}
+
+// load implements Object.load.
+func (*Struct) load(r Reader) Object {
+ s := loadStruct(r)
+ return &s
+}
+
+// Object types.
+//
+// N.B. Be careful about changing the order or introducing new elements in the
+// middle here. This is part of the wire format and shouldn't change.
+const (
+ typeBool Uint = iota
+ typeInt
+ typeUint
+ typeFloat32
+ typeFloat64
+ typeNil
+ typeRef
+ typeString
+ typeSlice
+ typeArray
+ typeMap
+ typeStruct
+ typeNoObjects
+ typeMultipleObjects
+ typeInterface
+ typeComplex64
+ typeComplex128
+ typeType
+)
+
+// Save saves the given object.
+//
+// +checkescape all
+//
+// N.B. This function will panic on error.
+func Save(w Writer, obj Object) {
+ switch x := obj.(type) {
+ case Bool:
+ typeBool.save(w)
+ x.save(w)
+ case Int:
+ typeInt.save(w)
+ x.save(w)
+ case Uint:
+ typeUint.save(w)
+ x.save(w)
+ case Float32:
+ typeFloat32.save(w)
+ x.save(w)
+ case Float64:
+ typeFloat64.save(w)
+ x.save(w)
+ case Nil:
+ typeNil.save(w)
+ x.save(w)
+ case *Ref:
+ typeRef.save(w)
+ x.save(w)
+ case *String:
+ typeString.save(w)
+ x.save(w)
+ case *Slice:
+ typeSlice.save(w)
+ x.save(w)
+ case *Array:
+ typeArray.save(w)
+ x.save(w)
+ case *Map:
+ typeMap.save(w)
+ x.save(w)
+ case *Struct:
+ typeStruct.save(w)
+ x.save(w)
+ case noObjects:
+ typeNoObjects.save(w)
+ x.save(w)
+ case *multipleObjects:
+ typeMultipleObjects.save(w)
+ x.save(w)
+ case *Interface:
+ typeInterface.save(w)
+ x.save(w)
+ case *Type:
+ typeType.save(w)
+ x.save(w)
+ case *Complex64:
+ typeComplex64.save(w)
+ x.save(w)
+ case *Complex128:
+ typeComplex128.save(w)
+ x.save(w)
+ default:
+ panic(fmt.Errorf("unknown type: %#v", obj))
+ }
+}
+
+// Load loads a new object.
+//
+// +checkescape all
+//
+// N.B. This function will panic on error.
+func Load(r Reader) Object {
+ switch hdr := loadUint(r); hdr {
+ case typeBool:
+ return loadBool(r)
+ case typeInt:
+ return loadInt(r)
+ case typeUint:
+ return loadUint(r)
+ case typeFloat32:
+ return loadFloat32(r)
+ case typeFloat64:
+ return loadFloat64(r)
+ case typeNil:
+ return loadNil(r)
+ case typeRef:
+ return ((*Ref)(nil)).load(r) // Escapes.
+ case typeString:
+ return ((*String)(nil)).load(r) // Escapes.
+ case typeSlice:
+ return ((*Slice)(nil)).load(r) // Escapes.
+ case typeArray:
+ return ((*Array)(nil)).load(r) // Escapes.
+ case typeMap:
+ return ((*Map)(nil)).load(r) // Escapes.
+ case typeStruct:
+ return ((*Struct)(nil)).load(r) // Escapes.
+ case typeNoObjects: // Special for struct.
+ return loadNoObjects(r)
+ case typeMultipleObjects: // Special for struct.
+ return ((*multipleObjects)(nil)).load(r) // Escapes.
+ case typeInterface:
+ return ((*Interface)(nil)).load(r) // Escapes.
+ case typeComplex64:
+ return ((*Complex64)(nil)).load(r) // Escapes.
+ case typeComplex128:
+ return ((*Complex128)(nil)).load(r) // Escapes.
+ case typeType:
+ return ((*Type)(nil)).load(r) // Escapes.
+ default:
+ // This is not a valid stream?
+ panic(fmt.Errorf("unknown header: %d", hdr))
+ }
+}
+
+// LoadUint loads a single unsigned integer.
+//
+// N.B. This function will panic on error.
+func LoadUint(r Reader) uint64 {
+ return uint64(loadUint(r))
+}
+
+// SaveUint saves a single unsigned integer.
+//
+// N.B. This function will panic on error.
+func SaveUint(w Writer, v uint64) {
+ Uint(v).save(w)
+}
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index d0d77e19c..4d47207f7 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -33,6 +33,7 @@ go_library(
"aliases.go",
"memmove_unsafe.go",
"mutex_unsafe.go",
+ "nocopy.go",
"norace_unsafe.go",
"race_unsafe.go",
"rwmutex_unsafe.go",
diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go
index ad4a3a37e..1d7780695 100644
--- a/pkg/sync/memmove_unsafe.go
+++ b/pkg/sync/memmove_unsafe.go
@@ -4,7 +4,7 @@
// license that can be found in the LICENSE file.
// +build go1.12
-// +build !go1.15
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
diff --git a/pkg/sync/mutex_unsafe.go b/pkg/sync/mutex_unsafe.go
index 3dd15578b..dc034d561 100644
--- a/pkg/sync/mutex_unsafe.go
+++ b/pkg/sync/mutex_unsafe.go
@@ -4,7 +4,7 @@
// license that can be found in the LICENSE file.
// +build go1.13
-// +build !go1.15
+// +build !go1.16
// When updating the build constraint (above), check that syncMutex matches the
// standard library sync.Mutex definition.
diff --git a/pkg/sync/nocopy.go b/pkg/sync/nocopy.go
new file mode 100644
index 000000000..722b29501
--- /dev/null
+++ b/pkg/sync/nocopy.go
@@ -0,0 +1,28 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sync
+
+// NoCopy may be embedded into structs which must not be copied
+// after the first use.
+//
+// See https://golang.org/issues/8005#issuecomment-190753527
+// for details.
+type NoCopy struct{}
+
+// Lock is a no-op used by -copylocks checker from `go vet`.
+func (*NoCopy) Lock() {}
+
+// Unlock is a no-op used by -copylocks checker from `go vet`.
+func (*NoCopy) Unlock() {}
diff --git a/pkg/sync/rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go
index ea6cdc447..995c0346e 100644
--- a/pkg/sync/rwmutex_unsafe.go
+++ b/pkg/sync/rwmutex_unsafe.go
@@ -4,7 +4,7 @@
// license that can be found in the LICENSE file.
// +build go1.13
-// +build !go1.15
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go
index 112e0e604..ad271e1a0 100644
--- a/pkg/syncevent/waiter_unsafe.go
+++ b/pkg/syncevent/waiter_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.11
-// +build !go1.15
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index 8ff922c69..5ae10939d 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -22,7 +22,7 @@ import (
// Mapping for tcpip.Error types.
var (
ErrUnknownProtocol = New(tcpip.ErrUnknownProtocol.String(), linux.EINVAL)
- ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.EINVAL)
+ ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.ENODEV)
ErrUnknownDevice = New(tcpip.ErrUnknownDevice.String(), linux.ENODEV)
ErrUnknownProtocolOption = New(tcpip.ErrUnknownProtocolOption.String(), linux.ENOPROTOOPT)
ErrDuplicateNICID = New(tcpip.ErrDuplicateNICID.String(), linux.EEXIST)
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index f86db0999..798e07b01 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -61,6 +61,7 @@ var (
ENOMEM = error(syscall.ENOMEM)
ENOSPC = error(syscall.ENOSPC)
ENOSYS = error(syscall.ENOSYS)
+ ENOTCONN = error(syscall.ENOTCONN)
ENOTDIR = error(syscall.ENOTDIR)
ENOTEMPTY = error(syscall.ENOTEMPTY)
ENOTSOCK = error(syscall.ENOTSOCK)
@@ -72,6 +73,7 @@ var (
EPERM = error(syscall.EPERM)
EPIPE = error(syscall.EPIPE)
ERANGE = error(syscall.ERANGE)
+ EREMOTE = error(syscall.EREMOTE)
EROFS = error(syscall.EROFS)
ESPIPE = error(syscall.ESPIPE)
ESRCH = error(syscall.ESRCH)
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index e57d45f2a..a984f1712 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -22,7 +22,6 @@ go_test(
size = "small",
srcs = ["gonet_test.go"],
library = ":gonet",
- tags = ["flaky"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 6e0db2741..d82ed5205 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -335,6 +335,11 @@ func (c *TCPConn) Read(b []byte) (int, error) {
deadline := c.readCancel()
numRead := 0
+ defer func() {
+ if numRead != 0 {
+ c.ep.ModerateRecvBuf(numRead)
+ }
+ }()
for numRead != len(b) {
if len(c.read) == 0 {
var err error
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index c1745ba6a..ee264b726 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -320,6 +320,22 @@ func DstPort(port uint16) TransportChecker {
}
}
+// NoChecksum creates a checker that checks if the checksum is zero.
+func NoChecksum(noChecksum bool) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ udp, ok := h.(header.UDP)
+ if !ok {
+ return
+ }
+
+ if b := udp.Checksum() == 0; b != noChecksum {
+ t.Errorf("bad checksum state, got %t, want %t", b, noChecksum)
+ }
+ }
+}
+
// SeqNum creates a checker that checks the sequence number.
func SeqNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 0cde694dc..d87797617 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -48,7 +48,7 @@ go_test(
"//pkg/rand",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
@@ -64,6 +64,6 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go
index 718a4720a..83189676e 100644
--- a/pkg/tcpip/header/arp.go
+++ b/pkg/tcpip/header/arp.go
@@ -14,14 +14,33 @@
package header
-import "gvisor.dev/gvisor/pkg/tcpip"
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
const (
// ARPProtocolNumber is the ARP network protocol number.
ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806
// ARPSize is the size of an IPv4-over-Ethernet ARP packet.
- ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4
+ ARPSize = 28
+)
+
+// ARPHardwareType is the hardware type for LinkEndpoint in an ARP header.
+type ARPHardwareType uint16
+
+// Typical ARP HardwareType values. Some of the constants have to be specific
+// values as they are egressed on the wire in the HTYPE field of an ARP header.
+const (
+ ARPHardwareNone ARPHardwareType = 0
+ // ARPHardwareEther specifically is the HTYPE for Ethernet as specified
+ // in the IANA list here:
+ //
+ // https://www.iana.org/assignments/arp-parameters/arp-parameters.xhtml#arp-parameters-2
+ ARPHardwareEther ARPHardwareType = 1
+ ARPHardwareLoopback ARPHardwareType = 2
)
// ARPOp is an ARP opcode.
@@ -36,54 +55,64 @@ const (
// ARP is an ARP packet stored in a byte array as described in RFC 826.
type ARP []byte
-func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) }
-func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) }
-func (a ARP) hardwareAddressSize() int { return int(a[4]) }
-func (a ARP) protocolAddressSize() int { return int(a[5]) }
+const (
+ hTypeOffset = 0
+ protocolOffset = 2
+ haAddressSizeOffset = 4
+ protoAddressSizeOffset = 5
+ opCodeOffset = 6
+ senderHAAddressOffset = 8
+ senderProtocolAddressOffset = senderHAAddressOffset + EthernetAddressSize
+ targetHAAddressOffset = senderProtocolAddressOffset + IPv4AddressSize
+ targetProtocolAddressOffset = targetHAAddressOffset + EthernetAddressSize
+)
+
+func (a ARP) hardwareAddressType() ARPHardwareType {
+ return ARPHardwareType(binary.BigEndian.Uint16(a[hTypeOffset:]))
+}
+
+func (a ARP) protocolAddressSpace() uint16 { return binary.BigEndian.Uint16(a[protocolOffset:]) }
+func (a ARP) hardwareAddressSize() int { return int(a[haAddressSizeOffset]) }
+func (a ARP) protocolAddressSize() int { return int(a[protoAddressSizeOffset]) }
// Op is the ARP opcode.
-func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) }
+func (a ARP) Op() ARPOp { return ARPOp(binary.BigEndian.Uint16(a[opCodeOffset:])) }
// SetOp sets the ARP opcode.
func (a ARP) SetOp(op ARPOp) {
- a[6] = uint8(op >> 8)
- a[7] = uint8(op)
+ binary.BigEndian.PutUint16(a[opCodeOffset:], uint16(op))
}
// SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet.
func (a ARP) SetIPv4OverEthernet() {
- a[0], a[1] = 0, 1 // htypeEthernet
- a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber
- a[4] = 6 // macSize
- a[5] = uint8(IPv4AddressSize)
+ binary.BigEndian.PutUint16(a[hTypeOffset:], uint16(ARPHardwareEther))
+ binary.BigEndian.PutUint16(a[protocolOffset:], uint16(IPv4ProtocolNumber))
+ a[haAddressSizeOffset] = EthernetAddressSize
+ a[protoAddressSizeOffset] = uint8(IPv4AddressSize)
}
// HardwareAddressSender is the link address of the sender.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) HardwareAddressSender() []byte {
- const s = 8
- return a[s : s+6]
+ return a[senderHAAddressOffset : senderHAAddressOffset+EthernetAddressSize]
}
// ProtocolAddressSender is the protocol address of the sender.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) ProtocolAddressSender() []byte {
- const s = 8 + 6
- return a[s : s+4]
+ return a[senderProtocolAddressOffset : senderProtocolAddressOffset+IPv4AddressSize]
}
// HardwareAddressTarget is the link address of the target.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) HardwareAddressTarget() []byte {
- const s = 8 + 6 + 4
- return a[s : s+6]
+ return a[targetHAAddressOffset : targetHAAddressOffset+EthernetAddressSize]
}
// ProtocolAddressTarget is the protocol address of the target.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) ProtocolAddressTarget() []byte {
- const s = 8 + 6 + 4 + 6
- return a[s : s+4]
+ return a[targetProtocolAddressOffset : targetProtocolAddressOffset+IPv4AddressSize]
}
// IsValid reports whether this is an ARP packet for IPv4 over Ethernet.
@@ -91,10 +120,8 @@ func (a ARP) IsValid() bool {
if len(a) < ARPSize {
return false
}
- const htypeEthernet = 1
- const macSize = 6
- return a.hardwareAddressSpace() == htypeEthernet &&
+ return a.hardwareAddressType() == ARPHardwareEther &&
a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) &&
- a.hardwareAddressSize() == macSize &&
+ a.hardwareAddressSize() == EthernetAddressSize &&
a.protocolAddressSize() == IPv4AddressSize
}
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index b1e92d2d7..eaface8cb 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -53,6 +53,10 @@ const (
// (all bits set to 0).
unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
+ // EthernetBroadcastAddress is an ethernet address that addresses every node
+ // on a local link.
+ EthernetBroadcastAddress = tcpip.LinkAddress("\xff\xff\xff\xff\xff\xff")
+
// unicastMulticastFlagMask is the mask of the least significant bit in
// the first octet (in network byte order) of an ethernet address that
// determines whether the ethernet address is a unicast or multicast. If
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 7908c5744..1a631b31a 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -72,6 +72,7 @@ const (
// Values for ICMP code as defined in RFC 792.
const (
ICMPv4TTLExceeded = 0
+ ICMPv4HostUnreachable = 1
ICMPv4PortUnreachable = 3
ICMPv4FragmentationNeeded = 4
)
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index c7ee2de57..a13b4b809 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -110,9 +110,16 @@ const (
ICMPv6RedirectMsg ICMPv6Type = 137
)
-// Values for ICMP code as defined in RFC 4443.
+// Values for ICMP destination unreachable code as defined in RFC 4443 section
+// 3.1.
const (
- ICMPv6PortUnreachable = 4
+ ICMPv6NetworkUnreachable = 0
+ ICMPv6Prohibited = 1
+ ICMPv6BeyondScope = 2
+ ICMPv6AddressUnreachable = 3
+ ICMPv6PortUnreachable = 4
+ ICMPv6Policy = 5
+ ICMPv6RejectRoute = 6
)
// Type is the ICMP type field.
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 76839eb92..62ac932bb 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -159,6 +159,11 @@ func (b IPv4) Flags() uint8 {
return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
}
+// More returns whether the more fragments flag is set.
+func (b IPv4) More() bool {
+ return b.Flags()&IPv4FlagMoreFragments != 0
+}
+
// TTL returns the "TTL" field of the ipv4 header.
func (b IPv4) TTL() uint8 {
return b[ttl]
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
index 2c4591409..3499d8399 100644
--- a/pkg/tcpip/header/ipv6_extension_headers.go
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -354,6 +354,13 @@ func (b IPv6FragmentExtHdr) ID() uint32 {
return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:])
}
+// IsAtomic returns whether the fragment header indicates an atomic fragment. An
+// atomic fragment is a fragment that contains all the data required to
+// reassemble a full packet.
+func (b IPv6FragmentExtHdr) IsAtomic() bool {
+ return !b.More() && b.FragmentOffset() == 0
+}
+
// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload.
//
// The IPv6 payload may contain IPv6 extension headers before any upper layer
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index b8b93e78e..39ca774ef 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 9bf67686d..e12a5929b 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -181,13 +182,13 @@ func (e *Endpoint) NumQueued() int {
}
// InjectInbound injects an inbound packet.
-func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.InjectLinkAddr(protocol, "", pkt)
}
// InjectLinkAddr injects an inbound packet with a remote link address.
-func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt stack.PacketBuffer) {
- e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt)
+func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt)
}
// Attach saves the stack network-layer dispatcher for use later when packets
@@ -229,13 +230,13 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
}
// WritePacket stores outbound packets into the channel.
-func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
// Clone r then release its resource so we only get the relevant fields from
// stack.Route without holding a reference to a NIC's endpoint.
route := r.Clone()
route.Release()
p := PacketInfo{
- Pkt: &pkt,
+ Pkt: pkt,
Proto: protocol,
GSO: gso,
Route: route,
@@ -296,3 +297,12 @@ func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle {
func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
e.q.RemoveNotify(handle)
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareNone
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index aa6db9aea..507b44abc 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -15,6 +15,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/binary",
+ "//pkg/iovec",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index affa1bbdf..c18bb91fb 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -45,6 +45,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -385,26 +386,35 @@ const (
_VIRTIO_NET_HDR_GSO_TCPV6 = 4
)
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
if e.hdrSize > 0 {
// Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
// Preserve the src address if it's set in the route.
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
}
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if e.hdrSize > 0 {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
+
+ var builder iovec.Builder
fd := e.fds[pkt.Hash%uint32(len(e.fds))]
if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
@@ -430,47 +440,28 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
}
vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
- return rawfile.NonBlockingWrite3(fd, vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView())
+ builder.Add(vnetHdrBuf)
}
- if pkt.Data.Size() == 0 {
- return rawfile.NonBlockingWrite(fd, pkt.Header.View())
- }
- if pkt.Header.UsedLength() == 0 {
- return rawfile.NonBlockingWrite(fd, pkt.Data.ToView())
+ builder.Add(pkt.Header.View())
+ for _, v := range pkt.Data.Views() {
+ builder.Add(v)
}
- return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil)
+ return rawfile.NonBlockingWriteIovec(fd, builder.Build())
}
func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) {
// Send a batch of packets through batchFD.
mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
for _, pkt := range batch {
- var ethHdrBuf []byte
- iovLen := 0
if e.hdrSize > 0 {
- // Add ethernet header if needed.
- ethHdrBuf = make([]byte, header.EthernetMinimumSize)
- eth := header.Ethernet(ethHdrBuf)
- ethHdr := &header.EthernetFields{
- DstAddr: pkt.EgressRoute.RemoteLinkAddress,
- Type: pkt.NetworkProtocolNumber,
- }
-
- // Preserve the src address if it's set in the route.
- if pkt.EgressRoute.LocalLinkAddress != "" {
- ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress
- } else {
- ethHdr.SrcAddr = e.addr
- }
- eth.Encode(ethHdr)
- iovLen++
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
- vnetHdr := virtioNetHdr{}
var vnetHdrBuf []byte
if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ vnetHdr := virtioNetHdr{}
if pkt.GSOOptions != nil {
vnetHdr.hdrLen = uint16(pkt.Header.UsedLength())
if pkt.GSOOptions.NeedsCsum {
@@ -491,45 +482,19 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
}
}
vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
- iovLen++
}
- iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views()))
+ var builder iovec.Builder
+ builder.Add(vnetHdrBuf)
+ builder.Add(pkt.Header.View())
+ for _, v := range pkt.Data.Views() {
+ builder.Add(v)
+ }
+ iovecs := builder.Build()
+
var mmsgHdr rawfile.MMsgHdr
mmsgHdr.Msg.Iov = &iovecs[0]
- iovecIdx := 0
- if vnetHdrBuf != nil {
- v := &iovecs[iovecIdx]
- v.Base = &vnetHdrBuf[0]
- v.Len = uint64(len(vnetHdrBuf))
- iovecIdx++
- }
- if ethHdrBuf != nil {
- v := &iovecs[iovecIdx]
- v.Base = &ethHdrBuf[0]
- v.Len = uint64(len(ethHdrBuf))
- iovecIdx++
- }
- pktSize := uint64(0)
- // Encode L3 Header
- v := &iovecs[iovecIdx]
- hdr := &pkt.Header
- hdrView := hdr.View()
- v.Base = &hdrView[0]
- v.Len = uint64(len(hdrView))
- pktSize += v.Len
- iovecIdx++
-
- // Now encode the Transport Payload.
- pktViews := pkt.Data.Views()
- for i := range pktViews {
- vec := &iovecs[iovecIdx]
- iovecIdx++
- vec.Base = &pktViews[i][0]
- vec.Len = uint64(len(pktViews[i]))
- pktSize += vec.Len
- }
- mmsgHdr.Msg.Iovlen = uint64(iovecIdx)
+ mmsgHdr.Msg.Iovlen = uint64(len(iovecs))
mmsgHdrs = append(mmsgHdrs, mmsgHdr)
}
@@ -626,6 +591,14 @@ func (e *endpoint) GSOMaxSize() uint32 {
return e.gsoMaxSize
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.hdrSize > 0 {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
+
// InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes
// to the FD, but does not read from it. All reads come from injected packets.
type InjectableEndpoint struct {
@@ -641,8 +614,8 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
}
// InjectInbound injects an inbound packet.
-func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt)
+func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt)
}
// NewInjectable creates a new fd-based InjectableEndpoint.
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 3bfb15a8e..7b995b85a 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -45,7 +45,7 @@ const (
type packetInfo struct {
raddr tcpip.LinkAddress
proto tcpip.NetworkProtocolNumber
- contents stack.PacketBuffer
+ contents *stack.PacketBuffer
}
type context struct {
@@ -103,10 +103,14 @@ func (c *context) cleanup() {
}
}
-func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
c.ch <- packetInfo{remote, protocol, pkt}
}
+func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestNoEthernetProperties(t *testing.T) {
c := newContext(t, &Options{MTU: mtu})
defer c.cleanup()
@@ -179,7 +183,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
L3HdrLen: header.IPv4MaximumHeaderSize,
}
}
- if err := c.ep.WritePacket(r, gso, proto, stack.PacketBuffer{
+ if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
Hash: hash,
@@ -295,7 +299,7 @@ func TestPreserveSrcAddress(t *testing.T) {
// WritePacket panics given a prependable with anything less than
// the minimum size of the ethernet header.
hdr := buffer.NewPrependable(header.EthernetMinimumSize)
- if err := c.ep.WritePacket(r, nil /* gso */, proto, stack.PacketBuffer{
+ if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{
Header: hdr,
Data: buffer.VectorisedView{},
}); err != nil {
@@ -358,7 +362,7 @@ func TestDeliverPacket(t *testing.T) {
want := packetInfo{
raddr: raddr,
proto: proto,
- contents: stack.PacketBuffer{
+ contents: &stack.PacketBuffer{
Data: buffer.View(b).ToVectorisedView(),
LinkHeader: buffer.View(hdr),
},
@@ -500,3 +504,80 @@ func TestRecvMMsgDispatcherCapLength(t *testing.T) {
}
}
+
+// fakeNetworkDispatcher delivers packets to pkts.
+type fakeNetworkDispatcher struct {
+ pkts []*stack.PacketBuffer
+}
+
+func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ d.pkts = append(d.pkts, pkt)
+}
+
+func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
+func TestDispatchPacketFormat(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ newDispatcher func(fd int, e *endpoint) (linkDispatcher, error)
+ }{
+ {
+ name: "readVDispatcher",
+ newDispatcher: newReadVDispatcher,
+ },
+ {
+ name: "recvMMsgDispatcher",
+ newDispatcher: newRecvMMsgDispatcher,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ // Create a socket pair to send/recv.
+ fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer syscall.Close(fds[0])
+ defer syscall.Close(fds[1])
+
+ data := []byte{
+ // Ethernet header.
+ 1, 2, 3, 4, 5, 60,
+ 1, 2, 3, 4, 5, 61,
+ 8, 0,
+ // Mock network header.
+ 40, 41, 42, 43,
+ }
+ err = syscall.Sendmsg(fds[1], data, nil, nil, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create and run dispatcher once.
+ sink := &fakeNetworkDispatcher{}
+ d, err := test.newDispatcher(fds[0], &endpoint{
+ hdrSize: header.EthernetMinimumSize,
+ dispatcher: sink,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if ok, err := d.dispatch(); !ok || err != nil {
+ t.Fatalf("d.dispatch() = %v, %v", ok, err)
+ }
+
+ // Verify packet.
+ if got, want := len(sink.pkts), 1; got != want {
+ t.Fatalf("len(sink.pkts) = %d, want %d", got, want)
+ }
+ pkt := sink.pkts[0]
+ if got, want := len(pkt.LinkHeader), header.EthernetMinimumSize; got != want {
+ t.Errorf("len(pkt.LinkHeader) = %d, want %d", got, want)
+ }
+ if got, want := pkt.Data.Size(), 4; got != want {
+ t.Errorf("pkt.Data.Size() = %d, want %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index fe2bf3b0b..2dfd29aa9 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -191,7 +191,7 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
}
pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, stack.PacketBuffer{
+ d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{
Data: buffer.View(pkt).ToVectorisedView(),
LinkHeader: buffer.View(eth),
})
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index cb4cbea69..d8f2504b3 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -139,13 +139,13 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
}
used := d.capViews(n, BufConfig)
- pkt := stack.PacketBuffer{
+ pkt := &stack.PacketBuffer{
Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)),
LinkHeader: buffer.View(eth),
}
pkt.Data.TrimFront(d.e.hdrSize)
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt)
+ d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt)
// Prepare e.views for another packet: release used views.
for i := 0; i < used; i++ {
@@ -169,7 +169,7 @@ type recvMMsgDispatcher struct {
// iovecs is an array of array of iovec records where each iovec base
// pointer and length are initialzed to the corresponding view above,
- // except when GSO is neabled then the first iovec in each array of
+ // except when GSO is enabled then the first iovec in each array of
// iovecs points to a buffer for the vnet header which is stripped
// before the views are passed up the stack for further processing.
iovecs [][]syscall.Iovec
@@ -278,7 +278,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth = header.Ethernet(d.views[k][0])
+ eth = header.Ethernet(d.views[k][0][:header.EthernetMinimumSize])
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -296,12 +296,12 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
}
used := d.capViews(k, int(n), BufConfig)
- pkt := stack.PacketBuffer{
+ pkt := &stack.PacketBuffer{
Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)),
LinkHeader: buffer.View(eth),
}
pkt.Data.TrimFront(d.e.hdrSize)
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt)
+ d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt)
// Prepare e.views for another packet: release used views.
for i := 0; i < used; i++ {
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 073c84ef9..781cdd317 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -76,7 +76,7 @@ func (*endpoint) Wait() {}
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
@@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw
// Because we're immediately turning around and writing the packet back
// to the rx path, we intentionally don't preserve the remote and local
// link addresses from the stack.Route we're passed.
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, stack.PacketBuffer{
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, &stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
@@ -106,10 +106,18 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
}
linkHeader := header.Ethernet(hdr)
vv.TrimFront(len(linkHeader))
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), &stack.PacketBuffer{
Data: vv,
LinkHeader: buffer.View(linkHeader),
})
return nil
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareLoopback
+}
+
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index 82b441b79..e7493e5c5 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -9,6 +9,7 @@ go_library(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index a5478ce17..56a611825 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -18,6 +18,7 @@ package muxed
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -80,8 +81,8 @@ func (m *InjectableEndpoint) IsAttached() bool {
}
// InjectInbound implements stack.InjectableLinkEndpoint.
-func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
- m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt)
+func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ m.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt)
}
// WritePackets writes outbound packets to the appropriate
@@ -98,7 +99,7 @@ func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts s
// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint
// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a
// route registered in this endpoint.
-func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
if endpoint, ok := m.routes[r.RemoteAddress]; ok {
return endpoint.WritePacket(r, gso, protocol, pkt)
}
@@ -129,6 +130,15 @@ func (m *InjectableEndpoint) Wait() {
}
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("unsupported operation")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 87c734c1f..0744f66d6 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -50,7 +50,7 @@ func TestInjectableEndpointDispatch(t *testing.T) {
hdr.Prepend(1)[0] = 0xFA
packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(),
})
@@ -70,7 +70,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) {
hdr := buffer.NewPrependable(1)
hdr.Prepend(1)[0] = 0xFA
packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buffer.NewView(0).ToVectorisedView(),
})
diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD
new file mode 100644
index 000000000..2cdb23475
--- /dev/null
+++ b/pkg/tcpip/link/nested/BUILD
@@ -0,0 +1,32 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "nested",
+ srcs = [
+ "nested.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "nested_test",
+ size = "small",
+ srcs = [
+ "nested_test.go",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
new file mode 100644
index 000000000..d40de54df
--- /dev/null
+++ b/pkg/tcpip/link/nested/nested.go
@@ -0,0 +1,152 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package nested provides helpers to implement the pattern of nested
+// stack.LinkEndpoints.
+package nested
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// Endpoint is a wrapper around stack.LinkEndpoint and stack.NetworkDispatcher
+// that can be used to implement nesting safely by providing lifecycle
+// concurrency guards.
+//
+// See the tests in this package for example usage.
+type Endpoint struct {
+ child stack.LinkEndpoint
+ embedder stack.NetworkDispatcher
+
+ // mu protects dispatcher.
+ mu sync.RWMutex
+ dispatcher stack.NetworkDispatcher
+}
+
+var _ stack.GSOEndpoint = (*Endpoint)(nil)
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+var _ stack.NetworkDispatcher = (*Endpoint)(nil)
+
+// Init initializes a nested.Endpoint that uses embedder as the dispatcher for
+// child on Attach.
+//
+// See the tests in this package for example usage.
+func (e *Endpoint) Init(child stack.LinkEndpoint, embedder stack.NetworkDispatcher) {
+ e.child = child
+ e.embedder = embedder
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.
+func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.mu.RLock()
+ d := e.dispatcher
+ e.mu.RUnlock()
+ if d != nil {
+ d.DeliverNetworkPacket(remote, local, protocol, pkt)
+ }
+}
+
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.mu.RLock()
+ d := e.dispatcher
+ e.mu.RUnlock()
+ if d != nil {
+ d.DeliverOutboundPacket(remote, local, protocol, pkt)
+ }
+}
+
+// Attach implements stack.LinkEndpoint.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.mu.Lock()
+ e.dispatcher = dispatcher
+ e.mu.Unlock()
+ // If we're attaching to a valid dispatcher, pass embedder as the dispatcher
+ // to our child, otherwise detach the child by giving it a nil dispatcher.
+ var pass stack.NetworkDispatcher
+ if dispatcher != nil {
+ pass = e.embedder
+ }
+ e.child.Attach(pass)
+}
+
+// IsAttached implements stack.LinkEndpoint.
+func (e *Endpoint) IsAttached() bool {
+ e.mu.RLock()
+ isAttached := e.dispatcher != nil
+ e.mu.RUnlock()
+ return isAttached
+}
+
+// MTU implements stack.LinkEndpoint.
+func (e *Endpoint) MTU() uint32 {
+ return e.child.MTU()
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.child.Capabilities()
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (e *Endpoint) MaxHeaderLength() uint16 {
+ return e.child.MaxHeaderLength()
+}
+
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.child.LinkAddress()
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ return e.child.WritePacket(r, gso, protocol, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ return e.child.WritePackets(r, gso, pkts, protocol)
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ return e.child.WriteRawPacket(vv)
+}
+
+// Wait implements stack.LinkEndpoint.
+func (e *Endpoint) Wait() {
+ e.child.Wait()
+}
+
+// GSOMaxSize implements stack.GSOEndpoint.
+func (e *Endpoint) GSOMaxSize() uint32 {
+ if e, ok := e.child.(stack.GSOEndpoint); ok {
+ return e.GSOMaxSize()
+ }
+ return 0
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.child.ARPHardwareType()
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.child.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go
new file mode 100644
index 000000000..7d9249c1c
--- /dev/null
+++ b/pkg/tcpip/link/nested/nested_test.go
@@ -0,0 +1,109 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nested_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type parentEndpoint struct {
+ nested.Endpoint
+}
+
+var _ stack.LinkEndpoint = (*parentEndpoint)(nil)
+var _ stack.NetworkDispatcher = (*parentEndpoint)(nil)
+
+type childEndpoint struct {
+ stack.LinkEndpoint
+ dispatcher stack.NetworkDispatcher
+}
+
+var _ stack.LinkEndpoint = (*childEndpoint)(nil)
+
+func (c *childEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ c.dispatcher = dispatcher
+}
+
+func (c *childEndpoint) IsAttached() bool {
+ return c.dispatcher != nil
+}
+
+type counterDispatcher struct {
+ count int
+}
+
+var _ stack.NetworkDispatcher = (*counterDispatcher)(nil)
+
+func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
+ d.count++
+}
+
+func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
+func TestNestedLinkEndpoint(t *testing.T) {
+ const emptyAddress = tcpip.LinkAddress("")
+
+ var (
+ childEP childEndpoint
+ nestedEP parentEndpoint
+ disp counterDispatcher
+ )
+ nestedEP.Endpoint.Init(&childEP, &nestedEP)
+
+ if childEP.IsAttached() {
+ t.Error("On init, childEP.IsAttached() = true, want = false")
+ }
+ if nestedEP.IsAttached() {
+ t.Error("On init, nestedEP.IsAttached() = true, want = false")
+ }
+
+ nestedEP.Attach(&disp)
+ if disp.count != 0 {
+ t.Fatalf("After attach, got disp.count = %d, want = 0", disp.count)
+ }
+ if !childEP.IsAttached() {
+ t.Error("After attach, childEP.IsAttached() = false, want = true")
+ }
+ if !nestedEP.IsAttached() {
+ t.Error("After attach, nestedEP.IsAttached() = false, want = true")
+ }
+
+ nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{})
+ if disp.count != 1 {
+ t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count)
+ }
+
+ nestedEP.Attach(nil)
+ if childEP.IsAttached() {
+ t.Error("After detach, childEP.IsAttached() = true, want = false")
+ }
+ if nestedEP.IsAttached() {
+ t.Error("After detach, nestedEP.IsAttached() = true, want = false")
+ }
+
+ disp.count = 0
+ nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{})
+ if disp.count != 0 {
+ t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count)
+ }
+
+}
diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD
new file mode 100644
index 000000000..6fff160ce
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "packetsocket",
+ srcs = ["endpoint.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
new file mode 100644
index 000000000..3922c2a04
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package packetsocket provides a link layer endpoint that provides the ability
+// to loop outbound packets to any AF_PACKET sockets that may be interested in
+// the outgoing packet.
+package packetsocket
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type endpoint struct {
+ nested.Endpoint
+}
+
+// New creates a new packetsocket LinkEndpoint.
+func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
+ e := &endpoint{}
+ e.Endpoint.Init(lower, e)
+ return e
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
+ return e.Endpoint.WritePacket(r, gso, protocol, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ }
+
+ return e.Endpoint.WritePackets(r, gso, pkts, proto)
+}
diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD
index 054c213bc..1d0079bd6 100644
--- a/pkg/tcpip/link/qdisc/fifo/BUILD
+++ b/pkg/tcpip/link/qdisc/fifo/BUILD
@@ -14,6 +14,7 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index 54432194d..467083239 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -102,8 +103,13 @@ func (q *queueDispatcher) dispatchLoop() {
}
// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket.
-func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
- e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt)
+func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
+}
+
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
}
// Attach implements stack.LinkEndpoint.Attach.
@@ -146,7 +152,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
}
// WritePacket implements stack.LinkEndpoint.WritePacket.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
// WritePacket caller's do not set the following fields in PacketBuffer
// so we populate them here.
newRoute := r.Clone()
@@ -154,7 +160,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
pkt.GSOOptions = gso
pkt.NetworkProtocolNumber = protocol
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
- if !d.q.enqueue(&pkt) {
+ if !d.q.enqueue(pkt) {
return tcpip.ErrNoBufferSpace
}
d.newPacketWaker.Assert()
@@ -193,6 +199,8 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3267/): Queue these packets as well once
+ // WriteRawPacket takes PacketBuffer instead of VectorisedView.
return e.lower.WriteRawPacket(vv)
}
@@ -207,3 +215,13 @@ func (e *endpoint) Wait() {
e.wg.Wait()
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.lower.ARPHardwareType()
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index 0b5a6cf49..99313ee25 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -14,7 +14,7 @@
// +build linux,amd64 linux,arm64
// +build go1.12
-// +build !go1.15
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 44e25d475..f4c32c2da 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -66,39 +66,14 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error {
return nil
}
-// NonBlockingWrite3 writes up to three byte slices to a file descriptor in a
-// single syscall. It fails if partial data is written.
-func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error {
- // If the is no second buffer, issue a regular write.
- if len(b2) == 0 {
- return NonBlockingWrite(fd, b1)
- }
-
- // We have two buffers. Build the iovec that represents them and issue
- // a writev syscall.
- iovec := [3]syscall.Iovec{
- {
- Base: &b1[0],
- Len: uint64(len(b1)),
- },
- {
- Base: &b2[0],
- Len: uint64(len(b2)),
- },
- }
- iovecLen := uintptr(2)
-
- if len(b3) > 0 {
- iovecLen++
- iovec[2].Base = &b3[0]
- iovec[2].Len = uint64(len(b3))
- }
-
+// NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall.
+// It fails if partial data is written.
+func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error {
+ iovecLen := uintptr(len(iovec))
_, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen)
if e != 0 {
return TranslateErrno(e)
}
-
return nil
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 0796d717e..507c76b76 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -183,22 +183,29 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
- // Add the ethernet header here.
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ // Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+
+ // Preserve the src address if it's set in the route.
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
v := pkt.Data.ToView()
// Transmit the packet.
@@ -275,7 +282,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
// Send packet up the stack.
eth := header.Ethernet(b[:header.EthernetMinimumSize])
- d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), stack.PacketBuffer{
+ d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), &stack.PacketBuffer{
Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(),
LinkHeader: buffer.View(eth),
})
@@ -287,3 +294,8 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
e.completed.Done()
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (*endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareEther
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 33f640b85..8f3cd9449 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -131,7 +131,7 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
return c
}
-func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
c.mu.Lock()
c.packets = append(c.packets, packetInfo{
addr: remoteLinkAddr,
@@ -143,6 +143,10 @@ func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr,
c.packetCh <- struct{}{}
}
+func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (c *testContext) cleanup() {
c.ep.Close()
closeFDs(&c.txCfg)
@@ -273,7 +277,7 @@ func TestSimpleSend(t *testing.T) {
randomFill(buf)
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -345,7 +349,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
hdr := buffer.NewPrependable(header.EthernetMinimumSize)
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{
Header: hdr,
}); err != nil {
t.Fatalf("WritePacket failed: %v", err)
@@ -401,7 +405,7 @@ func TestFillTxQueue(t *testing.T) {
for i := queuePipeSize / 40; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -419,7 +423,7 @@ func TestFillTxQueue(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != want {
@@ -447,7 +451,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
// Send two packets so that the id slice has at least two slots.
for i := 2; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -470,7 +474,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
ids := make(map[uint64]struct{})
for i := queuePipeSize / 40; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -488,7 +492,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != want {
@@ -514,7 +518,7 @@ func TestFillTxMemory(t *testing.T) {
ids := make(map[uint64]struct{})
for i := queueDataSize / bufferSize; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -533,7 +537,7 @@ func TestFillTxMemory(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
})
@@ -561,7 +565,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
// until there is only one buffer left.
for i := queueDataSize/bufferSize - 1; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -577,7 +581,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
{
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
uu := buffer.NewView(bufferSize).ToVectorisedView()
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: uu,
}); err != want {
@@ -588,7 +592,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
// Attempt to write the one-buffer packet again. It must succeed.
{
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
index 230a8d53a..7cbc305e7 100644
--- a/pkg/tcpip/link/sniffer/BUILD
+++ b/pkg/tcpip/link/sniffer/BUILD
@@ -14,6 +14,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/nested",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index da1c520ae..509076643 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -47,26 +48,22 @@ var LogPackets uint32 = 1
// LogPacketsToPCAP must be accessed atomically.
var LogPacketsToPCAP uint32 = 1
-var transportProtocolMinSizes map[tcpip.TransportProtocolNumber]int = map[tcpip.TransportProtocolNumber]int{
- header.ICMPv4ProtocolNumber: header.IPv4MinimumSize,
- header.ICMPv6ProtocolNumber: header.IPv6MinimumSize,
- header.UDPProtocolNumber: header.UDPMinimumSize,
- header.TCPProtocolNumber: header.TCPMinimumSize,
-}
-
type endpoint struct {
- dispatcher stack.NetworkDispatcher
- lower stack.LinkEndpoint
+ nested.Endpoint
writer io.Writer
maxPCAPLen uint32
}
+var _ stack.GSOEndpoint = (*endpoint)(nil)
+var _ stack.LinkEndpoint = (*endpoint)(nil)
+var _ stack.NetworkDispatcher = (*endpoint)(nil)
+
// New creates a new sniffer link-layer endpoint. It wraps around another
// endpoint and logs packets and they traverse the endpoint.
func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
- return &endpoint{
- lower: lower,
- }
+ sniffer := &endpoint{}
+ sniffer.Endpoint.Init(lower, sniffer)
+ return sniffer
}
func zoneOffset() (int32, error) {
@@ -110,62 +107,25 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (
if err := writePCAPHeader(writer, snapLen); err != nil {
return nil, err
}
- return &endpoint{
- lower: lower,
+ sniffer := &endpoint{
writer: writer,
maxPCAPLen: snapLen,
- }, nil
+ }
+ sniffer.Endpoint.Init(lower, sniffer)
+ return sniffer, nil
}
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
-func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
- e.dumpPacket("recv", nil, protocol, &pkt)
- e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt)
-}
-
-// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher
-// and registers with the lower endpoint as its dispatcher so that "e" is called
-// for inbound packets.
-func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
- e.dispatcher = dispatcher
- e.lower.Attach(e)
-}
-
-// IsAttached implements stack.LinkEndpoint.IsAttached.
-func (e *endpoint) IsAttached() bool {
- return e.dispatcher != nil
-}
-
-// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the
-// lower endpoint.
-func (e *endpoint) MTU() uint32 {
- return e.lower.MTU()
+func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dumpPacket("recv", nil, protocol, pkt)
+ e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
-// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the
-// request to the lower endpoint.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.lower.Capabilities()
-}
-
-// MaxHeaderLength implements the stack.LinkEndpoint interface. It just forwards
-// the request to the lower endpoint.
-func (e *endpoint) MaxHeaderLength() uint16 {
- return e.lower.MaxHeaderLength()
-}
-
-func (e *endpoint) LinkAddress() tcpip.LinkAddress {
- return e.lower.LinkAddress()
-}
-
-// GSOMaxSize returns the maximum GSO packet size.
-func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.lower.(stack.GSOEndpoint); ok {
- return gso.GSOMaxSize()
- }
- return 0
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
}
func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
@@ -208,9 +168,9 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw
// WritePacket implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
- e.dumpPacket("send", gso, protocol, &pkt)
- return e.lower.WritePacket(r, gso, protocol, pkt)
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.dumpPacket("send", gso, protocol, pkt)
+ return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
// WritePackets implements the stack.LinkEndpoint interface. It is called by
@@ -220,7 +180,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.dumpPacket("send", gso, protocol, pkt)
}
- return e.lower.WritePackets(r, gso, pkts, protocol)
+ return e.Endpoint.WritePackets(r, gso, pkts, protocol)
}
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
@@ -228,12 +188,9 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
e.dumpPacket("send", nil, 0, &stack.PacketBuffer{
Data: vv,
})
- return e.lower.WriteRawPacket(vv)
+ return e.Endpoint.WriteRawPacket(vv)
}
-// Wait implements stack.LinkEndpoint.Wait.
-func (e *endpoint) Wait() { e.lower.Wait() }
-
func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
@@ -287,7 +244,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
vv.TrimFront(header.ARPSize)
arp := header.ARP(hdr)
log.Infof(
- "%s arp %v (%v) -> %v (%v) valid:%v",
+ "%s arp %s (%s) -> %s (%s) valid:%t",
prefix,
tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()),
tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()),
@@ -299,13 +256,6 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
return
}
- // We aren't guaranteed to have a transport header - it's possible for
- // writes via raw endpoints to contain only network headers.
- if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize {
- log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto)
- return
- }
-
// Figure out the transport layer info.
transName := "unknown"
srcPort := uint16(0)
@@ -346,7 +296,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
icmpType = "info reply"
}
}
- log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.ICMPv6ProtocolNumber:
@@ -381,7 +331,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.ICMPv6RedirectMsg:
icmpType = "redirect message"
}
- log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.UDPProtocolNumber:
@@ -428,7 +378,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
flagsStr[i] = ' '
}
}
- details = fmt.Sprintf("flags:0x%02x (%v) seqnum: %v ack: %v win: %v xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
+ details = fmt.Sprintf("flags:0x%02x (%s) seqnum: %d ack: %d win: %d xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
if flags&header.TCPFlagSyn != 0 {
details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0))
} else {
@@ -437,7 +387,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
}
default:
- log.Infof("%s %v -> %v unknown transport protocol: %d", prefix, src, dst, transProto)
+ log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto)
return
}
@@ -445,5 +395,5 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
details += fmt.Sprintf(" gso: %+v", gso)
}
- log.Infof("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
+ log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 617446ea2..04ae58e59 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -139,6 +139,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
stack: s,
nicID: id,
name: name,
+ isTap: prefix == "tap",
}
endpoint.Endpoint.LinkEPCapabilities = linkCaps
if endpoint.name == "" {
@@ -213,7 +214,7 @@ func (d *Device) Write(data []byte) (int64, error) {
remote = tcpip.LinkAddress(zeroMAC[:])
}
- pkt := stack.PacketBuffer{
+ pkt := &stack.PacketBuffer{
Data: buffer.View(data).ToVectorisedView(),
}
if ethHdr != nil {
@@ -271,21 +272,9 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
if d.hasFlags(linux.IFF_TAP) {
// Add ethernet header if not provided.
if info.Pkt.LinkHeader == nil {
- hdr := &header.EthernetFields{
- SrcAddr: info.Route.LocalLinkAddress,
- DstAddr: info.Route.RemoteLinkAddress,
- Type: info.Proto,
- }
- if hdr.SrcAddr == "" {
- hdr.SrcAddr = d.endpoint.LinkAddress()
- }
-
- eth := make(header.Ethernet, header.EthernetMinimumSize)
- eth.Encode(hdr)
- vv.AppendView(buffer.View(eth))
- } else {
- vv.AppendView(info.Pkt.LinkHeader)
+ d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt)
}
+ vv.AppendView(info.Pkt.LinkHeader)
}
// Append upper headers.
@@ -348,6 +337,7 @@ type tunEndpoint struct {
stack *stack.Stack
nicID tcpip.NICID
name string
+ isTap bool
}
// DecRef decrements refcount of e, removes NIC if refcount goes to 0.
@@ -356,3 +346,38 @@ func (e *tunEndpoint) DecRef() {
e.stack.RemoveNIC(e.nicID)
})
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.isTap {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.isTap {
+ return
+ }
+ eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
+ pkt.LinkHeader = buffer.View(eth)
+ hdr := &header.EthernetFields{
+ SrcAddr: local,
+ DstAddr: remote,
+ Type: protocol,
+ }
+ if hdr.SrcAddr == "" {
+ hdr.SrcAddr = e.LinkAddress()
+ }
+
+ eth.Encode(hdr)
+}
+
+// MaxHeaderLength returns the maximum size of the link layer header.
+func (e *tunEndpoint) MaxHeaderLength() uint16 {
+ if e.isTap {
+ return header.EthernetMinimumSize
+ }
+ return 0
+}
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index 0956d2c65..ee84c3d96 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -12,6 +12,7 @@ go_library(
"//pkg/gate",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
@@ -25,6 +26,7 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 2b3741276..b152a0f26 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/gate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -50,12 +51,21 @@ func New(lower stack.LinkEndpoint) *Endpoint {
// It is called by the link-layer endpoint being wrapped when a packet arrives,
// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't
// been called.
-func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
if !e.dispatchGate.Enter() {
return
}
- e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt)
+ e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
+ e.dispatchGate.Leave()
+}
+
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.dispatchGate.Enter() {
+ return
+ }
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
e.dispatchGate.Leave()
}
@@ -99,7 +109,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
if !e.writeGate.Enter() {
return nil
}
@@ -147,3 +157,13 @@ func (e *Endpoint) WaitDispatch() {
// Wait implements stack.LinkEndpoint.Wait.
func (e *Endpoint) Wait() {}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.lower.ARPHardwareType()
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 54eb5322b..c448a888f 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -35,10 +36,14 @@ type countedEndpoint struct {
dispatcher stack.NetworkDispatcher
}
-func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.dispatchCount++
}
+func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.attachCount++
e.dispatcher = dispatcher
@@ -65,7 +70,7 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
e.writeCount++
return nil
}
@@ -81,29 +86,39 @@ func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
return nil
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("unimplemented")
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*countedEndpoint) Wait() {}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
wep := New(ep)
// Write and check that it goes through.
- wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{})
if want := 1; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on dispatches, then try to write. It must go through.
wep.WaitDispatch()
- wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{})
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on writes, then try to write. It must not go through.
wep.WaitWrite()
- wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{})
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
@@ -120,21 +135,21 @@ func TestWaitDispatch(t *testing.T) {
}
// Dispatch and check that it goes through.
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{})
if want := 1; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on writes, then try to dispatch. It must go through.
wep.WaitWrite()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{})
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on dispatches, then try to dispatch. It must not go through.
wep.WaitDispatch()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{})
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 9d0797af7..31a242482 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -80,7 +80,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -94,16 +94,12 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList
return 0, tcpip.ErrNotSupported
}
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
-func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- v, ok := pkt.Data.PullUp(header.ARPSize)
- if !ok {
- return
- }
- h := header.ARP(v)
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ h := header.ARP(pkt.NetworkHeader)
if !h.IsValid() {
return
}
@@ -122,7 +118,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
- e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
+ e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
})
fallthrough // also fill the cache from requests
@@ -164,9 +160,12 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
- RemoteLinkAddress: broadcastMAC,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ if len(r.RemoteLinkAddress) == 0 {
+ r.RemoteLinkAddress = header.EthernetBroadcastAddress
}
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
@@ -177,7 +176,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
copy(h.ProtocolAddressSender(), localAddr)
copy(h.ProtocolAddressTarget(), addr)
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
})
}
@@ -185,7 +184,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
- return broadcastMAC, true
+ return header.EthernetBroadcastAddress, true
}
if header.IsV4MulticastAddress(addr) {
return header.EthernetAddressFromMulticastIPv4Address(addr), true
@@ -209,7 +208,16 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+// Parse implements stack.NetworkProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ hdr, ok := pkt.Data.PullUp(header.ARPSize)
+ if !ok {
+ return 0, false, false
+ }
+ pkt.NetworkHeader = hdr
+ pkt.Data.TrimFront(header.ARPSize)
+ return 0, false, true
+}
// NewProtocol returns an ARP network protocol.
func NewProtocol() stack.NetworkProtocol {
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 1646d9cde..a35a64a0f 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -32,10 +32,14 @@ import (
)
const (
- stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
- stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
- stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
- stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+ stackLinkAddr1 = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
+ stackLinkAddr2 = tcpip.LinkAddress("\x0b\x0b\x0c\x0c\x0d\x0d")
+ stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
+ stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
+ stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+
+ defaultChannelSize = 1
+ defaultMTU = 65536
)
type testContext struct {
@@ -50,8 +54,7 @@ func newTestContext(t *testing.T) *testContext {
TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
})
- const defaultMTU = 65536
- ep := channel.New(256, defaultMTU, stackLinkAddr)
+ ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1)
wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
@@ -103,7 +106,7 @@ func TestDirectRequest(t *testing.T) {
inject := func(addr tcpip.Address) {
copy(h.ProtocolAddressTarget(), addr)
- c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{
+ c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{
Data: v.ToVectorisedView(),
})
}
@@ -119,7 +122,7 @@ func TestDirectRequest(t *testing.T) {
if !rep.IsValid() {
t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength())
}
- if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want {
t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
}
if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
@@ -144,3 +147,44 @@ func TestDirectRequest(t *testing.T) {
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
}
}
+
+func TestLinkAddressRequest(t *testing.T) {
+ tests := []struct {
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Unicast",
+ remoteLinkAddr: stackLinkAddr2,
+ expectLinkAddr: stackLinkAddr2,
+ },
+ {
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectLinkAddr: header.EthernetBroadcastAddress,
+ },
+ }
+
+ for _, test := range tests {
+ p := arp.NewProtocol()
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
+ }
+
+ linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1)
+ if err := linkRes.LinkAddressRequest(stackAddr1, stackAddr2, test.remoteLinkAddr, linkEP); err != nil {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr1, stackAddr2, test.remoteLinkAddr, err)
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index f42abc4bb..1827666c5 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -17,28 +17,58 @@
package fragmentation
import (
+ "errors"
"fmt"
"log"
"time"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
-// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
-const DefaultReassembleTimeout = 30 * time.Second
+const (
+ // DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
+ DefaultReassembleTimeout = 30 * time.Second
-// HighFragThreshold is the threshold at which we start trimming old
-// fragmented packets. Linux uses a default value of 4 MB. See
-// net.ipv4.ipfrag_high_thresh for more information.
-const HighFragThreshold = 4 << 20 // 4MB
+ // HighFragThreshold is the threshold at which we start trimming old
+ // fragmented packets. Linux uses a default value of 4 MB. See
+ // net.ipv4.ipfrag_high_thresh for more information.
+ HighFragThreshold = 4 << 20 // 4MB
-// LowFragThreshold is the threshold we reach to when we start dropping
-// older fragmented packets. It's important that we keep enough room for newer
-// packets to be re-assembled. Hence, this needs to be lower than
-// HighFragThreshold enough. Linux uses a default value of 3 MB. See
-// net.ipv4.ipfrag_low_thresh for more information.
-const LowFragThreshold = 3 << 20 // 3MB
+ // LowFragThreshold is the threshold we reach to when we start dropping
+ // older fragmented packets. It's important that we keep enough room for newer
+ // packets to be re-assembled. Hence, this needs to be lower than
+ // HighFragThreshold enough. Linux uses a default value of 3 MB. See
+ // net.ipv4.ipfrag_low_thresh for more information.
+ LowFragThreshold = 3 << 20 // 3MB
+
+ // minBlockSize is the minimum block size for fragments.
+ minBlockSize = 1
+)
+
+var (
+ // ErrInvalidArgs indicates to the caller that that an invalid argument was
+ // provided.
+ ErrInvalidArgs = errors.New("invalid args")
+)
+
+// FragmentID is the identifier for a fragment.
+type FragmentID struct {
+ // Source is the source address of the fragment.
+ Source tcpip.Address
+
+ // Destination is the destination address of the fragment.
+ Destination tcpip.Address
+
+ // ID is the identification value of the fragment.
+ //
+ // This is a uint32 because IPv6 uses a 32-bit identification value.
+ ID uint32
+
+ // The protocol for the packet.
+ Protocol uint8
+}
// Fragmentation is the main structure that other modules
// of the stack should use to implement IP Fragmentation.
@@ -46,14 +76,17 @@ type Fragmentation struct {
mu sync.Mutex
highLimit int
lowLimit int
- reassemblers map[uint32]*reassembler
+ reassemblers map[FragmentID]*reassembler
rList reassemblerList
size int
timeout time.Duration
+ blockSize uint16
}
// NewFragmentation creates a new Fragmentation.
//
+// blockSize specifies the fragment block size, in bytes.
+//
// highMemoryLimit specifies the limit on the memory consumed
// by the fragments stored by Fragmentation (overhead of internal data-structures
// is not accounted). Fragments are dropped when the limit is reached.
@@ -64,7 +97,7 @@ type Fragmentation struct {
// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
-func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
+func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
if lowMemoryLimit >= highMemoryLimit {
lowMemoryLimit = highMemoryLimit
}
@@ -73,17 +106,46 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t
lowMemoryLimit = 0
}
+ if blockSize < minBlockSize {
+ blockSize = minBlockSize
+ }
+
return &Fragmentation{
- reassemblers: make(map[uint32]*reassembler),
+ reassemblers: make(map[FragmentID]*reassembler),
highLimit: highMemoryLimit,
lowLimit: lowMemoryLimit,
timeout: reassemblingTimeout,
+ blockSize: blockSize,
}
}
-// Process processes an incoming fragment belonging to an ID
-// and returns a complete packet when all the packets belonging to that ID have been received.
-func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
+// Process processes an incoming fragment belonging to an ID and returns a
+// complete packet when all the packets belonging to that ID have been received.
+//
+// [first, last] is the range of the fragment bytes.
+//
+// first must be a multiple of the block size f is configured with. The size
+// of the fragment data must be a multiple of the block size, unless there are
+// no fragments following this fragment (more set to false).
+func (f *Fragmentation) Process(id FragmentID, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
+ if first > last {
+ return buffer.VectorisedView{}, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
+ }
+
+ if first%f.blockSize != 0 {
+ return buffer.VectorisedView{}, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs)
+ }
+
+ fragmentSize := last - first + 1
+ if more && fragmentSize%f.blockSize != 0 {
+ return buffer.VectorisedView{}, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
+ }
+
+ if l := vv.Size(); l < int(fragmentSize) {
+ return buffer.VectorisedView{}, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
+ }
+ vv.CapLength(int(fragmentSize))
+
f.mu.Lock()
r, ok := f.reassemblers[id]
if ok && r.tooOld(f.timeout) {
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 72c0f53be..9eedd33c4 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -15,6 +15,7 @@
package fragmentation
import (
+ "errors"
"reflect"
"testing"
"time"
@@ -33,7 +34,7 @@ func vv(size int, pieces ...string) buffer.VectorisedView {
}
type processInput struct {
- id uint32
+ id FragmentID
first uint16
last uint16
more bool
@@ -53,8 +54,8 @@ var processTestCases = []struct {
{
comment: "One ID",
in: []processInput{
- {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -64,10 +65,10 @@ var processTestCases = []struct {
{
comment: "Two IDs",
in: []processInput{
- {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")},
- {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")},
- {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, vv: vv(2, "ab")},
+ {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, vv: vv(2, "cd")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -81,7 +82,7 @@ var processTestCases = []struct {
func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout)
for i, in := range c.in {
vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv)
if err != nil {
@@ -110,14 +111,14 @@ func TestFragmentationProcess(t *testing.T) {
func TestReassemblingTimeout(t *testing.T) {
timeout := time.Millisecond
- f := NewFragmentation(1024, 512, timeout)
+ f := NewFragmentation(minBlockSize, 1024, 512, timeout)
// Send first fragment with id = 0, first = 0, last = 0, and more = true.
- f.Process(0, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, vv(1, "0"))
// Sleep more than the timeout.
time.Sleep(2 * timeout)
// Send another fragment that completes a packet.
// However, no packet should be reassembled because the fragment arrived after the timeout.
- _, done, err := f.Process(0, 1, 1, false, vv(1, "1"))
+ _, done, err := f.Process(FragmentID{}, 1, 1, false, vv(1, "1"))
if err != nil {
t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err)
}
@@ -127,35 +128,35 @@ func TestReassemblingTimeout(t *testing.T) {
}
func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(3, 1, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout)
// Send first fragment with id = 0.
- f.Process(0, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{ID: 0}, 0, 0, true, vv(1, "0"))
// Send first fragment with id = 1.
- f.Process(1, 0, 0, true, vv(1, "1"))
+ f.Process(FragmentID{ID: 1}, 0, 0, true, vv(1, "1"))
// Send first fragment with id = 2.
- f.Process(2, 0, 0, true, vv(1, "2"))
+ f.Process(FragmentID{ID: 2}, 0, 0, true, vv(1, "2"))
// Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
// evicted.
- f.Process(3, 0, 0, true, vv(1, "3"))
+ f.Process(FragmentID{ID: 3}, 0, 0, true, vv(1, "3"))
- if _, ok := f.reassemblers[0]; ok {
+ if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
}
- if _, ok := f.reassemblers[1]; ok {
+ if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok {
t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
}
- if _, ok := f.reassemblers[3]; !ok {
+ if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok {
t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
}
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(1, 0, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout)
// Send first fragment with id = 0.
- f.Process(0, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, vv(1, "0"))
// Send the same packet again.
- f.Process(0, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, vv(1, "0"))
got := f.size
want := 1
@@ -163,3 +164,97 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
}
}
+
+func TestErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ blockSize uint16
+ first uint16
+ last uint16
+ more bool
+ data string
+ err error
+ }{
+ {
+ name: "exact block size without more",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: false,
+ data: "01",
+ },
+ {
+ name: "exact block size with more",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "01",
+ },
+ {
+ name: "exact block size with more and extra data",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "012",
+ },
+ {
+ name: "exact block size with more and too little data",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "0",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "not exact block size with more",
+ blockSize: 2,
+ first: 2,
+ last: 2,
+ more: true,
+ data: "0",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "not exact block size without more",
+ blockSize: 2,
+ first: 2,
+ last: 2,
+ more: false,
+ data: "0",
+ },
+ {
+ name: "first not a multiple of block size",
+ blockSize: 2,
+ first: 3,
+ last: 4,
+ more: true,
+ data: "01",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "first more than last",
+ blockSize: 2,
+ first: 4,
+ last: 3,
+ more: true,
+ data: "01",
+ err: ErrInvalidArgs,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout)
+ _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, vv(len(test.data), test.data))
+ if !errors.Is(err, test.err) {
+ t.Errorf("got Proceess(_, %d, %d, %t, %q) = (_, _, %v), want = (_, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
+ }
+ if done {
+ t.Errorf("got Proceess(_, %d, %d, %t, %q) = (_, true, _), want = (_, false, _)", test.first, test.last, test.more, test.data)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 0a83d81f2..50d30bbf0 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -32,7 +32,7 @@ type hole struct {
type reassembler struct {
reassemblerEntry
- id uint32
+ id FragmentID
size int
mu sync.Mutex
holes []hole
@@ -42,7 +42,7 @@ type reassembler struct {
creationTime time.Time
}
-func newReassembler(id uint32) *reassembler {
+func newReassembler(id FragmentID) *reassembler {
r := &reassembler{
id: id,
holes: make([]hole, 0, 16),
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index 7eee0710d..dff7c9dcb 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -94,7 +94,7 @@ var holesTestCases = []struct {
func TestUpdateHoles(t *testing.T) {
for _, c := range holesTestCases {
- r := newReassembler(0)
+ r := newReassembler(FragmentID{})
for _, i := range c.in {
r.updateHoles(i.first, i.last, i.more)
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 4c20301c6..615bae648 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -96,7 +96,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) {
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) {
t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
t.dataCalls++
}
@@ -104,7 +104,7 @@ func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.Trans
// DeliverTransportControlPacket is called by network endpoints after parsing
// incoming control (ICMP) packets. This is used by the test object to verify
// that the results of the parsing are expected.
-func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
t.checkValues(trans, pkt.Data, remote, local)
if typ != t.typ {
t.t.Errorf("typ = %v, want %v", typ, t.typ)
@@ -150,7 +150,7 @@ func (*testObject) Wait() {}
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
@@ -172,14 +172,24 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
-func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
return tcpip.ErrNotSupported
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*testObject) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("not implemented")
+}
+
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
@@ -246,7 +256,11 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: 123,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
}); err != nil {
@@ -289,9 +303,9 @@ func TestIPv4Receive(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: view.ToVectorisedView(),
- })
+ pkt := stack.PacketBuffer{Data: view.ToVectorisedView()}
+ proto.Parse(&pkt)
+ ep.HandlePacket(&r, &pkt)
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -378,10 +392,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
o.typ = c.expectedTyp
o.extra = c.expectedExtra
- vv := view[:len(view)-c.trunc].ToVectorisedView()
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: vv,
- })
+ ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
if want := c.expectedCount; o.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
}
@@ -444,17 +455,17 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Send first segment.
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: frag1.ToVectorisedView(),
- })
+ pkt := stack.PacketBuffer{Data: frag1.ToVectorisedView()}
+ proto.Parse(&pkt)
+ ep.HandlePacket(&r, &pkt)
if o.dataCalls != 0 {
t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
}
// Send second segment.
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: frag2.ToVectorisedView(),
- })
+ pkt = stack.PacketBuffer{Data: frag2.ToVectorisedView()}
+ proto.Parse(&pkt)
+ ep.HandlePacket(&r, &pkt)
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -487,7 +498,11 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: 123,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
}); err != nil {
@@ -530,9 +545,9 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: view.ToVectorisedView(),
- })
+ pkt := stack.PacketBuffer{Data: view.ToVectorisedView()}
+ proto.Parse(&pkt)
+ ep.HandlePacket(&r, &pkt)
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -644,12 +659,25 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: view[:len(view)-c.trunc].ToVectorisedView(),
- })
+ ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
if want := c.expectedCount; o.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
}
})
}
}
+
+// truncatedPacket returns a PacketBuffer based on a truncated view. If view,
+// after truncation, is large enough to hold a network header, it makes part of
+// view the packet's NetworkHeader and the rest its Data. Otherwise all of view
+// becomes Data.
+func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer {
+ v := view[:len(view)-trunc]
+ if len(v) < netHdrLen {
+ return &stack.PacketBuffer{Data: v.ToVectorisedView()}
+ }
+ return &stack.PacketBuffer{
+ NetworkHeader: v[:netHdrLen],
+ Data: v[netHdrLen:].ToVectorisedView(),
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 78420d6e6..d142b4ffa 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -34,6 +34,6 @@ go_test(
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 4cbefe5ab..83e71cb8c 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -24,7 +24,7 @@ import (
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
return
@@ -56,9 +56,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
+func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
stats := r.Stats()
received := stats.ICMP.V4PacketsReceived
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their
+ // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
if !ok {
received.Invalid.Increment()
@@ -88,7 +91,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
// It's possible that a raw socket expects to receive this.
h.SetChecksum(wantChecksum)
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, &stack.PacketBuffer{
Data: pkt.Data.Clone(nil),
NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...),
})
@@ -102,7 +105,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
pkt.SetChecksum(0)
pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv4ProtocolNumber,
+ TTL: r.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
Header: hdr,
Data: vv,
TransportHeader: buffer.View(pkt),
@@ -122,6 +129,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
pkt.Data.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
+ case header.ICMPv4HostUnreachable:
+ e.handleControl(stack.ControlNoRoute, 0, pkt)
+
case header.ICMPv4PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 64046cbbf..d5f5d38f7 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -21,6 +21,7 @@
package ipv4
import (
+ "fmt"
"sync/atomic"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -44,6 +45,10 @@ const (
// buckets is the number of identifier buckets.
buckets = 2048
+
+ // The size of a fragment block, in bytes, as per RFC 791 section 3.1,
+ // page 14.
+ fragmentblockSize = 8
)
type endpoint struct {
@@ -65,7 +70,7 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi
prefixLen: addrWithPrefix.PrefixLen,
linkEP: linkEP,
dispatcher: dispatcher,
- fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
protocol: p,
stack: st,
}
@@ -129,7 +134,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
// packet's stated length matches the length of the header+payload. mtu
// includes the IP header and options. This does not support the DontFragment
// IP flag.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error {
// This packet is too big, it needs to be fragmented.
ip := header.IPv4(pkt.Header.View())
flags := ip.Flags()
@@ -169,7 +174,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
if i > 0 {
newPayload := pkt.Data.Clone(nil)
newPayload.CapLength(innerMTU)
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
Header: pkt.Header,
Data: newPayload,
NetworkHeader: buffer.View(h),
@@ -188,7 +193,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
newPayload := pkt.Data.Clone(nil)
newPayloadLength := outerMTU - pkt.Header.UsedLength()
newPayload.CapLength(newPayloadLength)
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
Header: pkt.Header,
Data: newPayload,
NetworkHeader: buffer.View(h),
@@ -202,7 +207,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
startOfHdr := pkt.Header
startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU)
emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
Header: startOfHdr,
Data: emptyVV,
NetworkHeader: buffer.View(h),
@@ -224,12 +229,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
length := uint16(hdr.UsedLength() + payloadSize)
- id := uint32(0)
- if length > header.IPv4MaximumHeaderSize+8 {
- // Packets of 68 bytes or less are required by RFC 791 to not be
- // fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
- }
+ // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
+ // datagrams. Since the DF bit is never being set here, all datagrams
+ // are non-atomic and need an ID.
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: length,
@@ -245,7 +248,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
@@ -253,43 +256,29 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
ipt := e.stack.IPTables()
- if ok := ipt.Check(stack.Output, &pkt, gso, r, "", nicName); !ok {
+ if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
return nil
}
+ // If the packet is manipulated as per NAT Ouput rules, handle packet
+ // based on destination address and do not send the packet to link layer.
+ // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than
+ // only NATted packets, but removing this check short circuits broadcasts
+ // before they are sent out to other hosts.
if pkt.NatDone {
- // If the packet is manipulated as per NAT Ouput rules, handle packet
- // based on destination address and do not send the packet to link layer.
netHeader := header.IPv4(pkt.NetworkHeader)
ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
- src := netHeader.SourceAddress()
- dst := netHeader.DestinationAddress()
- route := r.ReverseRoute(src, dst)
-
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
- packet := stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)}
- ep.HandlePacket(&route, packet)
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ ep.HandlePacket(&route, pkt)
return nil
}
}
if r.Loop&stack.PacketLoop != 0 {
- // The inbound path expects the network header to still be in
- // the PacketBuffer's Data field.
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
-
- e.HandlePacket(&loopedR, stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
- })
-
+ e.HandlePacket(&loopedR, pkt)
loopedR.Release()
}
if r.Loop&stack.PacketOut == 0 {
@@ -342,23 +331,16 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader)
- ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
- if err == nil {
+ if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
-
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
- packet := stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)}
- ep.HandlePacket(&route, packet)
+ ep.HandlePacket(&route, pkt)
n++
continue
}
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil {
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
return n, err
}
@@ -370,7 +352,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
@@ -396,13 +378,12 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf
// Set the packet ID when zero.
if ip.ID() == 0 {
- id := uint32(0)
- if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 {
- // Packets of 68 bytes or less are required by RFC 791 to not be
- // fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
+ // RFC 6864 section 4.3 mandates uniqueness of ID values for
+ // non-atomic datagrams, so assign an ID to all such datagrams
+ // according to the definition given in RFC 6864 section 4.
+ if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 {
+ ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
}
- ip.SetID(uint16(id))
}
// Always set the checksum.
@@ -426,35 +407,23 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ h := header.IPv4(pkt.NetworkHeader)
+ if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
- h := header.IPv4(headerView)
- if !h.IsValid(pkt.Data.Size()) {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- return
- }
- pkt.NetworkHeader = headerView[:h.HeaderLength()]
-
- hlen := int(h.HeaderLength())
- tlen := int(h.TotalLength())
- pkt.Data.TrimFront(hlen)
- pkt.Data.CapLength(tlen - hlen)
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
ipt := e.stack.IPTables()
- if ok := ipt.Check(stack.Input, &pkt, nil, nil, "", ""); !ok {
+ if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
return
}
- more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
- if more || h.FragmentOffset() != 0 {
- if pkt.Data.Size() == 0 {
+ if h.More() || h.FragmentOffset() != 0 {
+ if pkt.Data.Size()+len(pkt.TransportHeader) == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -473,7 +442,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
}
var ready bool
var err error
- pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data)
+ pkt.Data, ready, err = e.fragmentation.Process(
+ fragmentation.FragmentID{
+ Source: h.SourceAddress(),
+ Destination: h.DestinationAddress(),
+ ID: uint32(h.ID()),
+ Protocol: h.Protocol(),
+ },
+ h.FragmentOffset(),
+ last,
+ h.More(),
+ pkt.Data,
+ )
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
@@ -485,7 +465,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
}
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
- headerView.CapLength(hlen)
+ pkt.NetworkHeader.CapLength(int(h.HeaderLength()))
e.handleICMP(r, pkt)
return
}
@@ -565,6 +545,35 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return 0, false, false
+ }
+ ipHdr := header.IPv4(hdr)
+
+ // If there are options, pull those into hdr as well.
+ if headerLen := int(ipHdr.HeaderLength()); headerLen > header.IPv4MinimumSize && headerLen <= pkt.Data.Size() {
+ hdr, ok = pkt.Data.PullUp(headerLen)
+ if !ok {
+ panic(fmt.Sprintf("There are only %d bytes in pkt.Data, but there should be at least %d", pkt.Data.Size(), headerLen))
+ }
+ ipHdr = header.IPv4(hdr)
+ }
+
+ // If this is a fragment, don't bother parsing the transport header.
+ parseTransportHeader := true
+ if ipHdr.More() || ipHdr.FragmentOffset() != 0 {
+ parseTransportHeader = false
+ }
+
+ pkt.NetworkHeader = hdr
+ pkt.Data.TrimFront(len(hdr))
+ pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
+ return ipHdr.TransportProtocol(), parseTransportHeader, true
+}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 36035c820..ded97ac64 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -114,7 +114,7 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
-func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) {
+func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
t.Helper()
// Make a complete array of the sourcePacketInfo packet.
source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
@@ -174,7 +174,7 @@ func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketIn
type errorChannel struct {
*channel.Endpoint
- Ch chan stack.PacketBuffer
+ Ch chan *stack.PacketBuffer
packetCollectorErrors []*tcpip.Error
}
@@ -184,7 +184,7 @@ type errorChannel struct {
func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
return &errorChannel{
Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan stack.PacketBuffer, size),
+ Ch: make(chan *stack.PacketBuffer, size),
packetCollectorErrors: packetCollectorErrors,
}
}
@@ -203,7 +203,7 @@ func (e *errorChannel) Drain() int {
}
// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
select {
case e.Ch <- pkt:
default:
@@ -282,13 +282,17 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
- source := stack.PacketBuffer{
+ source := &stack.PacketBuffer{
Header: hdr,
// Save the source payload because WritePacket will modify it.
Data: payload.Clone(nil),
}
c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: 42,
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
Header: hdr,
Data: payload,
})
@@ -296,7 +300,7 @@ func TestFragmentation(t *testing.T) {
t.Errorf("err got %v, want %v", err, nil)
}
- var results []stack.PacketBuffer
+ var results []*stack.PacketBuffer
L:
for {
select {
@@ -338,7 +342,11 @@ func TestFragmentationErrors(t *testing.T) {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: 42,
+ TOS: stack.DefaultTOS,
+ }, &stack.PacketBuffer{
Header: hdr,
Data: payload,
})
@@ -460,7 +468,7 @@ func TestInvalidFragments(t *testing.T) {
s.CreateNIC(nicID, sniffer.New(ep))
for _, pkt := range tc.packets {
- ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{
+ ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, &stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
})
}
@@ -478,12 +486,16 @@ func TestInvalidFragments(t *testing.T) {
// TestReceiveFragments feeds fragments in through the incoming packet path to
// test reassembly
func TestReceiveFragments(t *testing.T) {
- const addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
- const addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
- const nicID = 1
+ const (
+ nicID = 1
+
+ addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
+ addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
+ addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3
+ )
// Build and return a UDP header containing payload.
- udpGen := func(payloadLen int, multiplier uint8) buffer.View {
+ udpGen := func(payloadLen int, multiplier uint8, src, dst tcpip.Address) buffer.View {
payload := buffer.NewView(payloadLen)
for i := 0; i < len(payload); i++ {
payload[i] = uint8(i) * multiplier
@@ -499,20 +511,29 @@ func TestReceiveFragments(t *testing.T) {
Length: uint16(udpLength),
})
copy(u.Payload(), payload)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength))
sum = header.Checksum(payload, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
return hdr.View()
}
// UDP header plus a payload of 0..256
- ipv4Payload1 := udpGen(256, 1)
- udpPayload1 := ipv4Payload1[header.UDPMinimumSize:]
+ ipv4Payload1Addr1ToAddr2 := udpGen(256, 1, addr1, addr2)
+ udpPayload1Addr1ToAddr2 := ipv4Payload1Addr1ToAddr2[header.UDPMinimumSize:]
+ ipv4Payload1Addr3ToAddr2 := udpGen(256, 1, addr3, addr2)
+ udpPayload1Addr3ToAddr2 := ipv4Payload1Addr3ToAddr2[header.UDPMinimumSize:]
// UDP header plus a payload of 0..256 in increments of 2.
- ipv4Payload2 := udpGen(128, 2)
- udpPayload2 := ipv4Payload2[header.UDPMinimumSize:]
+ ipv4Payload2Addr1ToAddr2 := udpGen(128, 2, addr1, addr2)
+ udpPayload2Addr1ToAddr2 := ipv4Payload2Addr1ToAddr2[header.UDPMinimumSize:]
+ // UDP header plus a payload of 0..256 in increments of 3.
+ // Used to test cases where the fragment blocks are not a multiple of
+ // the fragment block size of 8 (RFC 791 section 3.1 page 14).
+ ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2)
+ udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:]
type fragmentData struct {
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
id uint16
flags uint8
fragmentOffset uint16
@@ -528,22 +549,40 @@ func TestReceiveFragments(t *testing.T) {
name: "No fragmentation",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: 0,
fragmentOffset: 0,
- payload: ipv4Payload1,
+ payload: ipv4Payload1Addr1ToAddr2,
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "No fragmentation with size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 0,
+ payload: ipv4Payload3Addr1ToAddr2,
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
},
{
name: "More fragments without payload",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1,
+ payload: ipv4Payload1Addr1ToAddr2,
},
},
expectedPayloads: nil,
@@ -552,10 +591,12 @@ func TestReceiveFragments(t *testing.T) {
name: "Non-zero fragment offset without payload",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: 0,
fragmentOffset: 8,
- payload: ipv4Payload1,
+ payload: ipv4Payload1Addr1ToAddr2,
},
},
expectedPayloads: nil,
@@ -564,34 +605,86 @@ func TestReceiveFragments(t *testing.T) {
name: "Two fragments",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1[:64],
+ payload: ipv4Payload1Addr1ToAddr2[:64],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: 0,
fragmentOffset: 64,
- payload: ipv4Payload1[64:],
+ payload: ipv4Payload1Addr1ToAddr2[64:],
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with last fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload3Addr1ToAddr2[:64],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload3Addr1ToAddr2[64:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with first fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload3Addr1ToAddr2[:63],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 63,
+ payload: ipv4Payload3Addr1ToAddr2[63:],
+ },
+ },
+ expectedPayloads: nil,
},
{
name: "Second fragment has MoreFlags set",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1[:64],
+ payload: ipv4Payload1Addr1ToAddr2[:64],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 64,
- payload: ipv4Payload1[64:],
+ payload: ipv4Payload1Addr1ToAddr2[64:],
},
},
expectedPayloads: nil,
@@ -600,16 +693,20 @@ func TestReceiveFragments(t *testing.T) {
name: "Two fragments with different IDs",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1[:64],
+ payload: ipv4Payload1Addr1ToAddr2[:64],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 2,
flags: 0,
fragmentOffset: 64,
- payload: ipv4Payload1[64:],
+ payload: ipv4Payload1Addr1ToAddr2[64:],
},
},
expectedPayloads: nil,
@@ -618,31 +715,91 @@ func TestReceiveFragments(t *testing.T) {
name: "Two interleaved fragmented packets",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1[:64],
+ payload: ipv4Payload1Addr1ToAddr2[:64],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 2,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload2[:64],
+ payload: ipv4Payload2Addr1ToAddr2[:64],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: 0,
fragmentOffset: 64,
- payload: ipv4Payload1[64:],
+ payload: ipv4Payload1Addr1ToAddr2[64:],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 2,
flags: 0,
fragmentOffset: 64,
- payload: ipv4Payload2[64:],
+ payload: ipv4Payload2Addr1ToAddr2[64:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2},
+ },
+ {
+ name: "Two interleaved fragmented packets from different sources but with same ID",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr3ToAddr2[:32],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload1Addr1ToAddr2[64:],
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 32,
+ payload: ipv4Payload1Addr3ToAddr2[32:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2},
+ },
+ {
+ name: "Fragment without followup",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
},
},
- expectedPayloads: [][]byte{udpPayload1, udpPayload2},
+ expectedPayloads: nil,
},
}
@@ -691,14 +848,14 @@ func TestReceiveFragments(t *testing.T) {
FragmentOffset: frag.fragmentOffset,
TTL: 64,
Protocol: uint8(header.UDPProtocolNumber),
- SrcAddr: addr1,
- DstAddr: addr2,
+ SrcAddr: frag.srcAddr,
+ DstAddr: frag.dstAddr,
})
vv := hdr.View().ToVectorisedView()
vv.AppendView(frag.payload)
- e.InjectInbound(header.IPv4ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{
Data: vv,
})
}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 3f71fc520..bcc64994e 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -14,7 +14,6 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/network/fragmentation",
- "//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
],
)
@@ -39,6 +38,6 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index bdf3a0d25..24600d877 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -27,7 +27,7 @@ import (
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
if !ok {
return
@@ -70,17 +70,20 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) {
+func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their
+ // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize)
if !ok {
received.Invalid.Increment()
return
}
h := header.ICMPv6(v)
- iph := header.IPv6(netHeader)
+ iph := header.IPv6(pkt.NetworkHeader)
// Validate ICMPv6 checksum before processing the packet.
//
@@ -125,6 +128,8 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
}
pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
switch header.ICMPv6(hdr).Code() {
+ case header.ICMPv6NetworkUnreachable:
+ e.handleControl(stack.ControlNetworkUnreachable, 0, pkt)
case header.ICMPv6PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
}
@@ -288,7 +293,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
//
// The IP Hop Limit field has a value of 255, i.e., the packet
// could not possibly have been forwarded by a router.
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
}); err != nil {
sent.Dropped.Increment()
@@ -390,7 +395,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
copy(packet, icmpHdr)
packet.SetType(header.ICMPv6EchoReply)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
Data: pkt.Data,
}); err != nil {
@@ -491,8 +496,6 @@ const (
icmpV6LengthOffset = 25
)
-var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
-
var _ stack.LinkAddressResolver = (*protocol)(nil)
// LinkAddressProtocol implements stack.LinkAddressResolver.
@@ -501,7 +504,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
snaddr := header.SolicitedNodeAddr(addr)
// TODO(b/148672031): Use stack.FindRoute instead of manually creating the
@@ -510,8 +513,12 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
r := &stack.Route{
LocalAddress: localAddr,
RemoteAddress: snaddr,
- RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr),
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ if len(r.RemoteLinkAddress) == 0 {
+ r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr)
}
+
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
@@ -532,7 +539,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
})
// TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{
Header: hdr,
})
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index d412ff688..f86aaed1d 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -34,6 +34,9 @@ const (
linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+
+ defaultChannelSize = 1
+ defaultMTU = 65536
)
var (
@@ -57,7 +60,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
-func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error {
+func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
return nil
}
@@ -67,7 +70,7 @@ type stubDispatcher struct {
stack.TransportDispatcher
}
-func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) {
+func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
}
type stubLinkAddressCache struct {
@@ -179,36 +182,32 @@ func TestICMPCounts(t *testing.T) {
},
}
- handleIPv6Payload := func(hdr buffer.Prependable) {
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ handleIPv6Payload := func(icmp header.ICMPv6) {
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
+ PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
+ ep.HandlePacket(&r, &stack.PacketBuffer{
+ NetworkHeader: buffer.View(ip),
+ Data: buffer.View(icmp).ToVectorisedView(),
})
}
for _, typ := range types {
- extraDataLen := len(typ.extraData)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
- extraData := buffer.View(hdr.Prepend(extraDataLen))
- copy(extraData, typ.extraData)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
-
- handleIPv6Payload(hdr)
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ handleIPv6Payload(icmp)
}
// Construct an empty ICMP packet so that
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
- handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize))
+ handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
@@ -261,8 +260,7 @@ func newTestContext(t *testing.T) *testContext {
}),
}
- const defaultMTU = 65536
- c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
+ c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0)
wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
if testing.Verbose() {
@@ -275,7 +273,7 @@ func newTestContext(t *testing.T) *testContext {
t.Fatalf("AddAddress lladdr0: %v", err)
}
- c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
+ c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
@@ -328,7 +326,7 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size()
vv := buffer.NewVectorisedView(size, views)
- args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{
+ args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), &stack.PacketBuffer{
Data: vv,
})
}
@@ -546,25 +544,22 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
}
handleIPv6Payload := func(checksum bool) {
- extraDataLen := len(typ.extraData)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
- extraData := buffer.View(hdr.Prepend(extraDataLen))
- copy(extraData, typ.extraData)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
}
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(typ.size + extraDataLen),
+ PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
})
}
@@ -740,7 +735,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
}
@@ -918,7 +913,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
})
}
@@ -958,3 +953,47 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
})
}
}
+
+func TestLinkAddressRequest(t *testing.T) {
+ snaddr := header.SolicitedNodeAddr(lladdr0)
+ mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
+
+ tests := []struct {
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Unicast",
+ remoteLinkAddr: linkAddr1,
+ expectLinkAddr: linkAddr1,
+ },
+ {
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectLinkAddr: mcaddr,
+ },
+ }
+
+ for _, test := range tests {
+ p := NewProtocol()
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
+ }
+
+ linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
+ if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err)
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index daf1fcbc6..a0a5c9c01 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -28,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
- "gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -116,7 +115,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
@@ -128,7 +127,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, stack.PacketBuffer{
+ e.HandlePacket(&loopedR, &stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
@@ -163,30 +162,28 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
+func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
// TODO(b/146666412): Support IPv6 header-included packets.
return tcpip.ErrNotSupported
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
- if !ok {
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ h := header.IPv6(pkt.NetworkHeader)
+ if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
- h := header.IPv6(headerView)
- if !h.IsValid(pkt.Data.Size()) {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- return
- }
-
- pkt.NetworkHeader = headerView[:header.IPv6MinimumSize]
- pkt.Data.TrimFront(header.IPv6MinimumSize)
- pkt.Data.CapLength(int(h.PayloadLength()))
- it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data)
+ // vv consists of:
+ // - Any IPv6 header bytes after the first 40 (i.e. extensions).
+ // - The transport header, if present.
+ // - Any other payload data.
+ vv := pkt.NetworkHeader[header.IPv6MinimumSize:].ToVectorisedView()
+ vv.AppendView(pkt.TransportHeader)
+ vv.Append(pkt.Data)
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
hasFragmentHeader := false
for firstHeader := true; ; firstHeader = false {
@@ -262,9 +259,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
case header.IPv6FragmentExtHdr:
hasFragmentHeader = true
- fragmentOffset := extHdr.FragmentOffset()
- more := extHdr.More()
- if !more && fragmentOffset == 0 {
+ if extHdr.IsAtomic() {
// This fragment extension header indicates that this packet is an
// atomic fragment. An atomic fragment is a fragment that contains
// all the data required to reassemble a full packet. As per RFC 6946,
@@ -277,9 +272,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
// Don't consume the iterator if we have the first fragment because we
// will use it to validate that the first fragment holds the upper layer
// header.
- rawPayload := it.AsRawHeader(fragmentOffset != 0 /* consume */)
+ rawPayload := it.AsRawHeader(extHdr.FragmentOffset() != 0 /* consume */)
- if fragmentOffset == 0 {
+ if extHdr.FragmentOffset() == 0 {
// Check that the iterator ends with a raw payload as the first fragment
// should include all headers up to and including any upper layer
// headers, as per RFC 8200 section 4.5; only upper layer data
@@ -332,7 +327,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
}
// The packet is a fragment, let's try to reassemble it.
- start := fragmentOffset * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
+ start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
last := start + uint16(fragmentPayloadLen) - 1
// Drop the packet if the fragmentOffset is incorrect. i.e the
@@ -345,7 +340,21 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
}
var ready bool
- pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, more, rawPayload.Buf)
+ // Note that pkt doesn't have its transport header set after reassembly,
+ // and won't until DeliverNetworkPacket sets it.
+ pkt.Data, ready, err = e.fragmentation.Process(
+ // IPv6 ignores the Protocol field since the ID only needs to be unique
+ // across source-destination pairs, as per RFC 8200 section 4.5.
+ fragmentation.FragmentID{
+ Source: h.SourceAddress(),
+ Destination: h.DestinationAddress(),
+ ID: extHdr.ID(),
+ },
+ start,
+ last,
+ extHdr.More(),
+ rawPayload.Buf,
+ )
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
@@ -394,10 +403,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
case header.IPv6RawPayloadHeader:
// If the last header in the payload isn't a known IPv6 extension header,
// handle it as if it is transport layer data.
+
+ // For unfragmented packets, extHdr still contains the transport header.
+ // Get rid of it.
+ //
+ // For reassembled fragments, pkt.TransportHeader is unset, so this is a
+ // no-op and pkt.Data begins with the transport header.
+ extHdr.Buf.TrimFront(len(pkt.TransportHeader))
pkt.Data = extHdr.Buf
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
- e.handleICMP(r, headerView, pkt, hasFragmentHeader)
+ e.handleICMP(r, pkt, hasFragmentHeader)
} else {
r.Stats().IP.PacketsDelivered.Increment()
// TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
@@ -462,7 +478,7 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi
linkEP: linkEP,
linkAddrCache: linkAddrCache,
dispatcher: dispatcher,
- fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
protocol: p,
}, nil
}
@@ -505,6 +521,79 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return 0, false, false
+ }
+ ipHdr := header.IPv6(hdr)
+
+ // dataClone consists of:
+ // - Any IPv6 header bytes after the first 40 (i.e. extensions).
+ // - The transport header, if present.
+ // - Any other payload data.
+ views := [8]buffer.View{}
+ dataClone := pkt.Data.Clone(views[:])
+ dataClone.TrimFront(header.IPv6MinimumSize)
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)
+
+ // Iterate over the IPv6 extensions to find their length.
+ //
+ // Parsing occurs again in HandlePacket because we don't track the
+ // extensions in PacketBuffer. Unfortunately, that means HandlePacket
+ // has to do the parsing work again.
+ var nextHdr tcpip.TransportProtocolNumber
+ foundNext := true
+ extensionsSize := 0
+traverseExtensions:
+ for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() {
+ if err != nil {
+ break
+ }
+ // If we exhaust the extension list, the entire packet is the IPv6 header
+ // and (possibly) extensions.
+ if done {
+ extensionsSize = dataClone.Size()
+ foundNext = false
+ break
+ }
+
+ switch extHdr := extHdr.(type) {
+ case header.IPv6FragmentExtHdr:
+ // If this is an atomic fragment, we don't have to treat it specially.
+ if !extHdr.More() && extHdr.FragmentOffset() == 0 {
+ continue
+ }
+ // This is a non-atomic fragment and has to be re-assembled before we can
+ // examine the payload for a transport header.
+ foundNext = false
+
+ case header.IPv6RawPayloadHeader:
+ // We've found the payload after any extensions.
+ extensionsSize = dataClone.Size() - extHdr.Buf.Size()
+ nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
+ break traverseExtensions
+
+ default:
+ // Any other extension is a no-op, keep looping until we find the payload.
+ }
+ }
+
+ // Put the IPv6 header with extensions in pkt.NetworkHeader.
+ hdr, ok = pkt.Data.PullUp(header.IPv6MinimumSize + extensionsSize)
+ if !ok {
+ panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
+ }
+ ipHdr = header.IPv6(hdr)
+
+ pkt.NetworkHeader = hdr
+ pkt.Data.TrimFront(len(hdr))
+ pkt.Data.CapLength(int(ipHdr.PayloadLength()))
+
+ return nextHdr, foundNext, true
+}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 841a0cb7a..3d65814de 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -65,7 +65,7 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -123,7 +123,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -637,7 +637,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
DstAddr: addr2,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -673,20 +673,27 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// fragmentData holds the IPv6 payload for a fragmented IPv6 packet.
type fragmentData struct {
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
nextHdr uint8
data buffer.VectorisedView
}
func TestReceiveIPv6Fragments(t *testing.T) {
- const nicID = 1
- const udpPayload1Length = 256
- const udpPayload2Length = 128
- const fragmentExtHdrLen = 8
- // Note, not all routing extension headers will be 8 bytes but this test
- // uses 8 byte routing extension headers for most sub tests.
- const routingExtHdrLen = 8
-
- udpGen := func(payload []byte, multiplier uint8) buffer.View {
+ const (
+ nicID = 1
+ udpPayload1Length = 256
+ udpPayload2Length = 128
+ // Used to test cases where the fragment blocks are not a multiple of
+ // the fragment block size of 8 (RFC 8200 section 4.5).
+ udpPayload3Length = 127
+ fragmentExtHdrLen = 8
+ // Note, not all routing extension headers will be 8 bytes but this test
+ // uses 8 byte routing extension headers for most sub tests.
+ routingExtHdrLen = 8
+ )
+
+ udpGen := func(payload []byte, multiplier uint8, src, dst tcpip.Address) buffer.View {
payloadLen := len(payload)
for i := 0; i < payloadLen; i++ {
payload[i] = uint8(i) * multiplier
@@ -702,19 +709,27 @@ func TestReceiveIPv6Fragments(t *testing.T) {
Length: uint16(udpLength),
})
copy(u.Payload(), payload)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength))
sum = header.Checksum(payload, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
return hdr.View()
}
- var udpPayload1Buf [udpPayload1Length]byte
- udpPayload1 := udpPayload1Buf[:]
- ipv6Payload1 := udpGen(udpPayload1, 1)
+ var udpPayload1Addr1ToAddr2Buf [udpPayload1Length]byte
+ udpPayload1Addr1ToAddr2 := udpPayload1Addr1ToAddr2Buf[:]
+ ipv6Payload1Addr1ToAddr2 := udpGen(udpPayload1Addr1ToAddr2, 1, addr1, addr2)
+
+ var udpPayload1Addr3ToAddr2Buf [udpPayload1Length]byte
+ udpPayload1Addr3ToAddr2 := udpPayload1Addr3ToAddr2Buf[:]
+ ipv6Payload1Addr3ToAddr2 := udpGen(udpPayload1Addr3ToAddr2, 4, addr3, addr2)
- var udpPayload2Buf [udpPayload2Length]byte
- udpPayload2 := udpPayload2Buf[:]
- ipv6Payload2 := udpGen(udpPayload2, 2)
+ var udpPayload2Addr1ToAddr2Buf [udpPayload2Length]byte
+ udpPayload2Addr1ToAddr2 := udpPayload2Addr1ToAddr2Buf[:]
+ ipv6Payload2Addr1ToAddr2 := udpGen(udpPayload2Addr1ToAddr2, 2, addr1, addr2)
+
+ var udpPayload3Addr1ToAddr2Buf [udpPayload3Length]byte
+ udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:]
+ ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2)
tests := []struct {
name string
@@ -726,34 +741,98 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "No fragmentation",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: uint8(header.UDPProtocolNumber),
- data: ipv6Payload1.ToVectorisedView(),
+ data: ipv6Payload1Addr1ToAddr2.ToVectorisedView(),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
},
{
name: "Atomic fragment",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2),
+ []buffer.View{
+ // Fragment extension header.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
+
+ ipv6Payload1Addr1ToAddr2,
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Atomic fragment with size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1),
+ fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
- ipv6Payload1,
+ ipv6Payload3Addr1ToAddr2,
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
},
{
name: "Two fragments",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with last fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+64,
@@ -763,31 +842,73 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload3Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload3Addr1ToAddr2[64:],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with first fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+63,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload3Addr1ToAddr2[:63],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-63,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload3Addr1ToAddr2[63:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
},
{
name: "Two fragments with different IDs",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+64,
@@ -797,21 +918,23 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 2
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
@@ -822,6 +945,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "Two fragments with per-fragment routing header with zero segments left",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: routingExtHdrID,
data: buffer.NewVectorisedView(
routingExtHdrLen+fragmentExtHdrLen+64,
@@ -836,14 +961,16 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: routingExtHdrID,
data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64,
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Routing extension header.
//
@@ -855,17 +982,19 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 8, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
},
{
name: "Two fragments with per-fragment routing header with non-zero segments left",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: routingExtHdrID,
data: buffer.NewVectorisedView(
routingExtHdrLen+fragmentExtHdrLen+64,
@@ -880,14 +1009,16 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: routingExtHdrID,
data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64,
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Routing extension header.
//
@@ -899,7 +1030,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 9, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
@@ -910,6 +1041,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "Two fragments with routing header with zero segments left",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
routingExtHdrLen+fragmentExtHdrLen+64,
@@ -924,31 +1057,35 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Segments left = 0.
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 9, More = false, ID = 1
buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
},
{
name: "Two fragments with routing header with non-zero segments left",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
routingExtHdrLen+fragmentExtHdrLen+64,
@@ -963,21 +1100,23 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Segments left = 1.
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 9, More = false, ID = 1
buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
@@ -988,6 +1127,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "Two fragments with routing header with zero segments left across fragments",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
// The length of this payload is fragmentExtHdrLen+8 because the
@@ -1008,12 +1149,14 @@ func TestReceiveIPv6Fragments(t *testing.T) {
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
// The length of this payload is
- // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of
+ // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of
// the 16 byte routing extension header is in this fagment.
- fragmentExtHdrLen+8+len(ipv6Payload1),
+ fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
//
@@ -1023,7 +1166,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Routing extension header (part 2)
buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
- ipv6Payload1,
+ ipv6Payload1Addr1ToAddr2,
},
),
},
@@ -1034,6 +1177,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "Two fragments with routing header with non-zero segments left across fragments",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
// The length of this payload is fragmentExtHdrLen+8 because the
@@ -1054,12 +1199,14 @@ func TestReceiveIPv6Fragments(t *testing.T) {
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
// The length of this payload is
- // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of
+ // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of
// the 16 byte routing extension header is in this fagment.
- fragmentExtHdrLen+8+len(ipv6Payload1),
+ fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
//
@@ -1069,7 +1216,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Routing extension header (part 2)
buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
- ipv6Payload1,
+ ipv6Payload1Addr1ToAddr2,
},
),
},
@@ -1082,6 +1229,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "Two fragments with atomic",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+64,
@@ -1091,47 +1240,53 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
// This fragment has the same ID as the other fragments but is an atomic
// fragment. It should not interfere with the other fragments.
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload2),
+ fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 0, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}),
- ipv6Payload2,
+ ipv6Payload2Addr1ToAddr2,
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload2, udpPayload1},
+ expectedPayloads: [][]byte{udpPayload2Addr1ToAddr2, udpPayload1Addr1ToAddr2},
},
{
name: "Two interleaved fragmented packets",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+64,
@@ -1141,11 +1296,13 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+32,
@@ -1155,40 +1312,114 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 2
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}),
- ipv6Payload2[:32],
+ ipv6Payload2Addr1ToAddr2[:32],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload2)-32,
+ fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2)-32,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 4, More = false, ID = 2
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}),
- ipv6Payload2[32:],
+ ipv6Payload2Addr1ToAddr2[32:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2},
+ },
+ {
+ name: "Two interleaved fragmented packets from different sources but with same ID",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr3ToAddr2[:32],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 4, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr3ToAddr2[32:],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1, udpPayload2},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2},
},
}
@@ -1231,14 +1462,14 @@ func TestReceiveIPv6Fragments(t *testing.T) {
PayloadLength: uint16(f.data.Size()),
NextHeader: f.nextHdr,
HopLimit: 255,
- SrcAddr: addr1,
- DstAddr: addr2,
+ SrcAddr: f.srcAddr,
+ DstAddr: f.dstAddr,
})
vv := hdr.View().ToVectorisedView()
vv.Append(f.data)
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: vv,
})
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 12b70f7e9..64239ce9a 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -136,7 +136,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -380,7 +380,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{
+ e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -497,7 +497,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -551,25 +551,29 @@ func TestNDPValidation(t *testing.T) {
return s, ep, r
}
- handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
nextHdr := uint8(header.ICMPv6ProtocolNumber)
+ var extensions buffer.View
if atomicFragment {
- bytes := hdr.Prepend(header.IPv6FragmentExtHdrLength)
- bytes[0] = nextHdr
+ extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
+ extensions[0] = nextHdr
nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
}
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize + len(extensions)))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
+ PayloadLength: uint16(len(payload) + len(extensions)),
NextHeader: nextHdr,
HopLimit: hopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(r, stack.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
+ if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
+ t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
+ }
+ ep.HandlePacket(r, &stack.PacketBuffer{
+ NetworkHeader: buffer.View(ip),
+ Data: payload.ToVectorisedView(),
})
}
@@ -676,14 +680,11 @@ func TestNDPValidation(t *testing.T) {
invalid := stats.Invalid
typStat := typ.statCounter(stats)
- extraDataLen := len(typ.extraData)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen + header.IPv6FragmentExtHdrLength)
- extraData := buffer.View(hdr.Prepend(extraDataLen))
- copy(extraData, typ.extraData)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
- pkt.SetCode(test.code)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetCode(test.code)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
// Rx count of the NDP message should initially be 0.
if got := typStat.Value(); got != 0 {
@@ -699,7 +700,7 @@ func TestNDPValidation(t *testing.T) {
t.FailNow()
}
- handleIPv6Payload(hdr, test.hopLimit, test.atomicFragment, ep, &r)
+ handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
// Rx count of the NDP packet should have increased.
if got := typStat.Value(); got != 1 {
@@ -884,7 +885,7 @@ func TestRouterAdvertValidation(t *testing.T) {
t.Fatalf("got rxRA = %d, want = 0", got)
}
- e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index b937cb84b..f6d592eb5 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -52,19 +52,35 @@ type Flags struct {
//
// LoadBalanced takes precidence over MostRecent.
LoadBalanced bool
+
+ // TupleOnly represents TCP SO_REUSEADDR.
+ TupleOnly bool
}
-func (f Flags) bits() reuseFlag {
- var rf reuseFlag
+// Bits converts the Flags to their bitset form.
+func (f Flags) Bits() BitFlags {
+ var rf BitFlags
if f.MostRecent {
- rf |= mostRecentFlag
+ rf |= MostRecentFlag
}
if f.LoadBalanced {
- rf |= loadBalancedFlag
+ rf |= LoadBalancedFlag
+ }
+ if f.TupleOnly {
+ rf |= TupleOnlyFlag
}
return rf
}
+// Effective returns the effective behavior of a flag config.
+func (f Flags) Effective() Flags {
+ e := f
+ if e.LoadBalanced && e.MostRecent {
+ e.MostRecent = false
+ }
+ return e
+}
+
// PortManager manages allocating, reserving and releasing ports.
type PortManager struct {
mu sync.RWMutex
@@ -78,83 +94,166 @@ type PortManager struct {
hint uint32
}
-type reuseFlag int
+// BitFlags is a bitset representation of Flags.
+type BitFlags uint32
const (
- mostRecentFlag reuseFlag = 1 << iota
- loadBalancedFlag
+ // MostRecentFlag represents Flags.MostRecent.
+ MostRecentFlag BitFlags = 1 << iota
+
+ // LoadBalancedFlag represents Flags.LoadBalanced.
+ LoadBalancedFlag
+
+ // TupleOnlyFlag represents Flags.TupleOnly.
+ TupleOnlyFlag
+
+ // nextFlag is the value that the next added flag will have.
+ //
+ // It is used to calculate FlagMask below. It is also the number of
+ // valid flag states.
nextFlag
- flagMask = nextFlag - 1
+ // FlagMask is a bit mask for BitFlags.
+ FlagMask = nextFlag - 1
+
+ // MultiBindFlagMask contains the flags that allow binding the same
+ // tuple multiple times.
+ MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag
)
-type portNode struct {
- // refs stores the count for each possible flag combination.
+// ToFlags converts the bitset into a Flags struct.
+func (f BitFlags) ToFlags() Flags {
+ return Flags{
+ MostRecent: f&MostRecentFlag != 0,
+ LoadBalanced: f&LoadBalancedFlag != 0,
+ TupleOnly: f&TupleOnlyFlag != 0,
+ }
+}
+
+// FlagCounter counts how many references each flag combination has.
+type FlagCounter struct {
+ // refs stores the count for each possible flag combination, (0 though
+ // FlagMask).
refs [nextFlag]int
}
-func (p portNode) totalRefs() int {
+// AddRef increases the reference count for a specific flag combination.
+func (c *FlagCounter) AddRef(flags BitFlags) {
+ c.refs[flags]++
+}
+
+// DropRef decreases the reference count for a specific flag combination.
+func (c *FlagCounter) DropRef(flags BitFlags) {
+ c.refs[flags]--
+}
+
+// TotalRefs calculates the total number of references for all flag
+// combinations.
+func (c FlagCounter) TotalRefs() int {
var total int
- for _, r := range p.refs {
+ for _, r := range c.refs {
total += r
}
return total
}
-// flagRefs returns the number of references with all specified flags.
-func (p portNode) flagRefs(flags reuseFlag) int {
+// FlagRefs returns the number of references with all specified flags.
+func (c FlagCounter) FlagRefs(flags BitFlags) int {
var total int
- for i, r := range p.refs {
- if reuseFlag(i)&flags == flags {
+ for i, r := range c.refs {
+ if BitFlags(i)&flags == flags {
total += r
}
}
return total
}
-// allRefsHave returns if all references have all specified flags.
-func (p portNode) allRefsHave(flags reuseFlag) bool {
- for i, r := range p.refs {
- if reuseFlag(i)&flags == flags && r > 0 {
+// AllRefsHave returns if all references have all specified flags.
+func (c FlagCounter) AllRefsHave(flags BitFlags) bool {
+ for i, r := range c.refs {
+ if BitFlags(i)&flags != flags && r > 0 {
return false
}
}
return true
}
-// intersectionRefs returns the set of flags shared by all references.
-func (p portNode) intersectionRefs() reuseFlag {
- intersection := flagMask
- for i, r := range p.refs {
+// IntersectionRefs returns the set of flags shared by all references.
+func (c FlagCounter) IntersectionRefs() BitFlags {
+ intersection := FlagMask
+ for i, r := range c.refs {
if r > 0 {
- intersection &= reuseFlag(i)
+ intersection &= BitFlags(i)
}
}
return intersection
}
+type destination struct {
+ addr tcpip.Address
+ port uint16
+}
+
+func makeDestination(a tcpip.FullAddress) destination {
+ return destination{
+ a.Addr,
+ a.Port,
+ }
+}
+
+// portNode is never empty. When it has no elements, it is removed from the
+// map that references it.
+type portNode map[destination]FlagCounter
+
+// intersectionRefs calculates the intersection of flag bit values which affect
+// the specified destination.
+//
+// If no destinations are present, all flag values are returned as there are no
+// entries to limit possible flag values of a new entry.
+//
+// In addition to the intersection, the number of intersecting refs is
+// returned.
+func (p portNode) intersectionRefs(dst destination) (BitFlags, int) {
+ intersection := FlagMask
+ var count int
+
+ for d, f := range p {
+ if d == dst {
+ intersection &= f.IntersectionRefs()
+ count++
+ continue
+ }
+ // Wildcard destinations affect all destinations for TupleOnly.
+ if d.addr == anyIPAddress || dst.addr == anyIPAddress {
+ // Only bitwise and the TupleOnlyFlag.
+ intersection &= ((^TupleOnlyFlag) | f.IntersectionRefs())
+ count++
+ }
+ }
+
+ return intersection, count
+}
+
// deviceNode is never empty. When it has no elements, it is removed from the
// map that references it.
type deviceNode map[tcpip.NICID]portNode
// isAvailable checks whether binding is possible by device. If not binding to a
-// device, check against all portNodes. If binding to a specific device, check
+// device, check against all FlagCounters. If binding to a specific device, check
// against the unspecified device and the provided device.
//
// If either of the port reuse flags is enabled on any of the nodes, all nodes
// sharing a port must share at least one reuse flag. This matches Linux's
// behavior.
-func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool {
- flagBits := flags.bits()
+func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
+ flagBits := flags.Bits()
if bindToDevice == 0 {
- // Trying to binding all devices.
- if flagBits == 0 {
- // Can't bind because the (addr,port) is already bound.
- return false
- }
- intersection := flagMask
+ intersection := FlagMask
for _, p := range d {
- i := p.intersectionRefs()
+ i, c := p.intersectionRefs(dst)
+ if c == 0 {
+ continue
+ }
intersection &= i
if intersection&flagBits == 0 {
// Can't bind because the (addr,port) was
@@ -165,19 +264,20 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool {
return true
}
- intersection := flagMask
+ intersection := FlagMask
if p, ok := d[0]; ok {
- intersection = p.intersectionRefs()
- if intersection&flagBits == 0 {
+ var c int
+ intersection, c = p.intersectionRefs(dst)
+ if c > 0 && intersection&flagBits == 0 {
return false
}
}
if p, ok := d[bindToDevice]; ok {
- i := p.intersectionRefs()
+ i, c := p.intersectionRefs(dst)
intersection &= i
- if intersection&flagBits == 0 {
+ if c > 0 && intersection&flagBits == 0 {
return false
}
}
@@ -191,12 +291,12 @@ type bindAddresses map[tcpip.Address]deviceNode
// isAvailable checks whether an IP address is available to bind to. If the
// address is the "any" address, check all other addresses. Otherwise, just
// check against the "any" address and the provided address.
-func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID) bool {
+func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
if addr == anyIPAddress {
// If binding to the "any" address then check that there are no conflicts
// with all addresses.
for _, d := range b {
- if !d.isAvailable(flags, bindToDevice) {
+ if !d.isAvailable(flags, bindToDevice, dst) {
return false
}
}
@@ -205,14 +305,14 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice
// Check that there is no conflict with the "any" address.
if d, ok := b[anyIPAddress]; ok {
- if !d.isAvailable(flags, bindToDevice) {
+ if !d.isAvailable(flags, bindToDevice, dst) {
return false
}
}
// Check that this is no conflict with the provided address.
if d, ok := b[addr]; ok {
- if !d.isAvailable(flags, bindToDevice) {
+ if !d.isAvailable(flags, bindToDevice, dst) {
return false
}
}
@@ -278,17 +378,17 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui
}
// IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool {
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool {
s.mu.Lock()
defer s.mu.Unlock()
- return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice)
+ return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, makeDestination(dest))
}
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool {
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr, flags, bindToDevice) {
+ if !addrs.isAvailable(addr, flags, bindToDevice, dst) {
return false
}
}
@@ -300,14 +400,16 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
+ dst := makeDestination(dest)
+
// If a port is specified, just try to reserve it for all network
// protocols.
if port != 0 {
- if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice) {
+ if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) {
return 0, tcpip.ErrPortInUse
}
return port, nil
@@ -315,16 +417,17 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
// A port wasn't specified, so try to find one.
return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice), nil
+ return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst), nil
})
}
// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool {
- if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) {
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, dst) {
return false
}
- flagBits := flags.bits()
+
+ flagBits := flags.Bits()
// Reserve port on all network protocols.
for _, network := range networks {
@@ -339,9 +442,65 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
d = make(deviceNode)
m[addr] = d
}
- n := d[bindToDevice]
- n.refs[flagBits]++
- d[bindToDevice] = n
+ p := d[bindToDevice]
+ if p == nil {
+ p = make(portNode)
+ }
+ n := p[dst]
+ n.AddRef(flagBits)
+ p[dst] = n
+ d[bindToDevice] = p
+ }
+
+ return true
+}
+
+// ReserveTuple adds a port reservation for the tuple on all given protocol.
+func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool {
+ flagBits := flags.Bits()
+ dst := makeDestination(dest)
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // It is easier to undo the entire reservation, so if we find that the
+ // tuple can't be fully added, finish and undo the whole thing.
+ undo := false
+
+ // Reserve port on all network protocols.
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ m, ok := s.allocatedPorts[desc]
+ if !ok {
+ m = make(bindAddresses)
+ s.allocatedPorts[desc] = m
+ }
+ d, ok := m[addr]
+ if !ok {
+ d = make(deviceNode)
+ m[addr] = d
+ }
+ p := d[bindToDevice]
+ if p == nil {
+ p = make(portNode)
+ }
+
+ n := p[dst]
+ if n.TotalRefs() != 0 && n.IntersectionRefs()&flagBits == 0 {
+ // Tuple already exists.
+ undo = true
+ }
+ n.AddRef(flagBits)
+ p[dst] = n
+ d[bindToDevice] = p
+ }
+
+ if undo {
+ // releasePortLocked decrements the counts (rather than setting
+ // them to zero), so it will undo the incorrect incrementing
+ // above.
+ s.releasePortLocked(networks, transport, addr, port, flagBits, bindToDevice, dst)
+ return false
}
return true
@@ -349,12 +508,14 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
-func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) {
+func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) {
s.mu.Lock()
defer s.mu.Unlock()
- flagBits := flags.bits()
+ s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, makeDestination(dest))
+}
+func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags BitFlags, bindToDevice tcpip.NICID, dst destination) {
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if m, ok := s.allocatedPorts[desc]; ok {
@@ -362,21 +523,32 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp
if !ok {
continue
}
- n, ok := d[bindToDevice]
+ p, ok := d[bindToDevice]
if !ok {
continue
}
- n.refs[flagBits]--
- d[bindToDevice] = n
- if n.refs == [nextFlag]int{} {
- delete(d, bindToDevice)
+ n, ok := p[dst]
+ if !ok {
+ continue
}
- if len(d) == 0 {
- delete(m, addr)
+ n.DropRef(flags)
+ if n.TotalRefs() > 0 {
+ p[dst] = n
+ continue
}
- if len(m) == 0 {
- delete(s.allocatedPorts, desc)
+ delete(p, dst)
+ if len(p) > 0 {
+ continue
+ }
+ delete(d, bindToDevice)
+ if len(d) > 0 {
+ continue
+ }
+ delete(m, addr)
+ if len(m) > 0 {
+ continue
}
+ delete(s.allocatedPorts, desc)
}
}
}
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index d6969d050..58db5868c 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -36,6 +36,7 @@ type portReserveTestAction struct {
flags Flags
release bool
device tcpip.NICID
+ dest tcpip.FullAddress
}
func TestPortReservation(t *testing.T) {
@@ -272,6 +273,54 @@ func TestPortReservation(t *testing.T) {
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
},
+ }, {
+ tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil},
+ },
+ }, {
+ tname: "bind tuple with reuseaddr, and then wildcard",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ },
+ }, {
+ tname: "bind tuple with reuseaddr, and then wildcard",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind two tuples with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil},
+ },
+ }, {
+ tname: "bind two tuples",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil},
+ },
+ }, {
+ tname: "bind wildcard, and then tuple with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind wildcard twice with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil},
+ {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil},
+ },
},
} {
t.Run(test.tname, func(t *testing.T) {
@@ -280,19 +329,18 @@ func TestPortReservation(t *testing.T) {
for _, test := range test.actions {
if test.release {
- pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device)
+ pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest)
continue
}
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device)
+ gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest)
if err != test.want {
- t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want)
+ t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want)
}
if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
- t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
+ t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
}
}
})
-
}
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index f71073207..1c58bed2d 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -16,6 +16,18 @@ go_template_instance(
)
go_template_instance(
+ name = "neighbor_entry_list",
+ out = "neighbor_entry_list.go",
+ package = "stack",
+ prefix = "neighborEntry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*neighborEntry",
+ "Linker": "*neighborEntry",
+ },
+)
+
+go_template_instance(
name = "packet_buffer_list",
out = "packet_buffer_list.go",
package = "stack",
@@ -27,6 +39,18 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "tuple_list",
+ out = "tuple_list.go",
+ package = "stack",
+ prefix = "tuple",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*tuple",
+ "Linker": "*tuple",
+ },
+)
+
go_library(
name = "stack",
srcs = [
@@ -35,12 +59,18 @@ go_library(
"forwarder.go",
"icmp_rate_limit.go",
"iptables.go",
+ "iptables_state.go",
"iptables_targets.go",
"iptables_types.go",
"linkaddrcache.go",
"linkaddrentry_list.go",
"ndp.go",
+ "neighbor_cache.go",
+ "neighbor_entry.go",
+ "neighbor_entry_list.go",
+ "neighborstate_string.go",
"nic.go",
+ "nud.go",
"packet_buffer.go",
"packet_buffer_list.go",
"rand.go",
@@ -48,7 +78,9 @@ go_library(
"route.go",
"stack.go",
"stack_global_state.go",
+ "stack_options.go",
"transport_demuxer.go",
+ "tuple_list.go",
],
visibility = ["//visibility:public"],
deps = [
@@ -74,10 +106,12 @@ go_test(
size = "medium",
srcs = [
"ndp_test.go",
+ "nud_test.go",
"stack_test.go",
"transport_demuxer_test.go",
"transport_test.go",
],
+ shard_count = 20,
deps = [
":stack",
"//pkg/rand",
@@ -89,10 +123,12 @@ go_test(
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
@@ -100,8 +136,11 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
+ "fake_time_test.go",
"forwarder_test.go",
"linkaddrcache_test.go",
+ "neighbor_cache_test.go",
+ "neighbor_entry_test.go",
"nic_test.go",
],
library = ":stack",
@@ -110,5 +149,9 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "@com_github_dpjacques_clockwork//:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 7d1ede1f2..470c265aa 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -20,332 +20,330 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
)
// Connection tracking is used to track and manipulate packets for NAT rules.
-// The connection is created for a packet if it does not exist. Every connection
-// contains two tuples (original and reply). The tuples are manipulated if there
-// is a matching NAT rule. The packet is modified by looking at the tuples in the
-// Prerouting and Output hooks.
+// The connection is created for a packet if it does not exist. Every
+// connection contains two tuples (original and reply). The tuples are
+// manipulated if there is a matching NAT rule. The packet is modified by
+// looking at the tuples in the Prerouting and Output hooks.
+//
+// Currently, only TCP tracking is supported.
+
+// Our hash table has 16K buckets.
+// TODO(gvisor.dev/issue/170): These should be tunable.
+const numBuckets = 1 << 14
// Direction of the tuple.
-type ctDirection int
+type direction int
const (
- dirOriginal ctDirection = iota
+ dirOriginal direction = iota
dirReply
)
-// Status of connection.
-// TODO(gvisor.dev/issue/170): Add other states of connection.
-type connStatus int
-
-const (
- connNew connStatus = iota
- connEstablished
-)
-
// Manipulation type for the connection.
type manipType int
const (
- manipDstPrerouting manipType = iota
+ manipNone manipType = iota
+ manipDstPrerouting
manipDstOutput
)
-// connTrackMutable is the manipulatable part of the tuple.
-type connTrackMutable struct {
- // addr is source address of the tuple.
- addr tcpip.Address
-
- // port is source port of the tuple.
- port uint16
-
- // protocol is network layer protocol.
- protocol tcpip.NetworkProtocolNumber
-}
-
-// connTrackImmutable is the non-manipulatable part of the tuple.
-type connTrackImmutable struct {
- // addr is destination address of the tuple.
- addr tcpip.Address
+// tuple holds a connection's identifying and manipulating data in one
+// direction. It is immutable.
+//
+// +stateify savable
+type tuple struct {
+ // tupleEntry is used to build an intrusive list of tuples.
+ tupleEntry
- // direction is direction (original or reply) of the tuple.
- direction ctDirection
+ tupleID
- // port is destination port of the tuple.
- port uint16
+ // conn is the connection tracking entry this tuple belongs to.
+ conn *conn
- // protocol is transport layer protocol.
- protocol tcpip.TransportProtocolNumber
+ // direction is the direction of the tuple.
+ direction direction
}
-// connTrackTuple represents the tuple which is created from the
-// packet.
-type connTrackTuple struct {
- // dst is non-manipulatable part of the tuple.
- dst connTrackImmutable
-
- // src is manipulatable part of the tuple.
- src connTrackMutable
+// tupleID uniquely identifies a connection in one direction. It currently
+// contains enough information to distinguish between any TCP or UDP
+// connection, and will need to be extended to support other protocols.
+//
+// +stateify savable
+type tupleID struct {
+ srcAddr tcpip.Address
+ srcPort uint16
+ dstAddr tcpip.Address
+ dstPort uint16
+ transProto tcpip.TransportProtocolNumber
+ netProto tcpip.NetworkProtocolNumber
}
-// connTrackTupleHolder is the container of tuple and connection.
-type ConnTrackTupleHolder struct {
- // conn is pointer to the connection tracking entry.
- conn *connTrack
-
- // tuple is original or reply tuple.
- tuple connTrackTuple
+// reply creates the reply tupleID.
+func (ti tupleID) reply() tupleID {
+ return tupleID{
+ srcAddr: ti.dstAddr,
+ srcPort: ti.dstPort,
+ dstAddr: ti.srcAddr,
+ dstPort: ti.srcPort,
+ transProto: ti.transProto,
+ netProto: ti.netProto,
+ }
}
-// connTrack is the connection.
-type connTrack struct {
- // originalTupleHolder contains tuple in original direction.
- originalTupleHolder ConnTrackTupleHolder
-
- // replyTupleHolder contains tuple in reply direction.
- replyTupleHolder ConnTrackTupleHolder
-
- // status indicates connection is new or established.
- status connStatus
+// conn is a tracked connection.
+//
+// +stateify savable
+type conn struct {
+ // original is the tuple in original direction. It is immutable.
+ original tuple
- // timeout indicates the time connection should be active.
- timeout time.Duration
+ // reply is the tuple in reply direction. It is immutable.
+ reply tuple
- // manip indicates if the packet should be manipulated.
+ // manip indicates if the packet should be manipulated. It is immutable.
manip manipType
- // tcb is TCB control block. It is used to keep track of states
- // of tcp connection.
- tcb tcpconntrack.TCB
-
// tcbHook indicates if the packet is inbound or outbound to
- // update the state of tcb.
+ // update the state of tcb. It is immutable.
tcbHook Hook
-}
-// ConnTrackTable contains a map of all existing connections created for
-// NAT rules.
-type ConnTrackTable struct {
- // connMu protects connTrackTable.
- connMu sync.RWMutex
+ // mu protects all mutable state.
+ mu sync.Mutex `state:"nosave"`
+ // tcb is TCB control block. It is used to keep track of states
+ // of tcp connection and is protected by mu.
+ tcb tcpconntrack.TCB
+ // lastUsed is the last time the connection saw a relevant packet, and
+ // is updated by each packet on the connection. It is protected by mu.
+ lastUsed time.Time `state:".(unixTime)"`
+}
- // connTrackTable maintains a map of tuples needed for connection tracking
- // for iptables NAT rules. The key for the map is an integer calculated
- // using seed, source address, destination address, source port and
- // destination port.
- CtMap map[uint32]ConnTrackTupleHolder
+// timedOut returns whether the connection timed out based on its state.
+func (cn *conn) timedOut(now time.Time) bool {
+ const establishedTimeout = 5 * 24 * time.Hour
+ const defaultTimeout = 120 * time.Second
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+ if cn.tcb.State() == tcpconntrack.ResultAlive {
+ // Use the same default as Linux, which doesn't delete
+ // established connections for 5(!) days.
+ return now.Sub(cn.lastUsed) > establishedTimeout
+ }
+ // Use the same default as Linux, which lets connections in most states
+ // other than established remain for <= 120 seconds.
+ return now.Sub(cn.lastUsed) > defaultTimeout
+}
- // seed is a one-time random value initialized at stack startup
- // and is used in calculation of hash key for connection tracking
- // table.
- Seed uint32
+// update the connection tracking state.
+//
+// Precondition: ct.mu must be held.
+func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+ // Update the state of tcb. tcb assumes it's always initialized on the
+ // client. However, we only need to know whether the connection is
+ // established or not, so the client/server distinction isn't important.
+ // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
+ // other tcp states.
+ if ct.tcb.IsEmpty() {
+ ct.tcb.Init(tcpHeader)
+ } else if hook == ct.tcbHook {
+ ct.tcb.UpdateStateOutbound(tcpHeader)
+ } else {
+ ct.tcb.UpdateStateInbound(tcpHeader)
+ }
}
-// parseHeaders sets headers in the packet.
-func parseHeaders(pkt *PacketBuffer) {
- newPkt := pkt.Clone()
+// ConnTrack tracks all connections created for NAT rules. Most users are
+// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop.
+//
+// ConnTrack keeps all connections in a slice of buckets, each of which holds a
+// linked list of tuples. This gives us some desirable properties:
+// - Each bucket has its own lock, lessening lock contention.
+// - The slice is large enough that lists stay short (<10 elements on average).
+// Thus traversal is fast.
+// - During linked list traversal we reap expired connections. This amortizes
+// the cost of reaping them and makes reapUnused faster.
+//
+// Locks are ordered by their location in the buckets slice. That is, a
+// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j.
+//
+// +stateify savable
+type ConnTrack struct {
+ // seed is a one-time random value initialized at stack startup
+ // and is used in the calculation of hash keys for the list of buckets.
+ // It is immutable.
+ seed uint32
- // Set network header.
- hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- return
- }
- netHeader := header.IPv4(hdr)
- newPkt.NetworkHeader = hdr
- length := int(netHeader.HeaderLength())
+ // mu protects the buckets slice, but not buckets' contents. Only take
+ // the write lock if you are modifying the slice or saving for S/R.
+ mu sync.RWMutex `state:"nosave"`
- // TODO(gvisor.dev/issue/170): Need to support for other
- // protocols as well.
- // Set transport header.
- switch protocol := netHeader.TransportProtocol(); protocol {
- case header.UDPProtocolNumber:
- if newPkt.TransportHeader == nil {
- h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize)
- if !ok {
- return
- }
- newPkt.TransportHeader = buffer.View(header.UDP(h[length:]))
- }
- case header.TCPProtocolNumber:
- if newPkt.TransportHeader == nil {
- h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize)
- if !ok {
- return
- }
- newPkt.TransportHeader = buffer.View(header.TCP(h[length:]))
- }
- }
- pkt.NetworkHeader = newPkt.NetworkHeader
- pkt.TransportHeader = newPkt.TransportHeader
+ // buckets is protected by mu.
+ buckets []bucket
}
-// packetToTuple converts packet to a tuple in original direction.
-func packetToTuple(pkt PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) {
- var tuple connTrackTuple
+// +stateify savable
+type bucket struct {
+ // mu protects tuples.
+ mu sync.Mutex `state:"nosave"`
+ tuples tupleList
+}
- netHeader := header.IPv4(pkt.NetworkHeader)
+// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
+// TCP header.
+func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
// TODO(gvisor.dev/issue/170): Need to support for other
// protocols as well.
+ netHeader := header.IPv4(pkt.NetworkHeader)
if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tuple, tcpip.ErrUnknownProtocol
+ return tupleID{}, tcpip.ErrUnknownProtocol
}
tcpHeader := header.TCP(pkt.TransportHeader)
if tcpHeader == nil {
- return tuple, tcpip.ErrUnknownProtocol
+ return tupleID{}, tcpip.ErrUnknownProtocol
}
- tuple.src.addr = netHeader.SourceAddress()
- tuple.src.port = tcpHeader.SourcePort()
- tuple.src.protocol = header.IPv4ProtocolNumber
-
- tuple.dst.addr = netHeader.DestinationAddress()
- tuple.dst.port = tcpHeader.DestinationPort()
- tuple.dst.protocol = netHeader.TransportProtocol()
-
- return tuple, nil
-}
-
-// getReplyTuple creates reply tuple for the given tuple.
-func getReplyTuple(tuple connTrackTuple) connTrackTuple {
- var replyTuple connTrackTuple
- replyTuple.src.addr = tuple.dst.addr
- replyTuple.src.port = tuple.dst.port
- replyTuple.src.protocol = tuple.src.protocol
- replyTuple.dst.addr = tuple.src.addr
- replyTuple.dst.port = tuple.src.port
- replyTuple.dst.protocol = tuple.dst.protocol
- replyTuple.dst.direction = dirReply
-
- return replyTuple
+ return tupleID{
+ srcAddr: netHeader.SourceAddress(),
+ srcPort: tcpHeader.SourcePort(),
+ dstAddr: netHeader.DestinationAddress(),
+ dstPort: tcpHeader.DestinationPort(),
+ transProto: netHeader.TransportProtocol(),
+ netProto: header.IPv4ProtocolNumber,
+ }, nil
}
-// makeNewConn creates new connection.
-func makeNewConn(tuple, replyTuple connTrackTuple) connTrack {
- var conn connTrack
- conn.status = connNew
- conn.originalTupleHolder.tuple = tuple
- conn.originalTupleHolder.conn = &conn
- conn.replyTupleHolder.tuple = replyTuple
- conn.replyTupleHolder.conn = &conn
-
- return conn
+// newConn creates new connection.
+func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
+ conn := conn{
+ manip: manip,
+ tcbHook: hook,
+ lastUsed: time.Now(),
+ }
+ conn.original = tuple{conn: &conn, tupleID: orig}
+ conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
+ return &conn
}
-// getTupleHash returns hash of the tuple. The fields used for
-// generating hash are seed (generated once for stack), source address,
-// destination address, source port and destination ports.
-func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 {
- h := jenkins.Sum32(ct.Seed)
- h.Write([]byte(tuple.src.addr))
- h.Write([]byte(tuple.dst.addr))
- portBuf := make([]byte, 2)
- binary.LittleEndian.PutUint16(portBuf, tuple.src.port)
- h.Write([]byte(portBuf))
- binary.LittleEndian.PutUint16(portBuf, tuple.dst.port)
- h.Write([]byte(portBuf))
-
- return h.Sum32()
+// connFor gets the conn for pkt if it exists, or returns nil
+// if it does not. It returns an error when pkt does not contain a valid TCP
+// header.
+// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support
+// other transport protocols.
+func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil, dirOriginal
+ }
+ return ct.connForTID(tid)
}
-// connTrackForPacket returns connTrack for packet.
-// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other
-// transport protocols.
-func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) {
- if hook == Prerouting {
- // Headers will not be set in Prerouting.
- // TODO(gvisor.dev/issue/170): Change this after parsing headers
- // code is added.
- parseHeaders(pkt)
+func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
+ bucket := ct.bucket(tid)
+ now := time.Now()
+
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ ct.buckets[bucket].mu.Lock()
+ defer ct.buckets[bucket].mu.Unlock()
+
+ // Iterate over the tuples in a bucket, cleaning up any unused
+ // connections we find.
+ for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() {
+ // Clean up any timed-out connections we happen to find.
+ if ct.reapTupleLocked(other, bucket, now) {
+ // The tuple expired.
+ continue
+ }
+ if tid == other.tupleID {
+ return other.conn, other.direction
+ }
}
- var dir ctDirection
- tuple, err := packetToTuple(*pkt, hook)
- if err != nil {
- return nil, dir
- }
-
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
-
- connTrackTable := ct.CtMap
- hash := ct.getTupleHash(tuple)
-
- var conn *connTrack
- switch createConn {
- case true:
- // If connection does not exist for the hash, create a new
- // connection.
- replyTuple := getReplyTuple(tuple)
- replyHash := ct.getTupleHash(replyTuple)
- newConn := makeNewConn(tuple, replyTuple)
- conn = &newConn
-
- // Add tupleHolders to the map.
- // TODO(gvisor.dev/issue/170): Need to support collisions using linked list.
- ct.CtMap[hash] = conn.originalTupleHolder
- ct.CtMap[replyHash] = conn.replyTupleHolder
- default:
- tupleHolder, ok := connTrackTable[hash]
- if !ok {
- return nil, dir
- }
+ return nil, dirOriginal
+}
- // If this is the reply of new connection, set the connection
- // status as ESTABLISHED.
- conn = tupleHolder.conn
- if conn.status == connNew && tupleHolder.tuple.dst.direction == dirReply {
- conn.status = connEstablished
- }
- if tupleHolder.conn == nil {
- panic("tupleHolder has null connection tracking entry")
- }
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil
+ }
+ if hook != Prerouting && hook != Output {
+ return nil
+ }
- dir = tupleHolder.tuple.dst.direction
+ // Create a new connection and change the port as per the iptables
+ // rule. This tuple will be used to manipulate the packet in
+ // handlePacket.
+ replyTID := tid.reply()
+ replyTID.srcAddr = rt.MinIP
+ replyTID.srcPort = rt.MinPort
+ var manip manipType
+ switch hook {
+ case Prerouting:
+ manip = manipDstPrerouting
+ case Output:
+ manip = manipDstOutput
}
- return conn, dir
+ conn := newConn(tid, replyTID, manip, hook)
+ ct.insertConn(conn)
+ return conn
}
-// SetNatInfo will manipulate the tuples according to iptables NAT rules.
-func (ct *ConnTrackTable) SetNatInfo(pkt *PacketBuffer, rt RedirectTarget, hook Hook) {
- // Get the connection. Connection is always created before this
- // function is called.
- conn, _ := ct.connTrackForPacket(pkt, hook, false)
- if conn == nil {
- panic("connection should be created to manipulate tuples.")
+// insertConn inserts conn into the appropriate table bucket.
+func (ct *ConnTrack) insertConn(conn *conn) {
+ // Lock the buckets in the correct order.
+ tupleBucket := ct.bucket(conn.original.tupleID)
+ replyBucket := ct.bucket(conn.reply.tupleID)
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ if tupleBucket < replyBucket {
+ ct.buckets[tupleBucket].mu.Lock()
+ ct.buckets[replyBucket].mu.Lock()
+ } else if tupleBucket > replyBucket {
+ ct.buckets[replyBucket].mu.Lock()
+ ct.buckets[tupleBucket].mu.Lock()
+ } else {
+ // Both tuples are in the same bucket.
+ ct.buckets[tupleBucket].mu.Lock()
}
- replyTuple := conn.replyTupleHolder.tuple
- replyHash := ct.getTupleHash(replyTuple)
- // TODO(gvisor.dev/issue/170): Support only redirect of ports. Need to
- // support changing of address for Prerouting.
-
- // Change the port as per the iptables rule. This tuple will be used
- // to manipulate the packet in HandlePacket.
- conn.replyTupleHolder.tuple.src.addr = rt.MinIP
- conn.replyTupleHolder.tuple.src.port = rt.MinPort
- newHash := ct.getTupleHash(conn.replyTupleHolder.tuple)
+ // Now that we hold the locks, ensure the tuple hasn't been inserted by
+ // another thread.
+ alreadyInserted := false
+ for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
+ if other.tupleID == conn.original.tupleID {
+ alreadyInserted = true
+ break
+ }
+ }
- // Add the changed tuple to the map.
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
- ct.CtMap[newHash] = conn.replyTupleHolder
- if hook == Output {
- conn.replyTupleHolder.conn.manip = manipDstOutput
+ if !alreadyInserted {
+ // Add the tuple to the map.
+ ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
+ ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
}
- // Delete the old tuple.
- delete(ct.CtMap, replyHash)
+ // Unlocking can happen in any order.
+ ct.buckets[tupleBucket].mu.Unlock()
+ if tupleBucket != replyBucket {
+ ct.buckets[replyBucket].mu.Unlock()
+ }
}
// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
-// TODO(gvisor.dev/issue/170): Change address for Prerouting hook..
-func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) {
+// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
+func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -354,21 +352,31 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection)
// modified.
switch dir {
case dirOriginal:
- port := conn.replyTupleHolder.tuple.src.port
+ port := conn.reply.srcPort
tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
case dirReply:
- port := conn.originalTupleHolder.tuple.dst.port
+ port := conn.original.dstPort
tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
}
+ // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
+ // on inbound packets, so we don't recalculate them. However, we should
+ // support cases when they are validated, e.g. when we can't offload
+ // receive checksumming.
+
netHeader.SetChecksum(0)
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
// handlePacketOutput manipulates ports for packets in Output hook.
-func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, dir ctDirection) {
+func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -377,13 +385,13 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route,
// modified. For prerouting redirection, we only reach this point
// when replying, so packet sources are modified.
if conn.manip == manipDstOutput && dir == dirOriginal {
- port := conn.replyTupleHolder.tuple.src.port
+ port := conn.reply.srcPort
tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
} else {
- port := conn.originalTupleHolder.tuple.dst.port
+ port := conn.original.dstPort
tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
}
// Calculate the TCP checksum and set it.
@@ -402,33 +410,32 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route,
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
-// HandlePacket will manipulate the port and address of the packet if the
-// connection exists.
-func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) {
+// handlePacket will manipulate the port and address of the packet if the
+// connection exists. Returns whether, after the packet traverses the tables,
+// it should create a new entry in the table.
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool {
if pkt.NatDone {
- return
+ return false
}
if hook != Prerouting && hook != Output {
- return
+ return false
}
- conn, dir := ct.connTrackForPacket(pkt, hook, false)
- // Connection or Rule not found for the packet.
- if conn == nil {
- return
+ // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ return false
}
- netHeader := header.IPv4(pkt.NetworkHeader)
- // TODO(gvisor.dev/issue/170): Need to support for other transport
- // protocols as well.
- if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return
+ conn, dir := ct.connFor(pkt)
+ // Connection or Rule not found for the packet.
+ if conn == nil {
+ return true
}
tcpHeader := header.TCP(pkt.TransportHeader)
if tcpHeader == nil {
- return
+ return false
}
switch hook {
@@ -442,39 +449,184 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r
// Update the state of tcb.
// TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
// other tcp states.
- var st tcpconntrack.Result
- if conn.tcb.IsEmpty() {
- conn.tcb.Init(tcpHeader)
- conn.tcbHook = hook
- } else {
- switch hook {
- case conn.tcbHook:
- st = conn.tcb.UpdateStateOutbound(tcpHeader)
- default:
- st = conn.tcb.UpdateStateInbound(tcpHeader)
- }
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ // Mark the connection as having been used recently so it isn't reaped.
+ conn.lastUsed = time.Now()
+ // Update connection state.
+ conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+
+ return false
+}
+
+// maybeInsertNoop tries to insert a no-op connection entry to keep connections
+// from getting clobbered when replies arrive. It only inserts if there isn't
+// already a connection for pkt.
+//
+// This should be called after traversing iptables rules only, to ensure that
+// pkt.NatDone is set correctly.
+func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
+ // If there were a rule applying to this packet, it would be marked
+ // with NatDone.
+ if pkt.NatDone {
+ return
}
- // Delete conntrack if tcp connection is closed.
- if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset {
- ct.deleteConnTrack(conn)
+ // We only track TCP connections.
+ if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ return
}
-}
-// deleteConnTrack deletes the connection.
-func (ct *ConnTrackTable) deleteConnTrack(conn *connTrack) {
- if conn == nil {
+ // This is the first packet we're seeing for the TCP connection. Insert
+ // the noop entry (an identity mapping) so that the response doesn't
+ // get NATed, breaking the connection.
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
return
}
+ conn := newConn(tid, tid.reply(), manipNone, hook)
+ conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+ ct.insertConn(conn)
+}
+
+// bucket gets the conntrack bucket for a tupleID.
+func (ct *ConnTrack) bucket(id tupleID) int {
+ h := jenkins.Sum32(ct.seed)
+ h.Write([]byte(id.srcAddr))
+ h.Write([]byte(id.dstAddr))
+ shortBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(shortBuf, id.srcPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, id.dstPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto))
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto))
+ h.Write([]byte(shortBuf))
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ return int(h.Sum32()) % len(ct.buckets)
+}
+
+// reapUnused deletes timed out entries from the conntrack map. The rules for
+// reaping are:
+// - Most reaping occurs in connFor, which is called on each packet. connFor
+// cleans up the bucket the packet's connection maps to. Thus calls to
+// reapUnused should be fast.
+// - Each call to reapUnused traverses a fraction of the conntrack table.
+// Specifically, it traverses len(ct.buckets)/fractionPerReaping.
+// - After reaping, reapUnused decides when it should next run based on the
+// ratio of expired connections to examined connections. If the ratio is
+// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it
+// slightly increases the interval between runs.
+// - maxFullTraversal caps the time it takes to traverse the entire table.
+//
+// reapUnused returns the next bucket that should be checked and the time after
+// which it should be called again.
+func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) {
+ // TODO(gvisor.dev/issue/170): This can be more finely controlled, as
+ // it is in Linux via sysctl.
+ const fractionPerReaping = 128
+ const maxExpiredPct = 50
+ const maxFullTraversal = 60 * time.Second
+ const minInterval = 10 * time.Millisecond
+ const maxInterval = maxFullTraversal / fractionPerReaping
+
+ now := time.Now()
+ checked := 0
+ expired := 0
+ var idx int
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
+ idx = (i + start) % len(ct.buckets)
+ ct.buckets[idx].mu.Lock()
+ for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ checked++
+ if ct.reapTupleLocked(tuple, idx, now) {
+ expired++
+ }
+ }
+ ct.buckets[idx].mu.Unlock()
+ }
+ // We already checked buckets[idx].
+ idx++
+
+ // If half or more of the connections are expired, the table has gotten
+ // stale. Reschedule quickly.
+ expiredPct := 0
+ if checked != 0 {
+ expiredPct = expired * 100 / checked
+ }
+ if expiredPct > maxExpiredPct {
+ return idx, minInterval
+ }
+ if interval := prevInterval + minInterval; interval <= maxInterval {
+ // Increment the interval between runs.
+ return idx, interval
+ }
+ // We've hit the maximum interval.
+ return idx, maxInterval
+}
- tuple := conn.originalTupleHolder.tuple
- hash := ct.getTupleHash(tuple)
- replyTuple := conn.replyTupleHolder.tuple
- replyHash := ct.getTupleHash(replyTuple)
+// reapTupleLocked tries to remove tuple and its reply from the table. It
+// returns whether the tuple's connection has timed out.
+//
+// Preconditions: ct.mu is locked for reading and bucket is locked.
+func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
+ if !tuple.conn.timedOut(now) {
+ return false
+ }
+
+ // To maintain lock order, we can only reap these tuples if the reply
+ // appears later in the table.
+ replyBucket := ct.bucket(tuple.reply())
+ if bucket > replyBucket {
+ return true
+ }
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
+ // Don't re-lock if both tuples are in the same bucket.
+ differentBuckets := bucket != replyBucket
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Lock()
+ }
+
+ // We have the buckets locked and can remove both tuples.
+ if tuple.direction == dirOriginal {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply)
+ } else {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original)
+ }
+ ct.buckets[bucket].tuples.Remove(tuple)
+
+ // Don't re-unlock if both tuples are in the same bucket.
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Unlock()
+ }
+
+ return true
+}
+
+func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+ // Lookup the connection. The reply's original destination
+ // describes the original address.
+ tid := tupleID{
+ srcAddr: epID.LocalAddress,
+ srcPort: epID.LocalPort,
+ dstAddr: epID.RemoteAddress,
+ dstPort: epID.RemotePort,
+ transProto: header.TCPProtocolNumber,
+ netProto: header.IPv4ProtocolNumber,
+ }
+ conn, _ := ct.connForTID(tid)
+ if conn == nil {
+ // Not a tracked connection.
+ return "", 0, tcpip.ErrNotConnected
+ } else if conn.manip == manipNone {
+ // Unmanipulated connection.
+ return "", 0, tcpip.ErrInvalidOptionValue
+ }
- delete(ct.CtMap, hash)
- delete(ct.CtMap, replyHash)
+ return conn.original.dstAddr, conn.original.dstPort, nil
}
diff --git a/pkg/tcpip/stack/fake_time_test.go b/pkg/tcpip/stack/fake_time_test.go
new file mode 100644
index 000000000..92c8cb534
--- /dev/null
+++ b/pkg/tcpip/stack/fake_time_test.go
@@ -0,0 +1,209 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "container/heap"
+ "sync"
+ "time"
+
+ "github.com/dpjacques/clockwork"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+type fakeClock struct {
+ clock clockwork.FakeClock
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ // times is min-heap of times. A heap is used for quick retrieval of the next
+ // upcoming time of scheduled work.
+ times *timeHeap
+
+ // waitGroups stores one WaitGroup for all work scheduled to execute at the
+ // same time via AfterFunc. This allows parallel execution of all functions
+ // passed to AfterFunc scheduled for the same time.
+ waitGroups map[time.Time]*sync.WaitGroup
+}
+
+func newFakeClock() *fakeClock {
+ return &fakeClock{
+ clock: clockwork.NewFakeClock(),
+ times: &timeHeap{},
+ waitGroups: make(map[time.Time]*sync.WaitGroup),
+ }
+}
+
+var _ tcpip.Clock = (*fakeClock)(nil)
+
+// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
+func (fc *fakeClock) NowNanoseconds() int64 {
+ return fc.clock.Now().UnixNano()
+}
+
+// NowMonotonic implements tcpip.Clock.NowMonotonic.
+func (fc *fakeClock) NowMonotonic() int64 {
+ return fc.NowNanoseconds()
+}
+
+// AfterFunc implements tcpip.Clock.AfterFunc.
+func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer {
+ until := fc.clock.Now().Add(d)
+ wg := fc.addWait(until)
+ return &fakeTimer{
+ clock: fc,
+ until: until,
+ timer: fc.clock.AfterFunc(d, func() {
+ defer wg.Done()
+ f()
+ }),
+ }
+}
+
+// addWait adds an additional wait to the WaitGroup for parallel execution of
+// all work scheduled for t. Returns a reference to the WaitGroup modified.
+func (fc *fakeClock) addWait(t time.Time) *sync.WaitGroup {
+ fc.mu.RLock()
+ wg, ok := fc.waitGroups[t]
+ fc.mu.RUnlock()
+
+ if ok {
+ wg.Add(1)
+ return wg
+ }
+
+ fc.mu.Lock()
+ heap.Push(fc.times, t)
+ fc.mu.Unlock()
+
+ wg = &sync.WaitGroup{}
+ wg.Add(1)
+
+ fc.mu.Lock()
+ fc.waitGroups[t] = wg
+ fc.mu.Unlock()
+
+ return wg
+}
+
+// removeWait removes a wait from the WaitGroup for parallel execution of all
+// work scheduled for t.
+func (fc *fakeClock) removeWait(t time.Time) {
+ fc.mu.RLock()
+ defer fc.mu.RUnlock()
+
+ wg := fc.waitGroups[t]
+ wg.Done()
+}
+
+// advance executes all work that have been scheduled to execute within d from
+// the current fake time. Blocks until all work has completed execution.
+func (fc *fakeClock) advance(d time.Duration) {
+ // Block until all the work is done
+ until := fc.clock.Now().Add(d)
+ for {
+ fc.mu.Lock()
+ if fc.times.Len() == 0 {
+ fc.mu.Unlock()
+ return
+ }
+
+ t := heap.Pop(fc.times).(time.Time)
+ if t.After(until) {
+ // No work to do
+ heap.Push(fc.times, t)
+ fc.mu.Unlock()
+ return
+ }
+ fc.mu.Unlock()
+
+ diff := t.Sub(fc.clock.Now())
+ fc.clock.Advance(diff)
+
+ fc.mu.RLock()
+ wg := fc.waitGroups[t]
+ fc.mu.RUnlock()
+
+ wg.Wait()
+
+ fc.mu.Lock()
+ delete(fc.waitGroups, t)
+ fc.mu.Unlock()
+ }
+}
+
+type fakeTimer struct {
+ clock *fakeClock
+ timer clockwork.Timer
+
+ mu sync.RWMutex
+ until time.Time
+}
+
+var _ tcpip.Timer = (*fakeTimer)(nil)
+
+// Reset implements tcpip.Timer.Reset.
+func (ft *fakeTimer) Reset(d time.Duration) {
+ if !ft.timer.Reset(d) {
+ return
+ }
+
+ ft.mu.Lock()
+ defer ft.mu.Unlock()
+
+ ft.clock.removeWait(ft.until)
+ ft.until = ft.clock.clock.Now().Add(d)
+ ft.clock.addWait(ft.until)
+}
+
+// Stop implements tcpip.Timer.Stop.
+func (ft *fakeTimer) Stop() bool {
+ if !ft.timer.Stop() {
+ return false
+ }
+
+ ft.mu.RLock()
+ defer ft.mu.RUnlock()
+
+ ft.clock.removeWait(ft.until)
+ return true
+}
+
+type timeHeap []time.Time
+
+var _ heap.Interface = (*timeHeap)(nil)
+
+func (h timeHeap) Len() int {
+ return len(h)
+}
+
+func (h timeHeap) Less(i, j int) bool {
+ return h[i].Before(h[j])
+}
+
+func (h timeHeap) Swap(i, j int) {
+ h[i], h[j] = h[j], h[i]
+}
+
+func (h *timeHeap) Push(x interface{}) {
+ *h = append(*h, x.(time.Time))
+}
+
+func (h *timeHeap) Pop() interface{} {
+ last := (*h)[len(*h)-1]
+ *h = (*h)[:len(*h)-1]
+ return last
+}
diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go
index 6b64cd37f..3eff141e6 100644
--- a/pkg/tcpip/stack/forwarder.go
+++ b/pkg/tcpip/stack/forwarder.go
@@ -32,7 +32,7 @@ type pendingPacket struct {
nic *NIC
route *Route
proto tcpip.NetworkProtocolNumber
- pkt PacketBuffer
+ pkt *PacketBuffer
}
type forwardQueue struct {
@@ -50,7 +50,7 @@ func newForwardQueue() *forwardQueue {
return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
}
-func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
shouldWait := false
f.Lock()
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 8084d50bc..c962693f5 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
@@ -33,6 +34,10 @@ const (
// except where another value is explicitly used. It is chosen to match
// the MTU of loopback interfaces on linux systems.
fwdTestNetDefaultMTU = 65536
+
+ dstAddrOffset = 0
+ srcAddrOffset = 1
+ protocolNumberOffset = 2
)
// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
@@ -68,16 +73,9 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
return &f.id
}
-func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) {
- // Consume the network header.
- b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
- if !ok {
- return
- }
- pkt.Data.TrimFront(fwdTestNetHeaderLen)
-
+func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -96,13 +94,13 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu
return f.proto.Number()
}
-func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error {
+func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
// Add the protocol's header to the packet and send it to the link
// endpoint.
b := pkt.Header.Prepend(fwdTestNetHeaderLen)
- b[0] = r.RemoteAddress[0]
- b[1] = f.id.LocalAddress[0]
- b[2] = byte(params.Protocol)
+ b[dstAddrOffset] = r.RemoteAddress[0]
+ b[srcAddrOffset] = f.id.LocalAddress[0]
+ b[protocolNumberOffset] = byte(params.Protocol)
return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
}
@@ -112,7 +110,7 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf
panic("not implemented")
}
-func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error {
+func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -123,10 +121,12 @@ func (*fwdTestNetworkEndpoint) Close() {}
type fwdTestNetworkProtocol struct {
addrCache *linkAddrCache
addrResolveDelay time.Duration
- onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address)
+ onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
}
+var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)
+
func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
return fwdTestNetNumber
}
@@ -140,7 +140,17 @@ func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int {
}
func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
+ return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
+}
+
+func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
+ if !ok {
+ return 0, false, false
+ }
+ pkt.NetworkHeader = netHeader
+ pkt.Data.TrimFront(fwdTestNetHeaderLen)
+ return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true
}
func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
@@ -166,10 +176,10 @@ func (f *fwdTestNetworkProtocol) Close() {}
func (f *fwdTestNetworkProtocol) Wait() {}
-func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error {
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
if f.addrCache != nil && f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
- f.onLinkAddressResolved(f.addrCache, addr)
+ f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr)
})
}
return nil
@@ -190,7 +200,7 @@ func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumb
type fwdTestPacketInfo struct {
RemoteLinkAddress tcpip.LinkAddress
LocalLinkAddress tcpip.LinkAddress
- Pkt PacketBuffer
+ Pkt *PacketBuffer
}
type fwdTestLinkEndpoint struct {
@@ -203,13 +213,13 @@ type fwdTestLinkEndpoint struct {
}
// InjectInbound injects an inbound packet.
-func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
e.InjectLinkAddr(protocol, "", pkt)
}
// InjectLinkAddr injects an inbound packet with a remote link address.
-func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) {
- e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt)
+func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt)
}
// Attach saves the stack network-layer dispatcher for use later when packets
@@ -251,7 +261,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error {
+func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
p := fwdTestPacketInfo{
RemoteLinkAddress: r.RemoteLinkAddress,
LocalLinkAddress: r.LocalLinkAddress,
@@ -270,7 +280,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw
func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.WritePacket(r, gso, protocol, *pkt)
+ e.WritePacket(r, gso, protocol, pkt)
n++
}
@@ -280,7 +290,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
p := fwdTestPacketInfo{
- Pkt: PacketBuffer{Data: vv},
+ Pkt: &PacketBuffer{Data: vv},
}
select {
@@ -294,6 +304,16 @@ func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Er
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
@@ -361,8 +381,8 @@ func TestForwardingWithStaticResolver(t *testing.T) {
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -387,7 +407,7 @@ func TestForwardingWithFakeResolver(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any address will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -398,8 +418,8 @@ func TestForwardingWithFakeResolver(t *testing.T) {
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -429,8 +449,8 @@ func TestForwardingWithNoResolver(t *testing.T) {
// inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -445,7 +465,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Only packets to address 3 will be resolved to the
// link address "c".
if addr == "\x03" {
@@ -459,16 +479,16 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Inject an inbound packet to address 4 on NIC 1. This packet should
// not be forwarded.
buf := buffer.NewView(30)
- buf[0] = 4
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 4
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf = buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -480,9 +500,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
- if b[0] != 3 {
- t.Fatalf("got b[0] = %d, want = 3", b[0])
+ if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
@@ -498,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -509,8 +528,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
// Inject two inbound packets to address 3 on NIC 1.
for i := 0; i < 2; i++ {
buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -524,9 +543,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
- if b[0] != 3 {
- t.Fatalf("got b[0] = %d, want = 3", b[0])
+ if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
@@ -543,7 +561,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -554,10 +572,10 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
// Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
buf := buffer.NewView(30)
- buf[0] = 3
+ buf[dstAddrOffset] = 3
// Set the packet sequence number.
binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -571,14 +589,18 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
- if b[0] != 3 {
- t.Fatalf("got b[0] = %d, want = 3", b[0])
+ if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 {
+ t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
}
- // The first 5 packets should not be forwarded so the the
- // sequemnce number should start with 5.
+ seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes).
+ if !ok {
+ t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size())
+ }
+
+ // The first 5 packets should not be forwarded so the sequence number should
+ // start with 5.
want := uint16(i + 5)
- if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want {
+ if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
t.Fatalf("got the packet #%d, want = #%d", n, want)
}
@@ -596,7 +618,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -609,8 +631,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// Each packet has a different destination address (3 to
// maxPendingResolutions + 7).
buf := buffer.NewView(30)
- buf[0] = byte(3 + i)
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = byte(3 + i)
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -626,9 +648,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// The first 5 packets (address 3 to 7) should not be forwarded
// because their address resolutions are interrupted.
- b := p.Pkt.Data.ToView()
- if b[0] < 8 {
- t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0])
+ if p.Pkt.NetworkHeader[dstAddrOffset] < 8 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 443423b3c..110ba073d 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -16,40 +16,49 @@ package stack
import (
"fmt"
- "strings"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-// Table names.
+// tableID is an index into IPTables.tables.
+type tableID int
+
const (
- TablenameNat = "nat"
- TablenameMangle = "mangle"
- TablenameFilter = "filter"
+ natID tableID = iota
+ mangleID
+ filterID
+ numTables
)
-// Chain names as defined by net/ipv4/netfilter/ip_tables.c.
+// Table names.
const (
- ChainNamePrerouting = "PREROUTING"
- ChainNameInput = "INPUT"
- ChainNameForward = "FORWARD"
- ChainNameOutput = "OUTPUT"
- ChainNamePostrouting = "POSTROUTING"
+ NATTable = "nat"
+ MangleTable = "mangle"
+ FilterTable = "filter"
)
+// nameToID is immutable.
+var nameToID = map[string]tableID{
+ NATTable: natID,
+ MangleTable: mangleID,
+ FilterTable: filterID,
+}
+
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
+// reaperDelay is how long to wait before starting to reap connections.
+const reaperDelay = 5 * time.Second
+
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
-func DefaultTables() IPTables {
- // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for
- // iotas.
- return IPTables{
- Tables: map[string]Table{
- TablenameNat: Table{
+func DefaultTables() *IPTables {
+ return &IPTables{
+ tables: [numTables]Table{
+ natID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
@@ -57,65 +66,71 @@ func DefaultTables() IPTables {
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- Underflows: map[Hook]int{
+ Underflows: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- UserChains: map[string]int{},
},
- TablenameMangle: Table{
+ mangleID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Output: 1,
},
- Underflows: map[Hook]int{
- Prerouting: 0,
- Output: 1,
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: 1,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
- TablenameFilter: Table{
+ filterID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
},
- Priorities: map[Hook][]string{
- Input: []string{TablenameNat, TablenameFilter},
- Prerouting: []string{TablenameMangle, TablenameNat},
- Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
+ priorities: [NumHooks][]tableID{
+ Prerouting: []tableID{mangleID, natID},
+ Input: []tableID{natID, filterID},
+ Output: []tableID{mangleID, natID, filterID},
},
- connections: ConnTrackTable{
- CtMap: make(map[uint32]ConnTrackTupleHolder),
- Seed: generateRandUint32(),
+ connections: ConnTrack{
+ seed: generateRandUint32(),
},
+ reaperDone: make(chan struct{}, 1),
}
}
@@ -124,41 +139,61 @@ func DefaultTables() IPTables {
func EmptyFilterTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
}
}
-// EmptyNatTable returns a Table with no rules and the filter table chains
+// EmptyNATTable returns a Table with no rules and the filter table chains
// mapped to HookUnset.
-func EmptyNatTable() Table {
+func EmptyNATTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Forward: HookUnset,
},
- Underflows: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ Underflows: [NumHooks]int{
+ Forward: HookUnset,
},
- UserChains: map[string]int{},
}
}
+// GetTable returns a table by name.
+func (it *IPTables) GetTable(name string) (Table, bool) {
+ id, ok := nameToID[name]
+ if !ok {
+ return Table{}, false
+ }
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ return it.tables[id], true
+}
+
+// ReplaceTable replaces or inserts table by name.
+func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error {
+ id, ok := nameToID[name]
+ if !ok {
+ return tcpip.ErrInvalidOptionValue
+ }
+ it.mu.Lock()
+ defer it.mu.Unlock()
+ // If iptables is being enabled, initialize the conntrack table and
+ // reaper.
+ if !it.modified {
+ it.connections.buckets = make([]bucket, numBuckets)
+ it.startReaper(reaperDelay)
+ }
+ it.modified = true
+ it.tables[id] = table
+ return nil
+}
+
// A chainVerdict is what a table decides should be done with a packet.
type chainVerdict int
@@ -180,13 +215,27 @@ const (
//
// Precondition: pkt.NetworkHeader is set.
func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool {
+ // Many users never configure iptables. Spare them the cost of rule
+ // traversal if rules have never been set.
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ if !it.modified {
+ return true
+ }
+
// Packets are manipulated only if connection and matching
// NAT rule exists.
- it.connections.HandlePacket(pkt, hook, gso, r)
+ shouldTrack := it.connections.handlePacket(pkt, hook, gso, r)
// Go through each table containing the hook.
- for _, tablename := range it.Priorities[hook] {
- table := it.Tables[tablename]
+ priorities := it.priorities[hook]
+ for _, tableID := range priorities {
+ // If handlePacket already NATed the packet, we don't need to
+ // check the NAT table.
+ if tableID == natID && pkt.NatDone {
+ continue
+ }
+ table := it.tables[tableID]
ruleIdx := table.BuiltinChains[hook]
switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
// If the table returns Accept, move on to the next table.
@@ -215,17 +264,59 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
}
}
+ // If this connection should be tracked, try to add an entry for it. If
+ // traversing the nat table didn't end in adding an entry,
+ // maybeInsertNoop will add a no-op entry for the connection. This is
+ // needeed when establishing connections so that the SYN/ACK reply to an
+ // outgoing SYN is delivered to the correct endpoint rather than being
+ // redirected by a prerouting rule.
+ //
+ // From the iptables documentation: "If there is no rule, a `null'
+ // binding is created: this usually does not map the packet, but exists
+ // to ensure we don't map another stream over an existing one."
+ if shouldTrack {
+ it.connections.maybeInsertNoop(pkt, hook)
+ }
+
// Every table returned Accept.
return true
}
+// beforeSave is invoked by stateify.
+func (it *IPTables) beforeSave() {
+ // Ensure the reaper exits cleanly.
+ it.reaperDone <- struct{}{}
+ // Prevent others from modifying the connection table.
+ it.connections.mu.Lock()
+}
+
+// afterLoad is invoked by stateify.
+func (it *IPTables) afterLoad() {
+ it.startReaper(reaperDelay)
+}
+
+// startReaper starts a goroutine that wakes up periodically to reap timed out
+// connections.
+func (it *IPTables) startReaper(interval time.Duration) {
+ go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved.
+ bucket := 0
+ for {
+ select {
+ case <-it.reaperDone:
+ return
+ case <-time.After(interval):
+ bucket, interval = it.connections.reapUnused(bucket, interval)
+ }
+ }
+ }()
+}
+
// CheckPackets runs pkts through the rules for hook and returns a map of packets that
// should not go forward.
//
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-//
-// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
-// precondition.
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
@@ -249,9 +340,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
return drop, natPkts
}
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
-// precondition.
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
@@ -296,25 +387,14 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
return chainDrop
}
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
-// precondition.
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
- // If pkt.NetworkHeader hasn't been set yet, it will be contained in
- // pkt.Data.
- if pkt.NetworkHeader == nil {
- var ok bool
- pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- // Precondition has been violated.
- panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize))
- }
- }
-
// Check whether the packet matches the IP header filter.
- if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader), hook, nicName) {
+ if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
@@ -322,7 +402,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// Go through each rule matcher. If they all match, run
// the rule target.
for _, matcher := range rule.Matchers {
- matches, hotdrop := matcher.Match(hook, *pkt, "")
+ matches, hotdrop := matcher.Match(hook, pkt, "")
if hotdrop {
return RuleDrop, 0
}
@@ -336,46 +416,8 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
return rule.Target.Action(pkt, &it.connections, hook, gso, r, address)
}
-func filterMatch(filter IPHeaderFilter, hdr header.IPv4, hook Hook, nicName string) bool {
- // TODO(gvisor.dev/issue/170): Support other fields of the filter.
- // Check the transport protocol.
- if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() {
- return false
- }
-
- // Check the destination IP.
- dest := hdr.DestinationAddress()
- matches := true
- for i := range filter.Dst {
- if dest[i]&filter.DstMask[i] != filter.Dst[i] {
- matches = false
- break
- }
- }
- if matches == filter.DstInvert {
- return false
- }
-
- // Check the output interface.
- // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
- // hooks after supported.
- if hook == Output {
- n := len(filter.OutputInterface)
- if n == 0 {
- return true
- }
-
- // If the interface name ends with '+', any interface which begins
- // with the name should be matched.
- ifName := filter.OutputInterface
- matches = true
- if strings.HasSuffix(ifName, "+") {
- matches = strings.HasPrefix(nicName, ifName[:n-1])
- } else {
- matches = nicName == ifName
- }
- return filter.OutputInterfaceInvert != matches
- }
-
- return true
+// OriginalDst returns the original destination of redirected connections. It
+// returns an error if the connection doesn't exist or isn't redirected.
+func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+ return it.connections.originalDst(epID)
}
diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go
new file mode 100644
index 000000000..529e02a07
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_state.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "time"
+)
+
+// +stateify savable
+type unixTime struct {
+ second int64
+ nano int64
+}
+
+// saveLastUsed is invoked by stateify.
+func (cn *conn) saveLastUsed() unixTime {
+ return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()}
+}
+
+// loadLastUsed is invoked by stateify.
+func (cn *conn) loadLastUsed(unix unixTime) {
+ cn.lastUsed = time.Unix(unix.second, unix.nano)
+}
+
+// beforeSave is invoked by stateify.
+func (ct *ConnTrack) beforeSave() {
+ ct.mu.Lock()
+}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 36cc6275d..dc88033c7 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -24,7 +24,7 @@ import (
type AcceptTarget struct{}
// Action implements Target.Action.
-func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -32,7 +32,7 @@ func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, t
type DropTarget struct{}
// Action implements Target.Action.
-func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -41,7 +41,7 @@ func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcp
type ErrorTarget struct{}
// Action implements Target.Action.
-func (ErrorTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -52,7 +52,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -61,7 +61,7 @@ func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route
type ReturnTarget struct{}
// Action implements Target.Action.
-func (ReturnTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
@@ -92,17 +92,12 @@ type RedirectTarget struct {
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
// implementation only works for PREROUTING and calls pkt.Clone(), neither
// of which should be the case.
-func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
}
- // Set network header.
- if hook == Prerouting {
- parseHeaders(pkt)
- }
-
// Drop the packet if network and transport header are not set.
if pkt.NetworkHeader == nil || pkt.TransportHeader == nil {
return RuleDrop, 0
@@ -155,12 +150,11 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook
return RuleAccept, 0
}
- // Set up conection for matching NAT rule.
- // Only the first packet of the connection comes here.
- // Other packets will be manipulated in connection tracking.
- if conn, _ := ct.connTrackForPacket(pkt, hook, true); conn != nil {
- ct.SetNatInfo(pkt, rt, hook)
- ct.HandlePacket(pkt, hook, gso, r)
+ // Set up conection for matching NAT rule. Only the first
+ // packet of the connection comes here. Other packets will be
+ // manipulated in connection tracking.
+ if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil {
+ ct.handlePacket(pkt, hook, gso, r)
}
default:
return RuleDrop, 0
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index fe06007ae..73274ada9 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -15,7 +15,11 @@
package stack
import (
+ "strings"
+ "sync"
+
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
// A Hook specifies one of the hooks built into the network stack.
@@ -74,63 +78,65 @@ const (
)
// IPTables holds all the tables for a netstack.
+//
+// +stateify savable
type IPTables struct {
- // Tables maps table names to tables. User tables have arbitrary names.
- Tables map[string]Table
+ // mu protects tables, priorities, and modified.
+ mu sync.RWMutex
- // Priorities maps each hook to a list of table names. The order of the
+ // tables maps tableIDs to tables. Holds builtin tables only, not user
+ // tables. mu must be locked for accessing.
+ tables [numTables]Table
+
+ // priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
- // hook.
- Priorities map[Hook][]string
+ // hook. mu needs to be locked for accessing.
+ priorities [NumHooks][]tableID
+
+ // modified is whether tables have been modified at least once. It is
+ // used to elide the iptables performance overhead for workloads that
+ // don't utilize iptables.
+ modified bool
- connections ConnTrackTable
+ connections ConnTrack
+
+ // reaperDone can be signalled to stop the reaper goroutine.
+ reaperDone chan struct{}
}
// A Table defines a set of chains and hooks into the network stack. It is
-// really just a list of rules with some metadata for entrypoints and such.
+// really just a list of rules.
+//
+// +stateify savable
type Table struct {
// Rules holds the rules that make up the table.
Rules []Rule
// BuiltinChains maps builtin chains to their entrypoint rule in Rules.
- BuiltinChains map[Hook]int
+ BuiltinChains [NumHooks]int
// Underflows maps builtin chains to their underflow rule in Rules
// (i.e. the rule to execute if the chain returns without a verdict).
- Underflows map[Hook]int
-
- // UserChains holds user-defined chains for the keyed by name. Users
- // can give their chains arbitrary names.
- UserChains map[string]int
-
- // Metadata holds information about the Table that is useful to users
- // of IPTables, but not to the netstack IPTables code itself.
- metadata interface{}
+ Underflows [NumHooks]int
}
// ValidHooks returns a bitmap of the builtin hooks for the given table.
func (table *Table) ValidHooks() uint32 {
hooks := uint32(0)
- for hook := range table.BuiltinChains {
- hooks |= 1 << hook
+ for hook, ruleIdx := range table.BuiltinChains {
+ if ruleIdx != HookUnset {
+ hooks |= 1 << hook
+ }
}
return hooks
}
-// Metadata returns the metadata object stored in table.
-func (table *Table) Metadata() interface{} {
- return table.metadata
-}
-
-// SetMetadata sets the metadata object stored in table.
-func (table *Table) SetMetadata(metadata interface{}) {
- table.metadata = metadata
-}
-
// A Rule is a packet processing rule. It consists of two pieces. First it
// contains zero or more matchers, each of which is a specification of which
// packets this rule applies to. If there are no matchers in the rule, it
// applies to any packet.
+//
+// +stateify savable
type Rule struct {
// Filter holds basic IP filtering fields common to every rule.
Filter IPHeaderFilter
@@ -143,6 +149,8 @@ type Rule struct {
}
// IPHeaderFilter holds basic IP filtering data common to every rule.
+//
+// +stateify savable
type IPHeaderFilter struct {
// Protocol matches the transport protocol.
Protocol tcpip.TransportProtocolNumber
@@ -159,6 +167,16 @@ type IPHeaderFilter struct {
// comparison.
DstInvert bool
+ // Src matches the source IP address.
+ Src tcpip.Address
+
+ // SrcMask masks bits of the source IP address when comparing with Src.
+ SrcMask tcpip.Address
+
+ // SrcInvert inverts the meaning of the source IP check, i.e. when true the
+ // filter will match packets that fail the source comparison.
+ SrcInvert bool
+
// OutputInterface matches the name of the outgoing interface for the
// packet.
OutputInterface string
@@ -173,6 +191,55 @@ type IPHeaderFilter struct {
OutputInterfaceInvert bool
}
+// match returns whether hdr matches the filter.
+func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool {
+ // TODO(gvisor.dev/issue/170): Support other fields of the filter.
+ // Check the transport protocol.
+ if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() {
+ return false
+ }
+
+ // Check the source and destination IPs.
+ if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) {
+ return false
+ }
+
+ // Check the output interface.
+ // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
+ // hooks after supported.
+ if hook == Output {
+ n := len(fl.OutputInterface)
+ if n == 0 {
+ return true
+ }
+
+ // If the interface name ends with '+', any interface which begins
+ // with the name should be matched.
+ ifName := fl.OutputInterface
+ matches := true
+ if strings.HasSuffix(ifName, "+") {
+ matches = strings.HasPrefix(nicName, ifName[:n-1])
+ } else {
+ matches = nicName == ifName
+ }
+ return fl.OutputInterfaceInvert != matches
+ }
+
+ return true
+}
+
+// filterAddress returns whether addr matches the filter.
+func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool {
+ matches := true
+ for i := range filterAddr {
+ if addr[i]&mask[i] != filterAddr[i] {
+ matches = false
+ break
+ }
+ }
+ return matches != invert
+}
+
// A Matcher is the interface for matching packets.
type Matcher interface {
// Name returns the name of the Matcher.
@@ -183,7 +250,7 @@ type Matcher interface {
// used for suspicious packets.
//
// Precondition: packet.NetworkHeader is set.
- Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
+ Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
// A Target is the interface for taking an action for a packet.
@@ -191,5 +258,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(packet *PacketBuffer, connections *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int)
+ Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 403557fd7..6f73a0ce4 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -244,7 +244,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
+ linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP)
select {
case now := <-time.After(c.resolutionTimeout):
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 1baa498d0..b15b8d1cb 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -48,7 +48,7 @@ type testLinkAddressResolver struct {
onLinkAddressRequest func()
}
-func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
if f := r.onLinkAddressRequest; f != nil {
f()
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 526c7d6ff..5174e639c 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -33,12 +33,6 @@ const (
// Default = 1 (from RFC 4862 section 5.1)
defaultDupAddrDetectTransmits = 1
- // defaultRetransmitTimer is the default amount of time to wait between
- // sending NDP Neighbor solicitation messages.
- //
- // Default = 1s (from RFC 4861 section 10).
- defaultRetransmitTimer = time.Second
-
// defaultMaxRtrSolicitations is the default number of Router
// Solicitation messages to send when a NIC becomes enabled.
//
@@ -79,16 +73,6 @@ const (
// Default = true.
defaultAutoGenGlobalAddresses = true
- // minimumRetransmitTimer is the minimum amount of time to wait between
- // sending NDP Neighbor solicitation messages. Note, RFC 4861 does
- // not impose a minimum Retransmit Timer, but we do here to make sure
- // the messages are not sent all at once. We also come to this value
- // because in the RetransmitTimer field of a Router Advertisement, a
- // value of 0 means unspecified, so the smallest valid value is 1.
- // Note, the unit of the RetransmitTimer field in the Router
- // Advertisement is milliseconds.
- minimumRetransmitTimer = time.Millisecond
-
// minimumRtrSolicitationInterval is the minimum amount of time to wait
// between sending Router Solicitation messages. This limit is imposed
// to make sure that Router Solicitation messages are not sent all at
@@ -467,8 +451,17 @@ type ndpState struct {
// The default routers discovered through Router Advertisements.
defaultRouters map[tcpip.Address]defaultRouterState
- // The timer used to send the next router solicitation message.
- rtrSolicitTimer *time.Timer
+ rtrSolicit struct {
+ // The timer used to send the next router solicitation message.
+ timer tcpip.Timer
+
+ // Used to let the Router Solicitation timer know that it has been stopped.
+ //
+ // Must only be read from or written to while protected by the lock of
+ // the NIC this ndpState is associated with. MUST be set when the timer is
+ // set.
+ done *bool
+ }
// The on-link prefixes discovered through Router Advertisements' Prefix
// Information option.
@@ -494,7 +487,7 @@ type ndpState struct {
// to the DAD goroutine that DAD should stop.
type dadState struct {
// The DAD timer to send the next NS message, or resolve the address.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the DAD timer know that it has been stopped.
//
@@ -506,38 +499,38 @@ type dadState struct {
// defaultRouterState holds data associated with a default router discovered by
// a Router Advertisement (RA).
type defaultRouterState struct {
- // Timer to invalidate the default router.
+ // Job to invalidate the default router.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// onLinkPrefixState holds data associated with an on-link prefix discovered by
// a Router Advertisement's Prefix Information option (PI) when the NDP
// configurations was configured to do so.
type onLinkPrefixState struct {
- // Timer to invalidate the on-link prefix.
+ // Job to invalidate the on-link prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// tempSLAACAddrState holds state associated with a temporary SLAAC address.
type tempSLAACAddrState struct {
- // Timer to deprecate the temporary SLAAC address.
+ // Job to deprecate the temporary SLAAC address.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the temporary SLAAC address.
+ // Job to invalidate the temporary SLAAC address.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
- // Timer to regenerate the temporary SLAAC address.
+ // Job to regenerate the temporary SLAAC address.
//
// Must not be nil.
- regenTimer *tcpip.CancellableTimer
+ regenJob *tcpip.Job
createdAt time.Time
@@ -552,15 +545,15 @@ type tempSLAACAddrState struct {
// slaacPrefixState holds state associated with a SLAAC prefix.
type slaacPrefixState struct {
- // Timer to deprecate the prefix.
+ // Job to deprecate the prefix.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the prefix.
+ // Job to invalidate the prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
// Nonzero only when the address is not valid forever.
validUntil time.Time
@@ -642,19 +635,20 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
}
var done bool
- var timer *time.Timer
+ var timer tcpip.Timer
// We initially start a timer to fire immediately because some of the DAD work
// cannot be done while holding the NIC's lock. This is effectively the same
// as starting a goroutine but we use a timer that fires immediately so we can
// reset it for the next DAD iteration.
- timer = time.AfterFunc(0, func() {
- ndp.nic.mu.RLock()
+ timer = ndp.nic.stack.Clock().AfterFunc(0, func() {
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+
if done {
// If we reach this point, it means that the DAD timer fired after
// another goroutine already obtained the NIC lock and stopped DAD
// before this function obtained the NIC lock. Simply return here and do
// nothing further.
- ndp.nic.mu.RUnlock()
return
}
@@ -665,15 +659,23 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
}
dadDone := remaining == 0
- ndp.nic.mu.RUnlock()
var err *tcpip.Error
if !dadDone {
- err = ndp.sendDADPacket(addr)
+ // Use the unspecified address as the source address when performing DAD.
+ ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
+
+ // Do not hold the lock when sending packets which may be a long running
+ // task or may block link address resolution. We know this is safe
+ // because immediately after obtaining the lock again, we check if DAD
+ // has been stopped before doing any work with the NIC. Note, DAD would be
+ // stopped if the NIC was disabled or removed, or if the address was
+ // removed.
+ ndp.nic.mu.Unlock()
+ err = ndp.sendDADPacket(addr, ref)
+ ndp.nic.mu.Lock()
}
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
if done {
// If we reach this point, it means that DAD was stopped after we released
// the NIC's read lock and before we obtained the write lock.
@@ -721,17 +723,24 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// addr.
//
// addr must be a tentative IPv6 address on ndp's NIC.
-func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
+//
+// The NIC ndp belongs to MUST NOT be locked.
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- // Use the unspecified address as the source address when performing DAD.
- ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
- r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
// Route should resolve immediately since snmc is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the NIC is not locked,
+ // the NIC could have been disabled or removed by another goroutine.
+ if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
+ return err
+ }
+
panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err))
} else if c != nil {
panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()))
@@ -750,7 +759,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
TOS: DefaultTOS,
- }, PacketBuffer{Header: hdr},
+ }, &PacketBuffer{Header: hdr},
); err != nil {
sent.Dropped.Increment()
return err
@@ -846,9 +855,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
case ok && rl != 0:
// This is an already discovered default router. Update
- // the invalidation timer.
- rtr.invalidationTimer.StopLocked()
- rtr.invalidationTimer.Reset(rl)
+ // the invalidation job.
+ rtr.invalidationJob.Cancel()
+ rtr.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = rtr
case ok && rl == 0:
@@ -925,7 +934,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
return
}
- rtr.invalidationTimer.StopLocked()
+ rtr.invalidationJob.Cancel()
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
@@ -954,12 +963,12 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
}
state := defaultRouterState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
}
- state.invalidationTimer.Reset(rl)
+ state.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = state
}
@@ -984,13 +993,13 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
}
state := onLinkPrefixState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
ndp.invalidateOnLinkPrefix(prefix)
}),
}
if l < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(l)
+ state.invalidationJob.Schedule(l)
}
ndp.onLinkPrefixes[prefix] = state
@@ -1008,7 +1017,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
return
}
- s.invalidationTimer.StopLocked()
+ s.invalidationJob.Cancel()
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
@@ -1057,14 +1066,14 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// This is an already discovered on-link prefix with a
// new non-zero valid lifetime.
//
- // Update the invalidation timer.
+ // Update the invalidation job.
- prefixState.invalidationTimer.StopLocked()
+ prefixState.invalidationJob.Cancel()
if vl < header.NDPInfiniteLifetime {
- // Prefix is valid for a finite lifetime, reset the timer to expire after
+ // Prefix is valid for a finite lifetime, schedule the job to execute after
// the new valid lifetime.
- prefixState.invalidationTimer.Reset(vl)
+ prefixState.invalidationJob.Schedule(vl)
}
ndp.onLinkPrefixes[prefix] = prefixState
@@ -1129,7 +1138,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
state := slaacPrefixState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
@@ -1137,7 +1146,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
ndp.deprecateSLAACAddress(state.stableAddr.ref)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
@@ -1159,19 +1168,19 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
if !ndp.generateSLAACAddr(prefix, &state) {
// We were unable to generate an address for the prefix, we do not nothing
- // further as there is no reason to maintain state or timers for a prefix we
+ // further as there is no reason to maintain state or jobs for a prefix we
// do not have an address for.
return
}
- // Setup the initial timers to deprecate and invalidate prefix.
+ // Setup the initial jobs to deprecate and invalidate prefix.
if pl < header.NDPInfiniteLifetime && pl != 0 {
- state.deprecationTimer.Reset(pl)
+ state.deprecationJob.Schedule(pl)
}
if vl < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(vl)
+ state.invalidationJob.Schedule(vl)
state.validUntil = now.Add(vl)
}
@@ -1403,7 +1412,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
state := tempSLAACAddrState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
@@ -1416,7 +1425,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.deprecateSLAACAddress(tempAddrState.ref)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
@@ -1429,7 +1438,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
}),
- regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
@@ -1456,9 +1465,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ref: ref,
}
- state.deprecationTimer.Reset(pl)
- state.invalidationTimer.Reset(vl)
- state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration)
+ state.deprecationJob.Schedule(pl)
+ state.invalidationJob.Schedule(vl)
+ state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration)
prefixState.generationAttempts++
prefixState.tempAddrs[generatedAddr.Address] = state
@@ -1493,16 +1502,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
prefixState.stableAddr.ref.deprecated = false
}
- // If prefix was preferred for some finite lifetime before, stop the
- // deprecation timer so it can be reset.
- prefixState.deprecationTimer.StopLocked()
+ // If prefix was preferred for some finite lifetime before, cancel the
+ // deprecation job so it can be reset.
+ prefixState.deprecationJob.Cancel()
now := time.Now()
- // Reset the deprecation timer if prefix has a finite preferred lifetime.
+ // Schedule the deprecation job if prefix has a finite preferred lifetime.
if pl < header.NDPInfiniteLifetime {
if !deprecated {
- prefixState.deprecationTimer.Reset(pl)
+ prefixState.deprecationJob.Schedule(pl)
}
prefixState.preferredUntil = now.Add(pl)
} else {
@@ -1521,9 +1530,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
if vl >= header.NDPInfiniteLifetime {
- // Handle the infinite valid lifetime separately as we do not keep a timer
- // in this case.
- prefixState.invalidationTimer.StopLocked()
+ // Handle the infinite valid lifetime separately as we do not schedule a
+ // job in this case.
+ prefixState.invalidationJob.Cancel()
prefixState.validUntil = time.Time{}
} else {
var effectiveVl time.Duration
@@ -1544,8 +1553,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
if effectiveVl != 0 {
- prefixState.invalidationTimer.StopLocked()
- prefixState.invalidationTimer.Reset(effectiveVl)
+ prefixState.invalidationJob.Cancel()
+ prefixState.invalidationJob.Schedule(effectiveVl)
prefixState.validUntil = now.Add(effectiveVl)
}
}
@@ -1557,7 +1566,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// Note, we do not need to update the entries in the temporary address map
- // after updating the timers because the timers are held as pointers.
+ // after updating the jobs because the jobs are held as pointers.
var regenForAddr tcpip.Address
allAddressesRegenerated := true
for tempAddr, tempAddrState := range prefixState.tempAddrs {
@@ -1571,14 +1580,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer valid, invalidate it immediately. Otherwise,
- // reset the invalidation timer.
+ // reset the invalidation job.
newValidLifetime := validUntil.Sub(now)
if newValidLifetime <= 0 {
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState)
continue
}
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.invalidationTimer.Reset(newValidLifetime)
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.invalidationJob.Schedule(newValidLifetime)
// As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
// address is the lower of the preferred lifetime of the stable address or
@@ -1591,17 +1600,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer preferred, deprecate it immediately.
- // Otherwise, reset the deprecation timer.
+ // Otherwise, schedule the deprecation job again.
newPreferredLifetime := preferredUntil.Sub(now)
- tempAddrState.deprecationTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
if newPreferredLifetime <= 0 {
ndp.deprecateSLAACAddress(tempAddrState.ref)
} else {
tempAddrState.ref.deprecated = false
- tempAddrState.deprecationTimer.Reset(newPreferredLifetime)
+ tempAddrState.deprecationJob.Schedule(newPreferredLifetime)
}
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.regenJob.Cancel()
if tempAddrState.regenerated {
} else {
allAddressesRegenerated = false
@@ -1612,7 +1621,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// immediately after we finish iterating over the temporary addresses.
regenForAddr = tempAddr
} else {
- tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
+ tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
}
}
}
@@ -1692,7 +1701,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr
ndp.cleanupSLAACPrefixResources(prefix, state)
}
-// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry.
+// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry.
//
// Panics if the SLAAC prefix is not known.
//
@@ -1704,8 +1713,8 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa
}
state.stableAddr.ref = nil
- state.deprecationTimer.StopLocked()
- state.invalidationTimer.StopLocked()
+ state.deprecationJob.Cancel()
+ state.invalidationJob.Cancel()
delete(ndp.slaacPrefixes, prefix)
}
@@ -1750,13 +1759,13 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi
}
// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
-// timers and entry.
+// jobs and entry.
//
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
- tempAddrState.deprecationTimer.StopLocked()
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.regenJob.Cancel()
delete(tempAddrs, tempAddr)
}
@@ -1816,7 +1825,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
//
// The NIC ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
- if ndp.rtrSolicitTimer != nil {
+ if ndp.rtrSolicit.timer != nil {
// We are already soliciting routers.
return
}
@@ -1833,14 +1842,27 @@ func (ndp *ndpState) startSolicitingRouters() {
delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
}
- ndp.rtrSolicitTimer = time.AfterFunc(delay, func() {
+ var done bool
+ ndp.rtrSolicit.done = &done
+ ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() {
+ ndp.nic.mu.Lock()
+ if done {
+ // If we reach this point, it means that the RS timer fired after another
+ // goroutine already obtained the NIC lock and stopped solicitations.
+ // Simply return here and do nothing further.
+ ndp.nic.mu.Unlock()
+ return
+ }
+
// As per RFC 4861 section 4.1, the source of the RS is an address assigned
// to the sending interface, or the unspecified address if no address is
// assigned to the sending interface.
- ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress)
+ ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress)
if ref == nil {
- ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
+ ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
}
+ ndp.nic.mu.Unlock()
+
localAddr := ref.ep.ID().LocalAddress
r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
@@ -1849,6 +1871,13 @@ func (ndp *ndpState) startSolicitingRouters() {
// header.IPv6AllRoutersMulticastAddress is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the NIC is not locked,
+ // the NIC could have been disabled or removed by another goroutine.
+ if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
+ return
+ }
+
panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err))
} else if c != nil {
panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()))
@@ -1881,7 +1910,7 @@ func (ndp *ndpState) startSolicitingRouters() {
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
TOS: DefaultTOS,
- }, PacketBuffer{Header: hdr},
+ }, &PacketBuffer{Header: hdr},
); err != nil {
sent.Dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
@@ -1893,17 +1922,18 @@ func (ndp *ndpState) startSolicitingRouters() {
}
ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
- if remaining == 0 {
- ndp.rtrSolicitTimer = nil
- } else if ndp.rtrSolicitTimer != nil {
+ if done || remaining == 0 {
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
+ } else if ndp.rtrSolicit.timer != nil {
// Note, we need to explicitly check to make sure that
// the timer field is not nil because if it was nil but
// we still reached this point, then we know the NIC
// was requested to stop soliciting routers so we don't
// need to send the next Router Solicitation message.
- ndp.rtrSolicitTimer.Reset(ndp.configs.RtrSolicitationInterval)
+ ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
}
+ ndp.nic.mu.Unlock()
})
}
@@ -1913,13 +1943,15 @@ func (ndp *ndpState) startSolicitingRouters() {
//
// The NIC ndp belongs to MUST be locked.
func (ndp *ndpState) stopSolicitingRouters() {
- if ndp.rtrSolicitTimer == nil {
+ if ndp.rtrSolicit.timer == nil {
// Nothing to do.
return
}
- ndp.rtrSolicitTimer.Stop()
- ndp.rtrSolicitTimer = nil
+ *ndp.rtrSolicit.done = true
+ ndp.rtrSolicit.timer.Stop()
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
}
// initializeTempAddrState initializes state related to temporary SLAAC
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index b3d174cdd..644ba7c33 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -36,15 +36,24 @@ import (
)
const (
- addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
- linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
- linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
- linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
- defaultTimeout = 100 * time.Millisecond
- defaultAsyncEventTimeout = time.Second
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
+ linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
+ linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
+
+ // Extra time to use when waiting for an async event to occur.
+ defaultAsyncPositiveEventTimeout = 10 * time.Second
+
+ // Extra time to use when waiting for an async event to not occur.
+ //
+ // Since a negative check is used to make sure an event did not happen, it is
+ // okay to use a smaller timeout compared to the positive case since execution
+ // stall in regards to the monotonic clock will not affect the expected
+ // outcome.
+ defaultAsyncNegativeEventTimeout = time.Second
)
var (
@@ -421,45 +430,90 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
+ // We add a default route so the call to FindRoute below will succeed
+ // once we have an assigned address.
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: addr3,
+ NIC: nicID,
+ }})
+
if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); addr != want {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
// Make sure the address does not resolve before the resolution time has
// passed.
- time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout)
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout)
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+ // Should not get a route even if we specify the local address as the
+ // tentative address.
+ {
+ r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
+ if err != tcpip.ErrNoRoute {
+ t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
+ }
+ r.Release()
}
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ {
+ r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
+ if err != tcpip.ErrNoRoute {
+ t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
+ }
+ r.Release()
+ }
+
+ if t.Failed() {
+ t.FailNow()
}
// Wait for DAD to resolve.
select {
- case <-time.After(2 * defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if addr.Address != addr1 {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
}
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
+ // Should get a route using the address now that it is resolved.
+ {
+ r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Errorf("got FindRoute(%d, '', %s, %d, false): %s", nicID, addr2, header.IPv6ProtocolNumber, err)
+ } else if r.LocalAddress != addr1 {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
+ }
+ r.Release()
+ }
+ {
+ r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Errorf("got FindRoute(%d, %s, %s, %d, false): %s", nicID, addr1, addr2, header.IPv6ProtocolNumber, err)
+ } else if r.LocalAddress != addr1 {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
+ }
+ r.Release()
+ }
+
+ if t.Failed() {
+ t.FailNow()
}
// Should not have sent any more NS messages.
@@ -613,7 +667,7 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate multiple nodes owning or
// attempting to own the same address.
hdr := test.makeBuf(addr1)
- e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -935,7 +989,7 @@ func TestSetNDPConfigurations(t *testing.T) {
// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
// and DHCPv6 configurations specified.
-func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer {
+func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
@@ -970,14 +1024,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
DstAddr: header.IPv6AllNodesMulticastAddress,
})
- return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()}
+ return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()}
}
// raBufWithOpts returns a valid NDP Router Advertisement with options.
//
// Note, raBufWithOpts does not populate any of the RA fields other than the
// Router Lifetime.
-func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer {
+func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
}
@@ -986,7 +1040,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ
//
// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
// DHCPv6 related ones.
-func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer {
+func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer {
return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
}
@@ -994,7 +1048,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo
//
// Note, raBuf does not populate any of the RA fields other than the
// Router Lifetime.
-func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer {
+func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer {
return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{})
}
@@ -1003,7 +1057,7 @@ func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer {
//
// Note, raBufWithPI does not populate any of the RA fields other than the
// Router Lifetime.
-func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer {
+func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer {
flags := uint8(0)
if onLink {
// The OnLink flag is the 7th bit in the flags byte.
@@ -1124,7 +1178,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
select {
case <-ndpDisp.routerC:
t.Fatal("should not have received any router events")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
}
@@ -1200,14 +1254,14 @@ func TestRouterDiscovery(t *testing.T) {
default:
}
- // Wait for lladdr2's router invalidation timer to fire. The lifetime
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
// Rx an RA from lladdr2 with huge lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
@@ -1217,14 +1271,14 @@ func TestRouterDiscovery(t *testing.T) {
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
expectRouterEvent(llAddr2, false)
- // Wait for lladdr3's router invalidation timer to fire. The lifetime
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
}
// TestRouterDiscoveryMaxRouters tests that only
@@ -1373,7 +1427,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("should not have received any prefix events")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
}
@@ -1448,14 +1502,14 @@ func TestPrefixDiscovery(t *testing.T) {
default:
}
- // Wait for prefix2's most recent invalidation timer plus some buffer to
+ // Wait for prefix2's most recent invalidation job plus some buffer to
// expire.
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for prefix discovery event")
}
@@ -1520,7 +1574,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After(testInfiniteLifetime + defaultTimeout):
+ case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
}
// Receive an RA with finite lifetime.
@@ -1545,7 +1599,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After(testInfiniteLifetime + defaultTimeout):
+ case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
}
// Receive an RA with a prefix with a lifetime value greater than the
@@ -1554,7 +1608,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout):
+ case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout):
}
// Receive an RA with 0 lifetime.
@@ -1790,7 +1844,7 @@ func TestAutoGenAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
@@ -1917,7 +1971,7 @@ func TestAutoGenTempAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -1930,7 +1984,7 @@ func TestAutoGenTempAddr(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
}
@@ -2036,10 +2090,10 @@ func TestAutoGenTempAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- case <-time.After(newMinVLDuration + defaultTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
@@ -2135,7 +2189,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
@@ -2143,7 +2197,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Errorf("got unxpected auto gen addr event = %+v", e)
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
})
}
@@ -2220,7 +2274,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
select {
@@ -2228,7 +2282,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -2318,13 +2372,13 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
// Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" {
t.Fatal(mismatch)
}
// Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" {
t.Fatal(mismatch)
}
@@ -2341,7 +2395,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
for _, addr := range tempAddrs {
// Wait for a deprecation then invalidation event, or just an invalidation
// event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation timers could fire in any
+ // cases because the deprecation and invalidation jobs could execute in any
// order.
select {
case e := <-ndpDisp.autoGenAddrC:
@@ -2353,7 +2407,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
} else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" {
@@ -2362,12 +2416,12 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated event = %+v", e)
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
} else {
t.Fatalf("got unexpected auto-generated event = %+v", e)
}
- case <-time.After(invalidateAfter + defaultAsyncEventTimeout):
+ case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
@@ -2378,9 +2432,9 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
}
-// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's
-// regeneration timer gets updated when refreshing the address's lifetimes.
-func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
+// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's
+// regeneration job gets updated when refreshing the address's lifetimes.
+func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
const (
nicID = 1
regenAfter = 2 * time.Second
@@ -2472,14 +2526,14 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpected auto gen addr event = %+v", e)
- case <-time.After(regenAfter + defaultAsyncEventTimeout):
+ case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
}
// Prefer the prefix again.
//
// A new temporary address should immediately be generated since the
// regeneration time has already passed since the last address was generated
- // - this regeneration does not depend on a timer.
+ // - this regeneration does not depend on a job.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEvent(tempAddr2, newAddr)
@@ -2501,24 +2555,24 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpected auto gen addr event = %+v", e)
- case <-time.After(regenAfter + defaultAsyncEventTimeout):
+ case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
}
// Set the maximum lifetimes for temporary addresses such that on the next
- // RA, the regeneration timer gets reset.
+ // RA, the regeneration job gets scheduled again.
//
// The maximum lifetime is the sum of the minimum lifetimes for temporary
// addresses + the time that has already passed since the last address was
- // generated so that the regeneration timer is needed to generate the next
+ // generated so that the regeneration job is needed to generate the next
// address.
- newLifetimes := newMinVLDuration + regenAfter + defaultAsyncEventTimeout
+ newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes
if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
}
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
- expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
}
// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response
@@ -2666,7 +2720,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -2679,7 +2733,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
}
@@ -2939,9 +2993,9 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
expectPrimaryAddr(addr2)
}
-// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated
+// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated
// when its preferred lifetime expires.
-func TestAutoGenAddrTimerDeprecation(t *testing.T) {
+func TestAutoGenAddrJobDeprecation(t *testing.T) {
const nicID = 1
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
@@ -3025,7 +3079,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3065,7 +3119,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3079,7 +3133,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3111,7 +3165,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
} else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
@@ -3120,12 +3174,12 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
} else {
t.Fatalf("got unexpected auto-generated event")
}
- case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3250,7 +3304,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(minVLSeconds*time.Second + defaultAsyncEventTimeout):
+ case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout):
t.Fatal("timeout waiting for addr auto gen event")
}
})
@@ -3394,7 +3448,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout):
}
// Wait for the invalidation event.
@@ -3403,7 +3457,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(2 * defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timeout waiting for addr auto gen event")
}
})
@@ -3459,12 +3513,12 @@ func TestAutoGenAddrRemoval(t *testing.T) {
}
expectAutoGenAddrEvent(addr, invalidatedAddr)
- // Wait for the original valid lifetime to make sure the original timer
- // got stopped/cleaned up.
+ // Wait for the original valid lifetime to make sure the original job got
+ // cancelled/cleaned up.
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
}
@@ -3627,7 +3681,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
@@ -3725,7 +3779,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout):
+ case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3792,7 +3846,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -3818,7 +3872,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
}
@@ -3985,7 +4039,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
})
}
@@ -4104,7 +4158,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
})
}
@@ -4206,7 +4260,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
@@ -4232,7 +4286,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation")
}
} else {
@@ -4240,7 +4294,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
}
- case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for auto gen addr event")
}
}
@@ -4824,7 +4878,7 @@ func TestCleanupNDPState(t *testing.T) {
// Should not get any more events (invalidation timers should have been
// cancelled when the NDP state was cleaned up).
- time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
+ time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout)
select {
case <-ndpDisp.routerC:
t.Error("unexpected router event")
@@ -5127,24 +5181,24 @@ func TestRouterSolicitation(t *testing.T) {
// Make sure each RS is sent at the right time.
remaining := test.maxRtrSolicit
if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
+ waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout)
remaining--
}
for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncEventTimeout {
- waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncEventTimeout)
- waitForPkt(2 * defaultAsyncEventTimeout)
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout)
+ waitForPkt(defaultAsyncPositiveEventTimeout)
} else {
- waitForPkt(test.effectiveRtrSolicitInt * defaultAsyncEventTimeout)
+ waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout)
}
}
// Make sure no more RS.
if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncEventTimeout)
+ waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout)
} else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
+ waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout)
}
// Make sure the counter got properly
@@ -5260,11 +5314,11 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stop soliciting routers.
test.stopFn(t, s, true /* first */)
- ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
// A single RS may have been sent before solicitations were stopped.
- ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok = e.ReadContext(ctx); ok {
t.Fatal("should not have sent more than one RS message")
@@ -5274,7 +5328,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stopping router solicitations after it has already been stopped should
// do nothing.
test.stopFn(t, s, false /* first */)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
@@ -5287,10 +5341,10 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Start soliciting routers.
test.startFn(t, s)
- waitForPkt(delay + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout)
+ waitForPkt(delay + defaultAsyncPositiveEventTimeout)
+ waitForPkt(interval + defaultAsyncPositiveEventTimeout)
+ waitForPkt(interval + defaultAsyncPositiveEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
@@ -5299,7 +5353,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Starting router solicitations after it has already completed should do
// nothing.
test.startFn(t, s)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after finishing router solicitations")
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
new file mode 100644
index 000000000..1d37716c2
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -0,0 +1,335 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const neighborCacheSize = 512 // max entries per interface
+
+// neighborCache maps IP addresses to link addresses. It uses the Least
+// Recently Used (LRU) eviction strategy to implement a bounded cache for
+// dynmically acquired entries. It contains the state machine and configuration
+// for running Neighbor Unreachability Detection (NUD).
+//
+// There are two types of entries in the neighbor cache:
+// 1. Dynamic entries are discovered automatically by neighbor discovery
+// protocols (e.g. ARP, NDP). These protocols will attempt to reconfirm
+// reachability with the device once the entry's state becomes Stale.
+// 2. Static entries are explicitly added by a user and have no expiration.
+// Their state is always Static. The amount of static entries stored in the
+// cache is unbounded.
+//
+// neighborCache implements NUDHandler.
+type neighborCache struct {
+ nic *NIC
+ state *NUDState
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ cache map[tcpip.Address]*neighborEntry
+ dynamic struct {
+ lru neighborEntryList
+
+ // count tracks the amount of dynamic entries in the cache. This is
+ // needed since static entries do not count towards the LRU cache
+ // eviction strategy.
+ count uint16
+ }
+}
+
+var _ NUDHandler = (*neighborCache)(nil)
+
+// getOrCreateEntry retrieves a cache entry associated with addr. The
+// returned entry is always refreshed in the cache (it is reachable via the
+// map, and its place is bumped in LRU).
+//
+// If a matching entry exists in the cache, it is returned. If no matching
+// entry exists and the cache is full, an existing entry is evicted via LRU,
+// reset to state incomplete, and returned. If no matching entry exists and the
+// cache is not full, a new entry with state incomplete is allocated and
+// returned.
+func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if entry, ok := n.cache[remoteAddr]; ok {
+ entry.mu.RLock()
+ if entry.neigh.State != Static {
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.lru.PushFront(entry)
+ }
+ entry.mu.RUnlock()
+ return entry
+ }
+
+ // The entry that needs to be created must be dynamic since all static
+ // entries are directly added to the cache via addStaticEntry.
+ entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes)
+ if n.dynamic.count == neighborCacheSize {
+ e := n.dynamic.lru.Back()
+ e.mu.Lock()
+
+ delete(n.cache, e.neigh.Addr)
+ n.dynamic.lru.Remove(e)
+ n.dynamic.count--
+
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Unknown)
+ e.notifyWakersLocked()
+ e.mu.Unlock()
+ }
+ n.cache[remoteAddr] = entry
+ n.dynamic.lru.PushFront(entry)
+ n.dynamic.count++
+ return entry
+}
+
+// entry looks up the neighbor cache for translating address to link address
+// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there
+// is a LinkAddressResolver registered with the network protocol, the cache
+// attempts to resolve the address and returns ErrWouldBlock. If a Waker is
+// provided, it will be notified when address resolution is complete (success
+// or not).
+//
+// If address resolution is required, ErrNoLinkAddress and a notification
+// channel is returned for the top level caller to block. Channel is closed
+// once address resolution is complete (success or not).
+func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+ if linkRes != nil {
+ if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
+ e := NeighborEntry{
+ Addr: remoteAddr,
+ LocalAddr: localAddr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAt: time.Now(),
+ }
+ return e, nil, nil
+ }
+ }
+
+ entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
+ switch s := entry.neigh.State; s {
+ case Reachable, Static:
+ return entry.neigh, nil, nil
+
+ case Unknown, Incomplete, Stale, Delay, Probe:
+ entry.addWakerLocked(w)
+
+ if entry.done == nil {
+ // Address resolution needs to be initiated.
+ if linkRes == nil {
+ return entry.neigh, nil, tcpip.ErrNoLinkAddress
+ }
+ entry.done = make(chan struct{})
+ }
+
+ entry.handlePacketQueuedLocked()
+ return entry.neigh, entry.done, tcpip.ErrWouldBlock
+
+ case Failed:
+ return entry.neigh, nil, tcpip.ErrNoLinkAddress
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", s))
+ }
+}
+
+// removeWaker removes a waker that has been added when link resolution for
+// addr was requested.
+func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) {
+ n.mu.Lock()
+ if entry, ok := n.cache[addr]; ok {
+ delete(entry.wakers, waker)
+ }
+ n.mu.Unlock()
+}
+
+// entries returns all entries in the neighbor cache.
+func (n *neighborCache) entries() []NeighborEntry {
+ entries := make([]NeighborEntry, 0, len(n.cache))
+ n.mu.RLock()
+ for _, entry := range n.cache {
+ entry.mu.RLock()
+ entries = append(entries, entry.neigh)
+ entry.mu.RUnlock()
+ }
+ n.mu.RUnlock()
+ return entries
+}
+
+// addStaticEntry adds a static entry to the neighbor cache, mapping an IP
+// address to a link address. If a dynamic entry exists in the neighbor cache
+// with the same address, it will be replaced with this static entry. If a
+// static entry exists with the same address but different link address, it
+// will be updated with the new link address. If a static entry exists with the
+// same address and link address, nothing will happen.
+func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if entry, ok := n.cache[addr]; ok {
+ entry.mu.Lock()
+ if entry.neigh.State != Static {
+ // Dynamic entry found with the same address.
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.count--
+ } else if entry.neigh.LinkAddr == linkAddr {
+ // Static entry found with the same address and link address.
+ entry.mu.Unlock()
+ return
+ } else {
+ // Static entry found with the same address but different link address.
+ entry.neigh.LinkAddr = linkAddr
+ entry.dispatchChangeEventLocked(entry.neigh.State)
+ entry.mu.Unlock()
+ return
+ }
+
+ // Notify that resolution has been interrupted, just in case the entry was
+ // in the Incomplete or Probe state.
+ entry.dispatchRemoveEventLocked()
+ entry.setStateLocked(Unknown)
+ entry.notifyWakersLocked()
+ entry.mu.Unlock()
+ }
+
+ entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
+ n.cache[addr] = entry
+}
+
+// removeEntryLocked removes the specified entry from the neighbor cache.
+func (n *neighborCache) removeEntryLocked(entry *neighborEntry) {
+ if entry.neigh.State != Static {
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.count--
+ }
+ if entry.neigh.State != Failed {
+ entry.dispatchRemoveEventLocked()
+ }
+ entry.setStateLocked(Unknown)
+ entry.notifyWakersLocked()
+
+ delete(n.cache, entry.neigh.Addr)
+}
+
+// removeEntry removes a dynamic or static entry by address from the neighbor
+// cache. Returns true if the entry was found and deleted.
+func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ entry, ok := n.cache[addr]
+ if !ok {
+ return false
+ }
+
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
+ n.removeEntryLocked(entry)
+ return true
+}
+
+// clear removes all dynamic and static entries from the neighbor cache.
+func (n *neighborCache) clear() {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ for _, entry := range n.cache {
+ entry.mu.Lock()
+ entry.dispatchRemoveEventLocked()
+ entry.setStateLocked(Unknown)
+ entry.notifyWakersLocked()
+ entry.mu.Unlock()
+ }
+
+ n.dynamic.lru = neighborEntryList{}
+ n.cache = make(map[tcpip.Address]*neighborEntry)
+ n.dynamic.count = 0
+}
+
+// config returns the NUD configuration.
+func (n *neighborCache) config() NUDConfigurations {
+ return n.state.Config()
+}
+
+// setConfig changes the NUD configuration.
+//
+// If config contains invalid NUD configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (n *neighborCache) setConfig(config NUDConfigurations) {
+ config.resetInvalidFields()
+ n.state.SetConfig(config)
+}
+
+// HandleProbe implements NUDHandler.HandleProbe by following the logic defined
+// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled
+// by the caller.
+func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) {
+ entry := n.getOrCreateEntry(remoteAddr, localAddr, nil)
+ entry.mu.Lock()
+ entry.handleProbeLocked(remoteLinkAddr)
+ entry.mu.Unlock()
+}
+
+// HandleConfirmation implements NUDHandler.HandleConfirmation by following the
+// logic defined in RFC 4861 section 7.2.5.
+//
+// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other
+// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol
+// should be deployed where preventing access to the broadcast segment might
+// not be possible. SEND uses RSA key pairs to produce cryptographically
+// generated addresses, as defined in RFC 3972, Cryptographically Generated
+// Addresses (CGA). This ensures that the claimed source of an NDP message is
+// the owner of the claimed address.
+func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
+ n.mu.RLock()
+ entry, ok := n.cache[addr]
+ n.mu.RUnlock()
+ if ok {
+ entry.mu.Lock()
+ entry.handleConfirmationLocked(linkAddr, flags)
+ entry.mu.Unlock()
+ }
+ // The confirmation SHOULD be silently discarded if the recipient did not
+ // initiate any communication with the target. This is indicated if there is
+ // no matching entry for the remote address.
+}
+
+// HandleUpperLevelConfirmation implements
+// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in
+// RFC 4861 section 7.3.1.
+func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) {
+ n.mu.RLock()
+ entry, ok := n.cache[addr]
+ n.mu.RUnlock()
+ if ok {
+ entry.mu.Lock()
+ entry.handleUpperLevelConfirmationLocked()
+ entry.mu.Unlock()
+ }
+}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
new file mode 100644
index 000000000..4cb2c9c6b
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -0,0 +1,1752 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math"
+ "math/rand"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // entryStoreSize is the default number of entries that will be generated and
+ // added to the entry store. This number needs to be larger than the size of
+ // the neighbor cache to give ample opportunity for verifying behavior during
+ // cache overflows. Four times the size of the neighbor cache allows for
+ // three complete cache overflows.
+ entryStoreSize = 4 * neighborCacheSize
+
+ // typicalLatency is the typical latency for an ARP or NDP packet to travel
+ // to a router and back.
+ typicalLatency = time.Millisecond
+
+ // testEntryBroadcastAddr is a special address that indicates a packet should
+ // be sent to all nodes.
+ testEntryBroadcastAddr = tcpip.Address("broadcast")
+
+ // testEntryLocalAddr is the source address of neighbor probes.
+ testEntryLocalAddr = tcpip.Address("local_addr")
+
+ // testEntryBroadcastLinkAddr is a special link address sent back to
+ // multicast neighbor probes.
+ testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast")
+
+ // infiniteDuration indicates that a task will not occur in our lifetime.
+ infiniteDuration = time.Duration(math.MaxInt64)
+)
+
+// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor
+// entries. The UpdatedAt field is ignored due to a lack of a deterministic
+// method to predict the time that an event will be dispatched.
+func entryDiffOpts() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
+ }
+}
+
+// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to
+// sort slices of entries for cases where ordering must be ignored.
+func entryDiffOptsWithSort() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
+ cmpopts.SortSlices(func(a, b NeighborEntry) bool {
+ return strings.Compare(string(a.Addr), string(b.Addr)) < 0
+ }),
+ }
+}
+
+func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache {
+ config.resetInvalidFields()
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
+ return &neighborCache{
+ nic: &NIC{
+ stack: &Stack{
+ clock: clock,
+ nudDisp: nudDisp,
+ },
+ id: 1,
+ },
+ state: NewNUDState(config, rng),
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ }
+}
+
+// testEntryStore contains a set of IP to NeighborEntry mappings.
+type testEntryStore struct {
+ mu sync.RWMutex
+ entriesMap map[tcpip.Address]NeighborEntry
+}
+
+func toAddress(i int) tcpip.Address {
+ buf := new(bytes.Buffer)
+ binary.Write(buf, binary.BigEndian, uint8(1))
+ binary.Write(buf, binary.BigEndian, uint8(0))
+ binary.Write(buf, binary.BigEndian, uint16(i))
+ return tcpip.Address(buf.String())
+}
+
+func toLinkAddress(i int) tcpip.LinkAddress {
+ buf := new(bytes.Buffer)
+ binary.Write(buf, binary.BigEndian, uint8(1))
+ binary.Write(buf, binary.BigEndian, uint8(0))
+ binary.Write(buf, binary.BigEndian, uint32(i))
+ return tcpip.LinkAddress(buf.String())
+}
+
+// newTestEntryStore returns a testEntryStore pre-populated with entries.
+func newTestEntryStore() *testEntryStore {
+ store := &testEntryStore{
+ entriesMap: make(map[tcpip.Address]NeighborEntry),
+ }
+ for i := 0; i < entryStoreSize; i++ {
+ addr := toAddress(i)
+ linkAddr := toLinkAddress(i)
+
+ store.entriesMap[addr] = NeighborEntry{
+ Addr: addr,
+ LocalAddr: testEntryLocalAddr,
+ LinkAddr: linkAddr,
+ }
+ }
+ return store
+}
+
+// size returns the number of entries in the store.
+func (s *testEntryStore) size() int {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return len(s.entriesMap)
+}
+
+// entry returns the entry at index i. Returns an empty entry and false if i is
+// out of bounds.
+func (s *testEntryStore) entry(i int) (NeighborEntry, bool) {
+ return s.entryByAddr(toAddress(i))
+}
+
+// entryByAddr returns the entry matching addr for situations when the index is
+// not available. Returns an empty entry and false if no entries match addr.
+func (s *testEntryStore) entryByAddr(addr tcpip.Address) (NeighborEntry, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ entry, ok := s.entriesMap[addr]
+ return entry, ok
+}
+
+// entries returns all entries in the store.
+func (s *testEntryStore) entries() []NeighborEntry {
+ entries := make([]NeighborEntry, 0, len(s.entriesMap))
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for i := 0; i < entryStoreSize; i++ {
+ addr := toAddress(i)
+ if entry, ok := s.entriesMap[addr]; ok {
+ entries = append(entries, entry)
+ }
+ }
+ return entries
+}
+
+// set modifies the link addresses of an entry.
+func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) {
+ addr := toAddress(i)
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if entry, ok := s.entriesMap[addr]; ok {
+ entry.LinkAddr = linkAddr
+ s.entriesMap[addr] = entry
+ }
+}
+
+// testNeighborResolver implements LinkAddressResolver to emulate sending a
+// neighbor probe.
+type testNeighborResolver struct {
+ clock tcpip.Clock
+ neigh *neighborCache
+ entries *testEntryStore
+ delay time.Duration
+ onLinkAddressRequest func()
+}
+
+var _ LinkAddressResolver = (*testNeighborResolver)(nil)
+
+func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+ // Delay handling the request to emulate network latency.
+ r.clock.AfterFunc(r.delay, func() {
+ r.fakeRequest(addr)
+ })
+
+ // Execute post address resolution action, if available.
+ if f := r.onLinkAddressRequest; f != nil {
+ f()
+ }
+ return nil
+}
+
+// fakeRequest emulates handling a response for a link address request.
+func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) {
+ if entry, ok := r.entries.entryByAddr(addr); ok {
+ r.neigh.HandleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ }
+}
+
+func (*testNeighborResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == testEntryBroadcastAddr {
+ return testEntryBroadcastLinkAddr, true
+ }
+ return "", false
+}
+
+func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return 0
+}
+
+type entryEvent struct {
+ nicID tcpip.NICID
+ address tcpip.Address
+ linkAddr tcpip.LinkAddress
+ state NeighborState
+}
+
+func TestNeighborCacheGetConfig(t *testing.T) {
+ nudDisp := testNUDDispatcher{}
+ c := DefaultNUDConfigurations()
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, c, clock)
+
+ if got, want := neigh.config(), c; got != want {
+ t.Errorf("got neigh.config() = %+v, want = %+v", got, want)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheSetConfig(t *testing.T) {
+ nudDisp := testNUDDispatcher{}
+ c := DefaultNUDConfigurations()
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, c, clock)
+
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+ neigh.setConfig(c)
+
+ if got, want := neigh.config(), c; got != want {
+ t.Errorf("got neigh.config() = %+v, want = %+v", got, want)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheEntry(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, c, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+
+ clock.advance(typicalLatency)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+
+ // No more events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+// TestNeighborCacheEntryNoLinkAddress verifies calling entry() without a
+// LinkAddressResolver returns ErrNoLinkAddress.
+func TestNeighborCacheEntryNoLinkAddress(t *testing.T) {
+ nudDisp := testNUDDispatcher{}
+ c := DefaultNUDConfigurations()
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, c, clock)
+ store := newTestEntryStore()
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, nil, nil)
+ if err != tcpip.ErrNoLinkAddress {
+ t.Errorf("got neigh.entry(%s, %s, nil, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheRemoveEntry(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+
+ clock.advance(typicalLatency)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ neigh.removeEntry(entry.Addr)
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+}
+
+type testContext struct {
+ clock *fakeClock
+ neigh *neighborCache
+ store *testEntryStore
+ linkRes *testNeighborResolver
+ nudDisp *testNUDDispatcher
+}
+
+func newTestContext(c NUDConfigurations) testContext {
+ nudDisp := &testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(nudDisp, c, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ return testContext{
+ clock: clock,
+ neigh: neigh,
+ store: store,
+ linkRes: linkRes,
+ nudDisp: nudDisp,
+ }
+}
+
+type overflowOptions struct {
+ startAtEntryIndex int
+ wantStaticEntries []NeighborEntry
+}
+
+func (c *testContext) overflowCache(opts overflowOptions) error {
+ // Fill the neighbor cache to capacity to verify the LRU eviction strategy is
+ // working properly after the entry removal.
+ for i := opts.startAtEntryIndex; i < c.store.size(); i++ {
+ // Add a new entry
+ entry, ok := c.store.entry(i)
+ if !ok {
+ return fmt.Errorf("c.store.entry(%d) not found", i)
+ }
+ if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.advance(c.neigh.config().RetransmitTimer)
+
+ var wantEvents []testEntryEventInfo
+
+ // When beyond the full capacity, the cache will evict an entry as per the
+ // LRU eviction strategy. Note that the number of static entries should not
+ // affect the total number of dynamic entries that can be added.
+ if i >= neighborCacheSize+opts.startAtEntryIndex {
+ removedEntry, ok := c.store.entry(i - neighborCacheSize)
+ if !ok {
+ return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize)
+ }
+ wantEvents = append(wantEvents, testEntryEventInfo{
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ })
+ }
+
+ wantEvents = append(wantEvents, testEntryEventInfo{
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ }, testEntryEventInfo{
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ })
+
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Expect to find only the most recent entries. The order of entries reported
+ // by entries() is undeterministic, so entries have to be sorted before
+ // comparison.
+ wantUnsortedEntries := opts.wantStaticEntries
+ for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ {
+ entry, ok := c.store.entry(i)
+ if !ok {
+ return fmt.Errorf("c.store.entry(%d) not found", i)
+ }
+ wantEntry := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ }
+
+ if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
+ return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No more events should have been dispatched.
+ c.nudDisp.mu.Lock()
+ defer c.nudDisp.mu.Unlock()
+ if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ return nil
+}
+
+// TestNeighborCacheOverflow verifies that the LRU cache eviction strategy
+// respects the dynamic entry count.
+func TestNeighborCacheOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+// TestNeighborCacheRemoveEntryThenOverflow verifies that the LRU cache
+// eviction strategy respects the dynamic entry count when an entry is removed.
+func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a dynamic entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.advance(c.neigh.config().RetransmitTimer)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Remove the entry
+ c.neigh.removeEntry(entry.Addr)
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+// TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress verifies that
+// adding a duplicate static entry with the same link address does not dispatch
+// any events.
+func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ c := newTestContext(config)
+
+ // Add a static entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Remove the static entry that was just added
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+
+ // No more events should have been dispatched.
+ c.nudDisp.mu.Lock()
+ defer c.nudDisp.mu.Unlock()
+ if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+// TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress verifies that
+// adding a duplicate static entry with a different link address dispatches a
+// change event.
+func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ c := newTestContext(config)
+
+ // Add a static entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Add a duplicate entry with a different link address
+ staticLinkAddr += "duplicate"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ defer c.nudDisp.mu.Unlock()
+ if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+}
+
+// TestNeighborCacheRemoveStaticEntryThenOverflow verifies that the LRU cache
+// eviction strategy respects the dynamic entry count when a static entry is
+// added then removed. In this case, the dynamic entry count shouldn't have
+// been touched.
+func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a static entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Remove the static entry that was just added
+ c.neigh.removeEntry(entry.Addr)
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+// TestNeighborCacheOverwriteWithStaticEntryThenOverflow verifies that the LRU
+// cache eviction strategy keeps count of the dynamic entry count when an entry
+// is overwritten by a static entry. Static entries should not count towards
+// the size of the LRU cache.
+func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a dynamic entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.advance(typicalLatency)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Override the entry with a static one using the same address
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 1,
+ wantStaticEntries: []NeighborEntry{
+ {
+ Addr: entry.Addr,
+ LocalAddr: "", // static entries don't need a local address
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ },
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+func TestNeighborCacheNotifiesWaker(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ const wakerID = 1
+ s.AddWaker(&w, wakerID)
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ if doneCh == nil {
+ t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ }
+ clock.advance(typicalLatency)
+
+ select {
+ case <-doneCh:
+ default:
+ t.Fatal("expected notification from done channel")
+ }
+
+ id, ok := s.Fetch(false /* block */)
+ if !ok {
+ t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ }
+ if id != wakerID {
+ t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheRemoveWaker(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ const wakerID = 1
+ s.AddWaker(&w, wakerID)
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ if doneCh == nil {
+ t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ }
+
+ // Remove the waker before the neighbor cache has the opportunity to send a
+ // notification.
+ neigh.removeWaker(entry.Addr, &w)
+ clock.advance(typicalLatency)
+
+ select {
+ case <-doneCh:
+ default:
+ t.Fatal("expected notification from done channel")
+ }
+
+ if id, ok := s.Fetch(false /* block */); ok {
+ t.Errorf("unexpected notification from waker with id %d", id)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr)
+ e, _, err := c.neigh.entry(entry.Addr, "", nil, nil)
+ if err != nil {
+ t.Errorf("unexpected error from c.neigh.entry(%s, \"\", nil nil): %s", entry.Addr, err)
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: "", // static entries don't need a local address
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("c.neigh.entry(%s, \"\", nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 1,
+ wantStaticEntries: []NeighborEntry{
+ {
+ Addr: entry.Addr,
+ LocalAddr: "", // static entries don't need a local address
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ },
+ },
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+func TestNeighborCacheClear(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ // Add a dynamic entry.
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Add a static entry.
+ neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1)
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Clear shoud remove both dynamic and static entries.
+ neigh.clear()
+
+ // Remove events dispatched from clear() have no deterministic order so they
+ // need to be sorted beforehand.
+ wantUnsortedEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
+ }
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+// TestNeighborCacheClearThenOverflow verifies that the LRU cache eviction
+// strategy keeps count of the dynamic entry count when all entries are
+// cleared.
+func TestNeighborCacheClearThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a dynamic entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.advance(typicalLatency)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Clear the cache.
+ c.neigh.clear()
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ frequentlyUsedEntry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+
+ // The following logic is very similar to overflowCache, but
+ // periodically refreshes the frequently used entry.
+
+ // Fill the neighbor cache to capacity
+ for i := 0; i < neighborCacheSize; i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i)
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Keep adding more entries
+ for i := neighborCacheSize; i < store.size(); i++ {
+ // Periodically refresh the frequently used entry
+ if i%(neighborCacheSize/2) == 0 {
+ _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err)
+ }
+ }
+
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i)
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+
+ // An entry should have been removed, as per the LRU eviction strategy
+ removedEntry, ok := store.entry(i - neighborCacheSize + 1)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1)
+ }
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Expect to find only the frequently used entry and the most recent entries.
+ // The order of entries reported by entries() is undeterministic, so entries
+ // have to be sorted before comparison.
+ wantUnsortedEntries := []NeighborEntry{
+ {
+ Addr: frequentlyUsedEntry.Addr,
+ LocalAddr: frequentlyUsedEntry.LocalAddr,
+ LinkAddr: frequentlyUsedEntry.LinkAddr,
+ State: Reachable,
+ },
+ }
+
+ for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i)
+ }
+ wantEntry := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ }
+
+ if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
+ t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No more events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheConcurrent(t *testing.T) {
+ const concurrentProcesses = 16
+
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ storeEntries := store.entries()
+ for _, entry := range storeEntries {
+ var wg sync.WaitGroup
+ for r := 0; r < concurrentProcesses; r++ {
+ wg.Add(1)
+ go func(entry NeighborEntry) {
+ defer wg.Done()
+ e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil && err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock)
+ }
+ }(entry)
+ }
+
+ // Wait for all gorountines to send a request
+ wg.Wait()
+
+ // Process all the requests for a single entry concurrently
+ clock.advance(typicalLatency)
+ }
+
+ // All goroutines add in the same order and add more values than can fit in
+ // the cache. Our eviction strategy requires that the last entries are
+ // present, up to the size of the neighbor cache, and the rest are missing.
+ // The order of entries reported by entries() is undeterministic, so entries
+ // have to be sorted before comparison.
+ var wantUnsortedEntries []NeighborEntry
+ for i := store.size() - neighborCacheSize; i < store.size(); i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Errorf("store.entry(%d) not found", i)
+ }
+ wantEntry := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ }
+
+ if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
+ t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheReplace(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ // Add an entry
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+
+ // Verify the entry exists
+ e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ if doneCh != nil {
+ t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ }
+
+ // Notify of a link address change
+ var updatedLinkAddr tcpip.LinkAddress
+ {
+ entry, ok := store.entry(1)
+ if !ok {
+ t.Fatalf("store.entry(1) not found")
+ }
+ updatedLinkAddr = entry.LinkAddr
+ }
+ store.set(0, updatedLinkAddr)
+ neigh.HandleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+
+ // Requesting the entry again should start address resolution
+ {
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(config.DelayFirstProbeTime + typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+ }
+
+ // Verify the entry's new link address
+ {
+ e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ clock.advance(typicalLatency)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ want = NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: updatedLinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ }
+ }
+}
+
+func TestNeighborCacheResolutionFailed(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+
+ var requestCount uint32
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ onLinkAddressRequest: func() {
+ atomic.AddUint32(&requestCount, 1)
+ },
+ }
+
+ // First, sanity check that resolution is working
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+ got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ }
+
+ // Verify that address resolution for an unknown address returns ErrNoLinkAddress
+ before := atomic.LoadUint32(&requestCount)
+
+ entry.Addr += "2"
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
+ clock.advance(waitFor)
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ }
+
+ maxAttempts := neigh.config().MaxUnicastProbes
+ if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want {
+ t.Errorf("got link address request count = %d, want = %d", got, want)
+ }
+}
+
+// TestNeighborCacheResolutionTimeout simulates sending MaxMulticastProbes
+// probes and not retrieving a confirmation before the duration defined by
+// MaxMulticastProbes * RetransmitTimer.
+func TestNeighborCacheResolutionTimeout(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ config.RetransmitTimer = time.Millisecond // small enough to cause timeout
+
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: time.Minute, // large enough to cause timeout
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
+ clock.advance(waitFor)
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ }
+}
+
+// TestNeighborCacheStaticResolution checks that static link addresses are
+// resolved immediately and don't send resolution requests.
+func TestNeighborCacheStaticResolution(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil)
+ if err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err)
+ }
+ want := NeighborEntry{
+ Addr: testEntryBroadcastAddr,
+ LocalAddr: testEntryLocalAddr,
+ LinkAddr: testEntryBroadcastLinkAddr,
+ State: Static,
+ }
+ if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff)
+ }
+}
+
+func BenchmarkCacheClear(b *testing.B) {
+ b.StopTimer()
+ config := DefaultNUDConfigurations()
+ clock := &tcpip.StdClock{}
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: 0,
+ }
+
+ // Clear for every possible size of the cache
+ for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ {
+ // Fill the neighbor cache to capacity.
+ for i := 0; i < cacheSize; i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ b.Fatalf("store.entry(%d) not found", i)
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ if doneCh != nil {
+ <-doneCh
+ }
+ }
+
+ b.StartTimer()
+ neigh.clear()
+ b.StopTimer()
+ }
+}
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
new file mode 100644
index 000000000..0068cacb8
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -0,0 +1,482 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// NeighborEntry describes a neighboring device in the local network.
+type NeighborEntry struct {
+ Addr tcpip.Address
+ LocalAddr tcpip.Address
+ LinkAddr tcpip.LinkAddress
+ State NeighborState
+ UpdatedAt time.Time
+}
+
+// NeighborState defines the state of a NeighborEntry within the Neighbor
+// Unreachability Detection state machine, as per RFC 4861 section 7.3.2.
+type NeighborState uint8
+
+const (
+ // Unknown means reachability has not been verified yet. This is the initial
+ // state of entries that have been created automatically by the Neighbor
+ // Unreachability Detection state machine.
+ Unknown NeighborState = iota
+ // Incomplete means that there is an outstanding request to resolve the
+ // address.
+ Incomplete
+ // Reachable means the path to the neighbor is functioning properly for both
+ // receive and transmit paths.
+ Reachable
+ // Stale means reachability to the neighbor is unknown, but packets are still
+ // able to be transmitted to the possibly stale link address.
+ Stale
+ // Delay means reachability to the neighbor is unknown and pending
+ // confirmation from an upper-level protocol like TCP, but packets are still
+ // able to be transmitted to the possibly stale link address.
+ Delay
+ // Probe means a reachability confirmation is actively being sought by
+ // periodically retransmitting reachability probes until a reachability
+ // confirmation is received, or until the max amount of probes has been sent.
+ Probe
+ // Static describes entries that have been explicitly added by the user. They
+ // do not expire and are not deleted until explicitly removed.
+ Static
+ // Failed means traffic should not be sent to this neighbor since attempts of
+ // reachability have returned inconclusive.
+ Failed
+)
+
+// neighborEntry implements a neighbor entry's individual node behavior, as per
+// RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in
+// parallel with the sending of packets to a neighbor, necessitating the
+// entry's lock to be acquired for all operations.
+type neighborEntry struct {
+ neighborEntryEntry
+
+ nic *NIC
+ protocol tcpip.NetworkProtocolNumber
+
+ // linkRes provides the functionality to send reachability probes, used in
+ // Neighbor Unreachability Detection.
+ linkRes LinkAddressResolver
+
+ // nudState points to the Neighbor Unreachability Detection configuration.
+ nudState *NUDState
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ neigh NeighborEntry
+
+ // wakers is a set of waiters for address resolution result. Anytime state
+ // transitions out of incomplete these waiters are notified. It is nil iff
+ // address resolution is ongoing and no clients are waiting for the result.
+ wakers map[*sleep.Waker]struct{}
+
+ // done is used to allow callers to wait on address resolution. It is nil
+ // iff nudState is not Reachable and address resolution is not yet in
+ // progress.
+ done chan struct{}
+
+ isRouter bool
+ job *tcpip.Job
+}
+
+// newNeighborEntry creates a neighbor cache entry starting at the default
+// state, Unknown. Transition out of Unknown by calling either
+// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created
+// neighborEntry.
+func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry {
+ return &neighborEntry{
+ nic: nic,
+ linkRes: linkRes,
+ nudState: nudState,
+ neigh: NeighborEntry{
+ Addr: remoteAddr,
+ LocalAddr: localAddr,
+ State: Unknown,
+ },
+ }
+}
+
+// newStaticNeighborEntry creates a neighbor cache entry starting at the Static
+// state. The entry can only transition out of Static by directly calling
+// `setStateLocked`.
+func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry {
+ if nic.stack.nudDisp != nil {
+ nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now())
+ }
+ return &neighborEntry{
+ nic: nic,
+ nudState: state,
+ neigh: NeighborEntry{
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAt: time.Now(),
+ },
+ }
+}
+
+// addWaker adds w to the list of wakers waiting for address resolution.
+// Assumes the entry has already been appropriately locked.
+func (e *neighborEntry) addWakerLocked(w *sleep.Waker) {
+ if w == nil {
+ return
+ }
+ if e.wakers == nil {
+ e.wakers = make(map[*sleep.Waker]struct{})
+ }
+ e.wakers[w] = struct{}{}
+}
+
+// notifyWakersLocked notifies those waiting for address resolution, whether it
+// succeeded or failed. Assumes the entry has already been appropriately locked.
+func (e *neighborEntry) notifyWakersLocked() {
+ for w := range e.wakers {
+ w.Assert()
+ }
+ e.wakers = nil
+ if ch := e.done; ch != nil {
+ close(ch)
+ e.done = nil
+ }
+}
+
+// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
+// been added.
+func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) {
+ if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ }
+}
+
+// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
+// has changed state or link-layer address.
+func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) {
+ if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ }
+}
+
+// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry
+// has been removed.
+func (e *neighborEntry) dispatchRemoveEventLocked() {
+ if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now())
+ }
+}
+
+// setStateLocked transitions the entry to the specified state immediately.
+//
+// Follows the logic defined in RFC 4861 section 7.3.3.
+//
+// e.mu MUST be locked.
+func (e *neighborEntry) setStateLocked(next NeighborState) {
+ // Cancel the previously scheduled action, if there is one. Entries in
+ // Unknown, Stale, or Static state do not have scheduled actions.
+ if timer := e.job; timer != nil {
+ timer.Cancel()
+ }
+
+ prev := e.neigh.State
+ e.neigh.State = next
+ e.neigh.UpdatedAt = time.Now()
+ config := e.nudState.Config()
+
+ switch next {
+ case Incomplete:
+ var retryCounter uint32
+ var sendMulticastProbe func()
+
+ sendMulticastProbe = func() {
+ if retryCounter == config.MaxMulticastProbes {
+ // "If no Neighbor Advertisement is received after
+ // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed.
+ // The sender MUST return ICMP destination unreachable indications with
+ // code 3 (Address Unreachable) for each packet queued awaiting address
+ // resolution." - RFC 4861 section 7.2.2
+ //
+ // There is no need to send an ICMP destination unreachable indication
+ // since the failure to resolve the address is expected to only occur
+ // on this node. Thus, redirecting traffic is currently not supported.
+ //
+ // "If the error occurs on a node other than the node originating the
+ // packet, an ICMP error message is generated. If the error occurs on
+ // the originating node, an implementation is not required to actually
+ // create and send an ICMP error packet to the source, as long as the
+ // upper-layer sender is notified through an appropriate mechanism
+ // (e.g. return value from a procedure call). Note, however, that an
+ // implementation may find it convenient in some cases to return errors
+ // to the sender by taking the offending packet, generating an ICMP
+ // error message, and then delivering it (locally) through the generic
+ // error-handling routines.' - RFC 4861 section 2.1
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil {
+ // There is no need to log the error here; the NUD implementation may
+ // assume a working link. A valid link should be the responsibility of
+ // the NIC/stack.LinkEndpoint.
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ retryCounter++
+ e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job.Schedule(config.RetransmitTimer)
+ }
+
+ sendMulticastProbe()
+
+ case Reachable:
+ e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ })
+ e.job.Schedule(e.nudState.ReachableTime())
+
+ case Delay:
+ e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.dispatchChangeEventLocked(Probe)
+ e.setStateLocked(Probe)
+ })
+ e.job.Schedule(config.DelayFirstProbeTime)
+
+ case Probe:
+ var retryCounter uint32
+ var sendUnicastProbe func()
+
+ sendUnicastProbe = func() {
+ if retryCounter == config.MaxUnicastProbes {
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil {
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ retryCounter++
+ if retryCounter == config.MaxUnicastProbes {
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
+ e.job.Schedule(config.RetransmitTimer)
+ }
+
+ sendUnicastProbe()
+
+ case Failed:
+ e.notifyWakersLocked()
+ e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.nic.neigh.removeEntryLocked(e)
+ })
+ e.job.Schedule(config.UnreachableTime)
+
+ case Unknown, Stale, Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid state transition from %q to %q", prev, next))
+ }
+}
+
+// handlePacketQueuedLocked advances the state machine according to a packet
+// being queued for outgoing transmission.
+//
+// Follows the logic defined in RFC 4861 section 7.3.3.
+func (e *neighborEntry) handlePacketQueuedLocked() {
+ switch e.neigh.State {
+ case Unknown:
+ e.dispatchAddEventLocked(Incomplete)
+ e.setStateLocked(Incomplete)
+
+ case Stale:
+ e.dispatchChangeEventLocked(Delay)
+ e.setStateLocked(Delay)
+
+ case Incomplete, Reachable, Delay, Probe, Static, Failed:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
+
+// handleProbeLocked processes an incoming neighbor probe (e.g. ARP request or
+// Neighbor Solicitation for ARP or NDP, respectively).
+//
+// Follows the logic defined in RFC 4861 section 7.2.3.
+func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
+ // Probes MUST be silently discarded if the target address is tentative, does
+ // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These
+ // checks MUST be done by the NetworkEndpoint.
+
+ switch e.neigh.State {
+ case Unknown, Incomplete, Failed:
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.dispatchAddEventLocked(Stale)
+ e.setStateLocked(Stale)
+ e.notifyWakersLocked()
+
+ case Reachable, Delay, Probe:
+ if e.neigh.LinkAddr != remoteLinkAddr {
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ }
+
+ case Stale:
+ if e.neigh.LinkAddr != remoteLinkAddr {
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.dispatchChangeEventLocked(Stale)
+ }
+
+ case Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
+
+// handleConfirmationLocked processes an incoming neighbor confirmation
+// (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, respectively).
+//
+// Follows the state machine defined by RFC 4861 section 7.2.5.
+//
+// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other
+// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol
+// should be deployed where preventing access to the broadcast segment might
+// not be possible. SEND uses RSA key pairs to produce Cryptographically
+// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the
+// claimed source of an NDP message is the owner of the claimed address.
+func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
+ switch e.neigh.State {
+ case Incomplete:
+ if len(linkAddr) == 0 {
+ // "If the link layer has addresses and no Target Link-Layer Address
+ // option is included, the receiving node SHOULD silently discard the
+ // received advertisement." - RFC 4861 section 7.2.5
+ break
+ }
+
+ e.neigh.LinkAddr = linkAddr
+ if flags.Solicited {
+ e.dispatchChangeEventLocked(Reachable)
+ e.setStateLocked(Reachable)
+ } else {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ }
+ e.isRouter = flags.IsRouter
+ e.notifyWakersLocked()
+
+ // "Note that the Override flag is ignored if the entry is in the
+ // INCOMPLETE state." - RFC 4861 section 7.2.5
+
+ case Reachable, Stale, Delay, Probe:
+ sameLinkAddr := e.neigh.LinkAddr == linkAddr
+
+ if !sameLinkAddr {
+ if !flags.Override {
+ if e.neigh.State == Reachable {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ }
+ break
+ }
+
+ e.neigh.LinkAddr = linkAddr
+
+ if !flags.Solicited {
+ if e.neigh.State != Stale {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ } else {
+ // Notify the LinkAddr change, even though NUD state hasn't changed.
+ e.dispatchChangeEventLocked(e.neigh.State)
+ }
+ break
+ }
+ }
+
+ if flags.Solicited && (flags.Override || sameLinkAddr) {
+ if e.neigh.State != Reachable {
+ e.dispatchChangeEventLocked(Reachable)
+ }
+ // Set state to Reachable again to refresh timers.
+ e.setStateLocked(Reachable)
+ e.notifyWakersLocked()
+ }
+
+ if e.isRouter && !flags.IsRouter {
+ // "In those cases where the IsRouter flag changes from TRUE to FALSE as
+ // a result of this update, the node MUST remove that router from the
+ // Default Router List and update the Destination Cache entries for all
+ // destinations using that neighbor as a router as specified in Section
+ // 7.3.3. This is needed to detect when a node that is used as a router
+ // stops forwarding packets due to being configured as a host."
+ // - RFC 4861 section 7.2.5
+ e.nic.mu.Lock()
+ e.nic.mu.ndp.invalidateDefaultRouter(e.neigh.Addr)
+ e.nic.mu.Unlock()
+ }
+ e.isRouter = flags.IsRouter
+
+ case Unknown, Failed, Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
+
+// handleUpperLevelConfirmationLocked processes an incoming upper-level protocol
+// (e.g. TCP acknowledgements) reachability confirmation.
+func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
+ switch e.neigh.State {
+ case Reachable, Stale, Delay, Probe:
+ if e.neigh.State != Reachable {
+ e.dispatchChangeEventLocked(Reachable)
+ // Set state to Reachable again to refresh timers.
+ }
+ e.setStateLocked(Reachable)
+
+ case Unknown, Incomplete, Failed, Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
new file mode 100644
index 000000000..08c9ccd25
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -0,0 +1,2770 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "math"
+ "math/rand"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+
+ entryTestNICID tcpip.NICID = 1
+ entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+
+ entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01")
+ entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02")
+
+ // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
+ // except where another value is explicitly used. It is chosen to match the
+ // MTU of loopback interfaces on Linux systems.
+ entryTestNetDefaultMTU = 65536
+)
+
+// eventDiffOpts are the options passed to cmp.Diff to compare entry events.
+// The UpdatedAt field is ignored due to a lack of a deterministic method to
+// predict the time that an event will be dispatched.
+func eventDiffOpts() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
+ }
+}
+
+// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to
+// sort slices of events for cases where ordering must be ignored.
+func eventDiffOptsWithSort() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
+ cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
+ return strings.Compare(string(a.Addr), string(b.Addr)) < 0
+ }),
+ }
+}
+
+// The following unit tests exercise every state transition and verify its
+// behavior with RFC 4681.
+//
+// | From | To | Cause | Action | Event |
+// | ========== | ========== | ========================================== | =============== | ======= |
+// | Unknown | Unknown | Confirmation w/ unknown address | | Added |
+// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added |
+// | Unknown | Stale | Probe w/ unknown address | | Added |
+// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed |
+// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed |
+// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed |
+// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed |
+// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | |
+// | Reachable | Stale | Reachable timer expired | | Changed |
+// | Reachable | Stale | Probe or confirmation w/ different address | | Changed |
+// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Stale | Stale | Override confirmation | Update LinkAddr | Changed |
+// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed |
+// | Stale | Delay | Packet sent | | Changed |
+// | Delay | Reachable | Upper-layer confirmation | | Changed |
+// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Delay | Stale | Probe or confirmation w/ different address | | Changed |
+// | Delay | Probe | Delay timer expired | Send probe | Changed |
+// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed |
+// | Probe | Stale | Probe or confirmation w/ different address | | Changed |
+// | Probe | Probe | Retransmit timer expired | Send probe | Changed |
+// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed |
+// | Failed | | Unreachability timer expired | Delete entry | |
+
+type testEntryEventType uint8
+
+const (
+ entryTestAdded testEntryEventType = iota
+ entryTestChanged
+ entryTestRemoved
+)
+
+func (t testEntryEventType) String() string {
+ switch t {
+ case entryTestAdded:
+ return "add"
+ case entryTestChanged:
+ return "change"
+ case entryTestRemoved:
+ return "remove"
+ default:
+ return fmt.Sprintf("unknown (%d)", t)
+ }
+}
+
+// Fields are exported for use with cmp.Diff.
+type testEntryEventInfo struct {
+ EventType testEntryEventType
+ NICID tcpip.NICID
+ Addr tcpip.Address
+ LinkAddr tcpip.LinkAddress
+ State NeighborState
+ UpdatedAt time.Time
+}
+
+func (e testEntryEventInfo) String() string {
+ return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State)
+}
+
+// testNUDDispatcher implements NUDDispatcher to validate the dispatching of
+// events upon certain NUD state machine events.
+type testNUDDispatcher struct {
+ mu sync.Mutex
+ events []testEntryEventInfo
+}
+
+var _ NUDDispatcher = (*testNUDDispatcher)(nil)
+
+func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.events = append(d.events, e)
+}
+
+func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+ d.queueEvent(testEntryEventInfo{
+ EventType: entryTestAdded,
+ NICID: nicID,
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: state,
+ UpdatedAt: updatedAt,
+ })
+}
+
+func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+ d.queueEvent(testEntryEventInfo{
+ EventType: entryTestChanged,
+ NICID: nicID,
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: state,
+ UpdatedAt: updatedAt,
+ })
+}
+
+func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+ d.queueEvent(testEntryEventInfo{
+ EventType: entryTestRemoved,
+ NICID: nicID,
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: state,
+ UpdatedAt: updatedAt,
+ })
+}
+
+type entryTestLinkResolver struct {
+ mu sync.Mutex
+ probes []entryTestProbeInfo
+}
+
+var _ LinkAddressResolver = (*entryTestLinkResolver)(nil)
+
+type entryTestProbeInfo struct {
+ RemoteAddress tcpip.Address
+ RemoteLinkAddress tcpip.LinkAddress
+ LocalAddress tcpip.Address
+}
+
+func (p entryTestProbeInfo) String() string {
+ return fmt.Sprintf("probe with RemoteAddress=%q, RemoteLinkAddress=%q, LocalAddress=%q", p.RemoteAddress, p.RemoteLinkAddress, p.LocalAddress)
+}
+
+// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
+// to the local network if linkAddr is the zero value.
+func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+ p := entryTestProbeInfo{
+ RemoteAddress: addr,
+ RemoteLinkAddress: linkAddr,
+ LocalAddress: localAddr,
+ }
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.probes = append(r.probes, p)
+ return nil
+}
+
+// ResolveStaticAddress attempts to resolve address without sending requests.
+// It either resolves the name immediately or returns the empty LinkAddress.
+func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ return "", false
+}
+
+// LinkAddressProtocol returns the network protocol of the addresses this
+// resolver can resolve.
+func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return entryTestNetNumber
+}
+
+func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *fakeClock) {
+ clock := newFakeClock()
+ disp := testNUDDispatcher{}
+ nic := NIC{
+ id: entryTestNICID,
+ linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
+ stack: &Stack{
+ clock: clock,
+ nudDisp: &disp,
+ },
+ }
+
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
+ nudState := NewNUDState(c, rng)
+ linkRes := entryTestLinkResolver{}
+ entry := newNeighborEntry(&nic, entryTestAddr1, entryTestAddr2, nudState, &linkRes)
+
+ // Stub out ndpState to verify modification of default routers.
+ nic.mu.ndp = ndpState{
+ nic: &nic,
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ }
+
+ // Stub out the neighbor cache to verify deletion from the cache.
+ nic.neigh = &neighborCache{
+ nic: &nic,
+ state: nudState,
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ }
+ nic.neigh.cache[entryTestAddr1] = entry
+
+ return entry, &disp, &linkRes, clock
+}
+
+// TestEntryInitiallyUnknown verifies that the state of a newly created
+// neighborEntry is Unknown.
+func TestEntryInitiallyUnknown(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Unknown; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.RetransmitTimer)
+
+ // No probes should have been sent.
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Unknown; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(time.Hour)
+
+ // No probes should have been sent.
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryUnknownToIncomplete(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ }
+ {
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+}
+
+func TestEntryUnknownToStale(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ // No probes should have been sent.
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ updatedAt := e.neigh.UpdatedAt
+ e.mu.Unlock()
+
+ clock.advance(c.RetransmitTimer)
+
+ // UpdatedAt should remain the same during address resolution.
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.UpdatedAt, updatedAt; got != want {
+ t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.RetransmitTimer)
+
+ // UpdatedAt should change after failing address resolution. Timing out after
+ // sending the last probe transitions the entry to Failed.
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ clock.advance(c.RetransmitTimer)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant {
+ t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryIncompleteToReachable(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+// TestEntryAddsAndClearsWakers verifies that wakers are added when
+// addWakerLocked is called and cleared when address resolution finishes. In
+// this case, address resolution will finish when transitioning from Incomplete
+// to Reachable.
+func TestEntryAddsAndClearsWakers(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ s.AddWaker(&w, 123)
+ defer s.Done()
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got := e.wakers; got != nil {
+ t.Errorf("got e.wakers = %v, want = nil", got)
+ }
+ e.addWakerLocked(&w)
+ if got, want := w.IsAsserted(), false; got != want {
+ t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
+ }
+ if e.wakers == nil {
+ t.Error("expected e.wakers to be non-nil")
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.wakers != nil {
+ t.Errorf("got e.wakers = %v, want = nil", e.wakers)
+ }
+ if got, want := w.IsAsserted(), true; got != want {
+ t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: true,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.isRouter, true; got != want {
+ t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ linkRes.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToStale(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToFailed(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes)
+ clock.advance(waitFor)
+
+ wantProbes := []entryTestProbeInfo{
+ // The Incomplete-to-Incomplete state transition is tested here by
+ // verifying that 3 reachability probes were sent.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Failed; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+type testLocker struct{}
+
+var _ sync.Locker = (*testLocker)(nil)
+
+func (*testLocker) Lock() {}
+func (*testLocker) Unlock() {}
+
+func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: true,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.isRouter, true; got != want {
+ t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ }
+ e.nic.mu.ndp.defaultRouters[entryTestAddr1] = defaultRouterState{
+ invalidationJob: e.nic.stack.newJob(&testLocker{}, func() {}),
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.isRouter, false; got != want {
+ t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ }
+ if _, ok := e.nic.mu.ndp.defaultRouters[entryTestAddr1]; ok {
+ t.Errorf("unexpected defaultRouter for %s", entryTestAddr1)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToDelay(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleUpperLevelConfirmationLocked()
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 1
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToProbe(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryProbeToFailed(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ c.MaxUnicastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
+ clock.advance(waitFor)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The next three probe are caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Failed; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryFailedGetsDeleted(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ c.MaxUnicastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ // Verify the cache contains the entry.
+ if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok {
+ t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1)
+ }
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
+ clock.advance(waitFor)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The next three probe are caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ // Verify the cache no longer contains the entry.
+ if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok {
+ t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1)
+ }
+}
diff --git a/pkg/tcpip/stack/neighborstate_string.go b/pkg/tcpip/stack/neighborstate_string.go
new file mode 100644
index 000000000..aa7311ec6
--- /dev/null
+++ b/pkg/tcpip/stack/neighborstate_string.go
@@ -0,0 +1,44 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type NeighborState"; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[Unknown-0]
+ _ = x[Incomplete-1]
+ _ = x[Reachable-2]
+ _ = x[Stale-3]
+ _ = x[Delay-4]
+ _ = x[Probe-5]
+ _ = x[Static-6]
+ _ = x[Failed-7]
+}
+
+const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailed"
+
+var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53}
+
+func (i NeighborState) String() string {
+ if i >= NeighborState(len(_NeighborState_index)-1) {
+ return "NeighborState(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _NeighborState_name[_NeighborState_index[i]:_NeighborState_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 54103fdb3..f21066fce 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -16,6 +16,7 @@ package stack
import (
"fmt"
+ "math/rand"
"reflect"
"sort"
"strings"
@@ -45,6 +46,7 @@ type NIC struct {
context NICContext
stats NICStats
+ neigh *neighborCache
mu struct {
sync.RWMutex
@@ -141,6 +143,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{}
}
+ // Check for Neighbor Unreachability Detection support.
+ if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 {
+ rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds()))
+ nic.neigh = &neighborCache{
+ nic: nic,
+ state: NewNUDState(stack.nudConfigs, rng),
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ }
+ }
+
nic.linkEP.Attach(nic)
return nic
@@ -181,7 +193,7 @@ func (n *NIC) disableLocked() *tcpip.Error {
return nil
}
- // TODO(b/147015577): Should Routes that are currently bound to n be
+ // TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be
// invalidated? Currently, Routes will continue to work when a NIC is enabled
// again, and applications may not know that the underlying NIC was ever
// disabled.
@@ -457,8 +469,20 @@ type ipv6AddrCandidate struct {
// remoteAddr must be a valid IPv6 address.
func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
n.mu.RLock()
- defer n.mu.RUnlock()
+ ref := n.primaryIPv6EndpointRLocked(remoteAddr)
+ n.mu.RUnlock()
+ return ref
+}
+// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address
+// Selection (RFC 6724 section 5).
+//
+// Note, only rules 1-3 and 7 are followed.
+//
+// remoteAddr must be a valid IPv6 address.
+//
+// n.mu MUST be read locked.
+func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber]
if len(primaryAddrs) == 0 {
@@ -568,11 +592,6 @@ const (
// promiscuous indicates that the NIC's promiscuous flag should be observed
// when getting a NIC's referenced network endpoint.
promiscuous
-
- // forceSpoofing indicates that the NIC should be assumed to be spoofing,
- // regardless of what the NIC's spoofing flag is when getting a NIC's
- // referenced network endpoint.
- forceSpoofing
)
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
@@ -591,8 +610,6 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
// or spoofing. Promiscuous mode will only be checked if promiscuous is true.
// Similarly, spoofing will only be checked if spoofing is true.
func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint {
- id := NetworkEndpointID{address}
-
n.mu.RLock()
var spoofingOrPromiscuous bool
@@ -601,24 +618,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
spoofingOrPromiscuous = n.mu.spoofing
case promiscuous:
spoofingOrPromiscuous = n.mu.promiscuous
- case forceSpoofing:
- spoofingOrPromiscuous = true
}
- if ref, ok := n.mu.endpoints[id]; ok {
+ if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
// An endpoint with this id exists, check if it can be used and return it.
- switch ref.getKind() {
- case permanentExpired:
- if !spoofingOrPromiscuous {
- n.mu.RUnlock()
- return nil
- }
- fallthrough
- case temporary, permanent:
- if ref.tryIncRef() {
- n.mu.RUnlock()
- return ref
- }
+ if !ref.isAssignedRLocked(spoofingOrPromiscuous) {
+ n.mu.RUnlock()
+ return nil
+ }
+
+ if ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
}
}
@@ -654,11 +665,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
- if ref, ok := n.mu.endpoints[id]; ok {
+ ref := n.getRefOrCreateTempLocked(protocol, address, peb)
+ n.mu.Unlock()
+ return ref
+}
+
+/// getRefOrCreateTempLocked returns an existing endpoint for address or creates
+/// and returns a temporary endpoint.
+func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
+ if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
// No need to check the type as we are ok with expired endpoints at this
// point.
if ref.tryIncRef() {
- n.mu.Unlock()
return ref
}
// tryIncRef failing means the endpoint is scheduled to be removed once the
@@ -670,7 +688,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// Add a new temporary endpoint.
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
- n.mu.Unlock()
return nil
}
ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
@@ -680,8 +697,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
PrefixLen: netProto.DefaultPrefixLen(),
},
}, peb, temporary, static, false)
-
- n.mu.Unlock()
return ref
}
@@ -1153,7 +1168,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
return joins != 0
}
-func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) {
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) {
r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
@@ -1167,7 +1182,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
-func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
n.mu.RLock()
enabled := n.mu.enabled
// If the NIC is not yet enabled, don't receive any packets.
@@ -1197,27 +1212,34 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
// Are any packet sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
- // Check whether there are packet sockets listening for every protocol.
- // If we received a packet with protocol EthernetProtocolAll, then the
- // previous for loop will have handled it.
- if protocol != header.EthernetProtocolAll {
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
- }
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
n.mu.RUnlock()
for _, ep := range packetEPs {
- ep.HandlePacket(n.id, local, protocol, pkt.Clone())
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketHost
+ ep.HandlePacket(n.id, local, protocol, p)
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
n.stack.stats.IP.PacketsReceived.Increment()
}
- netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize())
+ // Parse headers.
+ transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
if !ok {
+ // The packet is too small to contain a network header.
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- src, dst := netProto.ParseAddresses(netHeader)
+ if hasTransportHdr {
+ // Parse the transport header if present.
+ if state, ok := n.stack.transportProtocols[transProtoNum]; ok {
+ state.proto.Parse(pkt)
+ }
+ }
+
+ src, dst := netProto.ParseAddresses(pkt.NetworkHeader)
if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
// The source address is one of our own, so we never should have gotten a
@@ -1229,18 +1251,19 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
// TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet.
- if protocol == header.IPv4ProtocolNumber {
+ // Loopback traffic skips the prerouting chain.
+ if protocol == header.IPv4ProtocolNumber && !n.isLoopback() {
// iptables filtering.
ipt := n.stack.IPTables()
address := n.primaryAddress(protocol)
- if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address, ""); !ok {
+ if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok {
// iptables is telling us to drop the packet.
return
}
}
if ref := n.getRef(protocol, dst); ref != nil {
- handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt)
+ handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt)
return
}
@@ -1298,24 +1321,55 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
}
-func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
+func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ n.mu.RLock()
+ // We do not deliver to protocol specific packet endpoints as on Linux
+ // only ETH_P_ALL endpoints get outbound packets.
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketOutgoing
+ // Add the link layer header as outgoing packets are intercepted
+ // before the link layer header is created.
+ n.linkEP.AddHeader(local, remote, protocol, p)
+ ep.HandlePacket(n.id, local, protocol, p)
+ }
+}
+
+func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
- if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 {
- pkt.Header = buffer.NewPrependable(linkHeaderLen)
+ // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this
+ // by having lower layers explicity write each header instead of just
+ // pkt.Header.
+
+ // pkt may have set its NetworkHeader and TransportHeader. If we're
+ // forwarding, we'll have to copy them into pkt.Header.
+ pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader))
+ if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader)))
+ }
+ if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader)))
}
+ // WritePacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Header.UsedLength() + pkt.Data.Size()
+
if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return
}
n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size()))
+ n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
}
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -1329,13 +1383,34 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// validly formed.
n.stack.demux.deliverRawPacket(r, protocol, pkt)
- transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
- if !ok {
+ // TransportHeader is nil only when pkt is an ICMP packet or was reassembled
+ // from fragments.
+ if pkt.TransportHeader == nil {
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
+ // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
+ // ICMP packets may be longer, but until icmp.Parse is implemented, here
+ // we parse it using the minimum size.
+ transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
+ if !ok {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+ pkt.TransportHeader = transHeader
+ pkt.Data.TrimFront(len(pkt.TransportHeader))
+ } else {
+ // This is either a bad packet or was re-assembled from fragments.
+ transProto.Parse(pkt)
+ }
+ }
+
+ if len(pkt.TransportHeader) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- srcPort, dstPort, err := transProto.ParsePorts(transHeader)
+ srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader)
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
return
@@ -1362,7 +1437,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// DeliverTransportControlPacket delivers control packets to the appropriate
// transport protocol endpoint.
-func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) {
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) {
state, ok := n.stack.transportProtocols[trans]
if !ok {
return
@@ -1477,6 +1552,27 @@ func (n *NIC) setNDPConfigs(c NDPConfigurations) {
n.mu.Unlock()
}
+// NUDConfigs gets the NUD configurations for n.
+func (n *NIC) NUDConfigs() (NUDConfigurations, *tcpip.Error) {
+ if n.neigh == nil {
+ return NUDConfigurations{}, tcpip.ErrNotSupported
+ }
+ return n.neigh.config(), nil
+}
+
+// setNUDConfigs sets the NUD configurations for n.
+//
+// Note, if c contains invalid NUD configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error {
+ if n.neigh == nil {
+ return tcpip.ErrNotSupported
+ }
+ c.resetInvalidFields()
+ n.neigh.setConfig(c)
+ return nil
+}
+
// handleNDPRA handles an NDP Router Advertisement message that arrived on n.
func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
n.mu.Lock()
@@ -1611,8 +1707,8 @@ func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) {
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
-// packet. It requires the endpoint to not be marked expired (i.e., its address
-// has been removed), or the NIC to be in spoofing mode.
+// packet. It requires the endpoint to not be marked expired (i.e., its address)
+// has been removed) unless the NIC is in spoofing mode, or temporary.
func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
r.nic.mu.RLock()
defer r.nic.mu.RUnlock()
@@ -1620,13 +1716,28 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
return r.isValidForOutgoingRLocked()
}
-// isValidForOutgoingRLocked returns true if the endpoint can be used to send
-// out a packet. It requires the endpoint to not be marked expired (i.e., its
-// address has been removed), or the NIC to be in spoofing mode.
-//
-// r's NIC must be read locked.
+// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires
+// r.nic.mu to be read locked.
func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
- return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing)
+ if !r.nic.mu.enabled {
+ return false
+ }
+
+ return r.isAssignedRLocked(r.nic.mu.spoofing)
+}
+
+// isAssignedRLocked returns true if r is considered to be assigned to the NIC.
+//
+// r.nic.mu must be read locked.
+func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool {
+ switch r.getKind() {
+ case permanentTentative:
+ return false
+ case permanentExpired:
+ return spoofingOrPromiscuous
+ default:
+ return true
+ }
}
// expireLocked decrements the reference count and marks the permanent endpoint
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index d672fc157..a70792b50 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -15,11 +15,278 @@
package stack
import (
+ "math"
"testing"
+ "time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
+var _ LinkEndpoint = (*testLinkEndpoint)(nil)
+
+// A LinkEndpoint that throws away outgoing packets.
+//
+// We use this instead of the channel endpoint as the channel package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testLinkEndpoint struct {
+ dispatcher NetworkDispatcher
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (e *testLinkEndpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (*testLinkEndpoint) MTU() uint32 {
+ return math.MaxUint16
+}
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities {
+ return CapabilityResolutionRequired
+}
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*testLinkEndpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// Wait implements LinkEndpoint.Wait.
+func (*testLinkEndpoint) Wait() {}
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // Our tests don't use this so we don't support it.
+ return 0, tcpip.ErrNotSupported
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ // Our tests don't use this so we don't support it.
+ return tcpip.ErrNotSupported
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
+var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
+
+// An IPv6 NetworkEndpoint that throws away outgoing packets.
+//
+// We use this instead of ipv6.endpoint because the ipv6 package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testIPv6Endpoint struct {
+ nicID tcpip.NICID
+ id NetworkEndpointID
+ prefixLen int
+ linkEP LinkEndpoint
+ protocol *testIPv6Protocol
+}
+
+// DefaultTTL implements NetworkEndpoint.DefaultTTL.
+func (*testIPv6Endpoint) DefaultTTL() uint8 {
+ return 0
+}
+
+// MTU implements NetworkEndpoint.MTU.
+func (e *testIPv6Endpoint) MTU() uint32 {
+ return e.linkEP.MTU() - header.IPv6MinimumSize
+}
+
+// Capabilities implements NetworkEndpoint.Capabilities.
+func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
+func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+}
+
+// WritePacket implements NetworkEndpoint.WritePacket.
+func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+// WritePackets implements NetworkEndpoint.WritePackets.
+func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) {
+ // Our tests don't use this so we don't support it.
+ return 0, tcpip.ErrNotSupported
+}
+
+// WriteHeaderIncludedPacket implements
+// NetworkEndpoint.WriteHeaderIncludedPacket.
+func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error {
+ // Our tests don't use this so we don't support it.
+ return tcpip.ErrNotSupported
+}
+
+// ID implements NetworkEndpoint.ID.
+func (e *testIPv6Endpoint) ID() *NetworkEndpointID {
+ return &e.id
+}
+
+// PrefixLen implements NetworkEndpoint.PrefixLen.
+func (e *testIPv6Endpoint) PrefixLen() int {
+ return e.prefixLen
+}
+
+// NICID implements NetworkEndpoint.NICID.
+func (e *testIPv6Endpoint) NICID() tcpip.NICID {
+ return e.nicID
+}
+
+// HandlePacket implements NetworkEndpoint.HandlePacket.
+func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
+}
+
+// Close implements NetworkEndpoint.Close.
+func (*testIPv6Endpoint) Close() {}
+
+// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber.
+func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+var _ NetworkProtocol = (*testIPv6Protocol)(nil)
+
+// An IPv6 NetworkProtocol that supports the bare minimum to make a stack
+// believe it supports IPv6.
+//
+// We use this instead of ipv6.protocol because the ipv6 package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testIPv6Protocol struct{}
+
+// Number implements NetworkProtocol.Number.
+func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize.
+func (*testIPv6Protocol) MinimumPacketSize() int {
+ return header.IPv6MinimumSize
+}
+
+// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen.
+func (*testIPv6Protocol) DefaultPrefixLen() int {
+ return header.IPv6AddressSize * 8
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv6(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// NewEndpoint implements NetworkProtocol.NewEndpoint.
+func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
+ return &testIPv6Endpoint{
+ nicID: nicID,
+ id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
+ linkEP: linkEP,
+ protocol: p,
+ }, nil
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error {
+ return nil
+}
+
+// Option implements NetworkProtocol.Option.
+func (*testIPv6Protocol) Option(interface{}) *tcpip.Error {
+ return nil
+}
+
+// Close implements NetworkProtocol.Close.
+func (*testIPv6Protocol) Close() {}
+
+// Wait implements NetworkProtocol.Wait.
+func (*testIPv6Protocol) Wait() {}
+
+// Parse implements NetworkProtocol.Parse.
+func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ return 0, false, false
+}
+
+var _ LinkAddressResolver = (*testIPv6Protocol)(nil)
+
+// LinkAddressProtocol implements LinkAddressResolver.
+func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// LinkAddressRequest implements LinkAddressResolver.
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
+ return nil
+}
+
+// ResolveStaticAddress implements LinkAddressResolver.
+func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if header.IsV6MulticastAddress(addr) {
+ return header.EthernetAddressFromMulticastIPv6Address(addr), true
+ }
+ return "", false
+}
+
+// Test the race condition where a NIC is removed and an RS timer fires at the
+// same time.
+func TestRemoveNICWhileHandlingRSTimer(t *testing.T) {
+ const (
+ nicID = 1
+
+ maxRtrSolicitations = 5
+ )
+
+ e := testLinkEndpoint{}
+ s := New(Options{
+ NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}},
+ NDPConfigs: NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: minimumRtrSolicitationInterval,
+ },
+ })
+
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ s.mu.Lock()
+ // Wait for the router solicitation timer to fire and block trying to obtain
+ // the stack lock when doing link address resolution.
+ time.Sleep(minimumRtrSolicitationInterval * 2)
+ if err := s.removeNICLocked(nicID); err != nil {
+ t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err)
+ }
+ s.mu.Unlock()
+}
+
func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
// When the NIC is disabled, the only field that matters is the stats field.
// This test is limited to stats counter checks.
@@ -44,7 +311,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
t.FailNow()
}
- nic.DeliverNetworkPacket(nil, "", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
+ nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
new file mode 100644
index 000000000..f848d50ad
--- /dev/null
+++ b/pkg/tcpip/stack/nud.go
@@ -0,0 +1,466 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "math"
+ "sync"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // defaultBaseReachableTime is the default base duration for computing the
+ // random reachable time.
+ //
+ // Reachable time is the duration for which a neighbor is considered
+ // reachable after a positive reachability confirmation is received. It is a
+ // function of a uniformly distributed random value between the minimum and
+ // maximum random factors, multiplied by the base reachable time. Using a
+ // random component eliminates the possibility that Neighbor Unreachability
+ // Detection messages will synchronize with each other.
+ //
+ // Default taken from REACHABLE_TIME of RFC 4861 section 10.
+ defaultBaseReachableTime = 30 * time.Second
+
+ // minimumBaseReachableTime is the minimum base duration for computing the
+ // random reachable time.
+ //
+ // Minimum = 1ms
+ minimumBaseReachableTime = time.Millisecond
+
+ // defaultMinRandomFactor is the default minimum value of the random factor
+ // used for computing reachable time.
+ //
+ // Default taken from MIN_RANDOM_FACTOR of RFC 4861 section 10.
+ defaultMinRandomFactor = 0.5
+
+ // defaultMaxRandomFactor is the default maximum value of the random factor
+ // used for computing reachable time.
+ //
+ // The default value depends on the value of MinRandomFactor.
+ // If MinRandomFactor is less than MAX_RANDOM_FACTOR of RFC 4861 section 10,
+ // the value from the RFC will be used; otherwise, the default is
+ // MinRandomFactor multiplied by three.
+ defaultMaxRandomFactor = 1.5
+
+ // defaultRetransmitTimer is the default amount of time to wait between
+ // sending reachability probes.
+ //
+ // Default taken from RETRANS_TIMER of RFC 4861 section 10.
+ defaultRetransmitTimer = time.Second
+
+ // minimumRetransmitTimer is the minimum amount of time to wait between
+ // sending reachability probes.
+ //
+ // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here
+ // to make sure the messages are not sent all at once. We also come to this
+ // value because in the RetransmitTimer field of a Router Advertisement, a
+ // value of 0 means unspecified, so the smallest valid value is 1. Note, the
+ // unit of the RetransmitTimer field in the Router Advertisement is
+ // milliseconds.
+ minimumRetransmitTimer = time.Millisecond
+
+ // defaultDelayFirstProbeTime is the default duration to wait for a
+ // non-Neighbor-Discovery related protocol to reconfirm reachability after
+ // entering the DELAY state. After this time, a reachability probe will be
+ // sent and the entry will transition to the PROBE state.
+ //
+ // Default taken from DELAY_FIRST_PROBE_TIME of RFC 4861 section 10.
+ defaultDelayFirstProbeTime = 5 * time.Second
+
+ // defaultMaxMulticastProbes is the default number of reachabililty probes
+ // to send before concluding negative reachability and deleting the neighbor
+ // entry from the INCOMPLETE state.
+ //
+ // Default taken from MAX_MULTICAST_SOLICIT of RFC 4861 section 10.
+ defaultMaxMulticastProbes = 3
+
+ // defaultMaxUnicastProbes is the default number of reachability probes to
+ // send before concluding retransmission from within the PROBE state should
+ // cease and the entry SHOULD be deleted.
+ //
+ // Default taken from MAX_UNICASE_SOLICIT of RFC 4861 section 10.
+ defaultMaxUnicastProbes = 3
+
+ // defaultMaxAnycastDelayTime is the default time in which the stack SHOULD
+ // delay sending a response for a random time between 0 and this time, if the
+ // target address is an anycast address.
+ //
+ // Default taken from MAX_ANYCAST_DELAY_TIME of RFC 4861 section 10.
+ defaultMaxAnycastDelayTime = time.Second
+
+ // defaultMaxReachbilityConfirmations is the default amount of unsolicited
+ // reachability confirmation messages a node MAY send to all-node multicast
+ // address when it determines its link-layer address has changed.
+ //
+ // Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10.
+ defaultMaxReachbilityConfirmations = 3
+
+ // defaultUnreachableTime is the default duration for how long an entry will
+ // remain in the FAILED state before being removed from the neighbor cache.
+ //
+ // Note, there is no equivalent protocol constant defined in RFC 4861. It
+ // leaves the specifics of any garbage collection mechanism up to the
+ // implementation.
+ defaultUnreachableTime = 5 * time.Second
+)
+
+// NUDDispatcher is the interface integrators of netstack must implement to
+// receive and handle NUD related events.
+type NUDDispatcher interface {
+ // OnNeighborAdded will be called when a new entry is added to a NIC's (with
+ // ID nicID) neighbor table.
+ //
+ // This function is permitted to block indefinitely without interfering with
+ // the stack's operation.
+ //
+ // May be called concurrently.
+ OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+
+ // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID)
+ // neighbor table changes state and/or link address.
+ //
+ // This function is permitted to block indefinitely without interfering with
+ // the stack's operation.
+ //
+ // May be called concurrently.
+ OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+
+ // OnNeighborRemoved will be called when an entry is removed from a NIC's
+ // (with ID nicID) neighbor table.
+ //
+ // This function is permitted to block indefinitely without interfering with
+ // the stack's operation.
+ //
+ // May be called concurrently.
+ OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+}
+
+// ReachabilityConfirmationFlags describes the flags used within a reachability
+// confirmation (e.g. ARP reply or Neighbor Advertisement for ARP or NDP,
+// respectively).
+type ReachabilityConfirmationFlags struct {
+ // Solicited indicates that the advertisement was sent in response to a
+ // reachability probe.
+ Solicited bool
+
+ // Override indicates that the reachability confirmation should override an
+ // existing neighbor cache entry and update the cached link-layer address.
+ // When Override is not set the confirmation will not update a cached
+ // link-layer address, but will update an existing neighbor cache entry for
+ // which no link-layer address is known.
+ Override bool
+
+ // IsRouter indicates that the sender is a router.
+ IsRouter bool
+}
+
+// NUDHandler communicates external events to the Neighbor Unreachability
+// Detection state machine, which is implemented per-interface. This is used by
+// network endpoints to inform the Neighbor Cache of probes and confirmations.
+type NUDHandler interface {
+ // HandleProbe processes an incoming neighbor probe (e.g. ARP request or
+ // Neighbor Solicitation for ARP or NDP, respectively). Validation of the
+ // probe needs to be performed before calling this function since the
+ // Neighbor Cache doesn't have access to view the NIC's assigned addresses.
+ HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress)
+
+ // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP
+ // reply or Neighbor Advertisement for ARP or NDP, respectively).
+ HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags)
+
+ // HandleUpperLevelConfirmation processes an incoming upper-level protocol
+ // (e.g. TCP acknowledgements) reachability confirmation.
+ HandleUpperLevelConfirmation(addr tcpip.Address)
+}
+
+// NUDConfigurations is the NUD configurations for the netstack. This is used
+// by the neighbor cache to operate the NUD state machine on each device in the
+// local network.
+type NUDConfigurations struct {
+ // BaseReachableTime is the base duration for computing the random reachable
+ // time.
+ //
+ // Reachable time is the duration for which a neighbor is considered
+ // reachable after a positive reachability confirmation is received. It is a
+ // function of uniformly distributed random value between minRandomFactor and
+ // maxRandomFactor multiplied by baseReachableTime. Using a random component
+ // eliminates the possibility that Neighbor Unreachability Detection messages
+ // will synchronize with each other.
+ //
+ // After this time, a neighbor entry will transition from REACHABLE to STALE
+ // state.
+ //
+ // Must be greater than 0.
+ BaseReachableTime time.Duration
+
+ // LearnBaseReachableTime enables learning BaseReachableTime during runtime
+ // from the neighbor discovery protocol, if supported.
+ //
+ // TODO(gvisor.dev/issue/2240): Implement this NUD configuration option.
+ LearnBaseReachableTime bool
+
+ // MinRandomFactor is the minimum value of the random factor used for
+ // computing reachable time.
+ //
+ // See BaseReachbleTime for more information on computing the reachable time.
+ //
+ // Must be greater than 0.
+ MinRandomFactor float32
+
+ // MaxRandomFactor is the maximum value of the random factor used for
+ // computing reachabile time.
+ //
+ // See BaseReachbleTime for more information on computing the reachable time.
+ //
+ // Must be great than or equal to MinRandomFactor.
+ MaxRandomFactor float32
+
+ // RetransmitTimer is the duration between retransmission of reachability
+ // probes in the PROBE state.
+ RetransmitTimer time.Duration
+
+ // LearnRetransmitTimer enables learning RetransmitTimer during runtime from
+ // the neighbor discovery protocol, if supported.
+ //
+ // TODO(gvisor.dev/issue/2241): Implement this NUD configuration option.
+ LearnRetransmitTimer bool
+
+ // DelayFirstProbeTime is the duration to wait for a non-Neighbor-Discovery
+ // related protocol to reconfirm reachability after entering the DELAY state.
+ // After this time, a reachability probe will be sent and the entry will
+ // transition to the PROBE state.
+ //
+ // Must be greater than 0.
+ DelayFirstProbeTime time.Duration
+
+ // MaxMulticastProbes is the number of reachability probes to send before
+ // concluding negative reachability and deleting the neighbor entry from the
+ // INCOMPLETE state.
+ //
+ // Must be greater than 0.
+ MaxMulticastProbes uint32
+
+ // MaxUnicastProbes is the number of reachability probes to send before
+ // concluding retransmission from within the PROBE state should cease and
+ // entry SHOULD be deleted.
+ //
+ // Must be greater than 0.
+ MaxUnicastProbes uint32
+
+ // MaxAnycastDelayTime is the time in which the stack SHOULD delay sending a
+ // response for a random time between 0 and this time, if the target address
+ // is an anycast address.
+ //
+ // TODO(gvisor.dev/issue/2242): Use this option when sending solicited
+ // neighbor confirmations to anycast addresses and proxying neighbor
+ // confirmations.
+ MaxAnycastDelayTime time.Duration
+
+ // MaxReachabilityConfirmations is the number of unsolicited reachability
+ // confirmation messages a node MAY send to all-node multicast address when
+ // it determines its link-layer address has changed.
+ //
+ // TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD
+ // configuration option is necessary.
+ MaxReachabilityConfirmations uint32
+
+ // UnreachableTime describes how long an entry will remain in the FAILED
+ // state before being removed from the neighbor cache.
+ UnreachableTime time.Duration
+}
+
+// DefaultNUDConfigurations returns a NUDConfigurations populated with default
+// values defined by RFC 4861 section 10.
+func DefaultNUDConfigurations() NUDConfigurations {
+ return NUDConfigurations{
+ BaseReachableTime: defaultBaseReachableTime,
+ LearnBaseReachableTime: true,
+ MinRandomFactor: defaultMinRandomFactor,
+ MaxRandomFactor: defaultMaxRandomFactor,
+ RetransmitTimer: defaultRetransmitTimer,
+ LearnRetransmitTimer: true,
+ DelayFirstProbeTime: defaultDelayFirstProbeTime,
+ MaxMulticastProbes: defaultMaxMulticastProbes,
+ MaxUnicastProbes: defaultMaxUnicastProbes,
+ MaxAnycastDelayTime: defaultMaxAnycastDelayTime,
+ MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations,
+ UnreachableTime: defaultUnreachableTime,
+ }
+}
+
+// resetInvalidFields modifies an invalid NDPConfigurations with valid values.
+// If invalid values are present in c, the corresponding default values will be
+// used instead. This is needed to check, and conditionally fix, user-specified
+// NUDConfigurations.
+func (c *NUDConfigurations) resetInvalidFields() {
+ if c.BaseReachableTime < minimumBaseReachableTime {
+ c.BaseReachableTime = defaultBaseReachableTime
+ }
+ if c.MinRandomFactor <= 0 {
+ c.MinRandomFactor = defaultMinRandomFactor
+ }
+ if c.MaxRandomFactor < c.MinRandomFactor {
+ c.MaxRandomFactor = calcMaxRandomFactor(c.MinRandomFactor)
+ }
+ if c.RetransmitTimer < minimumRetransmitTimer {
+ c.RetransmitTimer = defaultRetransmitTimer
+ }
+ if c.DelayFirstProbeTime == 0 {
+ c.DelayFirstProbeTime = defaultDelayFirstProbeTime
+ }
+ if c.MaxMulticastProbes == 0 {
+ c.MaxMulticastProbes = defaultMaxMulticastProbes
+ }
+ if c.MaxUnicastProbes == 0 {
+ c.MaxUnicastProbes = defaultMaxUnicastProbes
+ }
+ if c.UnreachableTime == 0 {
+ c.UnreachableTime = defaultUnreachableTime
+ }
+}
+
+// calcMaxRandomFactor calculates the maximum value of the random factor used
+// for computing reachable time. This function is necessary for when the
+// default specified in RFC 4861 section 10 is less than the current
+// MinRandomFactor.
+//
+// Assumes minRandomFactor is positive since validation of the minimum value
+// should come before the validation of the maximum.
+func calcMaxRandomFactor(minRandomFactor float32) float32 {
+ if minRandomFactor > defaultMaxRandomFactor {
+ return minRandomFactor * 3
+ }
+ return defaultMaxRandomFactor
+}
+
+// A Rand is a source of random numbers.
+type Rand interface {
+ // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0).
+ Float32() float32
+}
+
+// NUDState stores states needed for calculating reachable time.
+type NUDState struct {
+ rng Rand
+
+ // mu protects the fields below.
+ //
+ // It is necessary for NUDState to handle its own locking since neighbor
+ // entries may access the NUD state from within the goroutine spawned by
+ // time.AfterFunc(). This goroutine may run concurrently with the main
+ // process for controlling the neighbor cache and would otherwise introduce
+ // race conditions if NUDState was not locked properly.
+ mu sync.RWMutex
+
+ config NUDConfigurations
+
+ // reachableTime is the duration to wait for a REACHABLE entry to
+ // transition into STALE after inactivity. This value is calculated with
+ // the algorithm defined in RFC 4861 section 6.3.2.
+ reachableTime time.Duration
+
+ expiration time.Time
+ prevBaseReachableTime time.Duration
+ prevMinRandomFactor float32
+ prevMaxRandomFactor float32
+}
+
+// NewNUDState returns new NUDState using c as configuration and the specified
+// random number generator for use in recomputing ReachableTime.
+func NewNUDState(c NUDConfigurations, rng Rand) *NUDState {
+ s := &NUDState{
+ rng: rng,
+ }
+ s.config = c
+ return s
+}
+
+// Config returns the NUD configuration.
+func (s *NUDState) Config() NUDConfigurations {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.config
+}
+
+// SetConfig replaces the existing NUD configurations with c.
+func (s *NUDState) SetConfig(c NUDConfigurations) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.config = c
+}
+
+// ReachableTime returns the duration to wait for a REACHABLE entry to
+// transition into STALE after inactivity. This value is recalculated for new
+// values of BaseReachableTime, MinRandomFactor, and MaxRandomFactor using the
+// algorithm defined in RFC 4861 section 6.3.2.
+func (s *NUDState) ReachableTime() time.Duration {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if time.Now().After(s.expiration) ||
+ s.config.BaseReachableTime != s.prevBaseReachableTime ||
+ s.config.MinRandomFactor != s.prevMinRandomFactor ||
+ s.config.MaxRandomFactor != s.prevMaxRandomFactor {
+ return s.recomputeReachableTimeLocked()
+ }
+ return s.reachableTime
+}
+
+// recomputeReachableTimeLocked forces a recalculation of ReachableTime using
+// the algorithm defined in RFC 4861 section 6.3.2.
+//
+// This SHOULD automatically be invoked during certain situations, as per
+// RFC 4861 section 6.3.4:
+//
+// If the received Reachable Time value is non-zero, the host SHOULD set its
+// BaseReachableTime variable to the received value. If the new value
+// differs from the previous value, the host SHOULD re-compute a new random
+// ReachableTime value. ReachableTime is computed as a uniformly
+// distributed random value between MIN_RANDOM_FACTOR and MAX_RANDOM_FACTOR
+// times the BaseReachableTime. Using a random component eliminates the
+// possibility that Neighbor Unreachability Detection messages will
+// synchronize with each other.
+//
+// In most cases, the advertised Reachable Time value will be the same in
+// consecutive Router Advertisements, and a host's BaseReachableTime rarely
+// changes. In such cases, an implementation SHOULD ensure that a new
+// random value gets re-computed at least once every few hours.
+//
+// s.mu MUST be locked for writing.
+func (s *NUDState) recomputeReachableTimeLocked() time.Duration {
+ s.prevBaseReachableTime = s.config.BaseReachableTime
+ s.prevMinRandomFactor = s.config.MinRandomFactor
+ s.prevMaxRandomFactor = s.config.MaxRandomFactor
+
+ randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor)
+
+ // Check for overflow, given that minRandomFactor and maxRandomFactor are
+ // guaranteed to be positive numbers.
+ if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) {
+ s.reachableTime = time.Duration(math.MaxInt64)
+ } else if randomFactor == 1 {
+ // Avoid loss of precision when a large base reachable time is used.
+ s.reachableTime = s.config.BaseReachableTime
+ } else {
+ reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor)
+ s.reachableTime = time.Duration(reachableTime)
+ }
+
+ s.expiration = time.Now().Add(2 * time.Hour)
+ return s.reachableTime
+}
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
new file mode 100644
index 000000000..2494ee610
--- /dev/null
+++ b/pkg/tcpip/stack/nud_test.go
@@ -0,0 +1,795 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "math"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ defaultBaseReachableTime = 30 * time.Second
+ minimumBaseReachableTime = time.Millisecond
+ defaultMinRandomFactor = 0.5
+ defaultMaxRandomFactor = 1.5
+ defaultRetransmitTimer = time.Second
+ minimumRetransmitTimer = time.Millisecond
+ defaultDelayFirstProbeTime = 5 * time.Second
+ defaultMaxMulticastProbes = 3
+ defaultMaxUnicastProbes = 3
+ defaultMaxAnycastDelayTime = time.Second
+ defaultMaxReachbilityConfirmations = 3
+ defaultUnreachableTime = 5 * time.Second
+
+ defaultFakeRandomNum = 0.5
+)
+
+// fakeRand is a deterministic random number generator.
+type fakeRand struct {
+ num float32
+}
+
+var _ stack.Rand = (*fakeRand)(nil)
+
+func (f *fakeRand) Float32() float32 {
+ return f.num
+}
+
+// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if
+// we attempt to update NUD configurations using an invalid NICID.
+func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) {
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The networking
+ // stack will only allocate neighbor caches if a protocol providing link
+ // address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+
+ // No NIC with ID 1 yet.
+ config := stack.NUDConfigurations{}
+ if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID)
+ }
+}
+
+// TestNUDConfigurationFailsForNotSupported tests to make sure we get a
+// NotSupported error if we attempt to retrieve NUD configurations when the
+// stack doesn't support NUD.
+//
+// The stack will report to not support NUD if a neighbor cache for a given NIC
+// is not allocated. The networking stack will only allocate neighbor caches if
+// a protocol providing link address resolution is specified (e.g. ARP, IPv6).
+func TestNUDConfigurationFailsForNotSupported(t *testing.T) {
+ const nicID = 1
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported {
+ t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported)
+ }
+}
+
+// TestNUDConfigurationFailsForNotSupported tests to make sure we get a
+// NotSupported error if we attempt to set NUD configurations when the stack
+// doesn't support NUD.
+//
+// The stack will report to not support NUD if a neighbor cache for a given NIC
+// is not allocated. The networking stack will only allocate neighbor caches if
+// a protocol providing link address resolution is specified (e.g. ARP, IPv6).
+func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) {
+ const nicID = 1
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ config := stack.NUDConfigurations{}
+ if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported {
+ t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported)
+ }
+}
+
+// TestDefaultNUDConfigurationIsValid verifies that calling
+// resetInvalidFields() on the result of DefaultNUDConfigurations() does not
+// change anything. DefaultNUDConfigurations() should return a valid
+// NUDConfigurations.
+func TestDefaultNUDConfigurations(t *testing.T) {
+ const nicID = 1
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The networking
+ // stack will only allocate neighbor caches if a protocol providing link
+ // address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ c, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got, want := c, stack.DefaultNUDConfigurations(); got != want {
+ t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want)
+ }
+}
+
+func TestNUDConfigurationsBaseReachableTime(t *testing.T) {
+ tests := []struct {
+ name string
+ baseReachableTime time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ baseReachableTime: 0,
+ want: defaultBaseReachableTime,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ baseReachableTime: time.Millisecond,
+ want: time.Millisecond,
+ },
+ {
+ name: "MoreThanDefaultBaseReachableTime",
+ baseReachableTime: 2 * defaultBaseReachableTime,
+ want: 2 * defaultBaseReachableTime,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.BaseReachableTime = test.baseReachableTime
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.BaseReachableTime; got != test.want {
+ t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMinRandomFactor(t *testing.T) {
+ tests := []struct {
+ name string
+ minRandomFactor float32
+ want float32
+ }{
+ // Invalid cases
+ {
+ name: "LessThanZero",
+ minRandomFactor: -1,
+ want: defaultMinRandomFactor,
+ },
+ {
+ name: "EqualToZero",
+ minRandomFactor: 0,
+ want: defaultMinRandomFactor,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ minRandomFactor: 1,
+ want: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MinRandomFactor = test.minRandomFactor
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MinRandomFactor; got != test.want {
+ t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMaxRandomFactor(t *testing.T) {
+ tests := []struct {
+ name string
+ minRandomFactor float32
+ maxRandomFactor float32
+ want float32
+ }{
+ // Invalid cases
+ {
+ name: "LessThanZero",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: -1,
+ want: defaultMaxRandomFactor,
+ },
+ {
+ name: "EqualToZero",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: 0,
+ want: defaultMaxRandomFactor,
+ },
+ {
+ name: "LessThanMinRandomFactor",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: defaultMinRandomFactor * 0.99,
+ want: defaultMaxRandomFactor,
+ },
+ {
+ name: "MoreThanMinRandomFactorWhenMinRandomFactorIsLargerThanMaxRandomFactorDefault",
+ minRandomFactor: defaultMaxRandomFactor * 2,
+ maxRandomFactor: defaultMaxRandomFactor,
+ want: defaultMaxRandomFactor * 6,
+ },
+ // Valid cases
+ {
+ name: "EqualToMinRandomFactor",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: defaultMinRandomFactor,
+ want: defaultMinRandomFactor,
+ },
+ {
+ name: "MoreThanMinRandomFactor",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: defaultMinRandomFactor * 1.1,
+ want: defaultMinRandomFactor * 1.1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MinRandomFactor = test.minRandomFactor
+ c.MaxRandomFactor = test.maxRandomFactor
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MaxRandomFactor; got != test.want {
+ t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsRetransmitTimer(t *testing.T) {
+ tests := []struct {
+ name string
+ retransmitTimer time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ retransmitTimer: 0,
+ want: defaultRetransmitTimer,
+ },
+ {
+ name: "LessThanMinimumRetransmitTimer",
+ retransmitTimer: minimumRetransmitTimer - time.Nanosecond,
+ want: defaultRetransmitTimer,
+ },
+ // Valid cases
+ {
+ name: "EqualToMinimumRetransmitTimer",
+ retransmitTimer: minimumRetransmitTimer,
+ want: minimumBaseReachableTime,
+ },
+ {
+ name: "LargetThanMinimumRetransmitTimer",
+ retransmitTimer: 2 * minimumBaseReachableTime,
+ want: 2 * minimumBaseReachableTime,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.RetransmitTimer = test.retransmitTimer
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.RetransmitTimer; got != test.want {
+ t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) {
+ tests := []struct {
+ name string
+ delayFirstProbeTime time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ delayFirstProbeTime: 0,
+ want: defaultDelayFirstProbeTime,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ delayFirstProbeTime: time.Millisecond,
+ want: time.Millisecond,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.DelayFirstProbeTime = test.delayFirstProbeTime
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.DelayFirstProbeTime; got != test.want {
+ t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) {
+ tests := []struct {
+ name string
+ maxMulticastProbes uint32
+ want uint32
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ maxMulticastProbes: 0,
+ want: defaultMaxMulticastProbes,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ maxMulticastProbes: 1,
+ want: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MaxMulticastProbes = test.maxMulticastProbes
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MaxMulticastProbes; got != test.want {
+ t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) {
+ tests := []struct {
+ name string
+ maxUnicastProbes uint32
+ want uint32
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ maxUnicastProbes: 0,
+ want: defaultMaxUnicastProbes,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ maxUnicastProbes: 1,
+ want: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MaxUnicastProbes = test.maxUnicastProbes
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MaxUnicastProbes; got != test.want {
+ t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsUnreachableTime(t *testing.T) {
+ tests := []struct {
+ name string
+ unreachableTime time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ unreachableTime: 0,
+ want: defaultUnreachableTime,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ unreachableTime: time.Millisecond,
+ want: time.Millisecond,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.UnreachableTime = test.unreachableTime
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.UnreachableTime; got != test.want {
+ t.Errorf("got UnreachableTime = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+// TestNUDStateReachableTime verifies the correctness of the ReachableTime
+// computation.
+func TestNUDStateReachableTime(t *testing.T) {
+ tests := []struct {
+ name string
+ baseReachableTime time.Duration
+ minRandomFactor float32
+ maxRandomFactor float32
+ want time.Duration
+ }{
+ {
+ name: "AllZeros",
+ baseReachableTime: 0,
+ minRandomFactor: 0,
+ maxRandomFactor: 0,
+ want: 0,
+ },
+ {
+ name: "ZeroMaxRandomFactor",
+ baseReachableTime: time.Second,
+ minRandomFactor: 0,
+ maxRandomFactor: 0,
+ want: 0,
+ },
+ {
+ name: "ZeroMinRandomFactor",
+ baseReachableTime: time.Second,
+ minRandomFactor: 0,
+ maxRandomFactor: 1,
+ want: time.Duration(defaultFakeRandomNum * float32(time.Second)),
+ },
+ {
+ name: "FractionalRandomFactor",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 0.001,
+ maxRandomFactor: 0.002,
+ want: time.Duration((0.001 + (0.001 * defaultFakeRandomNum)) * float32(math.MaxInt64)),
+ },
+ {
+ name: "MinAndMaxRandomFactorsEqual",
+ baseReachableTime: time.Second,
+ minRandomFactor: 1,
+ maxRandomFactor: 1,
+ want: time.Second,
+ },
+ {
+ name: "MinAndMaxRandomFactorsDifferent",
+ baseReachableTime: time.Second,
+ minRandomFactor: 1,
+ maxRandomFactor: 2,
+ want: time.Duration((1.0 + defaultFakeRandomNum) * float32(time.Second)),
+ },
+ {
+ name: "MaxInt64",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 1,
+ maxRandomFactor: 1,
+ want: time.Duration(math.MaxInt64),
+ },
+ {
+ name: "Overflow",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 1.5,
+ maxRandomFactor: 1.5,
+ want: time.Duration(math.MaxInt64),
+ },
+ {
+ name: "DoubleOverflow",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 2.5,
+ maxRandomFactor: 2.5,
+ want: time.Duration(math.MaxInt64),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := stack.NUDConfigurations{
+ BaseReachableTime: test.baseReachableTime,
+ MinRandomFactor: test.minRandomFactor,
+ MaxRandomFactor: test.maxRandomFactor,
+ }
+ // A fake random number generator is used to ensure deterministic
+ // results.
+ rng := fakeRand{
+ num: defaultFakeRandomNum,
+ }
+ s := stack.NewNUDState(c, &rng)
+ if got, want := s.ReachableTime(), test.want; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+ })
+ }
+}
+
+// TestNUDStateRecomputeReachableTime exercises the ReachableTime function
+// twice to verify recomputation of reachable time when the min random factor,
+// max random factor, or base reachable time changes.
+func TestNUDStateRecomputeReachableTime(t *testing.T) {
+ const defaultBase = time.Second
+ const defaultMin = 2.0 * defaultMaxRandomFactor
+ const defaultMax = 3.0 * defaultMaxRandomFactor
+
+ tests := []struct {
+ name string
+ baseReachableTime time.Duration
+ minRandomFactor float32
+ maxRandomFactor float32
+ want time.Duration
+ }{
+ {
+ name: "BaseReachableTime",
+ baseReachableTime: 2 * defaultBase,
+ minRandomFactor: defaultMin,
+ maxRandomFactor: defaultMax,
+ want: time.Duration((defaultMin + (defaultMax-defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)),
+ },
+ {
+ name: "MinRandomFactor",
+ baseReachableTime: defaultBase,
+ minRandomFactor: defaultMax,
+ maxRandomFactor: defaultMax,
+ want: time.Duration(defaultMax * float32(defaultBase)),
+ },
+ {
+ name: "MaxRandomFactor",
+ baseReachableTime: defaultBase,
+ minRandomFactor: defaultMin,
+ maxRandomFactor: defaultMin,
+ want: time.Duration(defaultMin * float32(defaultBase)),
+ },
+ {
+ name: "BothRandomFactor",
+ baseReachableTime: defaultBase,
+ minRandomFactor: 2 * defaultMin,
+ maxRandomFactor: 2 * defaultMax,
+ want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(defaultBase)),
+ },
+ {
+ name: "BaseReachableTimeAndBothRandomFactors",
+ baseReachableTime: 2 * defaultBase,
+ minRandomFactor: 2 * defaultMin,
+ maxRandomFactor: 2 * defaultMax,
+ want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := stack.DefaultNUDConfigurations()
+ c.BaseReachableTime = defaultBase
+ c.MinRandomFactor = defaultMin
+ c.MaxRandomFactor = defaultMax
+
+ // A fake random number generator is used to ensure deterministic
+ // results.
+ rng := fakeRand{
+ num: defaultFakeRandomNum,
+ }
+ s := stack.NewNUDState(c, &rng)
+ old := s.ReachableTime()
+
+ if got, want := s.ReachableTime(), old; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+
+ // Check for recomputation when changing the min random factor, the max
+ // random factor, the base reachability time, or any permutation of those
+ // three options.
+ c.BaseReachableTime = test.baseReachableTime
+ c.MinRandomFactor = test.minRandomFactor
+ c.MaxRandomFactor = test.maxRandomFactor
+ s.SetConfig(c)
+
+ if got, want := s.ReachableTime(), test.want; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+
+ // Verify that ReachableTime isn't recomputed when none of the
+ // configuration options change. The random factor is changed so that if
+ // a recompution were to occur, ReachableTime would change.
+ rng.num = defaultFakeRandomNum / 2.0
+ if got, want := s.ReachableTime(), test.want; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 926df4d7b..5d6865e35 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -14,6 +14,7 @@
package stack
import (
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -24,6 +25,8 @@ import (
// multiple endpoints. Clone() should be called in such cases so that
// modifications to the Data field do not affect other copies.
type PacketBuffer struct {
+ _ sync.NoCopy
+
// PacketBufferEntry is used to build an intrusive list of
// PacketBuffers.
PacketBufferEntry
@@ -76,13 +79,31 @@ type PacketBuffer struct {
// NatDone indicates if the packet has been manipulated as per NAT
// iptables rule.
NatDone bool
+
+ // PktType indicates the SockAddrLink.PacketType of the packet as defined in
+ // https://www.man7.org/linux/man-pages/man7/packet.7.html.
+ PktType tcpip.PacketType
}
// Clone makes a copy of pk. It clones the Data field, which creates a new
// VectorisedView but does not deep copy the underlying bytes.
//
// Clone also does not deep copy any of its other fields.
-func (pk PacketBuffer) Clone() PacketBuffer {
- pk.Data = pk.Data.Clone(nil)
- return pk
+//
+// FIXME(b/153685824): Data gets copied but not other header references.
+func (pk *PacketBuffer) Clone() *PacketBuffer {
+ return &PacketBuffer{
+ PacketBufferEntry: pk.PacketBufferEntry,
+ Data: pk.Data.Clone(nil),
+ Header: pk.Header,
+ LinkHeader: pk.LinkHeader,
+ NetworkHeader: pk.NetworkHeader,
+ TransportHeader: pk.TransportHeader,
+ Hash: pk.Hash,
+ Owner: pk.Owner,
+ EgressRoute: pk.EgressRoute,
+ GSOOptions: pk.GSOOptions,
+ NetworkProtocolNumber: pk.NetworkProtocolNumber,
+ NatDone: pk.NatDone,
+ }
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index b331427c6..8604c4259 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -51,8 +52,11 @@ type TransportEndpointID struct {
type ControlType int
// The following are the allowed values for ControlType values.
+// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages.
const (
- ControlPacketTooBig ControlType = iota
+ ControlNetworkUnreachable ControlType = iota
+ ControlNoRoute
+ ControlPacketTooBig
ControlPortUnreachable
ControlUnknown
)
@@ -67,12 +71,12 @@ type TransportEndpoint interface {
// this transport endpoint. It sets pkt.TransportHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer)
+ HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer)
// HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
// HandleControlPacket takes ownership of pkt.
- HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer)
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer)
// Abort initiates an expedited endpoint teardown. It puts the endpoint
// in a closed state and frees all resources associated with it. This
@@ -100,7 +104,7 @@ type RawTransportEndpoint interface {
// layer up.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt PacketBuffer)
+ HandlePacket(r *Route, pkt *PacketBuffer)
}
// PacketEndpoint is the interface that needs to be implemented by packet
@@ -118,7 +122,7 @@ type PacketEndpoint interface {
// should construct its own ethernet header for applications.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer)
+ HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// TransportProtocol is the interface that needs to be implemented by transport
@@ -150,7 +154,7 @@ type TransportProtocol interface {
// stats purposes only).
//
// HandleUnknownDestinationPacket takes ownership of pkt.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -168,6 +172,11 @@ type TransportProtocol interface {
// Wait waits for any worker goroutines owned by the protocol to stop.
Wait()
+
+ // Parse sets pkt.TransportHeader and trims pkt.Data appropriately. It does
+ // neither and returns false if pkt.Data is too small, i.e. pkt.Data.Size() <
+ // MinimumPacketSize()
+ Parse(pkt *PacketBuffer) (ok bool)
}
// TransportDispatcher contains the methods used by the network stack to deliver
@@ -180,7 +189,7 @@ type TransportDispatcher interface {
// pkt.NetworkHeader must be set before calling DeliverTransportPacket.
//
// DeliverTransportPacket takes ownership of pkt.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer)
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
@@ -189,7 +198,7 @@ type TransportDispatcher interface {
// DeliverTransportControlPacket.
//
// DeliverTransportControlPacket takes ownership of pkt.
- DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer)
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer)
}
// PacketLooping specifies where an outbound packet should be sent.
@@ -240,17 +249,18 @@ type NetworkEndpoint interface {
MaxHeaderLength() uint16
// WritePacket writes a packet to the given destination address and
- // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have
- // already been set.
- WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error
+ // protocol. It takes ownership of pkt. pkt.TransportHeader must have already
+ // been set.
+ WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error
// WritePackets writes packets to the given destination address and
- // protocol. pkts must not be zero length.
+ // protocol. pkts must not be zero length. It takes ownership of pkts and
+ // underlying packets.
WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
- // header to the given destination address.
- WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error
+ // header to the given destination address. It takes ownership of pkt.
+ WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
@@ -265,7 +275,7 @@ type NetworkEndpoint interface {
// this network endpoint. It sets pkt.NetworkHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt PacketBuffer)
+ HandlePacket(r *Route, pkt *PacketBuffer)
// Close is called when the endpoint is reomved from a stack.
Close()
@@ -312,11 +322,18 @@ type NetworkProtocol interface {
// Wait waits for any worker goroutines owned by the protocol to stop.
Wait()
+
+ // Parse sets pkt.NetworkHeader and trims pkt.Data appropriately. It
+ // returns:
+ // - The encapsulated protocol, if present.
+ // - Whether there is an encapsulated transport protocol payload (e.g. ARP
+ // does not encapsulate anything).
+ // - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader.
+ Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool)
}
// NetworkDispatcher contains the methods used by the network stack to deliver
-// packets to the appropriate network endpoint after it has been handled by
-// the data link layer.
+// inbound/outbound packets to the appropriate network/packet(if any) endpoints.
type NetworkDispatcher interface {
// DeliverNetworkPacket finds the appropriate network protocol endpoint
// and hands the packet over for further processing.
@@ -326,7 +343,17 @@ type NetworkDispatcher interface {
// packets sent via loopback), and won't have the field set.
//
// DeliverNetworkPacket takes ownership of pkt.
- DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer)
+ DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
+
+ // DeliverOutboundPacket is called by link layer when a packet is being
+ // sent out.
+ //
+ // pkt.LinkHeader may or may not be set before calling
+ // DeliverOutboundPacket. Some packets do not have link headers (e.g.
+ // packets sent via loopback), and won't have the field set.
+ //
+ // DeliverOutboundPacket takes ownership of pkt.
+ DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -382,17 +409,17 @@ type LinkEndpoint interface {
LinkAddress() tcpip.LinkAddress
// WritePacket writes a packet with the given protocol through the
- // given route. It sets pkt.LinkHeader if a link layer header exists.
- // pkt.NetworkHeader and pkt.TransportHeader must have already been
- // set.
+ // given route. It takes ownership of pkt. pkt.NetworkHeader and
+ // pkt.TransportHeader must have already been set.
//
// To participate in transparent bridging, a LinkEndpoint implementation
// should call eth.Encode with header.EthernetFields.SrcAddr set to
// r.LocalLinkAddress if it is provided.
- WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error
// WritePackets writes packets with the given protocol through the
- // given route. pkts must not be zero length.
+ // given route. pkts must not be zero length. It takes ownership of pkts and
+ // underlying packets.
//
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, it may
@@ -400,7 +427,7 @@ type LinkEndpoint interface {
WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
// WriteRawPacket writes a packet directly to the link. The packet
- // should already have an ethernet header.
+ // should already have an ethernet header. It takes ownership of vv.
WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error
// Attach attaches the data link layer endpoint to the network-layer
@@ -422,6 +449,15 @@ type LinkEndpoint interface {
// Wait will not block if the endpoint hasn't started any goroutines
// yet, even if it might later.
Wait()
+
+ // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint.
+ //
+ // See:
+ // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30
+ ARPHardwareType() header.ARPHardwareType
+
+ // AddHeader adds a link layer header to pkt if required.
+ AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -430,7 +466,7 @@ type InjectableLinkEndpoint interface {
LinkEndpoint
// InjectInbound injects an inbound packet.
- InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer)
+ InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
// InjectOutbound writes a fully formed outbound packet directly to the
// link.
@@ -442,12 +478,13 @@ type InjectableLinkEndpoint interface {
// A LinkAddressResolver is an extension to a NetworkProtocol that
// can resolve link addresses.
type LinkAddressResolver interface {
- // LinkAddressRequest sends a request for the LinkAddress of addr.
- // The request is sent on linkEP with localAddr as the source.
+ // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
+ // the request on the local network if remoteLinkAddr is the zero value. The
+ // request is sent on linkEP with localAddr as the source.
//
// A valid response will cause the discovery protocol's network
// endpoint to call AddLinkAddress.
- LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
+ LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 150297ab9..91e0110f1 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -48,6 +48,10 @@ type Route struct {
// Loop controls where WritePacket should send packets.
Loop PacketLooping
+
+ // directedBroadcast indicates whether this route is sending a directed
+ // broadcast packet.
+ directedBroadcast bool
}
// makeRoute initializes a new route. It takes ownership of the provided
@@ -113,6 +117,8 @@ func (r *Route) GSOMaxSize() uint32 {
// If address resolution is required, ErrNoLinkAddress and a notification channel is
// returned for the top level caller to block. Channel is closed once address resolution
// is complete (success or not).
+//
+// The NIC r uses must not be locked.
func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if !r.IsResolutionRequired() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
@@ -148,22 +154,27 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
// IsResolutionRequired returns true if Resolve() must be called to resolve
// the link address before the this route can be written to.
+//
+// The NIC r uses must not be locked.
func (r *Route) IsResolutionRequired() bool {
return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error {
+func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
+ // WritePacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Header.UsedLength() + pkt.Data.Size()
+
err := r.ref.ep.WritePacket(r, gso, params, pkt)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
} else {
r.ref.nic.stats.Tx.Packets.Increment()
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size()))
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
}
return err
}
@@ -175,9 +186,12 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead
return 0, tcpip.ErrInvalidEndpointState
}
+ // WritePackets takes ownership of pkt, calculate length first.
+ numPkts := pkts.Len()
+
n, err := r.ref.ep.WritePackets(r, gso, pkts, params)
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n))
}
r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n))
@@ -193,17 +207,20 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error {
+func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
+ // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Data.Size()
+
if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
r.ref.nic.stats.Tx.Packets.Increment()
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
return nil
}
@@ -262,6 +279,12 @@ func (r *Route) Stack() *Stack {
return r.ref.stack()
}
+// IsBroadcast returns true if the route is to send a broadcast packet.
+func (r *Route) IsBroadcast() bool {
+ // Only IPv4 has a notion of broadcast.
+ return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast
+}
+
// ReverseRoute returns new route with given source and destination address.
func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
return Route{
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 0ab4c3e19..3f07e4159 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -52,7 +52,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool
+ defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -424,12 +424,9 @@ type Stack struct {
// handleLocal allows non-loopback interfaces to loop packets.
handleLocal bool
- // tablesMu protects iptables.
- tablesMu sync.RWMutex
-
- // tables are the iptables packet filtering and manipulation rules. The are
- // protected by tablesMu.`
- tables IPTables
+ // tables are the iptables packet filtering and manipulation rules.
+ // TODO(gvisor.dev/issue/170): S/R this field.
+ tables *IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
// stack is being restored.
@@ -448,6 +445,9 @@ type Stack struct {
// ndpConfigs is the default NDP configurations used by interfaces.
ndpConfigs NDPConfigurations
+ // nudConfigs is the default NUD configurations used by interfaces.
+ nudConfigs NUDConfigurations
+
// autoGenIPv6LinkLocal determines whether or not the stack will attempt
// to auto-generate an IPv6 link-local address for newly enabled non-loopback
// NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
@@ -457,6 +457,10 @@ type Stack struct {
// integrator NDP related events.
ndpDisp NDPDispatcher
+ // nudDisp is the NUD event dispatcher that is used to send the netstack
+ // integrator NUD related events.
+ nudDisp NUDDispatcher
+
// uniqueIDGenerator is a generator of unique identifiers.
uniqueIDGenerator UniqueID
@@ -475,6 +479,14 @@ type Stack struct {
// randomGenerator is an injectable pseudo random generator that can be
// used when a random number is required.
randomGenerator *mathrand.Rand
+
+ // sendBufferSize holds the min/default/max send buffer sizes for
+ // endpoints other than TCP.
+ sendBufferSize SendBufferSizeOption
+
+ // receiveBufferSize holds the min/default/max receive buffer sizes for
+ // endpoints other than TCP.
+ receiveBufferSize ReceiveBufferSizeOption
}
// UniqueID is an abstract generator of unique identifiers.
@@ -513,6 +525,9 @@ type Options struct {
// before assigning an address to a NIC.
NDPConfigs NDPConfigurations
+ // NUDConfigs is the default NUD configurations used by interfaces.
+ NUDConfigs NUDConfigurations
+
// AutoGenIPv6LinkLocal determines whether or not the stack will attempt to
// auto-generate an IPv6 link-local address for newly enabled non-loopback
// NICs.
@@ -531,6 +546,10 @@ type Options struct {
// receive NDP related events.
NDPDisp NDPDispatcher
+ // NUDDisp is the NUD event dispatcher that an integrator can provide to
+ // receive NUD related events.
+ NUDDisp NUDDispatcher
+
// RawFactory produces raw endpoints. Raw endpoints are enabled only if
// this is non-nil.
RawFactory RawFactory
@@ -665,6 +684,8 @@ func New(opts Options) *Stack {
// Make sure opts.NDPConfigs contains valid values only.
opts.NDPConfigs.validate()
+ opts.NUDConfigs.resetInvalidFields()
+
s := &Stack{
transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
@@ -676,16 +697,29 @@ func New(opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
+ tables: DefaultTables(),
icmpRateLimiter: NewICMPRateLimiter(),
seed: generateRandUint32(),
ndpConfigs: opts.NDPConfigs,
+ nudConfigs: opts.NUDConfigs,
autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
uniqueIDGenerator: opts.UniqueID,
ndpDisp: opts.NDPDisp,
+ nudDisp: opts.NUDDisp,
opaqueIIDOpts: opts.OpaqueIIDOpts,
tempIIDSeed: opts.TempIIDSeed,
forwarder: newForwardQueue(),
randomGenerator: mathrand.New(randSrc),
+ sendBufferSize: SendBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultBufferSize,
+ Max: DefaultMaxBufferSize,
+ },
+ receiveBufferSize: ReceiveBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultBufferSize,
+ Max: DefaultMaxBufferSize,
+ },
}
// Add specified network protocols.
@@ -712,6 +746,11 @@ func New(opts Options) *Stack {
return s
}
+// newJob returns a tcpip.Job using the Stack clock.
+func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job {
+ return tcpip.NewJob(s.clock, l, f)
+}
+
// UniqueID returns a unique identifier.
func (s *Stack) UniqueID() uint64 {
return s.uniqueIDGenerator.UniqueID()
@@ -778,16 +817,17 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
}
}
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (s *Stack) NowNanoseconds() int64 {
- return s.clock.NowNanoseconds()
+// Clock returns the Stack's clock for retrieving the current time and
+// scheduling work.
+func (s *Stack) Clock() tcpip.Clock {
+ return s.clock
}
// Stats returns a mutable copy of the current stats.
@@ -1020,6 +1060,13 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
+ return s.removeNICLocked(id)
+}
+
+// removeNICLocked removes NIC and all related routes from the network stack.
+//
+// s.mu must be locked.
+func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error {
nic, ok := s.nics[id]
if !ok {
return tcpip.ErrUnknownNICID
@@ -1029,14 +1076,14 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
// Remove routes in-place. n tracks the number of routes written.
n := 0
for i, r := range s.routeTable {
+ s.routeTable[i] = tcpip.Route{}
if r.NIC != id {
// Keep this route.
- if i > n {
- s.routeTable[n] = r
- }
+ s.routeTable[n] = r
n++
}
}
+
s.routeTable = s.routeTable[:n]
return nic.remove()
@@ -1072,6 +1119,11 @@ type NICInfo struct {
// Context is user-supplied data optionally supplied in CreateNICWithOptions.
// See type NICOptions for more details.
Context NICContext
+
+ // ARPHardwareType holds the ARP Hardware type of the NIC. This is the
+ // value sent in haType field of an ARP Request sent by this NIC and the
+ // value expected in the haType field of an ARP response.
+ ARPHardwareType header.ARPHardwareType
}
// HasNIC returns true if the NICID is defined in the stack.
@@ -1103,6 +1155,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
Context: nic.context,
+ ARPHardwareType: nic.linkEP.ARPHardwareType(),
}
}
return nics
@@ -1249,9 +1302,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
s.mu.RLock()
defer s.mu.RUnlock()
- isBroadcast := remoteAddr == header.IPv4Broadcast
+ isLocalBroadcast := remoteAddr == header.IPv4Broadcast
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
- needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
+ needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok && nic.enabled() {
if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
@@ -1272,9 +1325,16 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback())
- if needRoute {
- r.NextHop = route.Gateway
+ r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr)
+
+ if len(route.Gateway) > 0 {
+ if needRoute {
+ r.NextHop = route.Gateway
+ }
+ } else if r.directedBroadcast {
+ r.RemoteLinkAddress = header.EthernetBroadcastAddress
}
+
return r, nil
}
}
@@ -1400,25 +1460,31 @@ func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
- return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice)
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
+}
+
+// CheckRegisterTransportEndpoint checks if an endpoint can be registered with
+// the stack transport dispatcher.
+func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
- s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
// StartTransportEndpointCleanup removes the endpoint with the given id from
// the stack transport dispatcher. It also transitions it to the cleanup stage.
-func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
s.mu.Lock()
defer s.mu.Unlock()
s.cleanupEndpoints[ep] = struct{}{}
- s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
// CompleteTransportEndpointCleanup removes the endpoint from the cleanup
@@ -1741,18 +1807,8 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool,
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() IPTables {
- s.tablesMu.RLock()
- t := s.tables
- s.tablesMu.RUnlock()
- return t
-}
-
-// SetIPTables sets the stack's iptables.
-func (s *Stack) SetIPTables(ipt IPTables) {
- s.tablesMu.Lock()
- s.tables = ipt
- s.tablesMu.Unlock()
+func (s *Stack) IPTables() *IPTables {
+ return s.tables
}
// ICMPLimit returns the maximum number of ICMP messages that can be sent
@@ -1831,10 +1887,38 @@ func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip
}
nic.setNDPConfigs(c)
-
return nil
}
+// NUDConfigurations gets the per-interface NUD configurations.
+func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) {
+ s.mu.RLock()
+ nic, ok := s.nics[id]
+ s.mu.RUnlock()
+
+ if !ok {
+ return NUDConfigurations{}, tcpip.ErrUnknownNICID
+ }
+
+ return nic.NUDConfigs()
+}
+
+// SetNUDConfigurations sets the per-interface NUD configurations.
+//
+// Note, if c contains invalid NUD configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[id]
+ s.mu.RUnlock()
+
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.setNUDConfigs(c)
+}
+
// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement
// message that it needs to handle.
func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error {
diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go
new file mode 100644
index 000000000..0b093e6c5
--- /dev/null
+++ b/pkg/tcpip/stack/stack_options.go
@@ -0,0 +1,106 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+const (
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ MinBufferSize = 4 << 10 // 4 KiB
+
+ // DefaultBufferSize is the default size of the send/recv buffer for a
+ // transport endpoint.
+ DefaultBufferSize = 212 << 10 // 212 KiB
+
+ // DefaultMaxBufferSize is the default maximum permitted size of a
+ // send/receive buffer.
+ DefaultMaxBufferSize = 4 << 20 // 4 MiB
+)
+
+// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
+// get/set the default, min and max send buffer sizes.
+type SendBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to
+// get/set the default, min and max receive buffer sizes.
+type ReceiveBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// SetOption allows setting stack wide options.
+func (s *Stack) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case SendBufferSizeOption:
+ // Make sure we don't allow lowering the buffer below minimum
+ // required for stack to work.
+ if v.Min < MinBufferSize {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ if v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ s.mu.Lock()
+ s.sendBufferSize = v
+ s.mu.Unlock()
+ return nil
+
+ case ReceiveBufferSizeOption:
+ // Make sure we don't allow lowering the buffer below minimum
+ // required for stack to work.
+ if v.Min < MinBufferSize {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ if v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ s.mu.Lock()
+ s.receiveBufferSize = v
+ s.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option allows retrieving stack wide options.
+func (s *Stack) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *SendBufferSizeOption:
+ s.mu.RLock()
+ *v = s.sendBufferSize
+ s.mu.RUnlock()
+ return nil
+
+ case *ReceiveBufferSizeOption:
+ s.mu.RLock()
+ *v = s.receiveBufferSize
+ s.mu.RUnlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 1a2cf007c..f22062889 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -27,6 +27,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -52,6 +53,10 @@ const (
// where another value is explicitly used. It is chosen to match the MTU
// of loopback interfaces on linux systems.
defaultMTU = 65536
+
+ dstAddrOffset = 0
+ srcAddrOffset = 1
+ protocolNumberOffset = 2
)
// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
@@ -90,30 +95,28 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
return &f.id
}
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
+func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
- // Consume the network header.
- b, ok := pkt.Data.PullUp(fakeNetHeaderLen)
- if !ok {
- return
- }
- pkt.Data.TrimFront(fakeNetHeaderLen)
-
// Handle control packets.
- if b[2] == uint8(fakeControlProtocol) {
+ if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) {
nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
}
pkt.Data.TrimFront(fakeNetHeaderLen)
- f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt)
+ f.dispatcher.DeliverTransportControlPacket(
+ tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
+ tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
+ fakeNetNumber,
+ tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
+ stack.ControlPortUnreachable, 0, pkt)
return
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -132,24 +135,19 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe
return f.proto.Number()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
// Add the protocol's header to the packet and send it to the link
// endpoint.
- b := pkt.Header.Prepend(fakeNetHeaderLen)
- b[0] = r.RemoteAddress[0]
- b[1] = f.id.LocalAddress[0]
- b[2] = byte(params.Protocol)
+ pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen)
+ pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0]
+ pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0]
+ pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
- f.HandlePacket(r, stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
- })
+ f.HandlePacket(r, pkt)
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -163,7 +161,7 @@ func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts
panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -205,7 +203,7 @@ func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
}
func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
+ return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
@@ -247,6 +245,17 @@ func (*fakeNetworkProtocol) Close() {}
// Wait implements TransportProtocol.Wait.
func (*fakeNetworkProtocol) Wait() {}
+// Parse implements TransportProtocol.Parse.
+func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen)
+ if !ok {
+ return 0, false, false
+ }
+ pkt.NetworkHeader = hdr
+ pkt.Data.TrimFront(fakeNetHeaderLen)
+ return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
+}
+
func fakeNetFactory() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
@@ -292,8 +301,8 @@ func TestNetworkReceive(t *testing.T) {
buf := buffer.NewView(30)
// Make sure packet with wrong address is not delivered.
- buf[0] = 3
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 0 {
@@ -304,8 +313,8 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is delivered to first endpoint.
- buf[0] = 1
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ buf[dstAddrOffset] = 1
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -316,8 +325,8 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is delivered to second endpoint.
- buf[0] = 2
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ buf[dstAddrOffset] = 2
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -328,7 +337,7 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -340,7 +349,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -362,7 +371,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
func send(r stack.Route, payload buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
})
@@ -420,7 +429,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b
func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
t.Helper()
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if got := fakeNet.PacketCount(localAddrByte); got != want {
@@ -859,9 +868,9 @@ func TestRouteWithDownNIC(t *testing.T) {
// Writes with Routes that use NIC1 after being brought up should
// succeed.
//
- // TODO(b/147015577): Should we instead completely invalidate all
- // Routes that were bound to a NIC that was brought down at some
- // point?
+ // TODO(gvisor.dev/issue/1491): Should we instead completely
+ // invalidate all Routes that were bound to a NIC that was brought
+ // down at some point?
if err := upFn(s, nicID1); err != nil {
t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
}
@@ -982,7 +991,7 @@ func TestAddressRemoval(t *testing.T) {
buf := buffer.NewView(30)
// Send and receive packets, and verify they are received.
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -1032,7 +1041,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
}
// Send and receive packets, and verify they are received.
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSend(t, r, ep, nil)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -1114,7 +1123,7 @@ func TestEndpointExpiration(t *testing.T) {
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
if promiscuous {
if err := s.SetPromiscuousMode(nicID, true); err != nil {
@@ -1277,7 +1286,7 @@ func TestPromiscuousMode(t *testing.T) {
// Write a packet, and check that it doesn't get delivered as we don't
// have a matching endpoint.
const localAddrByte byte = 0x01
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
// Set promiscuous mode, then check that packet is delivered.
@@ -1658,7 +1667,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
const localAddrByte byte = 0x01
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -1766,7 +1775,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
const localAddrByte byte = 0x01
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -2263,7 +2272,7 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
@@ -2344,8 +2353,8 @@ func TestNICForwarding(t *testing.T) {
// Send a packet to dstAddr.
buf := buffer.NewView(30)
- buf[0] = dstAddr[0]
- ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ buf[dstAddrOffset] = dstAddr[0]
+ ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -3297,7 +3306,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
// Wait for DAD to resolve.
select {
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" {
@@ -3330,3 +3339,305 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
}
}
+
+func TestStackReceiveBufferSizeOption(t *testing.T) {
+ const sMin = stack.MinBufferSize
+ testCases := []struct {
+ name string
+ rs stack.ReceiveBufferSizeOption
+ err *tcpip.Error
+ }{
+ // Invalid configurations.
+ {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+
+ // Valid Configurations
+ {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ defer s.Close()
+ if err := s.SetOption(tc.rs); err != tc.err {
+ t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err)
+ }
+ var rs stack.ReceiveBufferSizeOption
+ if tc.err == nil {
+ if err := s.Option(&rs); err != nil {
+ t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err)
+ }
+ if got, want := rs, tc.rs; got != want {
+ t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want)
+ }
+ }
+ })
+ }
+}
+
+func TestStackSendBufferSizeOption(t *testing.T) {
+ const sMin = stack.MinBufferSize
+ testCases := []struct {
+ name string
+ ss stack.SendBufferSizeOption
+ err *tcpip.Error
+ }{
+ // Invalid configurations.
+ {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+
+ // Valid Configurations
+ {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ defer s.Close()
+ if err := s.SetOption(tc.ss); err != tc.err {
+ t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err)
+ }
+ var ss stack.SendBufferSizeOption
+ if tc.err == nil {
+ if err := s.Option(&ss); err != nil {
+ t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err)
+ }
+ if got, want := ss, tc.ss; got != want {
+ t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want)
+ }
+ }
+ })
+ }
+}
+
+func TestOutgoingSubnetBroadcast(t *testing.T) {
+ const (
+ unspecifiedNICID = 0
+ nicID1 = 1
+ )
+
+ defaultAddr := tcpip.AddressWithPrefix{
+ Address: header.IPv4Any,
+ PrefixLen: 0,
+ }
+ defaultSubnet := defaultAddr.Subnet()
+ ipv4Addr := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 24,
+ }
+ ipv4Subnet := ipv4Addr.Subnet()
+ ipv4SubnetBcast := ipv4Subnet.Broadcast()
+ ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 31,
+ }
+ ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
+ ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
+ ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 32,
+ }
+ ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
+ ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
+ ipv6Addr := tcpip.AddressWithPrefix{
+ Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ PrefixLen: 64,
+ }
+ ipv6Subnet := ipv6Addr.Subnet()
+ ipv6SubnetBcast := ipv6Subnet.Broadcast()
+ remNetAddr := tcpip.AddressWithPrefix{
+ Address: "\x64\x0a\x7b\x18",
+ PrefixLen: 24,
+ }
+ remNetSubnet := remNetAddr.Subnet()
+ remNetSubnetBcast := remNetSubnet.Broadcast()
+
+ tests := []struct {
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ expectedRoute stack.Route
+ }{
+ // Broadcast to a locally attached subnet populates the broadcast MAC.
+ {
+ name: "IPv4 Broadcast to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4SubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4Addr.Address,
+ RemoteAddress: ipv4SubnetBcast,
+ RemoteLinkAddress: header.EthernetBroadcastAddress,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to a locally attached /31 subnet does not populate the
+ // broadcast MAC.
+ {
+ name: "IPv4 Broadcast to local /31 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix31,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet31,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet31Bcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4AddrPrefix31.Address,
+ RemoteAddress: ipv4Subnet31Bcast,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to a locally attached /32 subnet does not populate the
+ // broadcast MAC.
+ {
+ name: "IPv4 Broadcast to local /32 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix32,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet32,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet32Bcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4AddrPrefix32.Address,
+ RemoteAddress: ipv4Subnet32Bcast,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // IPv6 has no notion of a broadcast.
+ {
+ name: "IPv6 'Broadcast' to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: ipv6Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv6Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv6SubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv6Addr.Address,
+ RemoteAddress: ipv6SubnetBcast,
+ NetProto: header.IPv6ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to a remote subnet in the route table is send to the next-hop
+ // gateway.
+ {
+ name: "IPv4 Broadcast to remote subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: remNetSubnet,
+ Gateway: ipv4Gateway,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: remNetSubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4Addr.Address,
+ RemoteAddress: remNetSubnetBcast,
+ NextHop: ipv4Gateway,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to an unknown subnet follows the default route. Note that this
+ // is essentially just routing an unknown destination IP, because w/o any
+ // subnet prefix information a subnet broadcast address is just a normal IP.
+ {
+ name: "IPv4 Broadcast to unknown subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: defaultSubnet,
+ Gateway: ipv4Gateway,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: remNetSubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4Addr.Address,
+ RemoteAddress: remNetSubnetBcast,
+ NextHop: ipv4Gateway,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ })
+ ep := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ }
+
+ s.SetRouteTable(test.routes)
+
+ var netProto tcpip.NetworkProtocolNumber
+ switch l := len(test.remoteAddr); l {
+ case header.IPv4AddressSize:
+ netProto = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ netProto = header.IPv6ProtocolNumber
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil {
+ t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err)
+ } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" {
+ t.Errorf("route mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 9a33ed375..b902c6ca9 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -15,7 +15,6 @@
package stack
import (
- "container/heap"
"fmt"
"math/rand"
@@ -23,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
)
type protocolIDs struct {
@@ -43,14 +43,14 @@ type transportEndpoints struct {
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
eps.mu.Lock()
defer eps.mu.Unlock()
epsByNIC, ok := eps.endpoints[id]
if !ok {
return
}
- if !epsByNIC.unregisterEndpoint(bindToDevice, ep) {
+ if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) {
return
}
delete(eps.endpoints, id)
@@ -152,7 +152,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) {
+func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
@@ -183,7 +183,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) {
+func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
@@ -204,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint
// registerEndpoint returns true if it succeeds. It fails and returns
// false if ep already has an element with the same key.
-func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
epsByNIC.mu.Lock()
defer epsByNIC.mu.Unlock()
@@ -214,23 +214,34 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t
demux: d,
netProto: netProto,
transProto: transProto,
- reuse: reusePort,
}
epsByNIC.endpoints[bindToDevice] = multiPortEp
}
- return multiPortEp.singleRegisterEndpoint(t, reusePort)
+ return multiPortEp.singleRegisterEndpoint(t, flags)
+}
+
+func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNIC.mu.RLock()
+ defer epsByNIC.mu.RUnlock()
+
+ multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
+ if !ok {
+ return nil
+ }
+
+ return multiPortEp.singleCheckEndpoint(flags)
}
// unregisterEndpoint returns true if endpointsByNIC has to be unregistered.
-func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
+func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool {
epsByNIC.mu.Lock()
defer epsByNIC.mu.Unlock()
multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
if !ok {
return false
}
- if multiPortEp.unregisterEndpoint(t) {
+ if multiPortEp.unregisterEndpoint(t, flags) {
delete(epsByNIC.endpoints, bindToDevice)
}
return len(epsByNIC.endpoints) == 0
@@ -251,7 +262,7 @@ type transportDemuxer struct {
// the dispatcher to delivery packets to the QueuePacket method instead of
// calling HandlePacket directly on the endpoint.
type queuedTransportProtocol interface {
- QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer)
+ QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
@@ -279,10 +290,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
for i, n := range netProtos {
- if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil {
- d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice)
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice)
return err
}
}
@@ -290,33 +301,15 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
return nil
}
-type transportEndpointHeap []TransportEndpoint
-
-var _ heap.Interface = (*transportEndpointHeap)(nil)
-
-func (h *transportEndpointHeap) Len() int {
- return len(*h)
-}
-
-func (h *transportEndpointHeap) Less(i, j int) bool {
- return (*h)[i].UniqueID() < (*h)[j].UniqueID()
-}
-
-func (h *transportEndpointHeap) Swap(i, j int) {
- (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
-}
-
-func (h *transportEndpointHeap) Push(x interface{}) {
- *h = append(*h, x.(TransportEndpoint))
-}
+// checkEndpoint checks if an endpoint can be registered with the dispatcher.
+func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ for _, n := range netProtos {
+ if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil {
+ return err
+ }
+ }
-func (h *transportEndpointHeap) Pop() interface{} {
- old := *h
- n := len(old)
- x := old[n-1]
- old[n-1] = nil
- *h = old[:n-1]
- return x
+ return nil
}
// multiPortEndpoint is a container for TransportEndpoints which are bound to
@@ -334,9 +327,10 @@ type multiPortEndpoint struct {
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
- endpoints transportEndpointHeap
- // reuse indicates if more than one endpoint is allowed.
- reuse bool
+ // endpoints stores the transport endpoints in the order in which they
+ // were bound. This is required for UDP SO_REUSEADDR.
+ endpoints []TransportEndpoint
+ flags ports.FlagCounter
}
func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
@@ -362,6 +356,10 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[0]
}
+ if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent {
+ return mpep.endpoints[len(mpep.endpoints)-1]
+ }
+
payload := []byte{
byte(id.LocalPort),
byte(id.LocalPort >> 8),
@@ -379,7 +377,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[idx]
}
-func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) {
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
ep.mu.RLock()
queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
// HandlePacket takes ownership of pkt, so each endpoint needs
@@ -401,40 +399,63 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
-func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
+ bits := flags.Bits() & ports.MultiBindFlagMask
+
if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
- if !ep.reuse || !reusePort {
+ if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
return tcpip.ErrPortInUse
}
}
- heap.Push(&ep.endpoints, t)
+ ep.endpoints = append(ep.endpoints, t)
+ ep.flags.AddRef(bits)
+
+ return nil
+}
+
+func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ bits := flags.Bits() & ports.MultiBindFlagMask
+
+ if len(ep.endpoints) != 0 {
+ // If it was previously bound, we need to check if we can bind again.
+ if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
+ return tcpip.ErrPortInUse
+ }
+ }
return nil
}
// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
-func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
+func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool {
ep.mu.Lock()
defer ep.mu.Unlock()
for i, endpoint := range ep.endpoints {
if endpoint == t {
- heap.Remove(&ep.endpoints, i)
+ copy(ep.endpoints[i:], ep.endpoints[i+1:])
+ ep.endpoints[len(ep.endpoints)-1] = nil
+ ep.endpoints = ep.endpoints[:len(ep.endpoints)-1]
+
+ ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask)
break
}
}
return len(ep.endpoints) == 0
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
if id.RemotePort != 0 {
- // TODO(eyalsoha): Why?
- reusePort = false
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
}
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
@@ -454,15 +475,42 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
eps.endpoints[id] = epsByNIC
}
- return epsByNIC.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
+ return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice)
+}
+
+func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ if id.RemotePort != 0 {
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
+ }
+
+ eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+
+ eps.mu.RLock()
+ defer eps.mu.RUnlock()
+
+ epsByNIC, ok := eps.endpoints[id]
+ if !ok {
+ return nil
+ }
+
+ return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+ if id.RemotePort != 0 {
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
+ }
+
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.unregisterEndpoint(id, ep, bindToDevice)
+ eps.unregisterEndpoint(id, ep, flags, bindToDevice)
}
}
}
@@ -470,7 +518,7 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if
// the packet no longer needs to be handled.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -520,7 +568,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool {
+func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -544,7 +592,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 2474a7db3..73dada928 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -127,7 +128,7 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
NetworkHeader: buffer.View(ip),
TransportHeader: buffer.View(u),
@@ -165,7 +166,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
NetworkHeader: buffer.View(ip),
TransportHeader: buffer.View(u),
@@ -195,7 +196,7 @@ func TestTransportDemuxerRegister(t *testing.T) {
if !ok {
t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
}
- if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want {
+ if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
}
})
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index a611e44ab..7e8b84867 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -83,12 +84,13 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return 0, nil, tcpip.ErrNoRoute
}
- hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
+ hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen)
+ hdr.Prepend(fakeTransHeaderLen)
v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
- if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
Data: buffer.View(v).ToVectorisedView(),
}); err != nil {
@@ -153,7 +155,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */)
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -198,8 +200,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
fakeTransNumber,
stack.TransportEndpointID{LocalAddress: a.Addr},
f,
- false, /* reuse */
- 0, /* bindtoDevice */
+ ports.Flags{},
+ 0, /* bindtoDevice */
); err != nil {
return err
}
@@ -215,7 +217,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) {
// Increment the number of received packets.
f.proto.packetCount++
if f.acceptQueue != nil {
@@ -232,7 +234,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE
}
}
-func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
// Increment the number of received control packets.
f.proto.controlCount++
}
@@ -289,7 +291,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
return true
}
@@ -324,6 +326,17 @@ func (*fakeTransportProtocol) Close() {}
// Wait implements TransportProtocol.Wait.
func (*fakeTransportProtocol) Wait() {}
+// Parse implements TransportProtocol.Parse.
+func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool {
+ hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen)
+ if !ok {
+ return false
+ }
+ pkt.TransportHeader = hdr
+ pkt.Data.TrimFront(fakeTransHeaderLen)
+ return true
+}
+
func fakeTransFactory() stack.TransportProtocol {
return &fakeTransportProtocol{}
}
@@ -369,7 +382,7 @@ func TestTransportReceive(t *testing.T) {
// Make sure packet with wrong protocol is not delivered.
buf[0] = 1
buf[2] = 0
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 0 {
@@ -380,7 +393,7 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 3
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 0 {
@@ -391,7 +404,7 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 2
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 1 {
@@ -446,7 +459,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 0
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = 0
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 0 {
@@ -457,7 +470,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 3
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 0 {
@@ -468,7 +481,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 2
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 1 {
@@ -623,7 +636,7 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: req.ToVectorisedView(),
})
@@ -642,11 +655,10 @@ func TestTransportForwarding(t *testing.T) {
t.Fatal("Response packet not forwarded")
}
- hdrs := p.Pkt.Data.ToView()
- if dst := hdrs[0]; dst != 3 {
+ if dst := p.Pkt.NetworkHeader[0]; dst != 3 {
t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
}
- if src := hdrs[1]; src != 1 {
+ if src := p.Pkt.NetworkHeader[1]; src != 1 {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index b7b227328..45f59b60f 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -43,6 +43,9 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// Using header.IPv4AddressSize would cause an import cycle.
+const ipv4AddressSize = 4
+
// Error represents an error in the netstack error space. Using a special type
// ensures that errors outside of this space are not accidentally introduced.
//
@@ -192,7 +195,7 @@ func (e ErrSaveRejection) Error() string {
return "save rejected due to unsupported networking state: " + e.Err.Error()
}
-// A Clock provides the current time.
+// A Clock provides the current time and schedules work for execution.
//
// Times returned by a Clock should always be used for application-visible
// time. Only monotonic times should be used for netstack internal timekeeping.
@@ -203,6 +206,31 @@ type Clock interface {
// NowMonotonic returns a monotonic time value.
NowMonotonic() int64
+
+ // AfterFunc waits for the duration to elapse and then calls f in its own
+ // goroutine. It returns a Timer that can be used to cancel the call using
+ // its Stop method.
+ AfterFunc(d time.Duration, f func()) Timer
+}
+
+// Timer represents a single event. A Timer must be created with
+// Clock.AfterFunc.
+type Timer interface {
+ // Stop prevents the Timer from firing. It returns true if the call stops the
+ // timer, false if the timer has already expired or been stopped.
+ //
+ // If Stop returns false, then the timer has already expired and the function
+ // f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop
+ // does not wait for f to complete before returning. If the caller needs to
+ // know whether f is completed, it must coordinate with f explicitly.
+ Stop() bool
+
+ // Reset changes the timer to expire after duration d.
+ //
+ // Reset should be invoked only on stopped or expired timers. If the timer is
+ // known to have expired, Reset can be used directly. Otherwise, the caller
+ // must coordinate with the function f of Clock.AfterFunc(d, f).
+ Reset(d time.Duration)
}
// Address is a byte slice cast as a string that represents the address of a
@@ -295,6 +323,29 @@ func (s *Subnet) Broadcast() Address {
return Address(addr)
}
+// IsBroadcast returns true if the address is considered a broadcast address.
+func (s *Subnet) IsBroadcast(address Address) bool {
+ // Only IPv4 supports the notion of a broadcast address.
+ if len(address) != ipv4AddressSize {
+ return false
+ }
+
+ // Normally, we would just compare address with the subnet's broadcast
+ // address but there is an exception where a simple comparison is not
+ // correct. This exception is for /31 and /32 IPv4 subnets where all
+ // addresses are considered valid host addresses.
+ //
+ // For /31 subnets, the case is easy. RFC 3021 Section 2.1 states that
+ // both addresses in a /31 subnet "MUST be interpreted as host addresses."
+ //
+ // For /32, the case is a bit more vague. RFC 3021 makes no mention of /32
+ // subnets. However, the same reasoning applies - if an exception is not
+ // made, then there do not exist any host addresses in a /32 subnet. RFC
+ // 4632 Section 3.1 also vaguely implies this interpretation by referring
+ // to addresses in /32 subnets as "host routes."
+ return s.Prefix() <= 30 && s.Broadcast() == address
+}
+
// Equal returns true if s equals o.
//
// Needed to use cmp.Equal on Subnet as its fields are unexported.
@@ -316,6 +367,28 @@ const (
ShutdownWrite
)
+// PacketType is used to indicate the destination of the packet.
+type PacketType uint8
+
+const (
+ // PacketHost indicates a packet addressed to the local host.
+ PacketHost PacketType = iota
+
+ // PacketOtherHost indicates an outgoing packet addressed to
+ // another host caught by a NIC in promiscuous mode.
+ PacketOtherHost
+
+ // PacketOutgoing for a packet originating from the local host
+ // that is looped back to a packet socket.
+ PacketOutgoing
+
+ // PacketBroadcast indicates a link layer broadcast packet.
+ PacketBroadcast
+
+ // PacketMulticast indicates a link layer multicast packet.
+ PacketMulticast
+)
+
// FullAddress represents a full transport node address, as required by the
// Connect() and Bind() methods.
//
@@ -549,6 +622,28 @@ type Endpoint interface {
SetOwner(owner PacketOwner)
}
+// LinkPacketInfo holds Link layer information for a received packet.
+//
+// +stateify savable
+type LinkPacketInfo struct {
+ // Protocol is the NetworkProtocolNumber for the packet.
+ Protocol NetworkProtocolNumber
+
+ // PktType is used to indicate the destination of the packet.
+ PktType PacketType
+}
+
+// PacketEndpoint are additional methods that are only implemented by Packet
+// endpoints.
+type PacketEndpoint interface {
+ // ReadPacket reads a datagram/packet from the endpoint and optionally
+ // returns the sender and additional LinkPacketInfo.
+ //
+ // This method does not block if there is no data pending. It will also
+ // either return an error or data, never both.
+ ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error)
+}
+
// EndpointInfo is the interface implemented by each endpoint info struct.
type EndpointInfo interface {
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
@@ -585,85 +680,108 @@ type WriteOptions struct {
type SockOptBool int
const (
- // BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether
- // datagram sockets are allowed to send packets to a broadcast address.
+ // BroadcastOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether datagram sockets are allowed to send packets to a broadcast
+ // address.
BroadcastOption SockOptBool = iota
- // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be
- // held until segments are full by the TCP transport protocol.
+ // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if
+ // data should be held until segments are full by the TCP transport
+ // protocol.
CorkOption
- // DelayOption is used by SetSockOpt/GetSockOpt to specify if data
- // should be sent out immediately by the transport protocol. For TCP,
- // it determines if the Nagle algorithm is on or off.
+ // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if
+ // data should be sent out immediately by the transport protocol. For
+ // TCP, it determines if the Nagle algorithm is on or off.
DelayOption
- // KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether
- // TCP keepalive is enabled for this socket.
+ // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to
+ // specify whether TCP keepalive is enabled for this socket.
KeepaliveEnabledOption
- // MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether
- // multicast packets sent over a non-loopback interface will be looped back.
+ // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to
+ // specify whether multicast packets sent over a non-loopback interface
+ // will be looped back.
MulticastLoopOption
- // PasscredOption is used by SetSockOpt/GetSockOpt to specify whether
- // SCM_CREDENTIALS socket control messages are enabled.
+ // NoChecksumOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether UDP checksum is disabled for this socket.
+ NoChecksumOption
+
+ // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether SCM_CREDENTIALS socket control messages are enabled.
//
// Only supported on Unix sockets.
PasscredOption
- // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
+ // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool.
QuickAckOption
- // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the
- // IPV6_TCLASS ancillary message is passed with incoming packets.
+ // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to
+ // specify if the IPV6_TCLASS ancillary message is passed with incoming
+ // packets.
ReceiveTClassOption
- // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS
- // ancillary message is passed with incoming packets.
+ // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify
+ // if the TOS ancillary message is passed with incoming packets.
ReceiveTOSOption
- // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify
- // if more inforamtion is provided with incoming packets such
- // as interface index and address.
+ // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to
+ // specify if more inforamtion is provided with incoming packets such as
+ // interface index and address.
ReceiveIPPacketInfoOption
- // ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind()
- // should allow reuse of local address.
+ // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to
+ // specify whether Bind() should allow reuse of local address.
ReuseAddressOption
- // ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets
- // to be bound to an identical socket address.
+ // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit
+ // multiple sockets to be bound to an identical socket address.
ReusePortOption
- // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
- // socket is to be restricted to sending and receiving IPv6 packets only.
+ // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether an IPv6 socket is to be restricted to sending and receiving
+ // IPv6 packets only.
V6OnlyOption
+
+ // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw
+ // endpoint that all packets being written have an IP header and the
+ // endpoint should not attach an IP header.
+ IPHdrIncludedOption
)
// SockOptInt represents socket options which values have the int type.
type SockOptInt int
const (
- // KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number
- // of un-ACKed TCP keepalives that will be sent before the connection is
- // closed.
+ // KeepaliveCountOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the number of un-ACKed TCP keepalives that will be sent
+ // before the connection is closed.
KeepaliveCountOption SockOptInt = iota
- // IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS
+ // IPv4TOSOption is used by SetSockOptInt/GetSockOptInt to specify TOS
// for all subsequent outgoing IPv4 packets from the endpoint.
IPv4TOSOption
- // IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS
- // for all subsequent outgoing IPv6 packets from the endpoint.
+ // IPv6TrafficClassOption is used by SetSockOptInt/GetSockOptInt to
+ // specify TOS for all subsequent outgoing IPv6 packets from the
+ // endpoint.
IPv6TrafficClassOption
- // MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current
- // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option.
+ // MaxSegOption is used by SetSockOptInt/GetSockOptInt to set/get the
+ // current Maximum Segment Size(MSS) value as specified using the
+ // TCP_MAXSEG option.
MaxSegOption
- // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
- // TTL value for multicast messages. The default is 1.
+ // MTUDiscoverOption is used to set/get the path MTU discovery setting.
+ //
+ // NOTE: Setting this option to any other value than PMTUDiscoveryDont
+ // is not supported and will fail as such, and getting this option will
+ // always return PMTUDiscoveryDont.
+ MTUDiscoverOption
+
+ // MulticastTTLOption is used by SetSockOptInt/GetSockOptInt to control
+ // the default TTL value for multicast messages. The default is 1.
MulticastTTLOption
// ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
@@ -682,26 +800,45 @@ const (
// number of unread bytes in the output buffer should be returned.
SendQueueSizeOption
- // TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop
- // limit value for unicast messages. The default is protocol specific.
+ // TTLOption is used by SetSockOptInt/GetSockOptInt to control the
+ // default TTL/hop limit value for unicast messages. The default is
+ // protocol specific.
//
// A zero value indicates the default.
TTLOption
- // TCPSynCountOption is used by SetSockOpt/GetSockOpt to specify the number of
- // SYN retransmits that TCP should send before aborting the attempt to
- // connect. It cannot exceed 255.
+ // TCPSynCountOption is used by SetSockOptInt/GetSockOptInt to specify
+ // the number of SYN retransmits that TCP should send before aborting
+ // the attempt to connect. It cannot exceed 255.
//
// NOTE: This option is currently only stubbed out and is no-op.
TCPSynCountOption
- // TCPWindowClampOption is used by SetSockOpt/GetSockOpt to bound the size
- // of the advertised window to this value.
+ // TCPWindowClampOption is used by SetSockOptInt/GetSockOptInt to bound
+ // the size of the advertised window to this value.
//
// NOTE: This option is currently only stubed out and is a no-op
TCPWindowClampOption
)
+const (
+ // PMTUDiscoveryWant is a setting of the MTUDiscoverOption to use
+ // per-route settings.
+ PMTUDiscoveryWant int = iota
+
+ // PMTUDiscoveryDont is a setting of the MTUDiscoverOption to disable
+ // path MTU discovery.
+ PMTUDiscoveryDont
+
+ // PMTUDiscoveryDo is a setting of the MTUDiscoverOption to always do
+ // path MTU discovery.
+ PMTUDiscoveryDo
+
+ // PMTUDiscoveryProbe is a setting of the MTUDiscoverOption to set DF
+ // but ignore path MTU.
+ PMTUDiscoveryProbe
+)
+
// ErrorOption is used in GetSockOpt to specify that the last error reported by
// the endpoint should be cleared and returned.
type ErrorOption struct{}
@@ -740,7 +877,7 @@ type CongestionControlOption string
// control algorithms.
type AvailableCongestionControlOption string
-// buffer moderation.
+// ModerateReceiveBufferOption is used by buffer moderation.
type ModerateReceiveBufferOption bool
// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
@@ -813,7 +950,15 @@ type OutOfBandInlineOption int
// a default TTL.
type DefaultTTLOption uint8
-// IPPacketInfo is the message struture for IP_PKTINFO.
+// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached
+// classic BPF filter on a given endpoint.
+type SocketDetachFilterOption int
+
+// OriginalDestinationOption is used to get the original destination address
+// and port of a redirected packet.
+type OriginalDestinationOption FullAddress
+
+// IPPacketInfo is the message structure for IP_PKTINFO.
//
// +stateify savable
type IPPacketInfo struct {
@@ -1198,6 +1343,12 @@ type UDPStats struct {
// PacketSendErrors is the number of datagrams failed to be sent.
PacketSendErrors *StatCounter
+
+ // ChecksumErrors is the number of datagrams dropped due to bad checksums.
+ ChecksumErrors *StatCounter
+
+ // InvalidSourceAddress is the number of invalid sourced datagrams dropped.
+ InvalidSourceAddress *StatCounter
}
// Stats holds statistics about the networking stack.
@@ -1241,6 +1392,9 @@ type ReceiveErrors struct {
// ClosedReceiver is the number of received packets dropped because
// of receiving endpoint state being closed.
ClosedReceiver StatCounter
+
+ // ChecksumErrors is the number of packets dropped due to bad checksums.
+ ChecksumErrors StatCounter
}
// SendErrors collects packet send errors within the transport layer for
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
index 2f98a996f..f32d58091 100644
--- a/pkg/tcpip/time_unsafe.go
+++ b/pkg/tcpip/time_unsafe.go
@@ -13,14 +13,14 @@
// limitations under the License.
// +build go1.9
-// +build !go1.15
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
package tcpip
import (
- _ "time" // Used with go:linkname.
+ "time" // Used with go:linkname.
_ "unsafe" // Required for go:linkname.
)
@@ -45,3 +45,31 @@ func (*StdClock) NowMonotonic() int64 {
_, _, mono := now()
return mono
}
+
+// AfterFunc implements Clock.AfterFunc.
+func (*StdClock) AfterFunc(d time.Duration, f func()) Timer {
+ return &stdTimer{
+ t: time.AfterFunc(d, f),
+ }
+}
+
+type stdTimer struct {
+ t *time.Timer
+}
+
+var _ Timer = (*stdTimer)(nil)
+
+// Stop implements Timer.Stop.
+func (st *stdTimer) Stop() bool {
+ return st.t.Stop()
+}
+
+// Reset implements Timer.Reset.
+func (st *stdTimer) Reset(d time.Duration) {
+ st.t.Reset(d)
+}
+
+// NewStdTimer returns a Timer implemented with the time package.
+func NewStdTimer(t *time.Timer) Timer {
+ return &stdTimer{t: t}
+}
diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go
index 59f3b391f..f1dd7c310 100644
--- a/pkg/tcpip/timer.go
+++ b/pkg/tcpip/timer.go
@@ -15,54 +15,54 @@
package tcpip
import (
- "sync"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
-// cancellableTimerInstance is a specific instance of CancellableTimer.
+// jobInstance is a specific instance of Job.
//
-// Different instances are created each time CancellableTimer is Reset so each
-// timer has its own earlyReturn signal. This is to address a bug when a
-// CancellableTimer is stopped and reset in quick succession resulting in a
-// timer instance's earlyReturn signal being affected or seen by another timer
-// instance.
+// Different instances are created each time Job is scheduled so each timer has
+// its own earlyReturn signal. This is to address a bug when a Job is stopped
+// and reset in quick succession resulting in a timer instance's earlyReturn
+// signal being affected or seen by another timer instance.
//
// Consider the following sceneario where timer instances share a common
// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a
// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second
// (B), third (C), and fourth (D) instance of the timer firing, respectively):
// T1: Obtain L
-// T1: Create a new CancellableTimer w/ lock L (create instance A)
+// T1: Create a new Job w/ lock L (create instance A)
// T2: instance A fires, blocked trying to obtain L.
// T1: Attempt to stop instance A (set earlyReturn = true)
-// T1: Reset timer (create instance B)
+// T1: Schedule timer (create instance B)
// T3: instance B fires, blocked trying to obtain L.
// T1: Attempt to stop instance B (set earlyReturn = true)
-// T1: Reset timer (create instance C)
+// T1: Schedule timer (create instance C)
// T4: instance C fires, blocked trying to obtain L.
// T1: Attempt to stop instance C (set earlyReturn = true)
-// T1: Reset timer (create instance D)
+// T1: Schedule timer (create instance D)
// T5: instance D fires, blocked trying to obtain L.
// T1: Release L
//
-// Now that T1 has released L, any of the 4 timer instances can take L and check
-// earlyReturn. If the timers simply check earlyReturn and then do nothing
-// further, then instance D will never early return even though it was not
-// requested to stop. If the timers reset earlyReturn before early returning,
-// then all but one of the timers will do work when only one was expected to.
-// If CancellableTimer resets earlyReturn when resetting, then all the timers
+// Now that T1 has released L, any of the 4 timer instances can take L and
+// check earlyReturn. If the timers simply check earlyReturn and then do
+// nothing further, then instance D will never early return even though it was
+// not requested to stop. If the timers reset earlyReturn before early
+// returning, then all but one of the timers will do work when only one was
+// expected to. If Job resets earlyReturn when resetting, then all the timers
// will fire (again, when only one was expected to).
//
// To address the above concerns the simplest solution was to give each timer
// its own earlyReturn signal.
-type cancellableTimerInstance struct {
- timer *time.Timer
+type jobInstance struct {
+ timer Timer
// Used to inform the timer to early return when it gets stopped while the
// lock the timer tries to obtain when fired is held (T1 is a goroutine that
// tries to cancel the timer and T2 is the goroutine that handles the timer
// firing):
- // T1: Obtain the lock, then call StopLocked()
+ // T1: Obtain the lock, then call Cancel()
// T2: timer fires, and gets blocked on obtaining the lock
// T1: Releases lock
// T2: Obtains lock does unintended work
@@ -73,27 +73,33 @@ type cancellableTimerInstance struct {
earlyReturn *bool
}
-// stop stops the timer instance t from firing if it hasn't fired already. If it
+// stop stops the job instance j from firing if it hasn't fired already. If it
// has fired and is blocked at obtaining the lock, earlyReturn will be set to
// true so that it will early return when it obtains the lock.
-func (t *cancellableTimerInstance) stop() {
- if t.timer != nil {
- t.timer.Stop()
- *t.earlyReturn = true
+func (j *jobInstance) stop() {
+ if j.timer != nil {
+ j.timer.Stop()
+ *j.earlyReturn = true
}
}
-// CancellableTimer is a timer that does some work and can be safely cancelled
-// when it fires at the same time some "related work" is being done.
+// Job represents some work that can be scheduled for execution. The work can
+// be safely cancelled when it fires at the same time some "related work" is
+// being done.
//
// The term "related work" is defined as some work that needs to be done while
// holding some lock that the timer must also hold while doing some work.
//
-// Note, it is not safe to copy a CancellableTimer as its timer instance creates
-// a closure over the address of the CancellableTimer.
-type CancellableTimer struct {
+// Note, it is not safe to copy a Job as its timer instance creates
+// a closure over the address of the Job.
+type Job struct {
+ _ sync.NoCopy
+
+ // The clock used to schedule the backing timer
+ clock Clock
+
// The active instance of a cancellable timer.
- instance cancellableTimerInstance
+ instance jobInstance
// locker is the lock taken by the timer immediately after it fires and must
// be held when attempting to stop the timer.
@@ -110,75 +116,91 @@ type CancellableTimer struct {
fn func()
}
-// StopLocked prevents the Timer from firing if it has not fired already.
+// Cancel prevents the Job from executing if it has not executed already.
//
-// If the timer is blocked on obtaining the t.locker lock when StopLocked is
-// called, it will early return instead of calling t.fn.
+// Cancel requires appropriate locking to be in place for any resources managed
+// by the Job. If the Job is blocked on obtaining the lock when Cancel is
+// called, it will early return.
//
// Note, t will be modified.
//
-// t.locker MUST be locked.
-func (t *CancellableTimer) StopLocked() {
- t.instance.stop()
+// j.locker MUST be locked.
+func (j *Job) Cancel() {
+ j.instance.stop()
// Nothing to do with the stopped instance anymore.
- t.instance = cancellableTimerInstance{}
+ j.instance = jobInstance{}
}
-// Reset changes the timer to expire after duration d.
+// Schedule schedules the Job for execution after duration d. This can be
+// called on cancelled or completed Jobs to schedule them again.
//
-// Note, t will be modified.
+// Schedule should be invoked only on unscheduled, cancelled, or completed
+// Jobs. To be safe, callers should always call Cancel before calling Schedule.
//
-// Reset should only be called on stopped or expired timers. To be safe, callers
-// should always call StopLocked before calling Reset.
-func (t *CancellableTimer) Reset(d time.Duration) {
+// Note, j will be modified.
+func (j *Job) Schedule(d time.Duration) {
// Create a new instance.
earlyReturn := false
// Capture the locker so that updating the timer does not cause a data race
// when a timer fires and tries to obtain the lock (read the timer's locker).
- locker := t.locker
- t.instance = cancellableTimerInstance{
- timer: time.AfterFunc(d, func() {
+ locker := j.locker
+ j.instance = jobInstance{
+ timer: j.clock.AfterFunc(d, func() {
locker.Lock()
defer locker.Unlock()
if earlyReturn {
// If we reach this point, it means that the timer fired while another
- // goroutine called StopLocked while it had the lock. Simply return
- // here and do nothing further.
+ // goroutine called Cancel while it had the lock. Simply return here
+ // and do nothing further.
earlyReturn = false
return
}
- t.fn()
+ j.fn()
}),
earlyReturn: &earlyReturn,
}
}
-// Lock is a no-op used by the copylocks checker from go vet.
-//
-// See CancellableTimer for details about why it shouldn't be copied.
-//
-// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more
-// details about the copylocks checker.
-func (*CancellableTimer) Lock() {}
-
-// Unlock is a no-op used by the copylocks checker from go vet.
-//
-// See CancellableTimer for details about why it shouldn't be copied.
-//
-// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more
-// details about the copylocks checker.
-func (*CancellableTimer) Unlock() {}
-
-// NewCancellableTimer returns an unscheduled CancellableTimer with the given
-// locker and fn.
-//
-// fn MUST NOT attempt to lock locker.
-//
-// Callers must call Reset to schedule the timer to fire.
-func NewCancellableTimer(locker sync.Locker, fn func()) *CancellableTimer {
- return &CancellableTimer{locker: locker, fn: fn}
+// NewJob returns a new Job that can be used to schedule f to run in its own
+// gorountine. l will be locked before calling f then unlocked after f returns.
+//
+// var clock tcpip.StdClock
+// var mu sync.Mutex
+// message := "foo"
+// job := tcpip.NewJob(&clock, &mu, func() {
+// fmt.Println(message)
+// })
+// job.Schedule(time.Second)
+//
+// mu.Lock()
+// message = "bar"
+// mu.Unlock()
+//
+// // Output: bar
+//
+// f MUST NOT attempt to lock l.
+//
+// l MUST be locked prior to calling the returned job's Cancel().
+//
+// var clock tcpip.StdClock
+// var mu sync.Mutex
+// message := "foo"
+// job := tcpip.NewJob(&clock, &mu, func() {
+// fmt.Println(message)
+// })
+// job.Schedule(time.Second)
+//
+// mu.Lock()
+// job.Cancel()
+// mu.Unlock()
+func NewJob(c Clock, l sync.Locker, f func()) *Job {
+ return &Job{
+ clock: c,
+ locker: l,
+ fn: f,
+ }
}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
index b4940e397..a82384c49 100644
--- a/pkg/tcpip/timer_test.go
+++ b/pkg/tcpip/timer_test.go
@@ -28,8 +28,8 @@ const (
longDuration = 1 * time.Second
)
-func TestCancellableTimerReassignment(t *testing.T) {
- var timer tcpip.CancellableTimer
+func TestJobReschedule(t *testing.T) {
+ var clock tcpip.StdClock
var wg sync.WaitGroup
var lock sync.Mutex
@@ -43,26 +43,27 @@ func TestCancellableTimerReassignment(t *testing.T) {
// that has an active timer (even if it has been stopped as a stopped
// timer may be blocked on a lock before it can check if it has been
// stopped while another goroutine holds the same lock).
- timer = *tcpip.NewCancellableTimer(&lock, func() {
+ job := tcpip.NewJob(&clock, &lock, func() {
wg.Done()
})
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
lock.Unlock()
}()
}
wg.Wait()
}
-func TestCancellableTimerFire(t *testing.T) {
+func TestJobExecution(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
- timer := tcpip.NewCancellableTimer(&lock, func() {
+ job := tcpip.NewJob(&clock, &lock, func() {
ch <- struct{}{}
})
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -82,17 +83,18 @@ func TestCancellableTimerFire(t *testing.T) {
func TestCancellableTimerResetFromLongDuration(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(middleDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(middleDuration)
lock.Lock()
- timer.StopLocked()
+ job.Cancel()
lock.Unlock()
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -109,16 +111,17 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) {
}
}
-func TestCancellableTimerResetFromShortDuration(t *testing.T) {
+func TestJobRescheduleFromShortDuration(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
// Wait for timer to fire if it wasn't correctly stopped.
@@ -128,7 +131,7 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) {
case <-time.After(middleDuration):
}
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -145,17 +148,18 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) {
}
}
-func TestCancellableTimerImmediatelyStop(t *testing.T) {
+func TestJobImmediatelyCancel(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
for i := 0; i < 1000; i++ {
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
}
@@ -167,25 +171,26 @@ func TestCancellableTimerImmediatelyStop(t *testing.T) {
}
}
-func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
+func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
for i := 0; i < 10; i++ {
- timer.Reset(middleDuration)
+ job.Schedule(middleDuration)
lock.Lock()
// Sleep until the timer fires and gets blocked trying to take the lock.
time.Sleep(middleDuration * 2)
- timer.StopLocked()
+ job.Cancel()
lock.Unlock()
}
@@ -201,17 +206,18 @@ func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
// Sleep until the timer fires and gets blocked trying to take the lock.
time.Sleep(middleDuration)
- timer.StopLocked()
- timer.Reset(shortDuration)
+ job.Cancel()
+ job.Schedule(shortDuration)
}
lock.Unlock()
@@ -230,18 +236,19 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
}
}
-func TestManyCancellableTimerResetUnderLock(t *testing.T) {
+func TestManyJobReschedulesUnderLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
- timer.StopLocked()
- timer.Reset(shortDuration)
+ job.Cancel()
+ job.Schedule(shortDuration)
}
lock.Unlock()
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 9ce625c17..7e5c79776 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -31,6 +31,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index b1d820372..4612be4e7 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -110,7 +111,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */)
}
// Close the receive list and drain it.
@@ -140,11 +141,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -348,6 +344,10 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+ }
return nil
}
@@ -450,7 +450,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
Data: data.ToVectorisedView(),
TransportHeader: buffer.View(icmpv4),
@@ -481,7 +481,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
Data: dataVV,
TransportHeader: buffer.View(icmpv6),
@@ -511,6 +511,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicID := addr.NIC
localPort := uint16(0)
switch e.state {
+ case stateInitial:
case stateBound, stateConnected:
localPort = e.ID.LocalPort
if e.BindNICID == 0 {
@@ -611,14 +612,14 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
switch err {
case nil:
return true, nil
@@ -743,19 +744,19 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
- h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
- if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply {
+ h := header.ICMPv4(pkt.TransportHeader)
+ if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
- h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
- if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply {
+ h := header.ICMPv6(pkt.TransportHeader)
+ if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
@@ -789,12 +790,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
},
}
- packet.data = pkt.Data
+ // ICMP socket's data includes ICMP header.
+ packet.data = pkt.TransportHeader.ToVectorisedView()
+ packet.data.Append(pkt.Data)
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
@@ -805,7 +808,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
}
// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 3c47692b2..74ef6541e 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
return true
}
@@ -124,6 +124,16 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ // TODO(gvisor.dev/issue/170): Implement parsing of ICMP.
+ //
+ // Right now, the Parse() method is tied to enabled protocols passed into
+ // stack.New. This works for UDP and TCP, but we handle ICMP traffic even
+ // when netstack users don't pass ICMP as a supported protocol.
+ return false
+}
+
// NewProtocol4 returns an ICMPv4 transport protocol.
func NewProtocol4() stack.TransportProtocol {
return &protocol{ProtocolNumber4}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 23158173d..df478115d 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -25,6 +25,8 @@
package packet
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -43,6 +45,9 @@ type packet struct {
timestampNS int64
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
+ // packetInfo holds additional information like the protocol
+ // of the packet etc.
+ packetInfo tcpip.LinkPacketInfo
}
// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
@@ -71,11 +76,17 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- closed bool
- stats tcpip.TransportEndpointStats `state:"nosave"`
- bound bool
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ closed bool
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ bound bool
+ boundNIC tcpip.NICID
+
+ // lastErrorMu protects lastError.
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
}
// NewEndpoint returns a new packet endpoint.
@@ -92,6 +103,17 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
sndBufSize: 32 * 1024,
}
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ ep.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ ep.rcvBufSizeMax = rs.Default
+ }
+
if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil {
return nil, err
}
@@ -132,13 +154,8 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (ep *endpoint) IPTables() (stack.IPTables, error) {
- return ep.stack.IPTables(), nil
-}
-
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.PacketEndpoint.ReadPacket.
+func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -163,11 +180,20 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
*addr = packet.senderAddr
}
+ if info != nil {
+ *info = packet.packetInfo
+ }
+
return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
}
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ return ep.ReadPacket(addr, nil)
+}
+
func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
- // TODO(b/129292371): Implement.
+ // TODO(gvisor.dev/issue/173): Implement.
return 0, nil, tcpip.ErrInvalidOptionValue
}
@@ -220,12 +246,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound {
- return tcpip.ErrAlreadyBound
+ if ep.bound && ep.boundNIC == addr.NIC {
+ // If the NIC being bound is the same then just return success.
+ return nil
}
// Unregister endpoint with all the nics.
ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+ ep.bound = false
// Bind endpoint to receive packets from specific interface.
if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
@@ -233,6 +261,7 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
ep.bound = true
+ ep.boundNIC = addr.NIC
return nil
}
@@ -269,7 +298,13 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
@@ -279,11 +314,63 @@ func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := ep.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ ep.mu.Lock()
+ ep.sndBufSizeMax = v
+ ep.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := ep.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ if v < rs.Min {
+ v = rs.Min
+ }
+ ep.rcvMu.Lock()
+ ep.rcvBufSizeMax = v
+ ep.rcvMu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (ep *endpoint) takeLastError() *tcpip.Error {
+ ep.lastErrorMu.Lock()
+ defer ep.lastErrorMu.Unlock()
+
+ err := ep.lastError
+ ep.lastError = nil
+ return err
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.ErrorOption:
+ return ep.takeLastError()
+ }
return tcpip.ErrNotSupported
}
@@ -294,11 +381,36 @@ func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
- return 0, tcpip.ErrNotSupported
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() {
+ p := ep.rcvList.Front()
+ v = p.data.Size()
+ }
+ ep.rcvMu.Unlock()
+ return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ v := ep.sndBufSizeMax
+ ep.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ v := ep.rcvBufSizeMax
+ ep.rcvMu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
}
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
ep.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -320,7 +432,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
// Push new packet into receive list and increment the buffer size.
var packet packet
- // TODO(b/129292371): Return network protocol.
+ // TODO(gvisor.dev/issue/173): Return network protocol.
if len(pkt.LinkHeader) > 0 {
// Get info directly from the ethernet header.
hdr := header.Ethernet(pkt.LinkHeader)
@@ -328,40 +440,66 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
NIC: nicID,
Addr: tcpip.Address(hdr.SourceAddress()),
}
+ packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
} else {
// Guess the would-be ethernet header.
packet.senderAddr = tcpip.FullAddress{
NIC: nicID,
Addr: tcpip.Address(localAddr),
}
+ packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
}
if ep.cooked {
// Cooked packets can simply be queued.
- packet.data = pkt.Data
+ switch pkt.PktType {
+ case tcpip.PacketHost:
+ packet.data = pkt.Data
+ case tcpip.PacketOutgoing:
+ // Strip Link Header from the Header.
+ pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):])
+ combinedVV := pkt.Header.View().ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ default:
+ panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
+ }
+
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
var linkHeader buffer.View
- if len(pkt.LinkHeader) == 0 {
- // We weren't provided with an actual ethernet header,
- // so fake one.
- ethFields := header.EthernetFields{
- SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
- DstAddr: localAddr,
- Type: netProto,
+ var combinedVV buffer.VectorisedView
+ if pkt.PktType != tcpip.PacketOutgoing {
+ if len(pkt.LinkHeader) == 0 {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
}
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- linkHeader = buffer.View(fakeHeader)
- } else {
- linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
+ combinedVV = linkHeader.ToVectorisedView()
+ }
+ if pkt.PktType == tcpip.PacketOutgoing {
+ // For outgoing packets the Link, Network and Transport
+ // headers are in the pkt.Header fields normally unless
+ // a Raw socket is in use in which case pkt.Header could
+ // be nil.
+ combinedVV.AppendView(pkt.Header.View())
}
- combinedVV := linkHeader.ToVectorisedView()
combinedVV.Append(pkt.Data)
packet.data = combinedVV
}
- packet.timestampNS = ep.stack.NowNanoseconds()
+ packet.timestampNS = ep.stack.Clock().NowNanoseconds()
ep.rcvList.PushBack(&packet)
ep.rcvBufSize += packet.data.Size()
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index 9b88f17e4..e2fa96d17 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,6 +15,7 @@
package packet
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() {
panic(*err)
}
}
+
+// saveLastError is invoked by stateify.
+func (ep *endpoint) saveLastError() string {
+ if ep.lastError == nil {
+ return ""
+ }
+
+ return ep.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (ep *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ ep.lastError = tcpip.StringToError(s)
+}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index eee754a5a..f85a68554 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,6 +26,8 @@
package raw
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -61,21 +63,23 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
associated bool
+ hdrIncluded bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
rcvList rawPacketList
- rcvBufSizeMax int `state:".(int)"`
rcvBufSize int
+ rcvBufSizeMax int `state:".(int)"`
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- closed bool
- connected bool
- bound bool
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ closed bool
+ connected bool
+ bound bool
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
route stack.Route `state:"manual"`
@@ -91,7 +95,7 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if netProto != header.IPv4ProtocolNumber {
+ if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
return nil, tcpip.ErrUnknownProtocol
}
@@ -103,8 +107,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
},
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
associated: associated,
+ hdrIncluded: !associated,
+ }
+
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ e.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ e.rcvBufSizeMax = rs.Default
}
// Unassociated endpoints are write-only and users call Write() with IP
@@ -166,17 +182,8 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read implements tcpip.Endpoint.Read.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- if !e.associated {
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue
- }
-
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -206,6 +213,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
// Write implements tcpip.Endpoint.Write.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // We can create, but not write to, unassociated IPv6 endpoints.
+ if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
n, ch, err := e.write(p, opts)
switch err {
case nil:
@@ -249,7 +261,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if !e.associated {
+ if e.hdrIncluded {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
e.mu.RUnlock()
@@ -310,12 +322,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrNoRoute
}
- // We don't support IPv6 yet, so this has to be an IPv4 address.
- if len(opts.To.Addr) != header.IPv4AddressSize {
- e.mu.RUnlock()
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
-
// Find the route to the destination. If BindAddress is 0,
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
@@ -345,28 +351,21 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- switch e.NetProto {
- case header.IPv4ProtocolNumber:
- if !e.associated {
- if err := route.WriteHeaderIncludedPacket(stack.PacketBuffer{
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- }); err != nil {
- return 0, nil, err
- }
- break
+ if e.hdrIncluded {
+ if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ }); err != nil {
+ return 0, nil, err
}
-
+ } else {
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
- if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
Data: buffer.View(payloadBytes).ToVectorisedView(),
Owner: e.owner,
}); err != nil {
return 0, nil, err
}
-
- default:
- return 0, nil, tcpip.ErrUnknownProtocol
}
return int64(len(payloadBytes)), nil, nil
@@ -391,11 +390,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // We don't support IPv6 yet.
- if len(addr.Addr) != header.IPv4AddressSize {
- return tcpip.ErrInvalidEndpointState
- }
-
nic := addr.NIC
if e.bound {
if e.BindNICID == 0 {
@@ -461,14 +455,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- // Callers must provide an IPv4 address or no network address (for
- // binding to a NIC, but not an address).
- if len(addr.Addr) != 0 && len(addr.Addr) != 4 {
- return tcpip.ErrInvalidEndpointState
- }
-
// If a local address was specified, verify that it's valid.
- if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
+ if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
@@ -518,17 +506,69 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ e.hdrIncluded = v
+ e.mu.Unlock()
+ return nil
+ }
return tcpip.ErrUnknownProtocolOption
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := e.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ e.mu.Lock()
+ e.sndBufSizeMax = v
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := e.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ if v < rs.Min {
+ v = rs.Min
+ }
+ e.rcvMu.Lock()
+ e.rcvBufSizeMax = v
+ e.rcvMu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
@@ -548,6 +588,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
case tcpip.KeepaliveEnabledOption:
return false, nil
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ v := e.hdrIncluded
+ e.mu.Unlock()
+ return v, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -568,7 +614,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
case tcpip.SendBufferSizeOption:
e.mu.Lock()
- v := e.sndBufSize
+ v := e.sndBufSizeMax
e.mu.Unlock()
return v, nil
@@ -584,11 +630,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
-func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
e.rcvMu.Lock()
- // Drop the packet if our buffer is currently full.
- if e.rcvClosed {
+ // Drop the packet if our buffer is currently full or if this is an unassociated
+ // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
+ // See: https://man7.org/linux/man-pages/man7/raw.7.html
+ //
+ // An IPPROTO_RAW socket is send only. If you really want to receive
+ // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
+ // Note that packet sockets don't reassemble IP fragments, unlike raw
+ // sockets.
+ if e.rcvClosed || !e.associated {
e.rcvMu.Unlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ClosedReceiver.Increment()
@@ -632,15 +685,25 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) {
},
}
- networkHeader := append(buffer.View(nil), pkt.NetworkHeader...)
- combinedVV := networkHeader.ToVectorisedView()
+ // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
+ // We copy headers' underlying bytes because pkt.*Header may point to
+ // the middle of a slice, and another struct may point to the "outer"
+ // slice. Save/restore doesn't support overlapping slices and will fail.
+ var combinedVV buffer.VectorisedView
+ if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
+ headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader))
+ headers = append(headers, pkt.NetworkHeader...)
+ headers = append(headers, pkt.TransportHeader...)
+ combinedVV = headers.ToVectorisedView()
+ } else {
+ combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView()
+ }
combinedVV.Append(pkt.Data)
packet.data = combinedVV
- packet.timestampNS = e.stack.NowNanoseconds()
+ packet.timestampNS = e.stack.Clock().NowNanoseconds()
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
-
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
// Notify waiters that there's data to be read.
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index f38eb6833..e860ee484 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -49,6 +49,7 @@ go_library(
"segment_heap.go",
"segment_queue.go",
"segment_state.go",
+ "segment_unsafe.go",
"snd.go",
"snd_state.go",
"tcp_endpoint_list.go",
@@ -76,7 +77,7 @@ go_library(
)
go_test(
- name = "tcp_test",
+ name = "tcp_x_test",
size = "medium",
srcs = [
"dual_stack_test.go",
@@ -86,10 +87,7 @@ go_test(
"tcp_test.go",
"tcp_timestamp_test.go",
],
- # FIXME(b/68809571)
- tags = [
- "flaky",
- ],
+ shard_count = 10,
deps = [
":tcp",
"//pkg/sync",
@@ -119,3 +117,11 @@ go_test(
"//pkg/tcpip/seqnum",
],
)
+
+go_test(
+ name = "tcp_test",
+ size = "small",
+ srcs = ["timer_test.go"],
+ library = ":tcp",
+ deps = ["//pkg/sleep"],
+)
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index e6a23c978..6e00e5526 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -198,9 +198,8 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
}
// createConnectingEndpoint creates a new endpoint in a connecting state, with
-// the connection parameters given by the arguments. The endpoint is returned
-// with n.mu held.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
+// the connection parameters given by the arguments.
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
@@ -221,32 +220,12 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
- // Create sender and receiver.
- //
- // The receiver at least temporarily has a zero receive window scale,
- // but the caller may change it (before starting the protocol loop).
- n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
- n.rcv = newReceiver(n, irs, seqnum.Size(n.initialReceiveWindow()), 0, seqnum.Size(n.receiveBufferSize()))
// Bootstrap the auto tuning algorithm. Starting at zero will result in
// a large step function on the first window adjustment causing the
// window to grow to a really large value.
n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
- // Lock the endpoint before registering to ensure that no out of
- // band changes are possible due to incoming packets etc till
- // the endpoint is done initializing.
- n.mu.Lock()
-
- // Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil {
- n.mu.Unlock()
- n.Close()
- return nil, err
- }
-
- n.isRegistered = true
-
- return n, nil
+ return n
}
// createEndpointAndPerformHandshake creates a new endpoint in connected state
@@ -257,10 +236,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
- if err != nil {
- return nil, err
- }
+ ep := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+
+ // Lock the endpoint before registering to ensure that no out of
+ // band changes are possible due to incoming packets etc till
+ // the endpoint is done initializing.
+ ep.mu.Lock()
ep.owner = owner
// listenEP is nil when listenContext is used by tcp.Forwarder.
@@ -268,18 +249,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if l.listenEP != nil {
l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
+
l.listenEP.mu.Unlock()
// Ensure we release any registrations done by the newly
// created endpoint.
ep.mu.Unlock()
ep.Close()
- // Wake up any waiters. This is strictly not required normally
- // as a socket that was never accepted can't really have any
- // registered waiters except when stack.Wait() is called which
- // waits for all registered endpoints to stop and expects an
- // EventHUp.
- ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
return nil, tcpip.ErrConnectionAborted
}
l.addPendingEndpoint(ep)
@@ -288,21 +264,44 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// to the newly created endpoint.
l.listenEP.propagateInheritableOptionsLocked(ep)
+ if !ep.reserveTupleLocked() {
+ ep.mu.Unlock()
+ ep.Close()
+
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ l.listenEP.mu.Unlock()
+ }
+
+ return nil, tcpip.ErrConnectionAborted
+ }
+
deferAccept = l.listenEP.deferAccept
l.listenEP.mu.Unlock()
}
+ // Register new endpoint so that packets are routed to it.
+ if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil {
+ ep.mu.Unlock()
+ ep.Close()
+
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ }
+
+ ep.drainClosingSegmentQueue()
+
+ return nil, err
+ }
+
+ ep.isRegistered = true
+
// Perform the 3-way handshake.
- h := newPassiveHandshake(ep, ep.rcv.rcvWnd, isn, irs, opts, deferAccept)
+ h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.mu.Unlock()
ep.Close()
- // Wake up any waiters. This is strictly not required normally
- // as a socket that was never accepted can't really have any
- // registered waiters except when stack.Wait() is called which
- // waits for all registered endpoints to stop and expects an
- // EventHUp.
- ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ ep.notifyAborted()
if l.listenEP != nil {
l.removePendingEndpoint(ep)
@@ -378,6 +377,43 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
// Precondition: e.mu and n.mu must be held.
func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
n.userTimeout = e.userTimeout
+ n.portFlags = e.portFlags
+ n.boundBindToDevice = e.boundBindToDevice
+ n.boundPortFlags = e.boundPortFlags
+}
+
+// reserveTupleLocked reserves an accepted endpoint's tuple.
+//
+// Preconditions:
+// * propagateInheritableOptionsLocked has been called.
+// * e.mu is held.
+func (e *endpoint) reserveTupleLocked() bool {
+ dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort}
+ if !e.stack.ReserveTuple(
+ e.effectiveNetProtos,
+ ProtocolNumber,
+ e.ID.LocalAddress,
+ e.ID.LocalPort,
+ e.boundPortFlags,
+ e.boundBindToDevice,
+ dest,
+ ) {
+ return false
+ }
+
+ e.isPortReserved = true
+ e.boundDest = dest
+ return true
+}
+
+// notifyAborted wakes up any waiters on registered, but not accepted
+// endpoints.
+//
+// This is strictly not required normally as a socket that was never accepted
+// can't really have any registered waiters except when stack.Wait() is called
+// which waits for all registered endpoints to stop and expects an EventHUp.
+func (e *endpoint) notifyAborted() {
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
// handleSynSegment is called in its own goroutine once the listening endpoint
@@ -534,6 +570,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
+ iss := s.ackNumber - 1
+ irs := s.sequenceNumber - 1
+
// Since SYN cookies are in use this is potentially an ACK to a
// SYN-ACK we sent but don't have a half open connection state
// as cookies are being used to protect against a potential SYN
@@ -544,7 +583,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// when under a potential syn flood attack.
//
// Validate the cookie.
- data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1)
+ data, ok := ctx.isCookieValid(s.id, iss, irs)
if !ok || int(data) >= len(mssTable) {
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
@@ -569,16 +608,34 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{})
- if err != nil {
+ n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+
+ n.mu.Lock()
+
+ // Propagate any inheritable options from the listening endpoint
+ // to the newly created endpoint.
+ e.propagateInheritableOptionsLocked(n)
+
+ if !n.reserveTupleLocked() {
+ n.mu.Unlock()
+ n.Close()
+
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
return
}
- // Propagate any inheritable options from the listening endpoint
- // to the newly created endpoint.
- e.propagateInheritableOptionsLocked(n)
+ // Register new endpoint so that packets are routed to it.
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil {
+ n.mu.Unlock()
+ n.Close()
+
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ return
+ }
+
+ n.isRegistered = true
// clear the tsOffset for the newly created
// endpoint as the Timestamp was already
@@ -587,10 +644,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n.tsOffset = 0
// Switch state to connected.
- // We do not use transitionToStateEstablishedLocked here as there is
- // no handshake state available when doing a SYN cookie based accept.
n.isConnectNotified = true
- n.setEndpointState(StateEstablished)
+ n.transitionToStateEstablishedLocked(&handshake{
+ ep: n,
+ iss: iss,
+ ackNum: irs + 1,
+ rcvWnd: seqnum.Size(n.initialReceiveWindow()),
+ sndWnd: s.window,
+ rcvWndScale: e.rcvWndScaleForHandshake(),
+ sndWndScale: rcvdSynOptions.WS,
+ mss: rcvdSynOptions.MSS,
+ })
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index a7e088d4e..6e5e55b6f 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
+ }
}
// Wait for notification.
@@ -509,9 +512,7 @@ func (h *handshake) execute() *tcpip.Error {
// Initialize the resend timer.
resendWaker := sleep.Waker{}
timeOut := time.Duration(time.Second)
- rt := time.AfterFunc(timeOut, func() {
- resendWaker.Assert()
- })
+ rt := time.AfterFunc(timeOut, resendWaker.Assert)
defer rt.Stop()
// Set up the wakers.
@@ -618,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
+ }
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
@@ -833,13 +837,13 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac
return sendTCPBatch(r, tf, data, gso, owner)
}
- pkt := stack.PacketBuffer{
+ pkt := &stack.PacketBuffer{
Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
Data: data,
Hash: tf.txHash,
Owner: owner,
}
- buildTCPHdr(r, tf, &pkt, gso)
+ buildTCPHdr(r, tf, pkt, gso)
if tf.ttl == 0 {
tf.ttl = r.DefaultTTL()
@@ -995,24 +999,22 @@ func (e *endpoint) completeWorkerLocked() {
// transitionToStateEstablisedLocked transitions a given endpoint
// to an established state using the handshake parameters provided.
-// It also initializes sender/receiver if required.
+// It also initializes sender/receiver.
func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
- if e.snd == nil {
- // Transfer handshake state to TCP connection. We disable
- // receive window scaling if the peer doesn't support it
- // (indicated by a negative send window scale).
- e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
- }
- if e.rcv == nil {
- rcvBufSize := seqnum.Size(e.receiveBufferSize())
- e.rcvListMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
- // Bootstrap the auto tuning algorithm. Starting at zero will
- // result in a really large receive window after the first auto
- // tuning adjustment.
- e.rcvAutoParams.prevCopied = int(h.rcvWnd)
- e.rcvListMu.Unlock()
- }
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ rcvBufSize := seqnum.Size(e.receiveBufferSize())
+ e.rcvListMu.Lock()
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
+ // Bootstrap the auto tuning algorithm. Starting at zero will
+ // result in a really large receive window after the first auto
+ // tuning adjustment.
+ e.rcvAutoParams.prevCopied = int(h.rcvWnd)
+ e.rcvListMu.Unlock()
+
e.setEndpointState(StateEstablished)
}
@@ -1022,14 +1024,19 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// delivered to this endpoint from the demuxer when the endpoint
// is transitioned to StateClose.
func (e *endpoint) transitionToStateCloseLocked() {
- if e.EndpointState() == StateClose {
+ s := e.EndpointState()
+ if s == StateClose {
return
}
+
+ if s.connected() {
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
+ e.stack.Stats().TCP.EstablishedClosed.Increment()
+ }
+
// Mark the endpoint as fully closed for reads/writes.
e.cleanupLocked()
e.setEndpointState(StateClose)
- e.stack.Stats().TCP.CurrentConnected.Decrement()
- e.stack.Stats().TCP.EstablishedClosed.Increment()
}
// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed
@@ -1052,8 +1059,8 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
panic("current endpoint not removed from demuxer, enqueing segments to itself")
}
- if ep.(*endpoint).enqueueSegment(s) {
- ep.(*endpoint).newSegmentWaker.Assert()
+ if ep := ep.(*endpoint); ep.enqueueSegment(s) {
+ ep.newSegmentWaker.Assert()
}
}
@@ -1122,7 +1129,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
- if e.EndpointState() == StateClose || e.EndpointState() == StateError {
+ if e.EndpointState().closed() {
return nil
}
s := e.segmentQueue.dequeue()
@@ -1347,6 +1354,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.setEndpointState(StateError)
e.HardError = err
+ e.workerCleanup = true
// Lock released below.
epilogue()
return err
@@ -1441,9 +1449,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
if e.EndpointState() == StateFinWait2 && e.closed {
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
- closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() {
- closeWaker.Assert()
- })
+ closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
}
@@ -1461,7 +1467,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
return err
}
}
- if e.EndpointState() != StateClose && e.EndpointState() != StateError {
+ if !e.EndpointState().closed() {
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
@@ -1517,6 +1523,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
cleanupOnError := func(err *tcpip.Error) {
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
e.workerCleanup = true
if err != nil {
e.resetConnectionLocked(err)
@@ -1526,7 +1533,12 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
loop:
- for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError {
+ for {
+ switch e.EndpointState() {
+ case StateTimeWait, StateClose, StateError:
+ break loop
+ }
+
e.mu.Unlock()
v, _ := s.Fetch(true)
e.mu.Lock()
@@ -1569,11 +1581,14 @@ loop:
reuseTW = e.doTimeWait()
}
- // Mark endpoint as closed.
- if e.EndpointState() != StateError {
- e.transitionToStateCloseLocked()
+ // Handle any StateError transition from StateTimeWait.
+ if e.EndpointState() == StateError {
+ cleanupOnError(nil)
+ return nil
}
+ e.transitionToStateCloseLocked()
+
// Lock released below.
epilogue()
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 6062ca916..98aecab9e 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -15,6 +15,8 @@
package tcp
import (
+ "encoding/binary"
+
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
@@ -66,89 +68,68 @@ func (q *epQueue) empty() bool {
// processor is responsible for processing packets queued to a tcp endpoint.
type processor struct {
epQ epQueue
+ sleeper sleep.Sleeper
newEndpointWaker sleep.Waker
closeWaker sleep.Waker
- id int
- wg sync.WaitGroup
-}
-
-func newProcessor(id int) *processor {
- p := &processor{
- id: id,
- }
- p.wg.Add(1)
- go p.handleSegments()
- return p
}
func (p *processor) close() {
p.closeWaker.Assert()
}
-func (p *processor) wait() {
- p.wg.Wait()
-}
-
func (p *processor) queueEndpoint(ep *endpoint) {
// Queue an endpoint for processing by the processor goroutine.
p.epQ.enqueue(ep)
p.newEndpointWaker.Assert()
}
-func (p *processor) handleSegments() {
- const newEndpointWaker = 1
- const closeWaker = 2
- s := sleep.Sleeper{}
- s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
- s.AddWaker(&p.closeWaker, closeWaker)
- defer s.Done()
+const (
+ newEndpointWaker = 1
+ closeWaker = 2
+)
+
+func (p *processor) start(wg *sync.WaitGroup) {
+ defer wg.Done()
+ defer p.sleeper.Done()
+
for {
- id, ok := s.Fetch(true)
- if ok && id == closeWaker {
- p.wg.Done()
- return
+ if id, _ := p.sleeper.Fetch(true); id == closeWaker {
+ break
}
- for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
+ for {
+ ep := p.epQ.dequeue()
+ if ep == nil {
+ break
+ }
if ep.segmentQueue.empty() {
continue
}
- // If socket has transitioned out of connected state
- // then just let the worker handle the packet.
+ // If socket has transitioned out of connected state then just let the
+ // worker handle the packet.
//
- // NOTE: We read this outside of e.mu lock which means
- // that by the time we get to handleSegments the
- // endpoint may not be in ESTABLISHED. But this should
- // be fine as all normal shutdown states are handled by
- // handleSegments and if the endpoint moves to a
- // CLOSED/ERROR state then handleSegments is a noop.
- if ep.EndpointState() != StateEstablished {
- ep.newSegmentWaker.Assert()
- continue
- }
-
- if !ep.mu.TryLock() {
- ep.newSegmentWaker.Assert()
- continue
- }
- // If the endpoint is in a connected state then we do
- // direct delivery to ensure low latency and avoid
- // scheduler interactions.
- if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose {
- // Send any active resets if required.
- if err != nil {
+ // NOTE: We read this outside of e.mu lock which means that by the time
+ // we get to handleSegments the endpoint may not be in ESTABLISHED. But
+ // this should be fine as all normal shutdown states are handled by
+ // handleSegments and if the endpoint moves to a CLOSED/ERROR state
+ // then handleSegments is a noop.
+ if ep.EndpointState() == StateEstablished && ep.mu.TryLock() {
+ // If the endpoint is in a connected state then we do direct delivery
+ // to ensure low latency and avoid scheduler interactions.
+ switch err := ep.handleSegments(true /* fastPath */); {
+ case err != nil:
+ // Send any active resets if required.
ep.resetConnectionLocked(err)
+ fallthrough
+ case ep.EndpointState() == StateClose:
+ ep.notifyProtocolGoroutine(notifyTickleWorker)
+ case !ep.segmentQueue.empty():
+ p.epQ.enqueue(ep)
}
- ep.notifyProtocolGoroutine(notifyTickleWorker)
ep.mu.Unlock()
- continue
- }
-
- if !ep.segmentQueue.empty() {
- p.epQ.enqueue(ep)
+ } else {
+ ep.newSegmentWaker.Assert()
}
-
- ep.mu.Unlock()
}
}
}
@@ -159,34 +140,39 @@ func (p *processor) handleSegments() {
// hash of the endpoint id to ensure that delivery for the same endpoint happens
// in-order.
type dispatcher struct {
- processors []*processor
+ processors []processor
seed uint32
-}
-
-func newDispatcher(nProcessors int) *dispatcher {
- processors := []*processor{}
- for i := 0; i < nProcessors; i++ {
- processors = append(processors, newProcessor(i))
- }
- return &dispatcher{
- processors: processors,
- seed: generateRandUint32(),
+ wg sync.WaitGroup
+}
+
+func (d *dispatcher) init(nProcessors int) {
+ d.close()
+ d.wait()
+ d.processors = make([]processor, nProcessors)
+ d.seed = generateRandUint32()
+ for i := range d.processors {
+ p := &d.processors[i]
+ p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ p.sleeper.AddWaker(&p.closeWaker, closeWaker)
+ d.wg.Add(1)
+ // NB: sleeper-waker registration must happen synchronously to avoid races
+ // with `close`. It's possible to pull all this logic into `start`, but
+ // that results in a heap-allocated function literal.
+ go p.start(&d.wg)
}
}
func (d *dispatcher) close() {
- for _, p := range d.processors {
- p.close()
+ for i := range d.processors {
+ d.processors[i].close()
}
}
func (d *dispatcher) wait() {
- for _, p := range d.processors {
- p.wait()
- }
+ d.wg.Wait()
}
-func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
ep := stackEP.(*endpoint)
s := newSegment(r, id, pkt)
if !s.parse() {
@@ -231,20 +217,18 @@ func generateRandUint32() uint32 {
if _, err := rand.Read(b); err != nil {
panic(err)
}
- return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+ return binary.LittleEndian.Uint32(b)
}
func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
- payload := []byte{
- byte(id.LocalPort),
- byte(id.LocalPort >> 8),
- byte(id.RemotePort),
- byte(id.RemotePort >> 8)}
+ var payload [4]byte
+ binary.LittleEndian.PutUint16(payload[0:], id.LocalPort)
+ binary.LittleEndian.PutUint16(payload[2:], id.RemotePort)
h := jenkins.Sum32(d.seed)
- h.Write(payload)
+ h.Write(payload[:])
h.Write([]byte(id.LocalAddress))
h.Write([]byte(id.RemoteAddress))
- return d.processors[h.Sum32()%uint32(len(d.processors))]
+ return &d.processors[h.Sum32()%uint32(len(d.processors))]
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index b5ba972f1..682687ebe 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -63,7 +63,8 @@ const (
StateClosing
)
-// connected is the set of states where an endpoint is connected to a peer.
+// connected returns true when s is one of the states representing an
+// endpoint connected to a peer.
func (s EndpointState) connected() bool {
switch s {
case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
@@ -73,6 +74,40 @@ func (s EndpointState) connected() bool {
}
}
+// connecting returns true when s is one of the states representing a
+// connection in progress, but not yet fully established.
+func (s EndpointState) connecting() bool {
+ switch s {
+ case StateConnecting, StateSynSent, StateSynRecv:
+ return true
+ default:
+ return false
+ }
+}
+
+// handshake returns true when s is one of the states representing an endpoint
+// in the middle of a TCP handshake.
+func (s EndpointState) handshake() bool {
+ switch s {
+ case StateSynSent, StateSynRecv:
+ return true
+ default:
+ return false
+ }
+}
+
+// closed returns true when s is one of the states an endpoint transitions to
+// when closed or when it encounters an error. This is distinct from a newly
+// initialized endpoint that was never connected.
+func (s EndpointState) closed() bool {
+ switch s {
+ case StateClose, StateError:
+ return true
+ default:
+ return false
+ }
+}
+
// String implements fmt.Stringer.String.
func (s EndpointState) String() string {
switch s {
@@ -361,7 +396,8 @@ type endpoint struct {
mu sync.Mutex `state:"nosave"`
ownedByUser uint32
- // state must be read/set using the EndpointState()/setEndpointState() methods.
+ // state must be read/set using the EndpointState()/setEndpointState()
+ // methods.
state EndpointState `state:".(EndpointState)"`
// origEndpointState is only used during a restore phase to save the
@@ -370,8 +406,8 @@ type endpoint struct {
origEndpointState EndpointState `state:"nosave"`
isPortReserved bool `state:"manual"`
- isRegistered bool
- boundNICID tcpip.NICID `state:"manual"`
+ isRegistered bool `state:"manual"`
+ boundNICID tcpip.NICID
route stack.Route `state:"manual"`
ttl uint8
v6only bool
@@ -380,10 +416,14 @@ type endpoint struct {
// disabling SO_BROADCAST, albeit as a NOOP.
broadcast bool
+ // portFlags stores the current values of port related flags.
+ portFlags ports.Flags
+
// Values used to reserve a port or register a transport endpoint
// (which ever happens first).
boundBindToDevice tcpip.NICID
boundPortFlags ports.Flags
+ boundDest tcpip.FullAddress
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -391,7 +431,7 @@ type endpoint struct {
// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
// address).
- effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"`
+ effectiveNetProtos []tcpip.NetworkProtocolNumber
// workerRunning specifies if a worker goroutine is running.
workerRunning bool
@@ -427,9 +467,6 @@ type endpoint struct {
// sack holds TCP SACK related information for this endpoint.
sack SACKInfo
- // reusePort is set to true if SO_REUSEPORT is enabled.
- reusePort bool
-
// bindToDevice is set to the NIC on which to bind or disabled if 0.
bindToDevice tcpip.NICID
@@ -449,7 +486,6 @@ type endpoint struct {
// The options below aren't implemented, but we remember the user
// settings because applications expect to be able to set/query these
// options.
- reuseAddr bool
// slowAck holds the negated state of quick ack. It is stubbed out and
// does nothing.
@@ -799,7 +835,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
rcvBufSize: DefaultReceiveBufferSize,
sndBufSize: DefaultSendBufferSize,
sndMTU: int(math.MaxInt32),
- reuseAddr: true,
keepalive: keepalive{
// Linux defaults.
idle: 2 * time.Hour,
@@ -986,14 +1021,15 @@ func (e *endpoint) closeNoShutdownLocked() {
// in Listen() when trying to register.
if e.EndpointState() == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
e.isPortReserved = false
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
+ e.boundDest = tcpip.FullAddress{}
}
// Mark endpoint as closed.
@@ -1051,16 +1087,17 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
e.isPortReserved = false
}
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
+ e.boundDest = tcpip.FullAddress{}
e.route.Release()
e.stack.CompleteTransportEndpointCleanup(e)
@@ -1172,14 +1209,27 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
+func (e *endpoint) takeLastError() *tcpip.Error {
+ e.lastErrorMu.Lock()
+ defer e.lastErrorMu.Unlock()
+ err := e.lastError
+ e.lastError = nil
+ return err
}
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
+ defer e.UnlockUser()
+
+ // When in SYN-SENT state, let the caller block on the receive.
+ // An application can initiate a non-blocking connect and then block
+ // on a receive. It can expect to read any data after the handshake
+ // is complete. RFC793, section 3.9, p58.
+ if e.EndpointState() == StateSynSent {
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ }
+
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data. Also note that a RST being received
// would cause the state to become StateError so we should allow the
@@ -1189,7 +1239,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.HardError
- e.UnlockUser()
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
@@ -1199,7 +1248,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
v, err := e.readLocked()
e.rcvListMu.Unlock()
- e.UnlockUser()
if err == tcpip.ErrClosedForReceive {
e.stats.ReadErrors.ReadClosed.Increment()
@@ -1486,12 +1534,12 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
case tcpip.ReuseAddressOption:
e.LockUser()
- e.reuseAddr = v
+ e.portFlags.TupleOnly = v
e.UnlockUser()
case tcpip.ReusePortOption:
e.LockUser()
- e.reusePort = v
+ e.portFlags.LoadBalanced = v
e.UnlockUser()
case tcpip.V6OnlyOption:
@@ -1549,6 +1597,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.UnlockUser()
e.notifyProtocolGoroutine(notifyMSSChanged)
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if attempting to set this option to
+ // anything other than path MTU discovery disabled.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -1745,6 +1800,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.deferAccept = time.Duration(v)
e.UnlockUser()
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
default:
return nil
}
@@ -1795,14 +1853,14 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
case tcpip.ReuseAddressOption:
e.LockUser()
- v := e.reuseAddr
+ v := e.portFlags.TupleOnly
e.UnlockUser()
return v, nil
case tcpip.ReusePortOption:
e.LockUser()
- v := e.reusePort
+ v := e.portFlags.LoadBalanced
e.UnlockUser()
return v, nil
@@ -1819,6 +1877,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return v, nil
+ case tcpip.MulticastLoopOption:
+ return true, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -1853,6 +1914,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
v := header.TCPDefaultMSS
return v, nil
+ case tcpip.MTUDiscoverOption:
+ // Always return the path MTU discovery disabled setting since
+ // it's the only one supported.
+ return tcpip.PMTUDiscoveryDont, nil
+
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
@@ -1886,6 +1952,9 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.UnlockUser()
return v, nil
+ case tcpip.MulticastTTLOption:
+ return 1, nil
+
default:
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -1895,11 +1964,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
- e.lastErrorMu.Lock()
- err := e.lastError
- e.lastError = nil
- e.lastErrorMu.Unlock()
- return err
+ return e.takeLastError()
case *tcpip.BindToDeviceOption:
e.LockUser()
@@ -1952,6 +2017,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = tcpip.TCPDeferAcceptOption(e.deferAccept)
e.UnlockUser()
+ case *tcpip.OriginalDestinationOption:
+ ipt := e.stack.IPTables()
+ addr, port, err := ipt.OriginalDst(e.ID)
+ if err != nil {
+ return err
+ }
+ *o = tcpip.OriginalDestinationOption{
+ Addr: addr,
+ Port: port,
+ }
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -2049,8 +2125,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
defer r.Release()
- origID := e.ID
-
netProtos := []tcpip.NetworkProtocolNumber{netProto}
e.ID.LocalAddress = r.LocalAddress
e.ID.RemoteAddress = r.RemoteAddress
@@ -2058,7 +2132,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if e.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
if err != nil {
return err
}
@@ -2085,39 +2159,33 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- // reusePort is false below because connect cannot reuse a port even if
- // reusePort was set.
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
return false, nil
}
id := e.ID
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
- case nil:
- // Port picking successful. Save the details of
- // the selected port.
- e.ID = id
- e.boundBindToDevice = e.bindToDevice
- return true, nil
- case tcpip.ErrPortInUse:
- return false, nil
- default:
+ if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr)
+ if err == tcpip.ErrPortInUse {
+ return false, nil
+ }
return false, err
}
+
+ // Port picking successful. Save the details of
+ // the selected port.
+ e.ID = id
+ e.isPortReserved = true
+ e.boundBindToDevice = e.bindToDevice
+ e.boundPortFlags = e.portFlags
+ e.boundDest = addr
+ return true, nil
}); err != nil {
return err
}
}
- // Remove the port reservation. This can happen when Bind is called
- // before Connect: in such a case we don't want to hold on to
- // reservations anymore.
- if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
- e.isPortReserved = false
- }
-
e.isRegistered = true
e.setEndpointState(StateConnecting)
e.route = r.Clone()
@@ -2296,7 +2364,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil {
return err
}
@@ -2388,16 +2456,13 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- flags := ports.Flags{
- LoadBalanced: e.reusePort,
- }
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
if err != nil {
return err
}
e.boundBindToDevice = e.bindToDevice
- e.boundPortFlags = flags
+ e.boundPortFlags = e.portFlags
e.isPortReserved = true
e.effectiveNetProtos = netProtos
e.ID.LocalPort = port
@@ -2405,7 +2470,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// Any failures beyond this point must remove the port registration.
defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{})
e.isPortReserved = false
e.effectiveNetProtos = nil
e.ID.LocalPort = 0
@@ -2428,6 +2493,10 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
e.ID.LocalAddress = addr.Addr
}
+ if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil {
+ return err
+ }
+
// Mark endpoint as bound.
e.setEndpointState(StateBound)
@@ -2462,7 +2531,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}, nil
}
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// TCP HandlePacket is not required anymore as inbound packets first
// land at the Dispatcher which then can either delivery using the
// worker go routine or directly do the invoke the tcp processing inline
@@ -2481,7 +2550,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool {
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
switch typ {
case stack.ControlPacketTooBig:
e.sndBufMu.Lock()
@@ -2492,6 +2561,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.sndBufMu.Unlock()
e.notifyProtocolGoroutine(notifyMTUChanged)
+
+ case stack.ControlNoRoute:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNoRoute
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
+
+ case stack.ControlNetworkUnreachable:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNetworkUnreachable
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
}
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index fc43c11e2..abf1ac5c9 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -49,11 +49,10 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
defer e.mu.Unlock()
- switch e.EndpointState() {
- case StateInitial, StateBound:
- // TODO(b/138137272): this enumeration duplicates
- // EndpointState.connected. remove it.
- case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ epState := e.EndpointState()
+ switch {
+ case epState == StateInitial || epState == StateBound:
+ case epState.connected() || epState.handshake():
if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
@@ -69,15 +68,16 @@ func (e *endpoint) beforeSave() {
break
}
fallthrough
- case StateListen, StateConnecting:
+ case epState == StateListen || epState == StateConnecting:
e.drainSegmentLocked()
- if e.EndpointState() != StateClose && e.EndpointState() != StateError {
+ // Refresh epState, since drainSegmentLocked may have changed it.
+ epState = e.EndpointState()
+ if !epState.closed() {
if !e.workerRunning {
panic("endpoint has no worker running in listen, connecting, or connected state")
}
- break
}
- case StateError, StateClose:
+ case epState.closed():
for e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
@@ -93,10 +93,6 @@ func (e *endpoint) beforeSave() {
if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
panic("endpoint still has waiters upon save")
}
-
- if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) {
- panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state")
- }
}
// saveAcceptedChan is invoked by stateify.
@@ -148,23 +144,23 @@ var connectingLoading sync.WaitGroup
// Bound endpoint loading happens last.
// loadState is invoked by stateify.
-func (e *endpoint) loadState(state EndpointState) {
+func (e *endpoint) loadState(epState EndpointState) {
// This is to ensure that the loading wait groups include all applicable
// endpoints before any asynchronous calls to the Wait() methods.
// For restore purposes we treat TimeWait like a connected endpoint.
- if state.connected() || state == StateTimeWait {
+ if epState.connected() || epState == StateTimeWait {
connectedLoading.Add(1)
}
- switch state {
- case StateListen:
+ switch {
+ case epState == StateListen:
listenLoading.Add(1)
- case StateConnecting, StateSynSent, StateSynRecv:
+ case epState.connecting():
connectingLoading.Add(1)
}
// Directly update the state here rather than using e.setEndpointState
// as the endpoint is still being loaded and the stack reference is not
// yet initialized.
- atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+ atomic.StoreUint32((*uint32)(&e.state), uint32(epState))
}
// afterLoad is invoked by stateify.
@@ -183,33 +179,40 @@ func (e *endpoint) afterLoad() {
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
e.segmentQueue.setLimit(MaxUnprocessedSegments)
- state := e.origEndpointState
- switch state {
+ epState := e.origEndpointState
+ switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss SendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
}
- if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
- panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
+ }
+
+ var rs ReceiveBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max {
+ panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max))
}
}
}
bind := func() {
- if len(e.BindAddr) == 0 {
- e.BindAddr = e.ID.LocalAddress
+ addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort})
+ if err != nil {
+ panic("unable to parse BindAddr: " + err.String())
}
- addr := e.BindAddr
- port := e.ID.LocalPort
- if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil {
- panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err))
+ if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok {
+ panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
}
+ e.isPortReserved = true
+
+ // Mark endpoint as bound.
+ e.setEndpointState(StateBound)
}
- switch state {
- case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ switch {
+ case epState.connected():
bind()
if len(e.connectingAddress) == 0 {
e.connectingAddress = e.ID.RemoteAddress
@@ -232,13 +235,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
closed := e.closed
e.mu.Unlock()
e.notifyProtocolGoroutine(notifyTickleWorker)
- if state == StateFinWait2 && closed {
+ if epState == StateFinWait2 && closed {
// If the endpoint has been closed then make sure we notify so
// that the FIN_WAIT2 timer is started after a restore.
e.notifyProtocolGoroutine(notifyClose)
}
connectedLoading.Done()
- case StateListen:
+ case epState == StateListen:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -255,7 +258,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
listenLoading.Done()
tcpip.AsyncLoading.Done()
}()
- case StateConnecting, StateSynSent, StateSynRecv:
+ case epState.connecting():
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -267,7 +270,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
connectingLoading.Done()
tcpip.AsyncLoading.Done()
}()
- case StateBound:
+ case epState == StateBound:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -276,27 +279,16 @@ func (e *endpoint) Resume(s *stack.Stack) {
bind()
tcpip.AsyncLoading.Done()
}()
- case StateClose:
- if e.isPortReserved {
- tcpip.AsyncLoading.Add(1)
- go func() {
- connectedLoading.Wait()
- listenLoading.Wait()
- connectingLoading.Wait()
- bind()
- e.setEndpointState(StateClose)
- tcpip.AsyncLoading.Done()
- }()
- }
+ case epState == StateClose:
+ e.isPortReserved = false
e.state = StateClose
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
- case StateError:
+ case epState == StateError:
e.state = StateError
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
-
}
// saveLastError is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 704d01c64..070b634b4 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -61,7 +61,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
s := newSegment(r, id, pkt)
defer s.decRef()
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2a2a7ddeb..b34e47bbd 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -21,6 +21,7 @@
package tcp
import (
+ "fmt"
"runtime"
"strings"
"time"
@@ -70,34 +71,36 @@ const (
DefaultSynRetries = 6
)
-// SACKEnabled option can be used to enable SACK support in the TCP
-// protocol. See: https://tools.ietf.org/html/rfc2018.
+const (
+ ccReno = "reno"
+ ccCubic = "cubic"
+)
+
+// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to
+// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018.
type SACKEnabled bool
-// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol.
+// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to
+// enable/disable Nagle's algorithm in TCP.
type DelayEnabled bool
-// SendBufferSizeOption allows the default, min and max send buffer sizes for
-// TCP endpoints to be queried or configured.
+// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption
+// to get/set the default, min and max TCP send buffer sizes.
type SendBufferSizeOption struct {
Min int
Default int
Max int
}
-// ReceiveBufferSizeOption allows the default, min and max receive buffer size
-// for TCP endpoints to be queried or configured.
+// ReceiveBufferSizeOption is used by
+// stack.(Stack*).TransportProtocolOption to get/set the default, min and max
+// TCP receive buffer sizes.
type ReceiveBufferSizeOption struct {
Min int
Default int
Max int
}
-const (
- ccReno = "reno"
- ccCubic = "cubic"
-)
-
// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The
// value is protected by a mutex so that we can increment only when it's
// guaranteed not to go above a threshold.
@@ -171,7 +174,7 @@ type protocol struct {
maxRetries uint32
synRcvdCount synRcvdCounter
synRetries uint8
- dispatcher *dispatcher
+ dispatcher dispatcher
}
// Number returns the tcp protocol number.
@@ -206,7 +209,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// to a specific processing queue. Each queue is serviced by its own processor
// goroutine which is responsible for dequeuing and doing full TCP dispatch of
// the packet.
-func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
p.dispatcher.queuePacket(r, ep, id, pkt)
}
@@ -217,7 +220,7 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
s := newSegment(r, id, pkt)
defer s.decRef()
@@ -490,20 +493,49 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter {
return &p.synRcvdCount
}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize)
+ if !ok {
+ return false
+ }
+
+ // If the header has options, pull those up as well.
+ if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() {
+ hdr, ok = pkt.Data.PullUp(offset)
+ if !ok {
+ panic(fmt.Sprintf("There should be at least %d bytes in pkt.Data.", offset))
+ }
+ }
+
+ pkt.TransportHeader = hdr
+ pkt.Data.TrimFront(len(hdr))
+ return true
+}
+
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
- return &protocol{
- sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
+ p := protocol{
+ sendBufferSize: SendBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultSendBufferSize,
+ Max: MaxBufferSize,
+ },
+ recvBufferSize: ReceiveBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultReceiveBufferSize,
+ Max: MaxBufferSize,
+ },
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
tcpLingerTimeout: DefaultTCPLingerTimeout,
tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
- dispatcher: newDispatcher(runtime.GOMAXPROCS(0)),
synRetries: DefaultSynRetries,
minRTO: MinRTO,
maxRTO: MaxRTO,
maxRetries: MaxRetries,
}
+ p.dispatcher.init(runtime.GOMAXPROCS(0))
+ return &p
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index dd89a292a..5e0bfe585 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -372,7 +372,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
// We only store the segment if it's within our buffer
// size limit.
if r.pendingBufUsed < r.pendingBufSize {
- r.pendingBufUsed += s.logicalLen()
+ r.pendingBufUsed += seqnum.Size(s.segMemSize())
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
@@ -406,7 +406,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
}
heap.Pop(&r.pendingRcvdSegments)
- r.pendingBufUsed -= s.logicalLen()
+ r.pendingBufUsed -= seqnum.Size(s.segMemSize())
s.decRef()
}
return false, nil
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 074edded6..bb60dc29d 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -35,6 +35,7 @@ type segment struct {
id stack.TransportEndpointID `state:"manual"`
route stack.Route `state:"manual"`
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View `state:"nosave"`
@@ -60,13 +61,14 @@ type segment struct {
xmitCount uint32
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) *segment {
+func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
s := &segment{
refCnt: 1,
id: id,
route: r.Clone(),
}
s.data = pkt.Data.Clone(s.views[:])
+ s.hdr = header.TCP(pkt.TransportHeader)
s.rcvdTime = time.Now()
return s
}
@@ -136,6 +138,12 @@ func (s *segment) logicalLen() seqnum.Size {
return l
}
+// segMemSize is the amount of memory used to hold the segment data and
+// the associated metadata.
+func (s *segment) segMemSize() int {
+ return segSize + s.data.Size()
+}
+
// parse populates the sequence & ack numbers, flags, and window fields of the
// segment from the TCP header stored in the data. It then updates the view to
// skip the header.
@@ -146,12 +154,6 @@ func (s *segment) logicalLen() seqnum.Size {
// TCP checksum and stores the checksum and result of checksum verification in
// the csum and csumValid fields of the segment.
func (s *segment) parse() bool {
- h, ok := s.data.PullUp(header.TCPMinimumSize)
- if !ok {
- return false
- }
- hdr := header.TCP(h)
-
// h is the header followed by the payload. We check that the offset to
// the data respects the following constraints:
// 1. That it's at least the minimum header size; if we don't do this
@@ -162,16 +164,12 @@ func (s *segment) parse() bool {
// N.B. The segment has already been validated as having at least the
// minimum TCP size before reaching here, so it's safe to read the
// fields.
- offset := int(hdr.DataOffset())
- if offset < header.TCPMinimumSize {
- return false
- }
- hdrWithOpts, ok := s.data.PullUp(offset)
- if !ok {
+ offset := int(s.hdr.DataOffset())
+ if offset < header.TCPMinimumSize || offset > len(s.hdr) {
return false
}
- s.options = []byte(hdrWithOpts[header.TCPMinimumSize:])
+ s.options = []byte(s.hdr[header.TCPMinimumSize:])
s.parsedOptions = header.ParseTCPOptions(s.options)
// Query the link capabilities to decide if checksum validation is
@@ -180,22 +178,19 @@ func (s *segment) parse() bool {
if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 {
s.csumValid = true
verifyChecksum = false
- s.data.TrimFront(offset)
}
if verifyChecksum {
- hdr = header.TCP(hdrWithOpts)
- s.csum = hdr.Checksum()
- xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()))
- xsum = hdr.CalculateChecksum(xsum)
- s.data.TrimFront(offset)
+ s.csum = s.hdr.Checksum()
+ xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr)))
+ xsum = s.hdr.CalculateChecksum(xsum)
xsum = header.ChecksumVV(s.data, xsum)
s.csumValid = xsum == 0xffff
}
- s.sequenceNumber = seqnum.Value(hdr.SequenceNumber())
- s.ackNumber = seqnum.Value(hdr.AckNumber())
- s.flags = hdr.Flags()
- s.window = seqnum.Size(hdr.WindowSize())
+ s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber())
+ s.ackNumber = seqnum.Value(s.hdr.AckNumber())
+ s.flags = s.hdr.Flags()
+ s.window = seqnum.Size(s.hdr.WindowSize())
return true
}
diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go
new file mode 100644
index 000000000..0ab7b8f56
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_unsafe.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "unsafe"
+)
+
+const (
+ segSize = int(unsafe.Sizeof(segment{}))
+)
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 06dc9b7d7..5862c32f2 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -618,6 +618,20 @@ func (s *sender) splitSeg(seg *segment, size int) {
nSeg.data.TrimFront(size)
nSeg.sequenceNumber.UpdateForward(seqnum.Size(size))
s.writeList.InsertAfter(seg, nSeg)
+
+ // The segment being split does not carry PUSH flag because it is
+ // followed by the newly split segment.
+ // RFC1122 section 4.2.2.2: MUST set the PSH bit in the last buffered
+ // segment (i.e., when there is no more queued data to be sent).
+ // Linux removes PSH flag only when the segment is being split over MSS
+ // and retains it when we are splitting the segment over lack of sender
+ // window space.
+ // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point()
+ // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test()
+ if seg.data.Size() > s.maxPayloadSize {
+ seg.flags ^= header.TCPFlagPsh
+ }
+
seg.data.CapLength(size)
}
@@ -739,7 +753,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
if !s.isAssignedSequenceNumber(seg) {
// Merge segments if allowed.
if seg.data.Size() != 0 {
- available := int(seg.sequenceNumber.Size(end))
+ available := int(s.sndNxt.Size(end))
if available > limit {
available = limit
}
@@ -782,8 +796,11 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// sent all at once.
return false
}
- if atomic.LoadUint32(&s.ep.cork) != 0 {
- // Hold back the segment until full.
+ // With TCP_CORK, hold back until minimum of the available
+ // send space and MSS.
+ // TODO(gvisor.dev/issue/2833): Drain the held segments after a
+ // timeout.
+ if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 {
return false
}
}
@@ -824,10 +841,52 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
if available == 0 {
return false
}
+
+ // If the whole segment or at least 1MSS sized segment cannot
+ // be accomodated in the receiver advertized window, skip
+ // splitting and sending of the segment. ref:
+ // net/ipv4/tcp_output.c::tcp_snd_wnd_test()
+ //
+ // Linux checks this for all segment transmits not triggered by
+ // a probe timer. On this condition, it defers the segment split
+ // and transmit to a short probe timer.
+ //
+ // ref: include/net/tcp.h::tcp_check_probe_timer()
+ // ref: net/ipv4/tcp_output.c::tcp_write_wakeup()
+ //
+ // Instead of defining a new transmit timer, we attempt to split
+ // the segment right here if there are no pending segments. If
+ // there are pending segments, segment transmits are deferred to
+ // the retransmit timer handler.
+ if s.sndUna != s.sndNxt {
+ switch {
+ case available >= seg.data.Size():
+ // OK to send, the whole segments fits in the
+ // receiver's advertised window.
+ case available >= s.maxPayloadSize:
+ // OK to send, at least 1 MSS sized segment fits
+ // in the receiver's advertised window.
+ default:
+ return false
+ }
+ }
+
+ // The segment size limit is computed as a function of sender
+ // congestion window and MSS. When sender congestion window is >
+ // 1, this limit can be larger than MSS. Ensure that the
+ // currently available send space is not greater than minimum of
+ // this limit and MSS.
if available > limit {
available = limit
}
+ // If GSO is not in use then cap available to
+ // maxPayloadSize. When GSO is in use the gVisor GSO logic or
+ // the host GSO logic will cap the segment to the correct size.
+ if s.ep.gso == nil && available > s.maxPayloadSize {
+ available = s.maxPayloadSize
+ }
+
if seg.data.Size() > available {
s.splitSeg(seg, available)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 5fe23113b..b9993ce1a 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -50,7 +50,7 @@ func TestFastRecovery(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -90,14 +90,14 @@ func TestFastRecovery(t *testing.T) {
// Wait before checking metrics.
metricPollFn := func() error {
if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want)
}
return nil
}
@@ -128,10 +128,10 @@ func TestFastRecovery(t *testing.T) {
// Wait before checking metrics.
metricPollFn = func() error {
if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
- return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
}
return nil
}
@@ -215,7 +215,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
expected := tcp.InitialCwnd
@@ -257,7 +257,7 @@ func TestCongestionAvoidance(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -362,7 +362,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -471,11 +471,11 @@ func TestRetransmit(t *testing.T) {
// MTU size though.
half := data[:len(data)/2]
if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
half = data[len(data)/2:]
if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -508,23 +508,23 @@ func TestRetransmit(t *testing.T) {
metricPollFn := func() error {
if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want)
}
return nil
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index ace79b7b2..99521f0c1 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -47,7 +47,7 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
t.Helper()
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err)
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err)
}
}
@@ -400,7 +400,7 @@ func TestSACKRecovery(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -454,7 +454,7 @@ func TestSACKRecovery(t *testing.T) {
}
for _, s := range stats {
if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %v, want = %v", s.name, got, want)
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
}
}
return nil
@@ -529,19 +529,19 @@ func TestSACKRecovery(t *testing.T) {
// In SACK recovery only the first segment is fast retransmitted when
// entering recovery.
if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want)
+ return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want)
}
if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
- return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
}
if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
- return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want)
+ return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want)
}
return nil
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 6ef32a1b3..fb25b86b9 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -57,7 +57,7 @@ func TestGiveUpConnect(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Register for notification, then start connection attempt.
@@ -66,7 +66,7 @@ func TestGiveUpConnect(t *testing.T) {
defer wq.EventUnregister(&waitEntry)
if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
// Close the connection, wait for completion.
@@ -75,21 +75,21 @@ func TestGiveUpConnect(t *testing.T) {
// Wait for ep to become writable.
<-notifyCh
if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted)
+ t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %s, want = %s", err, tcpip.ErrAborted)
}
// Call Connect again to retreive the handshake failure status
// and stats updates.
if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted)
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted)
}
if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
}
@@ -102,7 +102,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want)
}
}
@@ -115,10 +115,10 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want)
+ t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want)
}
}
@@ -129,20 +129,38 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
stats := c.Stack().Stats()
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
want := stats.TCP.FailedConnectionAttempts.Value() + 1
if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute {
- t.Errorf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
+ t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute)
}
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want)
+ t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want)
+ }
+}
+
+func TestCloseWithoutConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ c.EP.Close()
+
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -156,10 +174,10 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.SegmentsSent.Value(); got != want {
- t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want {
- t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want)
+ t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want)
}
}
@@ -170,16 +188,16 @@ func TestTCPResetsSentIncrement(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
want := stats.TCP.SegmentsSent.Value() + 1
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Send a SYN request.
@@ -213,7 +231,7 @@ func TestTCPResetsSentIncrement(t *testing.T) {
metricPollFn := func() error {
if got := stats.TCP.ResetsSent.Value(); got != want {
- return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want)
+ return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want)
}
return nil
}
@@ -292,7 +310,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
// are released instantly on Close.
tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil {
- t.Fatalf("e.stack.SetTransportProtocolOption(%d, %v) = %v", tcp.ProtocolNumber, tcpTW, err)
+ t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err)
}
c.EP.Close()
@@ -355,7 +373,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) {
})
if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
}
}
@@ -379,7 +397,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) {
})
if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
}
c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond)
}
@@ -403,7 +421,7 @@ func TestNonBlockingClose(t *testing.T) {
t0 := time.Now()
ep.Close()
if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %v", diff)
+ t.Fatalf("Took too long to close: %s", diff)
}
}
@@ -415,7 +433,7 @@ func TestConnectResetAfterClose(t *testing.T) {
// after 3 second in FIN_WAIT2 state.
tcpLingerTimeout := 3 * time.Second
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err)
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err)
}
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
@@ -497,11 +515,11 @@ func TestCurrentConnectedIncrement(t *testing.T) {
c.EP = nil
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 1", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got)
}
gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value()
if gotConnected != 1 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 1", gotConnected)
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected)
}
ep.Close()
@@ -524,10 +542,10 @@ func TestCurrentConnectedIncrement(t *testing.T) {
})
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = %v", got, gotConnected)
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected)
}
// Ack and send FIN as well.
@@ -556,10 +574,10 @@ func TestCurrentConnectedIncrement(t *testing.T) {
time.Sleep(1200 * time.Millisecond)
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -575,7 +593,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
c.EP = nil
if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
}
// Send a FIN for ESTABLISHED --> CLOSED-WAIT
@@ -603,7 +621,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
time.Sleep(10 * time.Millisecond)
if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
}
// Close the application endpoint for CLOSE_WAIT --> LAST_ACK
@@ -620,7 +638,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
)
if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Pause the endpoint`s protocolMainLoop.
@@ -657,15 +675,15 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
// Expect the endpoint to be closed.
if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
// Check if the endpoint was moved to CLOSED and netstack a reset in
@@ -691,7 +709,7 @@ func TestSimpleReceive(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -714,7 +732,7 @@ func TestSimpleReceive(t *testing.T) {
// Receive data.
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if !bytes.Equal(data, v) {
@@ -781,7 +799,7 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
// Start connection attempt to IPv4 address.
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet with our user supplied MSS.
@@ -842,7 +860,7 @@ func TestUserSuppliedMSSOnConnectV6(t *testing.T) {
// Start connection attempt to IPv6 address.
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet with our user supplied MSS.
@@ -1239,7 +1257,7 @@ func TestConnectBindToDevice(t *testing.T) {
defer c.WQ.EventUnregister(&waitEntry)
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet.
@@ -1251,7 +1269,7 @@ func TestConnectBindToDevice(t *testing.T) {
),
)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
}
tcpHdr := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
@@ -1270,74 +1288,97 @@ func TestConnectBindToDevice(t *testing.T) {
c.GetPacket()
if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
- t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
}
})
}
}
-func TestRstOnSynSent(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
+func TestSynSent(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ reset bool
+ }{
+ {"RstOnSynSent", true},
+ {"CloseOnSynSent", false},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
- // Create an endpoint, don't handshake because we want to interfere with the
- // handshake process.
- c.Create(-1)
+ // Create an endpoint, don't handshake because we want to interfere with the
+ // handshake process.
+ c.Create(-1)
- // Start connection attempt.
- waitEntry, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventOut)
- defer c.WQ.EventUnregister(&waitEntry)
+ // Start connection attempt.
+ waitEntry, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
- addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
- if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
- t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, tcpip.ErrConnectStarted)
- }
+ addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
+ if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted)
+ }
- // Receive SYN packet.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
- // Ensure that we've reached SynSent state
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
- // Send a packet with a proper ACK and a RST flag to cause the socket
- // to Error and close out
- iss := seqnum.Value(789)
- rcvWnd := seqnum.Size(30000)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagRst | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- TCPOpts: nil,
- })
+ if test.reset {
+ // Send a packet with a proper ACK and a RST flag to cause the socket
+ // to error and close out.
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagRst | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+ } else {
+ c.EP.Close()
+ }
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(3 * time.Second):
- t.Fatal("timed out waiting for packet to arrive")
- }
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for packet to arrive")
+ }
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrConnectionRefused)
- }
+ if test.reset {
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused)
+ }
+ } else {
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrAborted {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrAborted)
+ }
+ }
- // Due to the RST the endpoint should be in an error state.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+
+ // Due to the RST the endpoint should be in an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+ })
}
}
@@ -1352,7 +1393,7 @@ func TestOutOfOrderReceive(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send second half of data first, with seqnum 3 ahead of expected.
@@ -1379,7 +1420,7 @@ func TestOutOfOrderReceive(t *testing.T) {
// Wait 200ms and check that no data has been received.
time.Sleep(200 * time.Millisecond)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send the first 3 bytes now.
@@ -1406,7 +1447,7 @@ func TestOutOfOrderReceive(t *testing.T) {
}
continue
}
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
read = append(read, v...)
@@ -1436,7 +1477,7 @@ func TestOutOfOrderFlood(t *testing.T) {
c.CreateConnected(789, 30000, 10)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send 100 packets before the actual one that is expected.
@@ -1513,7 +1554,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -1556,7 +1597,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// This final ACK should be ignored because an ACK on a reset doesn't mean
@@ -1582,7 +1623,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -1624,7 +1665,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
))
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Cause a RST to be generated by closing the read end now since we have
@@ -1643,7 +1684,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// The ACK to the FIN should now be rejected since the connection has been
@@ -1665,19 +1706,19 @@ func TestShutdownRead(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
- t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want)
+ t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
}
}
@@ -1693,7 +1734,7 @@ func TestFullWindowReceive(t *testing.T) {
_, _, err := c.EP.Read(nil)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
// Fill up the window.
@@ -1728,7 +1769,7 @@ func TestFullWindowReceive(t *testing.T) {
// Receive data and check it.
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if !bytes.Equal(data, v) {
@@ -1737,7 +1778,7 @@ func TestFullWindowReceive(t *testing.T) {
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want {
- t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want)
+ t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want)
}
// Check that we get an ACK for the newly non-zero window.
@@ -1760,7 +1801,7 @@ func TestNoWindowShrinking(t *testing.T) {
c.CreateConnected(789, 30000, 10)
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
}
we, ch := waiter.NewChannelEntry(nil)
@@ -1768,7 +1809,7 @@ func TestNoWindowShrinking(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send 3 bytes, check that the peer acknowledges them.
@@ -1832,7 +1873,7 @@ func TestNoWindowShrinking(t *testing.T) {
for len(read) < len(data) {
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
read = append(read, v...)
@@ -1866,7 +1907,7 @@ func TestSimpleSend(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received.
@@ -1908,7 +1949,7 @@ func TestZeroWindowSend(t *testing.T) {
_, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
if err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check if we got a zero-window probe.
@@ -1976,7 +2017,7 @@ func TestScaledWindowConnect(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xbfff,
@@ -2008,7 +2049,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xffff,
@@ -2036,21 +2077,21 @@ func TestScaledWindowAccept(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake.
@@ -2068,7 +2109,7 @@ func TestScaledWindowAccept(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2081,7 +2122,7 @@ func TestScaledWindowAccept(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xbfff,
@@ -2109,21 +2150,21 @@ func TestNonScaledWindowAccept(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN
@@ -2142,7 +2183,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2155,7 +2196,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xffff,
@@ -2244,7 +2285,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
for sz < defaultMTU {
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
sz += len(v)
}
@@ -2311,7 +2352,7 @@ func TestSegmentMerging(t *testing.T) {
allData = append(allData, data...)
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2381,7 +2422,7 @@ func TestDelay(t *testing.T) {
allData = append(allData, data...)
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2428,7 +2469,7 @@ func TestUndelay(t *testing.T) {
for i, data := range allData {
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2512,7 +2553,7 @@ func TestMSSNotDelayed(t *testing.T) {
for i, data := range allData {
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2563,7 +2604,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received in chunks.
@@ -2631,7 +2672,7 @@ func TestSetTTL(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
@@ -2639,7 +2680,7 @@ func TestSetTTL(t *testing.T) {
}
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %s", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet.
@@ -2671,7 +2712,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
@@ -2683,11 +2724,11 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake.
@@ -2705,7 +2746,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2794,7 +2835,7 @@ func TestForwarderSendMSSLessThanMTU(t *testing.T) {
select {
case err := <-ch:
if err != nil {
- t.Fatalf("Error creating endpoint: %v", err)
+ t.Fatalf("Error creating endpoint: %s", err)
}
case <-time.After(2 * time.Second):
t.Fatalf("Timed out waiting for connection")
@@ -2813,7 +2854,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Set the buffer size to a deterministic size so that we can check the
@@ -2830,7 +2871,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
// Receive SYN packet.
@@ -2884,7 +2925,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
select {
case <-ch:
if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("GetSockOpt failed: %s", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for connection")
@@ -2899,22 +2940,22 @@ func TestCloseListener(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Close the listener and measure how long it takes.
t0 := time.Now()
ep.Close()
if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %v", diff)
+ t.Fatalf("Took too long to close: %s", diff)
}
}
@@ -2950,22 +2991,25 @@ loop:
case tcpip.ErrConnectionReset:
break loop
default:
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
// Expect the state to be StateError and subsequent Reads to fail with HardError.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
- t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got)
+ t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -2990,7 +3034,7 @@ func TestSendOnResetConnection(t *testing.T) {
// Try to write.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
@@ -3013,7 +3057,7 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
_, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
if err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Expect first transmit and MaxRetries retransmits.
@@ -3048,7 +3092,10 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
)
if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -3066,7 +3113,7 @@ func TestMaxRTO(t *testing.T) {
_, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
if err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
checker.TCP(
@@ -3089,6 +3136,63 @@ func TestMaxRTO(t *testing.T) {
}
}
+// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is
+// unique on retransmits.
+func TestRetransmitIPv4IDUniqueness(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ size int
+ }{
+ {"1Byte", 1},
+ {"512Bytes", 512},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ // Disabling PMTU discovery causes all packets sent from this socket to
+ // have DF=0. This needs to be done because the IPv4 ID uniqueness
+ // applies only to non-atomic IPv4 datagrams as defined in RFC 6864
+ // Section 4, and datagrams with DF=0 are non-atomic.
+ if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil {
+ t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
+ }
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}}
+ // Expect two retransmitted packets, and that all packets received have
+ // unique IPv4 ID values.
+ for i := 0; i <= 2; i++ {
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ id := header.IPv4(pkt).ID()
+ if _, exists := idSet[id]; exists {
+ t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id)
+ }
+ idSet[id] = struct{}{}
+ }
+ })
+ }
+}
+
func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -3097,7 +3201,7 @@ func TestFinImmediately(t *testing.T) {
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3140,7 +3244,7 @@ func TestFinRetransmit(t *testing.T) {
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3195,7 +3299,7 @@ func TestFinWithNoPendingData(t *testing.T) {
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -3221,7 +3325,7 @@ func TestFinWithNoPendingData(t *testing.T) {
// Shutdown, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3268,7 +3372,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
view := buffer.NewView(10)
for i := tcp.InitialCwnd; i > 0; i-- {
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
}
@@ -3290,7 +3394,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
// because the congestion window doesn't allow it. Wait until a
// retransmit is received.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3354,7 +3458,7 @@ func TestFinWithPendingData(t *testing.T) {
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -3380,7 +3484,7 @@ func TestFinWithPendingData(t *testing.T) {
// Write new data, but don't acknowledge it.
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3396,7 +3500,7 @@ func TestFinWithPendingData(t *testing.T) {
// Shutdown the connection, check that we do get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3441,7 +3545,7 @@ func TestFinWithPartialAck(t *testing.T) {
// FIN from the test side.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -3478,7 +3582,7 @@ func TestFinWithPartialAck(t *testing.T) {
// Write new data, but don't acknowledge it.
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3494,7 +3598,7 @@ func TestFinWithPartialAck(t *testing.T) {
// Shutdown the connection, check that we do get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -3540,20 +3644,20 @@ func TestUpdateListenBacklog(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Update the backlog with another Listen() on the same endpoint.
if err := ep.Listen(20); err != nil {
- t.Fatalf("Listen failed to update backlog: %v", err)
+ t.Fatalf("Listen failed to update backlog: %s", err)
}
ep.Close()
@@ -3585,7 +3689,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
// Send some data. Check that it's capped by the window size.
view := buffer.NewView(65535)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that only data that fits in the scaled window is sent.
@@ -3631,18 +3735,18 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
})
if got := stats.TCP.ValidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want {
- t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want)
+ t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want)
}
// Ensure there were no errors during handshake. If these stats have
// incremented, then the connection should not have been established.
if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0)
+ t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0)
}
if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0)
+ t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %d, want = %d", got, 0)
}
}
@@ -3666,10 +3770,10 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
c.SendSegment(vv)
if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
}
}
@@ -3770,7 +3874,7 @@ func TestReadAfterClosedState(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Shutdown immediately for write, check that we get a FIN.
@@ -3789,7 +3893,7 @@ func TestReadAfterClosedState(t *testing.T) {
)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Send some data and acknowledge the FIN.
@@ -3818,7 +3922,7 @@ func TestReadAfterClosedState(t *testing.T) {
time.Sleep(tcpTimeWaitTimeout * 2)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Wait for receive to be notified.
@@ -3853,11 +3957,11 @@ func TestReadAfterClosedState(t *testing.T) {
// Now that we drained the queue, check that functions fail with the
// right error code.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
}
@@ -3871,66 +3975,84 @@ func TestReusePort(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
// Second case, an endpoint that was bound and is connecting..
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
// Third case, an endpoint that was bound and is listening.
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
}
@@ -3939,11 +4061,11 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
if err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("GetSockOpt failed: %s", err)
}
if int(s) != v {
- t.Fatalf("got receive buffer size = %v, want = %v", s, v)
+ t.Fatalf("got receive buffer size = %d, want = %d", s, v)
}
}
@@ -3952,11 +4074,11 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
if err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("GetSockOpt failed: %s", err)
}
if int(s) != v {
- t.Fatalf("got send buffer size = %v, want = %v", s, v)
+ t.Fatalf("got send buffer size = %d, want = %d", s, v)
}
}
@@ -3969,7 +4091,7 @@ func TestDefaultBufferSizes(t *testing.T) {
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer func() {
if ep != nil {
@@ -3981,28 +4103,34 @@ func TestDefaultBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default send buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize * 2, tcp.DefaultSendBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{
+ Min: 1,
+ Default: tcp.DefaultSendBufferSize * 2,
+ Max: tcp.DefaultSendBufferSize * 20}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
ep.Close()
ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default receive buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize * 3, tcp.DefaultReceiveBufferSize * 30}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{
+ Min: 1,
+ Default: tcp.DefaultReceiveBufferSize * 3,
+ Max: tcp.DefaultReceiveBufferSize * 30}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
ep.Close()
ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
@@ -4018,17 +4146,17 @@ func TestMinMaxBufferSizes(t *testing.T) {
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
// Change the min/max values for send/receive
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultReceiveBufferSize * 2, tcp.DefaultReceiveBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultSendBufferSize * 3, tcp.DefaultSendBufferSize * 30}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Set values below the min.
@@ -4065,12 +4193,12 @@ func TestBindToDeviceOption(t *testing.T) {
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
if err := s.CreateNIC(321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %v", err)
+ t.Errorf("CreateNIC failed: %s", err)
}
// nicIDPtr is used instead of taking the address of NICID literals, which is
@@ -4095,12 +4223,12 @@ func TestBindToDeviceOption(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
+ t.Errorf("SetSockOpt(%#v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
bindToDevice := tcpip.BindToDeviceOption(88888)
if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt got %v, want %v", err, nil)
+ t.Errorf("GetSockOpt got %s, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
t.Errorf("bindToDevice got %d, want %d", got, want)
@@ -4166,12 +4294,12 @@ func TestSelfConnect(t *testing.T) {
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Register for notification, then start connection attempt.
@@ -4180,12 +4308,12 @@ func TestSelfConnect(t *testing.T) {
defer wq.EventUnregister(&waitEntry)
if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
<-notifyCh
if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- t.Fatalf("Connect failed: %v", err)
+ t.Fatalf("Connect failed: %s", err)
}
// Write something.
@@ -4193,7 +4321,7 @@ func TestSelfConnect(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Read back what was written.
@@ -4202,12 +4330,12 @@ func TestSelfConnect(t *testing.T) {
rd, _, err := ep.Read(nil)
if err != nil {
if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
<-notifyCh
rd, _, err = ep.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
}
@@ -4291,7 +4419,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
}
ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
eps = append(eps, ep)
switch network {
@@ -4342,7 +4470,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ {
if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
- t.Fatalf("Bind(%d) failed: %v", i, err)
+ t.Fatalf("Bind(%d) failed: %s", i, err)
}
}
want := tcpip.ErrConnectStarted
@@ -4350,7 +4478,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
want = tcpip.ErrNoPortAvailable
}
if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want {
- t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want)
+ t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want)
}
})
}
@@ -4384,7 +4512,7 @@ func TestPathMTUDiscovery(t *testing.T) {
}
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte {
@@ -4487,7 +4615,7 @@ func TestStackSetCongestionControl(t *testing.T) {
var oldCC tcpip.CongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err)
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err)
}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err {
@@ -4574,12 +4702,12 @@ func TestEndpointSetCongestionControl(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
var oldCC tcpip.CongestionControlOption
if err := c.EP.GetSockOpt(&oldCC); err != nil {
- t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err)
+ t.Fatalf("c.EP.SockOpt(%v) = %s", &oldCC, err)
}
if connected {
@@ -4587,12 +4715,12 @@ func TestEndpointSetCongestionControl(t *testing.T) {
}
if err := c.EP.SetSockOpt(tc.cc); err != tc.err {
- t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err)
+ t.Fatalf("c.EP.SetSockOpt(%v) = %s, want %s", tc.cc, err, tc.err)
}
var cc tcpip.CongestionControlOption
if err := c.EP.GetSockOpt(&cc); err != nil {
- t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err)
+ t.Fatalf("c.EP.SockOpt(%v) = %s", &cc, err)
}
got, want := cc, oldCC
@@ -4615,7 +4743,7 @@ func enableCUBIC(t *testing.T, c *context.Context) {
t.Helper()
opt := tcpip.CongestionControlOption("cubic")
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err)
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err)
}
}
@@ -4657,14 +4785,14 @@ func TestKeepalive(t *testing.T) {
// Check that the connection is still alive.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
view := buffer.NewView(3)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -4744,15 +4872,18 @@ func TestKeepalive(t *testing.T) {
)
if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
}
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -4854,19 +4985,19 @@ func TestListenBacklogFull(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
// Start listening.
listenBacklog := 2
if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
for i := 0; i < listenBacklog; i++ {
@@ -4899,7 +5030,7 @@ func TestListenBacklogFull(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4928,7 +5059,7 @@ func TestListenBacklogFull(t *testing.T) {
case <-ch:
newEP, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4942,7 +5073,7 @@ func TestListenBacklogFull(t *testing.T) {
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
+ t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
}
}
@@ -5162,19 +5293,19 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
// Start listening.
listenBacklog := 1
if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Send two SYN's the first one should get a SYN-ACK, the
@@ -5240,7 +5371,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
case <-ch:
newEP, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5254,7 +5385,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
pkt := c.GetPacket()
tcp = header.TCP(header.IPv4(pkt).Payload())
if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
+ t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
}
}
@@ -5316,7 +5447,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5450,7 +5581,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
pkt := c.GetPacket()
tcpHdr = header.TCP(header.IPv4(pkt).Payload())
if string(tcpHdr.Payload()) != data {
- t.Fatalf("Unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
+ t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
}
}
@@ -5460,20 +5591,20 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
stats := c.Stack().Stats()
@@ -5494,7 +5625,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5503,7 +5634,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
}
if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want)
}
}
@@ -5514,14 +5645,14 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
stats := c.Stack().Stats()
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
srcPort := uint16(context.TestPort)
@@ -5546,10 +5677,10 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
time.Sleep(50 * time.Millisecond)
if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want)
+ t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)
}
we, ch := waiter.NewChannelEntry(nil)
@@ -5564,7 +5695,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5579,28 +5710,28 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
- t.Errorf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrNotConnected)
+ t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected)
}
if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
- t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %v want %v", got, 1)
+ t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
@@ -5617,7 +5748,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
case <-ch:
aep, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -5625,25 +5756,25 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
}
}
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected {
- t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected)
+ t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected)
}
// Listening endpoint remains in listen state.
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
ep.Close()
// Give worker goroutines time to receive the close notification.
time.Sleep(1 * time.Second)
if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Accepted endpoint remains open when the listen endpoint is closed.
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
}
@@ -5663,13 +5794,13 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// the segment queue holding unprocessed packets is limited to 500.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Enable auto-tuning.
if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Change the expected window scale to match the value needed for the
// maximum buffer size defined above.
@@ -5784,13 +5915,13 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// the segment queue holding unprocessed packets is limited to 300.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Enable auto-tuning.
if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Change the expected window scale to match the value needed for the
// maximum buffer size used by stack.
@@ -5935,7 +6066,7 @@ func TestDelayEnabled(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil {
- t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err)
+ t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err)
}
checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption)
}
@@ -5946,7 +6077,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del
var gotDelayEnabled tcp.DelayEnabled
if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
- t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err)
+ t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err)
}
if gotDelayEnabled != wantDelayEnabled {
t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled)
@@ -5954,7 +6085,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue))
if err != nil {
- t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err)
+ t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err)
}
gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption)
if err != nil {
@@ -6515,10 +6646,10 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.TCPFlags(header.TCPFlagRst)))
if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want)
}
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
}
@@ -6715,7 +6846,7 @@ func TestTCPUserTimeout(t *testing.T) {
// Send some data and wait before ACKing it.
view := buffer.NewView(3)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -6765,11 +6896,14 @@ func TestTCPUserTimeout(t *testing.T) {
)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
}
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -6796,7 +6930,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
// Check that the connection is still alive.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Now receive 1 keepalives, but don't ACK it.
@@ -6837,10 +6971,13 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
}
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -6896,11 +7033,11 @@ func TestIncreaseWindowOnReceive(t *testing.T) {
// ack should be sent in response to that. The window was not
// zero, but it grew to larger than MSS.
if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
// After reading two packets, we surely crossed MSS. See the ack:
@@ -6997,13 +7134,13 @@ func TestTCPDeferAccept(t *testing.T) {
const tcpDeferAccept = 1 * time.Second
if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
- t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err)
}
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
}
// Send data. This should result in an acceptable endpoint.
@@ -7026,7 +7163,7 @@ func TestTCPDeferAccept(t *testing.T) {
time.Sleep(50 * time.Millisecond)
aep, _, err := c.EP.Accept()
if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
}
aep.Close()
@@ -7054,13 +7191,13 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
const tcpDeferAccept = 1 * time.Second
if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
- t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err)
}
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
}
// Sleep for a little of the tcpDeferAccept timeout.
@@ -7094,7 +7231,7 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
time.Sleep(50 * time.Millisecond)
aep, _, err := c.EP.Accept()
if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
}
aep.Close()
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 7b1d72cf4..37e7767d6 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -143,13 +143,15 @@ func New(t *testing.T, mtu uint32) *Context {
TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
})
+ const sendBufferSize = 1 << 20 // 1 MiB
+ const recvBufferSize = 1 << 20 // 1 MiB
// Allow minimum send/receive buffer sizes to be 1 during tests.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize, 10 * tcp.DefaultReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Increase minimum RTO in tests to avoid test flakes due to early
@@ -202,7 +204,7 @@ func New(t *testing.T, mtu uint32) *Context {
t: t,
s: s,
linkEP: ep,
- WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
+ WindowScale: uint8(tcp.FindWndScale(recvBufferSize)),
}
}
@@ -316,7 +318,7 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -372,7 +374,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp
// SendSegment sends a TCP segment that has already been built and written to a
// buffer.VectorisedView.
func (c *Context) SendSegment(s buffer.VectorisedView) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
Data: s,
})
}
@@ -380,7 +382,7 @@ func (c *Context) SendSegment(s buffer.VectorisedView) {
// SendPacket builds and sends a TCP segment(with the provided payload & TCP
// headers) in an IPv4 packet via the link layer endpoint.
func (c *Context) SendPacket(payload []byte, h *Headers) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
Data: c.BuildSegment(payload, h),
})
}
@@ -389,7 +391,7 @@ func (c *Context) SendPacket(payload []byte, h *Headers) {
// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
// provided source and destination IPv4 addresses.
func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
})
}
@@ -564,7 +566,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp
t.SetChecksum(^t.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
index c70525f27..7981d469b 100644
--- a/pkg/tcpip/transport/tcp/timer.go
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -85,6 +85,7 @@ func (t *timer) init(w *sleep.Waker) {
// cleanup frees all resources associated with the timer.
func (t *timer) cleanup() {
t.timer.Stop()
+ *t = timer{}
}
// checkExpiration checks if the given timer has actually expired, it should be
diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go
new file mode 100644
index 000000000..dbd6dff54
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/timer_test.go
@@ -0,0 +1,47 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+)
+
+func TestCleanup(t *testing.T) {
+ const (
+ timerDurationSeconds = 2
+ isAssertedTimeoutSeconds = timerDurationSeconds + 1
+ )
+
+ tmr := timer{}
+ w := sleep.Waker{}
+ tmr.init(&w)
+ tmr.enable(timerDurationSeconds * time.Second)
+ tmr.cleanup()
+
+ if want := (timer{}); tmr != want {
+ t.Errorf("got tmr = %+v, want = %+v", tmr, want)
+ }
+
+ // The waker should not be asserted.
+ for i := 0; i < isAssertedTimeoutSeconds; i++ {
+ time.Sleep(time.Second)
+ if w.IsAsserted() {
+ t.Fatalf("waker asserted unexpectedly")
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
index 12bc1b5b5..558b06df0 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result {
return st
}
+// State returns the current state of the TCB.
+func (t *TCB) State() Result {
+ return t.state
+}
+
// IsAlive returns true as long as the connection is established(Alive)
// or connecting state.
func (t *TCB) IsAlive() bool {
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 647b2067a..b7d735889 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,6 +15,9 @@
package udp
import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -93,6 +96,7 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
sndBufSize int
+ sndBufSizeMax int
state EndpointState
route stack.Route `state:"manual"`
dstPort uint16
@@ -102,9 +106,10 @@ type endpoint struct {
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
multicastLoop bool
- reusePort bool
+ portFlags ports.Flags
bindToDevice tcpip.NICID
broadcast bool
+ noChecksum bool
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -158,7 +163,7 @@ type multicastMembership struct {
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
- return &endpoint{
+ e := &endpoint{
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
@@ -180,10 +185,23 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
multicastTTL: 1,
multicastLoop: true,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
state: StateInitial,
uniqueID: s.UniqueID(),
}
+
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ e.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ e.rcvBufSizeMax = rs.Default
+ }
+
+ return e
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -213,8 +231,8 @@ func (e *endpoint) Close() {
switch e.state {
case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
}
@@ -247,11 +265,6 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (stack.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -430,24 +443,33 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
var route *stack.Route
+ var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error)
var dstPort uint16
if to == nil {
route = &e.route
dstPort = e.dstPort
-
- if route.IsResolutionRequired() {
- // Promote lock to exclusive if using a shared route, given that it may need to
- // change in Route.Resolve() call below.
+ resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) {
+ // Promote lock to exclusive if using a shared route, given that it may
+ // need to change in Route.Resolve() call below.
e.mu.RUnlock()
- defer e.mu.RLock()
-
e.mu.Lock()
- defer e.mu.Unlock()
// Recheck state after lock was re-acquired.
if e.state != StateConnected {
- return 0, nil, tcpip.ErrInvalidEndpointState
+ err = tcpip.ErrInvalidEndpointState
+ }
+ if err == nil && route.IsResolutionRequired() {
+ ch, err = route.Resolve(waker)
}
+
+ e.mu.Unlock()
+ e.mu.RLock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != StateConnected {
+ err = tcpip.ErrInvalidEndpointState
+ }
+ return
}
} else {
// Reject destination address if it goes through a different
@@ -461,10 +483,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
nicID = e.BindNICID
}
- if to.Addr == header.IPv4Broadcast && !e.broadcast {
- return 0, nil, tcpip.ErrBroadcastDisabled
- }
-
dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
@@ -478,10 +496,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
route = &r
dstPort = dst.Port
+ resolve = route.Resolve
+ }
+
+ if !e.broadcast && route.IsBroadcast() {
+ return 0, nil, tcpip.ErrBroadcastDisabled
}
if route.IsResolutionRequired() {
- if ch, err := route.Resolve(nil); err != nil {
+ if ch, err := resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
return 0, ch, tcpip.ErrNoLinkAddress
}
@@ -507,7 +530,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner); err != nil {
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
@@ -531,6 +554,11 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.multicastLoop = v
e.mu.Unlock()
+ case tcpip.NoChecksumOption:
+ e.mu.Lock()
+ e.noChecksum = v
+ e.mu.Unlock()
+
case tcpip.ReceiveTOSOption:
e.mu.Lock()
e.receiveTOS = v
@@ -552,10 +580,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.mu.Unlock()
case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.portFlags.MostRecent = v
+ e.mu.Unlock()
case tcpip.ReusePortOption:
e.mu.Lock()
- e.reusePort = v
+ e.portFlags.LoadBalanced = v
e.mu.Unlock()
case tcpip.V6OnlyOption:
@@ -581,6 +612,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if the value is not disabling path
+ // MTU discovery.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
e.multicastTTL = uint8(v)
@@ -602,8 +640,43 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.mu.Unlock()
case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := e.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err))
+ }
+
+ if v < rs.Min {
+ v = rs.Min
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+
+ e.mu.Lock()
+ e.rcvBufSizeMax = v
+ e.mu.Unlock()
+ return nil
case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := e.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err))
+ }
+
+ if v < ss.Min {
+ v = ss.Min
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ e.mu.Lock()
+ e.sndBufSizeMax = v
+ e.mu.Unlock()
+ return nil
}
return nil
@@ -743,6 +816,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Lock()
e.bindToDevice = id
e.mu.Unlock()
+
+ case tcpip.SocketDetachFilterOption:
+ return nil
}
return nil
}
@@ -765,6 +841,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.NoChecksumOption:
+ e.mu.RLock()
+ v := e.noChecksum
+ e.mu.RUnlock()
+ return v, nil
+
case tcpip.ReceiveTOSOption:
e.mu.RLock()
v := e.receiveTOS
@@ -789,11 +871,15 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return v, nil
case tcpip.ReuseAddressOption:
- return false, nil
+ e.mu.RLock()
+ v := e.portFlags.MostRecent
+ e.mu.RUnlock()
+
+ return v, nil
case tcpip.ReusePortOption:
e.mu.RLock()
- v := e.reusePort
+ v := e.portFlags.LoadBalanced
e.mu.RUnlock()
return v, nil
@@ -830,6 +916,10 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.MTUDiscoverOption:
+ // The only supported setting is path MTU discovery disabled.
+ return tcpip.PMTUDiscoveryDont, nil
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
v := int(e.multicastTTL)
@@ -848,7 +938,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
case tcpip.SendBufferSizeOption:
e.mu.Lock()
- v := e.sndBufSize
+ v := e.sndBufSizeMax
e.mu.Unlock()
return v, nil
@@ -895,7 +985,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
@@ -909,8 +999,12 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
Length: length,
})
- // Only calculate the checksum if offloading isn't supported.
- if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ // Set the checksum field unless TX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value indicates the
+ // transmitter skipped the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 &&
+ (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) {
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
for _, v := range data.Views() {
xsum = header.Checksum(v, xsum)
@@ -921,7 +1015,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
if useDefaultTTL {
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, stack.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: ProtocolNumber,
+ TTL: ttl,
+ TOS: tos,
+ }, &stack.PacketBuffer{
Header: hdr,
Data: data,
TransportHeader: buffer.View(udp),
@@ -958,6 +1056,11 @@ func (e *endpoint) Disconnect() *tcpip.Error {
id stack.TransportEndpointID
btd tcpip.NICID
)
+
+ // We change this value below and we need the old value to unregister
+ // the endpoint.
+ boundPortFlags := e.boundPortFlags
+
// Exclude ephemerally bound endpoints.
if e.BindNICID != 0 || e.ID.LocalAddress == "" {
var err *tcpip.Error
@@ -970,16 +1073,17 @@ func (e *endpoint) Disconnect() *tcpip.Error {
return err
}
e.state = StateBound
+ boundPortFlags = e.boundPortFlags
} else {
if e.ID.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
e.ID = id
e.boundBindToDevice = btd
e.route.Release()
@@ -1051,6 +1155,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
}
+ oldPortFlags := e.boundPortFlags
+
id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
@@ -1058,7 +1164,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Remove the old registration.
if e.ID.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
}
e.ID = id
@@ -1122,22 +1228,17 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
if e.ID.LocalPort == 0 {
- flags := ports.Flags{
- LoadBalanced: e.reusePort,
- // FIXME(b/129164367): Support SO_REUSEADDR.
- MostRecent: false,
- }
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
if err != nil {
return id, e.bindToDevice, err
}
- e.boundPortFlags = flags
id.LocalPort = port
}
+ e.boundPortFlags = e.portFlags
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
return id, e.bindToDevice, err
@@ -1269,22 +1370,47 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Get the header then trim it from the view.
- hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize)
- if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() {
+ hdr := header.UDP(pkt.TransportHeader)
+ if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
- pkt.Data.TrimFront(header.UDPMinimumSize)
+ // Never receive from a multicast address.
+ if header.IsV4MulticastAddress(id.RemoteAddress) ||
+ header.IsV6MulticastAddress(id.RemoteAddress) {
+ e.stack.Stats().UDP.InvalidSourceAddress.Increment()
+ e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ return
+ }
+
+ // Verify checksum unless RX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value means
+ // the transmitter omitted the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 &&
+ (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length())
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ if hdr.CalculateChecksum(xsum) != 0xffff {
+ // Checksum Error.
+ e.stack.Stats().UDP.ChecksumErrors.Increment()
+ e.stats.ReceiveErrors.ChecksumErrors.Increment()
+ return
+ }
+ }
- e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
e.stats.PacketsReceived.Increment()
+ e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
if !e.rcvReady || e.rcvClosed {
e.rcvMu.Unlock()
@@ -1325,7 +1451,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
}
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
@@ -1336,7 +1462,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
e.mu.RLock()
defer e.mu.RUnlock()
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index a674ceb68..c67e0ba95 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -43,7 +43,7 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
f.handler(&ForwarderRequest{
stack: f.stack,
route: r,
@@ -61,7 +61,7 @@ type ForwarderRequest struct {
stack *stack.Stack
route *stack.Route
id stack.TransportEndpointID
- pkt stack.PacketBuffer
+ pkt *stack.PacketBuffer
}
// ID returns the 4-tuple (src address, src port, dst address, dst port) that
@@ -73,7 +73,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
ep.Close()
return nil, err
}
@@ -82,6 +82,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.route = r.route.Clone()
ep.dstPort = r.id.RemotePort
ep.RegisterNICID = r.route.NICID()
+ ep.boundPortFlags = ep.portFlags
ep.state = StateConnected
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 52af6de22..0e7464e3a 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -32,9 +32,24 @@ import (
const (
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
+
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ MinBufferSize = 4 << 10 // 4KiB bytes.
+
+ // DefaultSendBufferSize is the default size of the send buffer for
+ // an endpoint.
+ DefaultSendBufferSize = 32 << 10 // 32KiB
+
+ // DefaultReceiveBufferSize is the default size of the receive buffer
+ // for an endpoint.
+ DefaultReceiveBufferSize = 32 << 10 // 32KiB
+
+ // MaxBufferSize is the largest size a receive/send buffer can grow to.
+ MaxBufferSize = 4 << 20 // 4MiB
)
-type protocol struct{}
+type protocol struct {
+}
// Number returns the udp protocol number.
func (*protocol) Number() tcpip.TransportProtocolNumber {
@@ -66,15 +81,9 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
- // Get the header then trim it from the view.
- h, ok := pkt.Data.PullUp(header.UDPMinimumSize)
- if !ok {
- // Malformed packet.
- r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
- return true
- }
- if int(header.UDP(h).Length()) > pkt.Data.Size() {
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+ hdr := header.UDP(pkt.TransportHeader)
+ if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
return true
@@ -121,7 +130,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
}
headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
available := int(mtu) - headerLen
- payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size()
+ payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
@@ -130,9 +139,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
// For example, a raw or packet socket may use what UDP
// considers an unreachable destination. Thus we deep copy pkt
// to prevent multiple ownership and SR errors.
- newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...)
- payload := newNetHeader.ToVectorisedView()
- payload.Append(pkt.Data.ToView().ToVectorisedView())
+ newHeader := append(buffer.View(nil), pkt.NetworkHeader...)
+ newHeader = append(newHeader, pkt.TransportHeader...)
+ payload := newHeader.ToVectorisedView()
+ payload.AppendView(pkt.Data.ToView())
payload.CapLength(payloadLen)
hdr := buffer.NewPrependable(headerLen)
@@ -140,9 +150,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv4DstUnreachable)
pkt.SetCode(header.ICMPv4PortUnreachable)
pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: payload,
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
+ Header: hdr,
+ TransportHeader: buffer.View(pkt),
+ Data: payload,
})
case header.IPv6AddressSize:
@@ -164,11 +175,11 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
}
headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
available := int(mtu) - headerLen
- payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size()
+ payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
- payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader})
+ payload := buffer.NewVectorisedView(len(pkt.NetworkHeader)+len(pkt.TransportHeader), []buffer.View{pkt.NetworkHeader, pkt.TransportHeader})
payload.Append(pkt.Data)
payload.CapLength(payloadLen)
@@ -177,21 +188,22 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv6DstUnreachable)
pkt.SetCode(header.ICMPv6PortUnreachable)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: payload,
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
+ Header: hdr,
+ TransportHeader: buffer.View(pkt),
+ Data: payload,
})
}
return true
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (*protocol) SetOption(option interface{}) *tcpip.Error {
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements stack.TransportProtocol.Option.
-func (*protocol) Option(option interface{}) *tcpip.Error {
+func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
@@ -201,6 +213,18 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ h, ok := pkt.Data.PullUp(header.UDPMinimumSize)
+ if !ok {
+ // Packet is too small
+ return false
+ }
+ pkt.TransportHeader = h
+ pkt.Data.TrimFront(header.UDPMinimumSize)
+ return true
+}
+
// NewProtocol returns a UDP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 8acaa607a..66e8911c8 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -83,16 +83,18 @@ type header4Tuple struct {
type testFlow int
const (
- unicastV4 testFlow = iota // V4 unicast on a V4 socket
- unicastV4in6 // V4-mapped unicast on a V6-dual socket
- unicastV6 // V6 unicast on a V6 socket
- unicastV6Only // V6 unicast on a V6-only socket
- multicastV4 // V4 multicast on a V4 socket
- multicastV4in6 // V4-mapped multicast on a V6-dual socket
- multicastV6 // V6 multicast on a V6 socket
- multicastV6Only // V6 multicast on a V6-only socket
- broadcast // V4 broadcast on a V4 socket
- broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ unicastV4 testFlow = iota // V4 unicast on a V4 socket
+ unicastV4in6 // V4-mapped unicast on a V6-dual socket
+ unicastV6 // V6 unicast on a V6 socket
+ unicastV6Only // V6 unicast on a V6-only socket
+ multicastV4 // V4 multicast on a V4 socket
+ multicastV4in6 // V4-mapped multicast on a V6-dual socket
+ multicastV6 // V6 multicast on a V6 socket
+ multicastV6Only // V6 multicast on a V6-only socket
+ broadcast // V4 broadcast on a V4 socket
+ broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ reverseMulticast4 // V4 multicast src. Must fail.
+ reverseMulticast6 // V6 multicast src. Must fail.
)
func (flow testFlow) String() string {
@@ -117,6 +119,10 @@ func (flow testFlow) String() string {
return "broadcast"
case broadcastIn6:
return "broadcastIn6"
+ case reverseMulticast4:
+ return "reverseMulticast4"
+ case reverseMulticast6:
+ return "reverseMulticast6"
default:
return "unknown"
}
@@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
h.dstAddr.Addr = multicastV6Addr
}
}
+ if flow.isReverseMulticast() {
+ h.srcAddr.Addr = flow.getMcastAddr()
+ }
return h
}
@@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
// endpoint for this flow.
func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
switch flow {
- case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
+ case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6:
return ipv6.ProtocolNumber
- case unicastV4, multicastV4, broadcast:
+ case unicastV4, multicastV4, broadcast, reverseMulticast4:
return ipv4.ProtocolNumber
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool {
switch flow {
case unicastV6Only, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool {
switch flow {
case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool {
switch flow {
case broadcast, broadcastIn6:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool {
switch flow {
case unicastV4in6, multicastV4in6, broadcastIn6:
return true
- case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
+ case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
}
}
+func (flow testFlow) isReverseMulticast() bool {
+ switch flow {
+ case reverseMulticast4, reverseMulticast6:
+ return true
+ default:
+ return false
+ }
+}
+
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
@@ -292,15 +310,15 @@ func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Optio
wep = sniffer.New(ep)
}
if err := s.CreateNIC(1, wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatalf("CreateNIC failed: %s", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatalf("AddAddress failed: %s", err)
}
if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatalf("AddAddress failed: %s", err)
}
s.SetRouteTable([]tcpip.Route{
@@ -391,17 +409,21 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) {
h := flow.header4Tuple(incoming)
if flow.isV4() {
- c.injectV4Packet(payload, &h, true /* valid */)
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
} else {
- c.injectV6Packet(payload, &h, true /* valid */)
+ buf := c.buildV6Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
}
}
-// injectV6Packet creates a V6 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
+// buildV6Packet creates a V6 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
payloadStart := len(buf) - len(payload)
@@ -420,16 +442,10 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
- l := uint16(header.UDPMinimumSize + len(payload))
- if !valid {
- // Change the UDP payload length to corrupt the header
- // as requested by the caller.
- l++
- }
u.Encode(&header.UDPFields{
SrcPort: h.srcAddr.Port,
DstPort: h.dstAddr.Port,
- Length: l,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
@@ -439,19 +455,12 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
- // Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- NetworkHeader: buffer.View(ip),
- TransportHeader: buffer.View(u),
- })
+ return buf
}
-// injectV4Packet creates a V4 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
+// buildV4Packet creates a V4 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
payloadStart := len(buf) - len(payload)
@@ -485,13 +494,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
- // Inject packet.
-
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- NetworkHeader: buffer.View(ip),
- TransportHeader: buffer.View(u),
- })
+ return buf
}
func newPayload() []byte {
@@ -513,7 +516,7 @@ func TestBindToDeviceOption(t *testing.T) {
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
@@ -647,7 +650,7 @@ func TestBindEphemeralPort(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}
@@ -658,19 +661,19 @@ func TestBindReservedPort(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
addr, err := c.ep.GetLocalAddress()
if err != nil {
- t.Fatalf("GetLocalAddress failed: %v", err)
+ t.Fatalf("GetLocalAddress failed: %s", err)
}
// We can't bind the address reserved by the connected endpoint above.
{
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want {
@@ -681,7 +684,7 @@ func TestBindReservedPort(t *testing.T) {
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// We can't bind ipv4-any on the port reserved by the connected endpoint
@@ -691,7 +694,7 @@ func TestBindReservedPort(t *testing.T) {
}
// We can bind an ipv4 address on this port, though.
if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}()
@@ -701,11 +704,11 @@ func TestBindReservedPort(t *testing.T) {
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}()
}
@@ -718,7 +721,7 @@ func TestV4ReadOnV6(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -733,7 +736,7 @@ func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
// Bind to v4 mapped wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -748,7 +751,7 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
// Bind to local address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -763,7 +766,7 @@ func TestV6ReadOnV6(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -800,7 +803,10 @@ func TestV4ReadSelfSource(t *testing.T) {
h := unicastV4.header4Tuple(incoming)
h.srcAddr = h.dstAddr
- c.injectV4Packet(payload, &h, true /* valid */)
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
@@ -821,7 +827,7 @@ func TestV4ReadOnV4(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -884,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) {
}
}
+// TestReadFromMulticast checks that an endpoint will NOT receive a packet
+// that was sent with multicast SOURCE address.
+func TestReadFromMulticast(t *testing.T) {
+ for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ testFailingRead(c, flow, false /* expectReadError */)
+ })
+ }
+}
+
+// TestReadFromMulticaststats checks that a discarded packet
+// that that was sent with multicast SOURCE address increments
+// the correct counters and that a regular packet does not.
+func TestReadFromMulticastStats(t *testing.T) {
+ t.Helper()
+ for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ c.injectPacket(flow, payload)
+
+ var want uint64 = 0
+ if flow.isReverseMulticast() {
+ want = 1
+ }
+ if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want {
+ t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
+ }
+ if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want {
+ t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
+ }
+ })
+ }
+}
+
// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
// and receive broadcast and unicast data.
func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
@@ -959,7 +1019,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
payload := buffer.View(newPayload())
n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+ c.t.Fatalf("Write failed: %s", err)
}
if n != int64(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
@@ -1009,7 +1069,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
p := testDualWrite(c)
@@ -1026,7 +1086,7 @@ func TestDualWriteConnectedToV6(t *testing.T) {
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testWrite(c, unicastV6)
@@ -1047,7 +1107,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) {
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testWrite(c, unicastV4in6)
@@ -1074,7 +1134,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
// Bind to v4 mapped address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Write to v6 address.
@@ -1089,7 +1149,7 @@ func TestV6WriteOnConnected(t *testing.T) {
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
testWriteWithoutDestination(c, unicastV6)
@@ -1103,7 +1163,7 @@ func TestV4WriteOnConnected(t *testing.T) {
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
testWriteWithoutDestination(c, unicastV4)
@@ -1238,7 +1298,7 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testRead(c, unicastV4)
@@ -1263,6 +1323,30 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
}
}
+func TestNoChecksum(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Disable the checksum generation.
+ if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil {
+ t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ // This option is effective on IPv4 only.
+ testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
+
+ // Enable the checksum generation.
+ if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil {
+ t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
+ })
+ }
+}
+
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1510,12 +1594,12 @@ func TestMulticastInterfaceOption(t *testing.T) {
Port: stackPort,
}
if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
}
if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ c.t.Fatalf("SetSockOpt failed: %s", err)
}
// Verify multicast interface addr and NIC were set correctly.
@@ -1523,7 +1607,7 @@ func TestMulticastInterfaceOption(t *testing.T) {
ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
var ifoptGot tcpip.MulticastInterfaceOption
if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt failed: %v", err)
+ c.t.Fatalf("GetSockOpt failed: %s", err)
}
if ifoptGot != ifoptWant {
c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
@@ -1695,7 +1779,7 @@ func TestV6UnknownDestination(t *testing.T) {
}
// TestIncrementMalformedPacketsReceived verifies if the malformed received
-// global and endpoint stats get incremented.
+// global and endpoint stats are incremented.
func TestIncrementMalformedPacketsReceived(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1703,20 +1787,271 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
payload := newPayload()
- c.t.Helper()
h := unicastV6.header4Tuple(incoming)
- c.injectV6Packet(payload, &h, false /* !valid */)
+ buf := c.buildV6Packet(payload, &h)
- var want uint64 = 1
+ // Invalidate the UDP header length field.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetLength(u.Length() + 1)
+
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
}
if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
+ }
+}
+
+// TestShortHeader verifies that when a packet with a too-short UDP header is
+// received, the malformed received global stat gets incremented.
+func TestShortHeader(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ h := unicastV6.header4Tuple(incoming)
+
+ // Allocate a buffer for an IPv6 and too-short UDP header.
+ const udpSize = header.UDPMinimumSize - 1
+ buf := buffer.NewView(header.IPv6MinimumSize + udpSize)
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ TrafficClass: testTOS,
+ PayloadLength: uint16(udpSize),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
+ })
+
+ // Initialize the UDP header.
+ udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize))
+ udpHdr.Encode(&header.UDPFields{
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
+ Length: header.UDPMinimumSize,
+ })
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr)))
+ udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
+ // Copy all but the last byte of the UDP header into the packet.
+ copy(buf[header.IPv6MinimumSize:], udpHdr)
+
+ // Inject packet.
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want {
+ t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want)
+ }
+}
+
+// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestIncrementChecksumErrorsV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+
+ // Invalidate the UDP header checksum field, taking care to avoid
+ // overflow to zero, which would disable checksum validation.
+ for u := header.UDP(buf[header.IPv4MinimumSize:]); ; {
+ u.SetChecksum(u.Checksum() + 1)
+ if u.Checksum() != 0 {
+ break
+ }
+ }
+
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestIncrementChecksumErrorsV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+
+ // Invalidate the UDP header checksum field.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(u.Checksum() + 1)
+
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV4 verifies if the checksum value is zero, global and
+// endpoint states are *not* incremented (UDP checksum is optional on IPv4).
+func TestChecksumZeroV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 0
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV6 verifies if the checksum value is zero, global and
+// endpoint states are incremented (UDP checksum is *not* optional on IPv6).
+func TestChecksumZeroV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
}
}
@@ -1730,15 +2065,15 @@ func TestShutdownRead(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
testFailingRead(c, unicastV6, true /* expectReadError */)
@@ -1761,11 +2096,11 @@ func TestShutdownWrite(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)
@@ -1807,3 +2142,192 @@ func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEn
c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
}
}
+
+func TestOutgoingSubnetBroadcast(t *testing.T) {
+ const nicID1 = 1
+
+ ipv4Addr := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 24,
+ }
+ ipv4Subnet := ipv4Addr.Subnet()
+ ipv4SubnetBcast := ipv4Subnet.Broadcast()
+ ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 31,
+ }
+ ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
+ ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
+ ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 32,
+ }
+ ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
+ ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
+ ipv6Addr := tcpip.AddressWithPrefix{
+ Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ PrefixLen: 64,
+ }
+ ipv6Subnet := ipv6Addr.Subnet()
+ ipv6SubnetBcast := ipv6Subnet.Broadcast()
+ remNetAddr := tcpip.AddressWithPrefix{
+ Address: "\x64\x0a\x7b\x18",
+ PrefixLen: 24,
+ }
+ remNetSubnet := remNetAddr.Subnet()
+ remNetSubnetBcast := remNetSubnet.Broadcast()
+
+ tests := []struct {
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ requiresBroadcastOpt bool
+ }{
+ {
+ name: "IPv4 Broadcast to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4SubnetBcast,
+ requiresBroadcastOpt: true,
+ },
+ {
+ name: "IPv4 Broadcast to local /31 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix31,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet31,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet31Bcast,
+ requiresBroadcastOpt: false,
+ },
+ {
+ name: "IPv4 Broadcast to local /32 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix32,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet32,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet32Bcast,
+ requiresBroadcastOpt: false,
+ },
+ // IPv6 has no notion of a broadcast.
+ {
+ name: "IPv6 'Broadcast' to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: ipv6Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv6Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv6SubnetBcast,
+ requiresBroadcastOpt: false,
+ },
+ {
+ name: "IPv4 Broadcast to remote subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: remNetSubnet,
+ Gateway: ipv4Gateway,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: remNetSubnetBcast,
+ requiresBroadcastOpt: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ }
+
+ s.SetRouteTable(test.routes)
+
+ var netProto tcpip.NetworkProtocolNumber
+ switch l := len(test.remoteAddr); l {
+ case header.IPv4AddressSize:
+ netProto = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ netProto = header.IPv6ProtocolNumber
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err)
+ }
+ defer ep.Close()
+
+ data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ to := tcpip.FullAddress{
+ Addr: test.remoteAddr,
+ Port: 80,
+ }
+ opts := tcpip.WriteOptions{To: &to}
+ expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled
+ if !test.requiresBroadcastOpt {
+ expectedErrWithoutBcastOpt = nil
+ }
+
+ if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err)
+ }
+
+ if n, _, err := ep.Write(data, opts); err != nil {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil {
+ t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err)
+ }
+
+ if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
+ }
+ })
+ }
+}
diff --git a/pkg/test/criutil/criutil.go b/pkg/test/criutil/criutil.go
index bebebb48e..70945f234 100644
--- a/pkg/test/criutil/criutil.go
+++ b/pkg/test/criutil/criutil.go
@@ -22,6 +22,9 @@ import (
"fmt"
"os"
"os/exec"
+ "path"
+ "regexp"
+ "strconv"
"strings"
"time"
@@ -33,28 +36,44 @@ import (
type Crictl struct {
logger testutil.Logger
endpoint string
+ runpArgs []string
cleanup []func()
}
-// resolvePath attempts to find binary paths. It may set the path to invalid,
+// ResolvePath attempts to find binary paths. It may set the path to invalid,
// which will cause the execution to fail with a sensible error.
-func resolvePath(executable string) string {
+func ResolvePath(executable string) string {
+ runtime, err := dockerutil.RuntimePath()
+ if err == nil {
+ // Check first the directory of the runtime itself.
+ if dir := path.Dir(runtime); dir != "" && dir != "." {
+ guess := path.Join(dir, executable)
+ if fi, err := os.Stat(guess); err == nil && (fi.Mode()&0111) != 0 {
+ return guess
+ }
+ }
+ }
+
+ // Try to find via the path.
guess, err := exec.LookPath(executable)
- if err != nil {
- guess = fmt.Sprintf("/usr/local/bin/%s", executable)
+ if err == nil {
+ return guess
}
- return guess
+
+ // Return a default path.
+ return fmt.Sprintf("/usr/local/bin/%s", executable)
}
// NewCrictl returns a Crictl configured with a timeout and an endpoint over
// which it will talk to containerd.
-func NewCrictl(logger testutil.Logger, endpoint string) *Crictl {
+func NewCrictl(logger testutil.Logger, endpoint string, runpArgs []string) *Crictl {
// Attempt to find the executable, but don't bother propagating the
// error at this point. The first command executed will return with a
// binary not found error.
return &Crictl{
logger: logger,
endpoint: endpoint,
+ runpArgs: runpArgs,
}
}
@@ -67,8 +86,8 @@ func (cc *Crictl) CleanUp() {
}
// RunPod creates a sandbox. It corresponds to `crictl runp`.
-func (cc *Crictl) RunPod(sbSpecFile string) (string, error) {
- podID, err := cc.run("runp", sbSpecFile)
+func (cc *Crictl) RunPod(runtime, sbSpecFile string) (string, error) {
+ podID, err := cc.run("runp", "--runtime", runtime, sbSpecFile)
if err != nil {
return "", fmt.Errorf("runp failed: %v", err)
}
@@ -79,10 +98,42 @@ func (cc *Crictl) RunPod(sbSpecFile string) (string, error) {
// Create creates a container within a sandbox. It corresponds to `crictl
// create`.
func (cc *Crictl) Create(podID, contSpecFile, sbSpecFile string) (string, error) {
- podID, err := cc.run("create", podID, contSpecFile, sbSpecFile)
+ // In version 1.16.0, crictl annoying starting attempting to pull the
+ // container, even if it was already available locally. We therefore
+ // need to parse the version and add an appropriate --no-pull argument
+ // since the image has already been loaded locally.
+ out, err := cc.run("-v")
+ if err != nil {
+ return "", err
+ }
+ r := regexp.MustCompile("crictl version ([0-9]+)\\.([0-9]+)\\.([0-9+])")
+ vs := r.FindStringSubmatch(out)
+ if len(vs) != 4 {
+ return "", fmt.Errorf("crictl -v had unexpected output: %s", out)
+ }
+ major, err := strconv.ParseUint(vs[1], 10, 64)
+ if err != nil {
+ return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out)
+ }
+ minor, err := strconv.ParseUint(vs[2], 10, 64)
if err != nil {
+ return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out)
+ }
+
+ args := []string{"create"}
+ if (major == 1 && minor >= 16) || major > 1 {
+ args = append(args, "--no-pull")
+ }
+ args = append(args, podID)
+ args = append(args, contSpecFile)
+ args = append(args, sbSpecFile)
+
+ podID, err = cc.run(args...)
+ if err != nil {
+ time.Sleep(10 * time.Minute) // XXX
return "", fmt.Errorf("create failed: %v", err)
}
+
// Strip the trailing newline from crictl output.
return strings.TrimSpace(podID), nil
}
@@ -113,6 +164,17 @@ func (cc *Crictl) Exec(contID string, args ...string) (string, error) {
return output, nil
}
+// Logs retrieves the container logs. It corresponds to `crictl logs`.
+func (cc *Crictl) Logs(contID string, args ...string) (string, error) {
+ a := []string{"logs", contID}
+ a = append(a, args...)
+ output, err := cc.run(a...)
+ if err != nil {
+ return "", fmt.Errorf("logs failed: %v", err)
+ }
+ return output, nil
+}
+
// Rm removes a container. It corresponds to `crictl rm`.
func (cc *Crictl) Rm(contID string) error {
_, err := cc.run("rm", contID)
@@ -168,7 +230,7 @@ func (cc *Crictl) Import(image string) error {
// be pushing a lot of bytes in order to import the image. The connect
// timeout stays the same and is inherited from the Crictl instance.
cmd := testutil.Command(cc.logger,
- resolvePath("ctr"),
+ ResolvePath("ctr"),
fmt.Sprintf("--connect-timeout=%s", 30*time.Second),
fmt.Sprintf("--address=%s", cc.endpoint),
"-n", "k8s.io", "images", "import", "-")
@@ -249,7 +311,7 @@ func (cc *Crictl) StopContainer(contID string) error {
// StartPodAndContainer starts a sandbox and container in that sandbox. It
// returns the pod ID and container ID.
-func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, string, error) {
+func (cc *Crictl) StartPodAndContainer(runtime, image, sbSpec, contSpec string) (string, string, error) {
if err := cc.Import(image); err != nil {
return "", "", err
}
@@ -266,7 +328,7 @@ func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string,
}
cc.cleanup = append(cc.cleanup, cleanup)
- podID, err := cc.RunPod(sbSpecFile)
+ podID, err := cc.RunPod(runtime, sbSpecFile)
if err != nil {
return "", "", err
}
@@ -296,7 +358,7 @@ func (cc *Crictl) StopPodAndContainer(podID, contID string) error {
// run runs crictl with the given args.
func (cc *Crictl) run(args ...string) (string, error) {
defaultArgs := []string{
- resolvePath("crictl"),
+ ResolvePath("crictl"),
"--image-endpoint", fmt.Sprintf("unix://%s", cc.endpoint),
"--runtime-endpoint", fmt.Sprintf("unix://%s", cc.endpoint),
}
diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD
index 7c8758e35..a5e84658a 100644
--- a/pkg/test/dockerutil/BUILD
+++ b/pkg/test/dockerutil/BUILD
@@ -1,14 +1,42 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "dockerutil",
testonly = 1,
- srcs = ["dockerutil.go"],
+ srcs = [
+ "container.go",
+ "dockerutil.go",
+ "exec.go",
+ "network.go",
+ "profile.go",
+ ],
visibility = ["//:sandbox"],
deps = [
"//pkg/test/testutil",
- "@com_github_kr_pty//:go_default_library",
+ "@com_github_docker_docker//api/types:go_default_library",
+ "@com_github_docker_docker//api/types/container:go_default_library",
+ "@com_github_docker_docker//api/types/mount:go_default_library",
+ "@com_github_docker_docker//api/types/network:go_default_library",
+ "@com_github_docker_docker//client:go_default_library",
+ "@com_github_docker_docker//pkg/stdcopy:go_default_library",
+ "@com_github_docker_go_connections//nat:go_default_library",
+ ],
+)
+
+go_test(
+ name = "profile_test",
+ size = "large",
+ srcs = [
+ "profile_test.go",
+ ],
+ library = ":dockerutil",
+ tags = [
+ # Requires docker and runsc to be configured before test runs.
+ # Also requires the test to be run as root.
+ "manual",
+ "local",
],
+ visibility = ["//:sandbox"],
)
diff --git a/pkg/test/dockerutil/README.md b/pkg/test/dockerutil/README.md
new file mode 100644
index 000000000..870292096
--- /dev/null
+++ b/pkg/test/dockerutil/README.md
@@ -0,0 +1,86 @@
+# dockerutil
+
+This package is for creating and controlling docker containers for testing
+runsc, gVisor's docker/kubernetes binary. A simple test may look like:
+
+```
+ func TestSuperCool(t *testing.T) {
+ ctx := context.Background()
+ c := dockerutil.MakeContainer(ctx, t)
+ got, err := c.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine"
+ }, "echo", "super cool")
+ if err != nil {
+ t.Fatalf("err was not nil: %v", err)
+ }
+ want := "super cool"
+ if !strings.Contains(got, want){
+ t.Fatalf("want: %s, got: %s", want, got)
+ }
+ }
+```
+
+For further examples, see many of our end to end tests elsewhere in the repo,
+such as those in //test/e2e or benchmarks at //test/benchmarks.
+
+dockerutil uses the "official" docker golang api, which is
+[very powerful](https://godoc.org/github.com/docker/docker/client). dockerutil
+is a thin wrapper around this API, allowing desired new use cases to be easily
+implemented.
+
+## Profiling
+
+dockerutil is capable of generating profiles. Currently, the only option is to
+use pprof profiles generated by `runsc debug`. The profiler will generate Block,
+CPU, Heap, Goroutine, and Mutex profiles. To generate profiles:
+
+* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc
+ ARGS="--profile"` Also add other flags with ARGS like `--platform=kvm` or
+ `--vfs2`.
+* Restart docker: `sudo service docker restart`
+
+To run and generate CPU profiles run:
+
+```
+make sudo TARGETS=//path/to:target \
+ ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt"
+```
+
+Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof`
+
+Container name in most tests and benchmarks in gVisor is usually the test name
+and some random characters like so:
+`BenchmarkABSL-CleanCache-JF2J2ZYF3U7SL47QAA727CSJI3C4ZAW2`
+
+Profiling requires root as runsc debug inspects running containers in /var/run
+among other things.
+
+### Writing for Profiling
+
+The below shows an example of using profiles with dockerutil.
+
+```
+func TestSuperCool(t *testing.T){
+ ctx := context.Background()
+ // profiled and using runtime from dockerutil.runtime flag
+ profiled := MakeContainer()
+
+ // not profiled and using runtime runc
+ native := MakeNativeContainer()
+
+ err := profiled.Spawn(ctx, RunOpts{
+ Image: "some/image",
+ }, "sleep", "100000")
+ // profiling has begun here
+ ...
+ expensive setup that I don't want to profile.
+ ...
+ profiled.RestartProfiles()
+ // profiled activity
+}
+```
+
+In the above example, `profiled` would be profiled and `native` would not. The
+call to `RestartProfiles()` restarts the clock on profiling. This is useful if
+the main activity being tested is done with `docker exec` or `container.Spawn()`
+followed by one or more `container.Exec()` calls.
diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go
new file mode 100644
index 000000000..5a2157951
--- /dev/null
+++ b/pkg/test/dockerutil/container.go
@@ -0,0 +1,558 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "os"
+ "path"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/docker/docker/api/types"
+ "github.com/docker/docker/api/types/container"
+ "github.com/docker/docker/api/types/mount"
+ "github.com/docker/docker/api/types/network"
+ "github.com/docker/docker/client"
+ "github.com/docker/docker/pkg/stdcopy"
+ "github.com/docker/go-connections/nat"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// Container represents a Docker Container allowing
+// user to configure and control as one would with the 'docker'
+// client. Container is backed by the offical golang docker API.
+// See: https://pkg.go.dev/github.com/docker/docker.
+type Container struct {
+ Name string
+ runtime string
+
+ logger testutil.Logger
+ client *client.Client
+ id string
+ mounts []mount.Mount
+ links []string
+ copyErr error
+ cleanups []func()
+
+ // Profiles are profiles added to this container. They contain methods
+ // that are run after Creation, Start, and Cleanup of this Container, along
+ // a handle to restart the profile. Generally, tests/benchmarks using
+ // profiles need to run as root.
+ profiles []Profile
+
+ // Stores streams attached to the container. Used by WaitForOutputSubmatch.
+ streams types.HijackedResponse
+
+ // stores previously read data from the attached streams.
+ streamBuf bytes.Buffer
+}
+
+// RunOpts are options for running a container.
+type RunOpts struct {
+ // Image is the image relative to images/. This will be mangled
+ // appropriately, to ensure that only first-party images are used.
+ Image string
+
+ // Memory is the memory limit in bytes.
+ Memory int
+
+ // Cpus in which to allow execution. ("0", "1", "0-2").
+ CpusetCpus string
+
+ // Ports are the ports to be allocated.
+ Ports []int
+
+ // WorkDir sets the working directory.
+ WorkDir string
+
+ // ReadOnly sets the read-only flag.
+ ReadOnly bool
+
+ // Env are additional environment variables.
+ Env []string
+
+ // User is the user to use.
+ User string
+
+ // Privileged enables privileged mode.
+ Privileged bool
+
+ // CapAdd are the extra set of capabilities to add.
+ CapAdd []string
+
+ // CapDrop are the extra set of capabilities to drop.
+ CapDrop []string
+
+ // Mounts is the list of directories/files to be mounted inside the container.
+ Mounts []mount.Mount
+
+ // Links is the list of containers to be connected to the container.
+ Links []string
+}
+
+// MakeContainer sets up the struct for a Docker container.
+//
+// Names of containers will be unique.
+// Containers will check flags for profiling requests.
+func MakeContainer(ctx context.Context, logger testutil.Logger) *Container {
+ c := MakeNativeContainer(ctx, logger)
+ c.runtime = *runtime
+ if p := MakePprofFromFlags(c); p != nil {
+ c.AddProfile(p)
+ }
+ return c
+}
+
+// MakeNativeContainer sets up the struct for a DockerContainer using runc. Native
+// containers aren't profiled.
+func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container {
+ // Slashes are not allowed in container names.
+ name := testutil.RandomID(logger.Name())
+ name = strings.ReplaceAll(name, "/", "-")
+ client, err := client.NewClientWithOpts(client.FromEnv)
+ if err != nil {
+ return nil
+ }
+ client.NegotiateAPIVersion(ctx)
+ return &Container{
+ logger: logger,
+ Name: name,
+ runtime: "",
+ client: client,
+ }
+}
+
+// AddProfile adds a profile to this container.
+func (c *Container) AddProfile(p Profile) {
+ c.profiles = append(c.profiles, p)
+}
+
+// RestartProfiles calls Restart on all profiles for this container.
+func (c *Container) RestartProfiles() error {
+ for _, profile := range c.profiles {
+ if err := profile.Restart(c); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Spawn is analogous to 'docker run -d'.
+func (c *Container) Spawn(ctx context.Context, r RunOpts, args ...string) error {
+ if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil {
+ return err
+ }
+ return c.Start(ctx)
+}
+
+// SpawnProcess is analogous to 'docker run -it'. It returns a process
+// which represents the root process.
+func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string) (Process, error) {
+ config, hostconf, netconf := c.ConfigsFrom(r, args...)
+ config.Tty = true
+ config.OpenStdin = true
+
+ if err := c.CreateFrom(ctx, config, hostconf, netconf); err != nil {
+ return Process{}, err
+ }
+
+ if err := c.Start(ctx); err != nil {
+ return Process{}, err
+ }
+
+ return Process{container: c, conn: c.streams}, nil
+}
+
+// Run is analogous to 'docker run'.
+func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, error) {
+ if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil {
+ return "", err
+ }
+
+ if err := c.Start(ctx); err != nil {
+ return "", err
+ }
+
+ if err := c.Wait(ctx); err != nil {
+ return "", err
+ }
+
+ return c.Logs(ctx)
+}
+
+// ConfigsFrom returns container configs from RunOpts and args. The caller should call 'CreateFrom'
+// and Start.
+func (c *Container) ConfigsFrom(r RunOpts, args ...string) (*container.Config, *container.HostConfig, *network.NetworkingConfig) {
+ return c.config(r, args), c.hostConfig(r), &network.NetworkingConfig{}
+}
+
+// MakeLink formats a link to add to a RunOpts.
+func (c *Container) MakeLink(target string) string {
+ return fmt.Sprintf("%s:%s", c.Name, target)
+}
+
+// CreateFrom creates a container from the given configs.
+func (c *Container) CreateFrom(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error {
+ return c.create(ctx, conf, hostconf, netconf)
+}
+
+// Create is analogous to 'docker create'.
+func (c *Container) Create(ctx context.Context, r RunOpts, args ...string) error {
+ return c.create(ctx, c.config(r, args), c.hostConfig(r), nil)
+}
+
+func (c *Container) create(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error {
+ cont, err := c.client.ContainerCreate(ctx, conf, hostconf, nil, c.Name)
+ if err != nil {
+ return err
+ }
+ c.id = cont.ID
+ for _, profile := range c.profiles {
+ if err := profile.OnCreate(c); err != nil {
+ return fmt.Errorf("OnCreate method failed with: %v", err)
+ }
+ }
+ return nil
+}
+
+func (c *Container) config(r RunOpts, args []string) *container.Config {
+ ports := nat.PortSet{}
+ for _, p := range r.Ports {
+ port := nat.Port(fmt.Sprintf("%d", p))
+ ports[port] = struct{}{}
+ }
+ env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name))
+
+ return &container.Config{
+ Image: testutil.ImageByName(r.Image),
+ Cmd: args,
+ ExposedPorts: ports,
+ Env: env,
+ WorkingDir: r.WorkDir,
+ User: r.User,
+ }
+}
+
+func (c *Container) hostConfig(r RunOpts) *container.HostConfig {
+ c.mounts = append(c.mounts, r.Mounts...)
+
+ return &container.HostConfig{
+ Runtime: c.runtime,
+ Mounts: c.mounts,
+ PublishAllPorts: true,
+ Links: r.Links,
+ CapAdd: r.CapAdd,
+ CapDrop: r.CapDrop,
+ Privileged: r.Privileged,
+ ReadonlyRootfs: r.ReadOnly,
+ Resources: container.Resources{
+ Memory: int64(r.Memory), // In bytes.
+ CpusetCpus: r.CpusetCpus,
+ },
+ }
+}
+
+// Start is analogous to 'docker start'.
+func (c *Container) Start(ctx context.Context) error {
+
+ // Open a connection to the container for parsing logs and for TTY.
+ streams, err := c.client.ContainerAttach(ctx, c.id,
+ types.ContainerAttachOptions{
+ Stream: true,
+ Stdin: true,
+ Stdout: true,
+ Stderr: true,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to connect to container: %v", err)
+ }
+
+ c.streams = streams
+ c.cleanups = append(c.cleanups, func() {
+ c.streams.Close()
+ })
+ if err := c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}); err != nil {
+ return fmt.Errorf("ContainerStart failed: %v", err)
+ }
+ for _, profile := range c.profiles {
+ if err := profile.OnStart(c); err != nil {
+ return fmt.Errorf("OnStart method failed: %v", err)
+ }
+ }
+ return nil
+}
+
+// Stop is analogous to 'docker stop'.
+func (c *Container) Stop(ctx context.Context) error {
+ return c.client.ContainerStop(ctx, c.id, nil)
+}
+
+// Pause is analogous to'docker pause'.
+func (c *Container) Pause(ctx context.Context) error {
+ return c.client.ContainerPause(ctx, c.id)
+}
+
+// Unpause is analogous to 'docker unpause'.
+func (c *Container) Unpause(ctx context.Context) error {
+ return c.client.ContainerUnpause(ctx, c.id)
+}
+
+// Checkpoint is analogous to 'docker checkpoint'.
+func (c *Container) Checkpoint(ctx context.Context, name string) error {
+ return c.client.CheckpointCreate(ctx, c.Name, types.CheckpointCreateOptions{CheckpointID: name, Exit: true})
+}
+
+// Restore is analogous to 'docker start --checkname [name]'.
+func (c *Container) Restore(ctx context.Context, name string) error {
+ return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{CheckpointID: name})
+}
+
+// Logs is analogous 'docker logs'.
+func (c *Container) Logs(ctx context.Context) (string, error) {
+ var out bytes.Buffer
+ err := c.logs(ctx, &out, &out)
+ return out.String(), err
+}
+
+func (c *Container) logs(ctx context.Context, stdout, stderr *bytes.Buffer) error {
+ opts := types.ContainerLogsOptions{ShowStdout: true, ShowStderr: true}
+ writer, err := c.client.ContainerLogs(ctx, c.id, opts)
+ if err != nil {
+ return err
+ }
+ defer writer.Close()
+ _, err = stdcopy.StdCopy(stdout, stderr, writer)
+
+ return err
+}
+
+// ID returns the container id.
+func (c *Container) ID() string {
+ return c.id
+}
+
+// SandboxPid returns the container's pid.
+func (c *Container) SandboxPid(ctx context.Context) (int, error) {
+ resp, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return -1, err
+ }
+ return resp.ContainerJSONBase.State.Pid, nil
+}
+
+// FindIP returns the IP address of the container.
+func (c *Container) FindIP(ctx context.Context, ipv6 bool) (net.IP, error) {
+ resp, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return nil, err
+ }
+
+ var ip net.IP
+ if ipv6 {
+ ip = net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.GlobalIPv6Address)
+ } else {
+ ip = net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.IPAddress)
+ }
+ if ip == nil {
+ return net.IP{}, fmt.Errorf("invalid IP: %q", ip)
+ }
+ return ip, nil
+}
+
+// FindPort returns the host port that is mapped to 'sandboxPort'.
+func (c *Container) FindPort(ctx context.Context, sandboxPort int) (int, error) {
+ desc, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return -1, fmt.Errorf("error retrieving port: %v", err)
+ }
+
+ format := fmt.Sprintf("%d/tcp", sandboxPort)
+ ports, ok := desc.NetworkSettings.Ports[nat.Port(format)]
+ if !ok {
+ return -1, fmt.Errorf("error retrieving port: %v", err)
+
+ }
+
+ port, err := strconv.Atoi(ports[0].HostPort)
+ if err != nil {
+ return -1, fmt.Errorf("error parsing port %q: %v", port, err)
+ }
+ return port, nil
+}
+
+// CopyFiles copies in and mounts the given files. They are always ReadOnly.
+func (c *Container) CopyFiles(opts *RunOpts, target string, sources ...string) {
+ dir, err := ioutil.TempDir("", c.Name)
+ if err != nil {
+ c.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err)
+ return
+ }
+ c.cleanups = append(c.cleanups, func() { os.RemoveAll(dir) })
+ if err := os.Chmod(dir, 0755); err != nil {
+ c.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err)
+ return
+ }
+ for _, name := range sources {
+ src, err := testutil.FindFile(name)
+ if err != nil {
+ c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err)
+ return
+ }
+ dst := path.Join(dir, path.Base(name))
+ if err := testutil.Copy(src, dst); err != nil {
+ c.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err)
+ return
+ }
+ c.logger.Logf("copy: %s -> %s", src, dst)
+ }
+ opts.Mounts = append(opts.Mounts, mount.Mount{
+ Type: mount.TypeBind,
+ Source: dir,
+ Target: target,
+ ReadOnly: false,
+ })
+}
+
+// Status inspects the container returns its status.
+func (c *Container) Status(ctx context.Context) (types.ContainerState, error) {
+ resp, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return types.ContainerState{}, err
+ }
+ return *resp.State, err
+}
+
+// Wait waits for the container to exit.
+func (c *Container) Wait(ctx context.Context) error {
+ statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning)
+ select {
+ case err := <-errChan:
+ return err
+ case <-statusChan:
+ return nil
+ }
+}
+
+// WaitTimeout waits for the container to exit with a timeout.
+func (c *Container) WaitTimeout(ctx context.Context, timeout time.Duration) error {
+ ctx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+ statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning)
+ select {
+ case <-ctx.Done():
+ if ctx.Err() == context.DeadlineExceeded {
+ return fmt.Errorf("container %s timed out after %v seconds", c.Name, timeout.Seconds())
+ }
+ return nil
+ case err := <-errChan:
+ return err
+ case <-statusChan:
+ return nil
+ }
+}
+
+// WaitForOutput searches container logs for pattern and returns or timesout.
+func (c *Container) WaitForOutput(ctx context.Context, pattern string, timeout time.Duration) (string, error) {
+ matches, err := c.WaitForOutputSubmatch(ctx, pattern, timeout)
+ if err != nil {
+ return "", err
+ }
+ if len(matches) == 0 {
+ return "", fmt.Errorf("didn't find pattern %s logs", pattern)
+ }
+ return matches[0], nil
+}
+
+// WaitForOutputSubmatch searches container logs for the given
+// pattern or times out. It returns any regexp submatches as well.
+func (c *Container) WaitForOutputSubmatch(ctx context.Context, pattern string, timeout time.Duration) ([]string, error) {
+ re := regexp.MustCompile(pattern)
+ if matches := re.FindStringSubmatch(c.streamBuf.String()); matches != nil {
+ return matches, nil
+ }
+
+ for exp := time.Now().Add(timeout); time.Now().Before(exp); {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+
+ c.streams.Conn.SetDeadline(time.Now().Add(50 * time.Millisecond))
+ _, err := stdcopy.StdCopy(&c.streamBuf, &c.streamBuf, c.streams.Reader)
+
+ if err != nil {
+ // check that it wasn't a timeout
+ if nerr, ok := err.(net.Error); !ok || !nerr.Timeout() {
+ return nil, err
+ }
+ }
+
+ if matches := re.FindStringSubmatch(c.streamBuf.String()); matches != nil {
+ return matches, nil
+ }
+ }
+
+ return nil, fmt.Errorf("timeout waiting for output %q: out: %s", re.String(), c.streamBuf.String())
+}
+
+// Kill kills the container.
+func (c *Container) Kill(ctx context.Context) error {
+ return c.client.ContainerKill(ctx, c.id, "")
+}
+
+// Remove is analogous to 'docker rm'.
+func (c *Container) Remove(ctx context.Context) error {
+ // Remove the image.
+ remove := types.ContainerRemoveOptions{
+ RemoveVolumes: c.mounts != nil,
+ RemoveLinks: c.links != nil,
+ Force: true,
+ }
+ return c.client.ContainerRemove(ctx, c.Name, remove)
+}
+
+// CleanUp kills and deletes the container (best effort).
+func (c *Container) CleanUp(ctx context.Context) {
+ // Execute profile cleanups before the container goes down.
+ for _, profile := range c.profiles {
+ profile.OnCleanUp(c)
+ }
+ // Forget profiles.
+ c.profiles = nil
+ // Kill the container.
+ if err := c.Kill(ctx); err != nil && !strings.Contains(err.Error(), "is not running") {
+ // Just log; can't do anything here.
+ c.logger.Logf("error killing container %q: %v", c.Name, err)
+ }
+ // Remove the image.
+ if err := c.Remove(ctx); err != nil {
+ c.logger.Logf("error removing container %q: %v", c.Name, err)
+ }
+ // Forget all mounts.
+ c.mounts = nil
+ // Execute all cleanups.
+ for _, c := range c.cleanups {
+ c()
+ }
+ c.cleanups = nil
+}
diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go
index 06f81d28d..5a9dd8bd8 100644
--- a/pkg/test/dockerutil/dockerutil.go
+++ b/pkg/test/dockerutil/dockerutil.go
@@ -22,17 +22,11 @@ import (
"io"
"io/ioutil"
"log"
- "net"
- "os"
"os/exec"
- "path"
"regexp"
"strconv"
- "strings"
- "syscall"
"time"
- "github.com/kr/pty"
"gvisor.dev/gvisor/pkg/test/testutil"
)
@@ -49,6 +43,26 @@ var (
// config is the default Docker daemon configuration path.
config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths")
+
+ // The following flags are for the "pprof" profiler tool.
+
+ // pprofBaseDir allows the user to change the directory to which profiles are
+ // written. By default, profiles will appear under:
+ // /tmp/profile/RUNTIME/CONTAINER_NAME/*.pprof.
+ pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)")
+
+ // duration is the max duration `runsc debug` will run and capture profiles.
+ // If the container's clean up method is called prior to duration, the
+ // profiling process will be killed.
+ duration = flag.Duration("pprof-duration", 10*time.Second, "duration to run the profile in seconds")
+
+ // The below flags enable each type of profile. Multiple profiles can be
+ // enabled for each run.
+ pprofBlock = flag.Bool("pprof-block", false, "enables block profiling with runsc debug")
+ pprofCPU = flag.Bool("pprof-cpu", false, "enables CPU profiling with runsc debug")
+ pprofGo = flag.Bool("pprof-go", false, "enables goroutine profiling with runsc debug")
+ pprofHeap = flag.Bool("pprof-heap", false, "enables heap profiling with runsc debug")
+ pprofMutex = flag.Bool("pprof-mutex", false, "enables mutex profiling with runsc debug")
)
// EnsureSupportedDockerVersion checks if correct docker is installed.
@@ -127,462 +141,7 @@ func Save(logger testutil.Logger, image string, w io.Writer) error {
return cmd.Run()
}
-// MountMode describes if the mount should be ro or rw.
-type MountMode int
-
-const (
- // ReadOnly is what the name says.
- ReadOnly MountMode = iota
- // ReadWrite is what the name says.
- ReadWrite
-)
-
-// String returns the mount mode argument for this MountMode.
-func (m MountMode) String() string {
- switch m {
- case ReadOnly:
- return "ro"
- case ReadWrite:
- return "rw"
- }
- panic(fmt.Sprintf("invalid mode: %d", m))
-}
-
-// Docker contains the name and the runtime of a docker container.
-type Docker struct {
- logger testutil.Logger
- Runtime string
- Name string
- copyErr error
- mounts []string
- cleanups []func()
-}
-
-// MakeDocker sets up the struct for a Docker container.
-//
-// Names of containers will be unique.
-func MakeDocker(logger testutil.Logger) *Docker {
- // Slashes are not allowed in container names.
- name := testutil.RandomID(logger.Name())
- name = strings.ReplaceAll(name, "/", "-")
-
- return &Docker{
- logger: logger,
- Name: name,
- Runtime: *runtime,
- }
-}
-
-// Mount mounts the given source and makes it available in the container.
-func (d *Docker) Mount(target, source string, mode MountMode) {
- d.mounts = append(d.mounts, fmt.Sprintf("-v=%s:%s:%v", source, target, mode))
-}
-
-// CopyFiles copies in and mounts the given files. They are always ReadOnly.
-func (d *Docker) CopyFiles(target string, sources ...string) {
- dir, err := ioutil.TempDir("", d.Name)
- if err != nil {
- d.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err)
- return
- }
- d.cleanups = append(d.cleanups, func() { os.RemoveAll(dir) })
- if err := os.Chmod(dir, 0755); err != nil {
- d.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err)
- return
- }
- for _, name := range sources {
- src, err := testutil.FindFile(name)
- if err != nil {
- d.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err)
- return
- }
- dst := path.Join(dir, path.Base(name))
- if err := testutil.Copy(src, dst); err != nil {
- d.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err)
- return
- }
- d.logger.Logf("copy: %s -> %s", src, dst)
- }
- d.Mount(target, dir, ReadOnly)
-}
-
-// Link links the given target.
-func (d *Docker) Link(target string, source *Docker) {
- d.mounts = append(d.mounts, fmt.Sprintf("--link=%s:%s", source.Name, target))
-}
-
-// RunOpts are options for running a container.
-type RunOpts struct {
- // Image is the image relative to images/. This will be mangled
- // appropriately, to ensure that only first-party images are used.
- Image string
-
- // Memory is the memory limit in kB.
- Memory int
-
- // Ports are the ports to be allocated.
- Ports []int
-
- // WorkDir sets the working directory.
- WorkDir string
-
- // ReadOnly sets the read-only flag.
- ReadOnly bool
-
- // Env are additional environment variables.
- Env []string
-
- // User is the user to use.
- User string
-
- // Privileged enables privileged mode.
- Privileged bool
-
- // CapAdd are the extra set of capabilities to add.
- CapAdd []string
-
- // CapDrop are the extra set of capabilities to drop.
- CapDrop []string
-
- // Pty indicates that a pty will be allocated. If this is non-nil, then
- // this will run after start-up with the *exec.Command and Pty file
- // passed in to the function.
- Pty func(*exec.Cmd, *os.File)
-
- // Foreground indicates that the container should be run in the
- // foreground. If this is true, then the output will be available as a
- // return value from the Run function.
- Foreground bool
-
- // Extra are extra arguments that may be passed.
- Extra []string
-}
-
-// args returns common arguments.
-//
-// Note that this does not define the complete behavior.
-func (d *Docker) argsFor(r *RunOpts, command string, p []string) (rv []string) {
- isExec := command == "exec"
- isRun := command == "run"
-
- if isRun || isExec {
- rv = append(rv, "-i")
- }
- if r.Pty != nil {
- rv = append(rv, "-t")
- }
- if r.User != "" {
- rv = append(rv, fmt.Sprintf("--user=%s", r.User))
- }
- if r.Privileged {
- rv = append(rv, "--privileged")
- }
- for _, c := range r.CapAdd {
- rv = append(rv, fmt.Sprintf("--cap-add=%s", c))
- }
- for _, c := range r.CapDrop {
- rv = append(rv, fmt.Sprintf("--cap-drop=%s", c))
- }
- for _, e := range r.Env {
- rv = append(rv, fmt.Sprintf("--env=%s", e))
- }
- if r.WorkDir != "" {
- rv = append(rv, fmt.Sprintf("--workdir=%s", r.WorkDir))
- }
- if !isExec {
- if r.Memory != 0 {
- rv = append(rv, fmt.Sprintf("--memory=%dk", r.Memory))
- }
- for _, p := range r.Ports {
- rv = append(rv, fmt.Sprintf("--publish=%d", p))
- }
- if r.ReadOnly {
- rv = append(rv, fmt.Sprintf("--read-only"))
- }
- if len(p) > 0 {
- rv = append(rv, "--entrypoint=")
- }
- }
-
- // Always attach the test environment & Extra.
- rv = append(rv, fmt.Sprintf("--env=RUNSC_TEST_NAME=%s", d.Name))
- rv = append(rv, r.Extra...)
-
- // Attach necessary bits.
- if isExec {
- rv = append(rv, d.Name)
- } else {
- rv = append(rv, d.mounts...)
- rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime))
- rv = append(rv, fmt.Sprintf("--name=%s", d.Name))
- rv = append(rv, testutil.ImageByName(r.Image))
- }
-
- // Attach other arguments.
- rv = append(rv, p...)
- return rv
-}
-
-// run runs a complete command.
-func (d *Docker) run(r RunOpts, command string, p ...string) (string, error) {
- if d.copyErr != nil {
- return "", d.copyErr
- }
- basicArgs := []string{"docker"}
- if command == "spawn" {
- command = "run"
- basicArgs = append(basicArgs, command)
- basicArgs = append(basicArgs, "-d")
- } else {
- basicArgs = append(basicArgs, command)
- }
- customArgs := d.argsFor(&r, command, p)
- cmd := testutil.Command(d.logger, append(basicArgs, customArgs...)...)
- if r.Pty != nil {
- // If allocating a terminal, then we just ignore the output
- // from the command.
- ptmx, err := pty.Start(cmd.Cmd)
- if err != nil {
- return "", err
- }
- defer cmd.Wait() // Best effort.
- r.Pty(cmd.Cmd, ptmx)
- } else {
- // Can't support PTY or streaming.
- out, err := cmd.CombinedOutput()
- return string(out), err
- }
- return "", nil
-}
-
-// Create calls 'docker create' with the arguments provided.
-func (d *Docker) Create(r RunOpts, args ...string) error {
- out, err := d.run(r, "create", args...)
- if strings.Contains(out, "Unable to find image") {
- return fmt.Errorf("unable to find image, did you remember to `make load-%s`: %w", r.Image, err)
- }
- return err
-}
-
-// Start calls 'docker start'.
-func (d *Docker) Start() error {
- return testutil.Command(d.logger, "docker", "start", d.Name).Run()
-}
-
-// Stop calls 'docker stop'.
-func (d *Docker) Stop() error {
- return testutil.Command(d.logger, "docker", "stop", d.Name).Run()
-}
-
-// Run calls 'docker run' with the arguments provided.
-func (d *Docker) Run(r RunOpts, args ...string) (string, error) {
- return d.run(r, "run", args...)
-}
-
-// Spawn starts the container and detaches.
-func (d *Docker) Spawn(r RunOpts, args ...string) error {
- _, err := d.run(r, "spawn", args...)
- return err
-}
-
-// Logs calls 'docker logs'.
-func (d *Docker) Logs() (string, error) {
- // Don't capture the output; since it will swamp the logs.
- out, err := exec.Command("docker", "logs", d.Name).CombinedOutput()
- return string(out), err
-}
-
-// Exec calls 'docker exec' with the arguments provided.
-func (d *Docker) Exec(r RunOpts, args ...string) (string, error) {
- return d.run(r, "exec", args...)
-}
-
-// Pause calls 'docker pause'.
-func (d *Docker) Pause() error {
- return testutil.Command(d.logger, "docker", "pause", d.Name).Run()
-}
-
-// Unpause calls 'docker pause'.
-func (d *Docker) Unpause() error {
- return testutil.Command(d.logger, "docker", "unpause", d.Name).Run()
-}
-
-// Checkpoint calls 'docker checkpoint'.
-func (d *Docker) Checkpoint(name string) error {
- return testutil.Command(d.logger, "docker", "checkpoint", "create", d.Name, name).Run()
-}
-
-// Restore calls 'docker start --checkname [name]'.
-func (d *Docker) Restore(name string) error {
- return testutil.Command(d.logger, "docker", "start", fmt.Sprintf("--checkpoint=%s", name), d.Name).Run()
-}
-
-// Kill calls 'docker kill'.
-func (d *Docker) Kill() error {
- // Skip logging this command, it will likely be an error.
- out, err := exec.Command("docker", "kill", d.Name).CombinedOutput()
- if err != nil && !strings.Contains(string(out), "is not running") {
- return err
- }
- return nil
-}
-
-// Remove calls 'docker rm'.
-func (d *Docker) Remove() error {
- return testutil.Command(d.logger, "docker", "rm", d.Name).Run()
-}
-
-// CleanUp kills and deletes the container (best effort).
-func (d *Docker) CleanUp() {
- // Kill the container.
- if err := d.Kill(); err != nil {
- // Just log; can't do anything here.
- d.logger.Logf("error killing container %q: %v", d.Name, err)
- }
- // Remove the image.
- if err := d.Remove(); err != nil {
- d.logger.Logf("error removing container %q: %v", d.Name, err)
- }
- // Forget all mounts.
- d.mounts = nil
- // Execute all cleanups.
- for _, c := range d.cleanups {
- c()
- }
- d.cleanups = nil
-}
-
-// FindPort returns the host port that is mapped to 'sandboxPort'. This calls
-// docker to allocate a free port in the host and prevent conflicts.
-func (d *Docker) FindPort(sandboxPort int) (int, error) {
- format := fmt.Sprintf(`{{ (index (index .NetworkSettings.Ports "%d/tcp") 0).HostPort }}`, sandboxPort)
- out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput()
- if err != nil {
- return -1, fmt.Errorf("error retrieving port: %v", err)
- }
- port, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n"))
- if err != nil {
- return -1, fmt.Errorf("error parsing port %q: %v", out, err)
- }
- return port, nil
-}
-
-// FindIP returns the IP address of the container.
-func (d *Docker) FindIP() (net.IP, error) {
- const format = `{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}`
- out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput()
- if err != nil {
- return net.IP{}, fmt.Errorf("error retrieving IP: %v", err)
- }
- ip := net.ParseIP(strings.TrimSpace(string(out)))
- if ip == nil {
- return net.IP{}, fmt.Errorf("invalid IP: %q", string(out))
- }
- return ip, nil
-}
-
-// SandboxPid returns the PID to the sandbox process.
-func (d *Docker) SandboxPid() (int, error) {
- out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.State.Pid}}", d.Name).CombinedOutput()
- if err != nil {
- return -1, fmt.Errorf("error retrieving pid: %v", err)
- }
- pid, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n"))
- if err != nil {
- return -1, fmt.Errorf("error parsing pid %q: %v", out, err)
- }
- return pid, nil
-}
-
-// ID returns the container ID.
-func (d *Docker) ID() (string, error) {
- out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.Id}}", d.Name).CombinedOutput()
- if err != nil {
- return "", fmt.Errorf("error retrieving ID: %v", err)
- }
- return strings.TrimSpace(string(out)), nil
-}
-
-// Wait waits for container to exit, up to the given timeout. Returns error if
-// wait fails or timeout is hit. Returns the application return code otherwise.
-// Note that the application may have failed even if err == nil, always check
-// the exit code.
-func (d *Docker) Wait(timeout time.Duration) (syscall.WaitStatus, error) {
- timeoutChan := time.After(timeout)
- waitChan := make(chan (syscall.WaitStatus))
- errChan := make(chan (error))
-
- go func() {
- out, err := testutil.Command(d.logger, "docker", "wait", d.Name).CombinedOutput()
- if err != nil {
- errChan <- fmt.Errorf("error waiting for container %q: %v", d.Name, err)
- }
- exit, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n"))
- if err != nil {
- errChan <- fmt.Errorf("error parsing exit code %q: %v", out, err)
- }
- waitChan <- syscall.WaitStatus(uint32(exit))
- }()
-
- select {
- case ws := <-waitChan:
- return ws, nil
- case err := <-errChan:
- return syscall.WaitStatus(1), err
- case <-timeoutChan:
- return syscall.WaitStatus(1), fmt.Errorf("timeout waiting for container %q", d.Name)
- }
-}
-
-// WaitForOutput calls 'docker logs' to retrieve containers output and searches
-// for the given pattern.
-func (d *Docker) WaitForOutput(pattern string, timeout time.Duration) (string, error) {
- matches, err := d.WaitForOutputSubmatch(pattern, timeout)
- if err != nil {
- return "", err
- }
- if len(matches) == 0 {
- return "", nil
- }
- return matches[0], nil
-}
-
-// WaitForOutputSubmatch calls 'docker logs' to retrieve containers output and
-// searches for the given pattern. It returns any regexp submatches as well.
-func (d *Docker) WaitForOutputSubmatch(pattern string, timeout time.Duration) ([]string, error) {
- re := regexp.MustCompile(pattern)
- var (
- lastOut string
- stopped bool
- )
- for exp := time.Now().Add(timeout); time.Now().Before(exp); {
- out, err := d.Logs()
- if err != nil {
- return nil, err
- }
- if out != lastOut {
- if lastOut == "" {
- d.logger.Logf("output (start): %s", out)
- } else if strings.HasPrefix(out, lastOut) {
- d.logger.Logf("output (contn): %s", out[len(lastOut):])
- } else {
- d.logger.Logf("output (trunc): %s", out)
- }
- lastOut = out // Save for future.
- if matches := re.FindStringSubmatch(lastOut); matches != nil {
- return matches, nil // Success!
- }
- } else if stopped {
- // The sandbox stopped and we looked at the
- // logs at least once since determining that.
- return nil, fmt.Errorf("no longer running: %v", err)
- } else if pid, err := d.SandboxPid(); pid == 0 || err != nil {
- // The sandbox may have stopped, but it's
- // possible that it has emitted the terminal
- // line between the last call to Logs and here.
- stopped = true
- }
- time.Sleep(100 * time.Millisecond)
- }
- return nil, fmt.Errorf("timeout waiting for output %q: %s", re.String(), lastOut)
+// Runtime returns the value of the flag runtime.
+func Runtime() string {
+ return *runtime
}
diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go
new file mode 100644
index 000000000..4c739c9e9
--- /dev/null
+++ b/pkg/test/dockerutil/exec.go
@@ -0,0 +1,193 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/docker/docker/api/types"
+ "github.com/docker/docker/pkg/stdcopy"
+)
+
+// ExecOpts holds arguments for Exec calls.
+type ExecOpts struct {
+ // Env are additional environment variables.
+ Env []string
+
+ // Privileged enables privileged mode.
+ Privileged bool
+
+ // User is the user to use.
+ User string
+
+ // Enables Tty and stdin for the created process.
+ UseTTY bool
+
+ // WorkDir is the working directory of the process.
+ WorkDir string
+}
+
+// Exec creates a process inside the container.
+func (c *Container) Exec(ctx context.Context, opts ExecOpts, args ...string) (string, error) {
+ p, err := c.doExec(ctx, opts, args)
+ if err != nil {
+ return "", err
+ }
+
+ if exitStatus, err := p.WaitExitStatus(ctx); err != nil {
+ return "", err
+ } else if exitStatus != 0 {
+ out, _ := p.Logs()
+ return out, fmt.Errorf("process terminated with status: %d", exitStatus)
+ }
+
+ return p.Logs()
+}
+
+// ExecProcess creates a process inside the container and returns a process struct
+// for the caller to use.
+func (c *Container) ExecProcess(ctx context.Context, opts ExecOpts, args ...string) (Process, error) {
+ return c.doExec(ctx, opts, args)
+}
+
+func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Process, error) {
+ config := c.execConfig(r, args)
+ resp, err := c.client.ContainerExecCreate(ctx, c.id, config)
+ if err != nil {
+ return Process{}, fmt.Errorf("exec create failed with err: %v", err)
+ }
+
+ hijack, err := c.client.ContainerExecAttach(ctx, resp.ID, types.ExecStartCheck{})
+ if err != nil {
+ return Process{}, fmt.Errorf("exec attach failed with err: %v", err)
+ }
+
+ if err := c.client.ContainerExecStart(ctx, resp.ID, types.ExecStartCheck{}); err != nil {
+ hijack.Close()
+ return Process{}, fmt.Errorf("exec start failed with err: %v", err)
+ }
+
+ return Process{
+ container: c,
+ execid: resp.ID,
+ conn: hijack,
+ }, nil
+}
+
+func (c *Container) execConfig(r ExecOpts, cmd []string) types.ExecConfig {
+ env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name))
+ return types.ExecConfig{
+ AttachStdin: r.UseTTY,
+ AttachStderr: true,
+ AttachStdout: true,
+ Cmd: cmd,
+ Privileged: r.Privileged,
+ WorkingDir: r.WorkDir,
+ Env: env,
+ Tty: r.UseTTY,
+ User: r.User,
+ }
+
+}
+
+// Process represents a containerized process.
+type Process struct {
+ container *Container
+ execid string
+ conn types.HijackedResponse
+}
+
+// Write writes buf to the process's stdin.
+func (p *Process) Write(timeout time.Duration, buf []byte) (int, error) {
+ p.conn.Conn.SetDeadline(time.Now().Add(timeout))
+ return p.conn.Conn.Write(buf)
+}
+
+// Read returns process's stdout and stderr.
+func (p *Process) Read() (string, string, error) {
+ var stdout, stderr bytes.Buffer
+ if err := p.read(&stdout, &stderr); err != nil {
+ return "", "", err
+ }
+ return stdout.String(), stderr.String(), nil
+}
+
+// Logs returns combined stdout/stderr from the process.
+func (p *Process) Logs() (string, error) {
+ var out bytes.Buffer
+ if err := p.read(&out, &out); err != nil {
+ return "", err
+ }
+ return out.String(), nil
+}
+
+func (p *Process) read(stdout, stderr *bytes.Buffer) error {
+ _, err := stdcopy.StdCopy(stdout, stderr, p.conn.Reader)
+ return err
+}
+
+// ExitCode returns the process's exit code.
+func (p *Process) ExitCode(ctx context.Context) (int, error) {
+ _, exitCode, err := p.runningExitCode(ctx)
+ return exitCode, err
+}
+
+// IsRunning checks if the process is running.
+func (p *Process) IsRunning(ctx context.Context) (bool, error) {
+ running, _, err := p.runningExitCode(ctx)
+ return running, err
+}
+
+// WaitExitStatus until process completes and returns exit status.
+func (p *Process) WaitExitStatus(ctx context.Context) (int, error) {
+ waitChan := make(chan (int))
+ errChan := make(chan (error))
+
+ go func() {
+ for {
+ running, exitcode, err := p.runningExitCode(ctx)
+ if err != nil {
+ errChan <- fmt.Errorf("error waiting process %s: container %v", p.execid, p.container.Name)
+ }
+ if !running {
+ waitChan <- exitcode
+ }
+ time.Sleep(time.Millisecond * 500)
+ }
+ }()
+
+ select {
+ case ws := <-waitChan:
+ return ws, nil
+ case err := <-errChan:
+ return -1, err
+ }
+}
+
+// runningExitCode collects if the process is running and the exit code.
+// The exit code is only valid if the process has exited.
+func (p *Process) runningExitCode(ctx context.Context) (bool, int, error) {
+ // If execid is not empty, this is a execed process.
+ if p.execid != "" {
+ status, err := p.container.client.ContainerExecInspect(ctx, p.execid)
+ return status.Running, status.ExitCode, err
+ }
+ // else this is the root process.
+ status, err := p.container.Status(ctx)
+ return status.Running, status.ExitCode, err
+}
diff --git a/pkg/test/dockerutil/network.go b/pkg/test/dockerutil/network.go
new file mode 100644
index 000000000..047091e75
--- /dev/null
+++ b/pkg/test/dockerutil/network.go
@@ -0,0 +1,113 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "context"
+ "net"
+
+ "github.com/docker/docker/api/types"
+ "github.com/docker/docker/api/types/network"
+ "github.com/docker/docker/client"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// Network is a docker network.
+type Network struct {
+ client *client.Client
+ id string
+ logger testutil.Logger
+ Name string
+ containers []*Container
+ Subnet *net.IPNet
+}
+
+// NewNetwork sets up the struct for a Docker network. Names of networks
+// will be unique.
+func NewNetwork(ctx context.Context, logger testutil.Logger) *Network {
+ client, err := client.NewClientWithOpts(client.FromEnv)
+ if err != nil {
+ logger.Logf("create client failed with: %v", err)
+ return nil
+ }
+ client.NegotiateAPIVersion(ctx)
+
+ return &Network{
+ logger: logger,
+ Name: testutil.RandomID(logger.Name()),
+ client: client,
+ }
+}
+
+func (n *Network) networkCreate() types.NetworkCreate {
+
+ var subnet string
+ if n.Subnet != nil {
+ subnet = n.Subnet.String()
+ }
+
+ ipam := network.IPAM{
+ Config: []network.IPAMConfig{{
+ Subnet: subnet,
+ }},
+ }
+
+ return types.NetworkCreate{
+ CheckDuplicate: true,
+ IPAM: &ipam,
+ }
+}
+
+// Create is analogous to 'docker network create'.
+func (n *Network) Create(ctx context.Context) error {
+
+ opts := n.networkCreate()
+ resp, err := n.client.NetworkCreate(ctx, n.Name, opts)
+ if err != nil {
+ return err
+ }
+ n.id = resp.ID
+ return nil
+}
+
+// Connect is analogous to 'docker network connect' with the arguments provided.
+func (n *Network) Connect(ctx context.Context, container *Container, ipv4, ipv6 string) error {
+ settings := network.EndpointSettings{
+ IPAMConfig: &network.EndpointIPAMConfig{
+ IPv4Address: ipv4,
+ IPv6Address: ipv6,
+ },
+ }
+ err := n.client.NetworkConnect(ctx, n.id, container.id, &settings)
+ if err == nil {
+ n.containers = append(n.containers, container)
+ }
+ return err
+}
+
+// Inspect returns this network's info.
+func (n *Network) Inspect(ctx context.Context) (types.NetworkResource, error) {
+ return n.client.NetworkInspect(ctx, n.id, types.NetworkInspectOptions{Verbose: true})
+}
+
+// Cleanup cleans up the docker network and all the containers attached to it.
+func (n *Network) Cleanup(ctx context.Context) error {
+ for _, c := range n.containers {
+ c.CleanUp(ctx)
+ }
+ n.containers = nil
+
+ return n.client.NetworkRemove(ctx, n.id)
+}
diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go
new file mode 100644
index 000000000..1fab33083
--- /dev/null
+++ b/pkg/test/dockerutil/profile.go
@@ -0,0 +1,152 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "time"
+)
+
+// Profile represents profile-like operations on a container,
+// such as running perf or pprof. It is meant to be added to containers
+// such that the container type calls the Profile during its lifecycle.
+type Profile interface {
+ // OnCreate is called just after the container is created when the container
+ // has a valid ID (e.g. c.ID()).
+ OnCreate(c *Container) error
+
+ // OnStart is called just after the container is started when the container
+ // has a valid Pid (e.g. c.SandboxPid()).
+ OnStart(c *Container) error
+
+ // Restart restarts the Profile on request.
+ Restart(c *Container) error
+
+ // OnCleanUp is called during the container's cleanup method.
+ // Cleanups should just log errors if they have them.
+ OnCleanUp(c *Container) error
+}
+
+// Pprof is for running profiles with 'runsc debug'. Pprof workloads
+// should be run as root and ONLY against runsc sandboxes. The runtime
+// should have --profile set as an option in /etc/docker/daemon.json in
+// order for profiling to work with Pprof.
+type Pprof struct {
+ BasePath string // path to put profiles
+ BlockProfile bool
+ CPUProfile bool
+ GoRoutineProfile bool
+ HeapProfile bool
+ MutexProfile bool
+ Duration time.Duration // duration to run profiler e.g. '10s' or '1m'.
+ shouldRun bool
+ cmd *exec.Cmd
+ stdout io.ReadCloser
+ stderr io.ReadCloser
+}
+
+// MakePprofFromFlags makes a Pprof profile from flags.
+func MakePprofFromFlags(c *Container) *Pprof {
+ if !(*pprofBlock || *pprofCPU || *pprofGo || *pprofHeap || *pprofMutex) {
+ return nil
+ }
+ return &Pprof{
+ BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name),
+ BlockProfile: *pprofBlock,
+ CPUProfile: *pprofCPU,
+ GoRoutineProfile: *pprofGo,
+ HeapProfile: *pprofHeap,
+ MutexProfile: *pprofMutex,
+ Duration: *duration,
+ }
+}
+
+// OnCreate implements Profile.OnCreate.
+func (p *Pprof) OnCreate(c *Container) error {
+ return os.MkdirAll(p.BasePath, 0755)
+}
+
+// OnStart implements Profile.OnStart.
+func (p *Pprof) OnStart(c *Container) error {
+ path, err := RuntimePath()
+ if err != nil {
+ return fmt.Errorf("failed to get runtime path: %v", err)
+ }
+
+ // The root directory of this container's runtime.
+ root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime)
+ // Format is `runsc --root=rootdir debug --profile-*=file --duration=* containerID`.
+ args := []string{root, "debug"}
+ args = append(args, p.makeProfileArgs(c)...)
+ args = append(args, c.ID())
+
+ // Best effort wait until container is running.
+ for now := time.Now(); time.Since(now) < 5*time.Second; {
+ if status, err := c.Status(context.Background()); err != nil {
+ return fmt.Errorf("failed to get status with: %v", err)
+
+ } else if status.Running {
+ break
+ }
+ time.Sleep(500 * time.Millisecond)
+ }
+ p.cmd = exec.Command(path, args...)
+ if err := p.cmd.Start(); err != nil {
+ return fmt.Errorf("process failed: %v", err)
+ }
+ return nil
+}
+
+// Restart implements Profile.Restart.
+func (p *Pprof) Restart(c *Container) error {
+ p.OnCleanUp(c)
+ return p.OnStart(c)
+}
+
+// OnCleanUp implements Profile.OnCleanup
+func (p *Pprof) OnCleanUp(c *Container) error {
+ defer func() { p.cmd = nil }()
+ if p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState != nil && !p.cmd.ProcessState.Exited() {
+ return p.cmd.Process.Kill()
+ }
+ return nil
+}
+
+// makeProfileArgs turns Pprof fields into runsc debug flags.
+func (p *Pprof) makeProfileArgs(c *Container) []string {
+ var ret []string
+ if p.BlockProfile {
+ ret = append(ret, fmt.Sprintf("--profile-block=%s", filepath.Join(p.BasePath, "block.pprof")))
+ }
+ if p.CPUProfile {
+ ret = append(ret, fmt.Sprintf("--profile-cpu=%s", filepath.Join(p.BasePath, "cpu.pprof")))
+ }
+ if p.GoRoutineProfile {
+ ret = append(ret, fmt.Sprintf("--profile-goroutine=%s", filepath.Join(p.BasePath, "go.pprof")))
+ }
+ if p.HeapProfile {
+ ret = append(ret, fmt.Sprintf("--profile-heap=%s", filepath.Join(p.BasePath, "heap.pprof")))
+ }
+ if p.MutexProfile {
+ ret = append(ret, fmt.Sprintf("--profile-mutex=%s", filepath.Join(p.BasePath, "mutex.pprof")))
+ }
+ ret = append(ret, fmt.Sprintf("--duration=%s", p.Duration))
+ return ret
+}
diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go
new file mode 100644
index 000000000..b7b4d7618
--- /dev/null
+++ b/pkg/test/dockerutil/profile_test.go
@@ -0,0 +1,117 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+)
+
+type testCase struct {
+ name string
+ pprof Pprof
+ expectedFiles []string
+}
+
+func TestPprof(t *testing.T) {
+ // Basepath and expected file names for each type of profile.
+ basePath := "/tmp/test/profile"
+ block := "block.pprof"
+ cpu := "cpu.pprof"
+ goprofle := "go.pprof"
+ heap := "heap.pprof"
+ mutex := "mutex.pprof"
+
+ testCases := []testCase{
+ {
+ name: "Cpu",
+ pprof: Pprof{
+ BasePath: basePath,
+ CPUProfile: true,
+ Duration: 2 * time.Second,
+ },
+ expectedFiles: []string{cpu},
+ },
+ {
+ name: "All",
+ pprof: Pprof{
+ BasePath: basePath,
+ BlockProfile: true,
+ CPUProfile: true,
+ GoRoutineProfile: true,
+ HeapProfile: true,
+ MutexProfile: true,
+ Duration: 2 * time.Second,
+ },
+ expectedFiles: []string{block, cpu, goprofle, heap, mutex},
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ ctx := context.Background()
+ c := MakeContainer(ctx, t)
+ // Set basepath to include the container name so there are no conflicts.
+ tc.pprof.BasePath = filepath.Join(tc.pprof.BasePath, c.Name)
+ c.AddProfile(&tc.pprof)
+
+ func() {
+ defer c.CleanUp(ctx)
+ // Start a container.
+ if err := c.Spawn(ctx, RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "1000"); err != nil {
+ t.Fatalf("run failed with: %v", err)
+ }
+
+ if status, err := c.Status(context.Background()); !status.Running {
+ t.Fatalf("container is not yet running: %+v err: %v", status, err)
+ }
+
+ // End early if the expected files exist and have data.
+ for start := time.Now(); time.Since(start) < tc.pprof.Duration; time.Sleep(500 * time.Millisecond) {
+ if err := checkFiles(tc); err == nil {
+ break
+ }
+ }
+ }()
+
+ // Check all expected files exist and have data.
+ if err := checkFiles(tc); err != nil {
+ t.Fatalf(err.Error())
+ }
+ })
+ }
+}
+
+func checkFiles(tc testCase) error {
+ for _, file := range tc.expectedFiles {
+ stat, err := os.Stat(filepath.Join(tc.pprof.BasePath, file))
+ if err != nil {
+ return fmt.Errorf("stat failed with: %v", err)
+ } else if stat.Size() < 1 {
+ return fmt.Errorf("file not written to: %+v", stat)
+ }
+ }
+ return nil
+}
+
+func TestMain(m *testing.M) {
+ EnsureSupportedDockerVersion()
+ os.Exit(m.Run())
+}
diff --git a/pkg/test/testutil/BUILD b/pkg/test/testutil/BUILD
index 03b1b4677..2d8f56bc0 100644
--- a/pkg/test/testutil/BUILD
+++ b/pkg/test/testutil/BUILD
@@ -15,6 +15,6 @@ go_library(
"//runsc/boot",
"//runsc/specutils",
"@com_github_cenkalti_backoff//:go_default_library",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
],
)
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
index ee8c78014..64c292698 100644
--- a/pkg/test/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -251,7 +251,10 @@ func RandomID(prefix string) string {
if _, err := rand.Read(b); err != nil {
panic("rand.Read failed: " + err.Error())
}
- return fmt.Sprintf("%s-%s", prefix, base32.StdEncoding.EncodeToString(b))
+ if prefix != "" {
+ prefix = prefix + "-"
+ }
+ return fmt.Sprintf("%s%s", prefix, base32.StdEncoding.EncodeToString(b))
}
// RandomContainerID generates a random container id for each test.
@@ -479,6 +482,21 @@ func IsStatic(filename string) (bool, error) {
return true, nil
}
+// TouchShardStatusFile indicates to Bazel that the test runner supports
+// sharding by creating or updating the last modified date of the file
+// specified by TEST_SHARD_STATUS_FILE.
+//
+// See https://docs.bazel.build/versions/master/test-encyclopedia.html#role-of-the-test-runner.
+func TouchShardStatusFile() error {
+ if statusFile := os.Getenv("TEST_SHARD_STATUS_FILE"); statusFile != "" {
+ cmd := exec.Command("touch", statusFile)
+ if b, err := cmd.CombinedOutput(); err != nil {
+ return fmt.Errorf("touch %q failed:\n output: %s\n error: %s", statusFile, string(b), err.Error())
+ }
+ }
+ return nil
+}
+
// TestIndicesForShard returns indices for this test shard based on the
// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars.
//
diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD
deleted file mode 100644
index 2dcba84ae..000000000
--- a/pkg/tmutex/BUILD
+++ /dev/null
@@ -1,17 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "tmutex",
- srcs = ["tmutex.go"],
- visibility = ["//:sandbox"],
-)
-
-go_test(
- name = "tmutex_test",
- size = "medium",
- srcs = ["tmutex_test.go"],
- library = ":tmutex",
- deps = ["//pkg/sync"],
-)
diff --git a/pkg/tmutex/tmutex.go b/pkg/tmutex/tmutex.go
deleted file mode 100644
index c4685020d..000000000
--- a/pkg/tmutex/tmutex.go
+++ /dev/null
@@ -1,81 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package tmutex provides the implementation of a mutex that implements an
-// efficient TryLock function in addition to Lock and Unlock.
-package tmutex
-
-import (
- "sync/atomic"
-)
-
-// Mutex is a mutual exclusion primitive that implements TryLock in addition
-// to Lock and Unlock.
-type Mutex struct {
- v int32
- ch chan struct{}
-}
-
-// Init initializes the mutex.
-func (m *Mutex) Init() {
- m.v = 1
- m.ch = make(chan struct{}, 1)
-}
-
-// Lock acquires the mutex. If it is currently held by another goroutine, Lock
-// will wait until it has a chance to acquire it.
-func (m *Mutex) Lock() {
- // Uncontended case.
- if atomic.AddInt32(&m.v, -1) == 0 {
- return
- }
-
- for {
- // Try to acquire the mutex again, at the same time making sure
- // that m.v is negative, which indicates to the owner of the
- // lock that it is contended, which will force it to try to wake
- // someone up when it releases the mutex.
- if v := atomic.LoadInt32(&m.v); v >= 0 && atomic.SwapInt32(&m.v, -1) == 1 {
- return
- }
-
- // Wait for the mutex to be released before trying again.
- <-m.ch
- }
-}
-
-// TryLock attempts to acquire the mutex without blocking. If the mutex is
-// currently held by another goroutine, it fails to acquire it and returns
-// false.
-func (m *Mutex) TryLock() bool {
- v := atomic.LoadInt32(&m.v)
- if v <= 0 {
- return false
- }
- return atomic.CompareAndSwapInt32(&m.v, 1, 0)
-}
-
-// Unlock releases the mutex.
-func (m *Mutex) Unlock() {
- if atomic.SwapInt32(&m.v, 1) == 0 {
- // There were no pending waiters.
- return
- }
-
- // Wake some waiter up.
- select {
- case m.ch <- struct{}{}:
- default:
- }
-}
diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go
deleted file mode 100644
index 05540696a..000000000
--- a/pkg/tmutex/tmutex_test.go
+++ /dev/null
@@ -1,258 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tmutex
-
-import (
- "fmt"
- "runtime"
- "sync/atomic"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/sync"
-)
-
-func TestBasicLock(t *testing.T) {
- var m Mutex
- m.Init()
-
- m.Lock()
-
- // Try blocking lock the mutex from a different goroutine. This must
- // not block because the mutex is held.
- ch := make(chan struct{}, 1)
- go func() {
- m.Lock()
- ch <- struct{}{}
- m.Unlock()
- ch <- struct{}{}
- }()
-
- select {
- case <-ch:
- t.Fatalf("Lock succeeded on locked mutex")
- case <-time.After(100 * time.Millisecond):
- }
-
- // Unlock the mutex and make sure that the goroutine waiting on Lock()
- // unblocks and succeeds.
- m.Unlock()
-
- select {
- case <-ch:
- case <-time.After(100 * time.Millisecond):
- t.Fatalf("Lock failed to acquire unlocked mutex")
- }
-
- // Make sure we can lock and unlock again.
- m.Lock()
- m.Unlock()
-}
-
-func TestTryLock(t *testing.T) {
- var m Mutex
- m.Init()
-
- // Try to lock. It should succeed.
- if !m.TryLock() {
- t.Fatalf("TryLock failed on unlocked mutex")
- }
-
- // Try to lock again, it should now fail.
- if m.TryLock() {
- t.Fatalf("TryLock succeeded on locked mutex")
- }
-
- // Try blocking lock the mutex from a different goroutine. This must
- // not block because the mutex is held.
- ch := make(chan struct{}, 1)
- go func() {
- m.Lock()
- ch <- struct{}{}
- m.Unlock()
- }()
-
- select {
- case <-ch:
- t.Fatalf("Lock succeeded on locked mutex")
- case <-time.After(100 * time.Millisecond):
- }
-
- // Unlock the mutex and make sure that the goroutine waiting on Lock()
- // unblocks and succeeds.
- m.Unlock()
-
- select {
- case <-ch:
- case <-time.After(100 * time.Millisecond):
- t.Fatalf("Lock failed to acquire unlocked mutex")
- }
-}
-
-func TestMutualExclusion(t *testing.T) {
- var m Mutex
- m.Init()
-
- // Test mutual exclusion by running "gr" goroutines concurrently, and
- // have each one increment a counter "iters" times within the critical
- // section established by the mutex.
- //
- // If at the end the counter is not gr * iters, then we know that
- // goroutines ran concurrently within the critical section.
- //
- // If one of the goroutines doesn't complete, it's likely a bug that
- // causes to it to wait forever.
- const gr = 1000
- const iters = 100000
- v := 0
- var wg sync.WaitGroup
- for i := 0; i < gr; i++ {
- wg.Add(1)
- go func() {
- for j := 0; j < iters; j++ {
- m.Lock()
- v++
- m.Unlock()
- }
- wg.Done()
- }()
- }
-
- wg.Wait()
-
- if v != gr*iters {
- t.Fatalf("Bad count: got %v, want %v", v, gr*iters)
- }
-}
-
-func TestMutualExclusionWithTryLock(t *testing.T) {
- var m Mutex
- m.Init()
-
- // Similar to the previous, with the addition of some goroutines that
- // only increment the count if TryLock succeeds.
- const gr = 1000
- const iters = 100000
- total := int64(gr * iters)
- var tryTotal int64
- v := int64(0)
- var wg sync.WaitGroup
- for i := 0; i < gr; i++ {
- wg.Add(2)
- go func() {
- for j := 0; j < iters; j++ {
- m.Lock()
- v++
- m.Unlock()
- }
- wg.Done()
- }()
- go func() {
- local := int64(0)
- for j := 0; j < iters; j++ {
- if m.TryLock() {
- v++
- m.Unlock()
- local++
- }
- }
- atomic.AddInt64(&tryTotal, local)
- wg.Done()
- }()
- }
-
- wg.Wait()
-
- t.Logf("tryTotal = %d", tryTotal)
- total += tryTotal
-
- if v != total {
- t.Fatalf("Bad count: got %v, want %v", v, total)
- }
-}
-
-// BenchmarkTmutex is equivalent to TestMutualExclusion, with the following
-// differences:
-//
-// - The number of goroutines is variable, with the maximum value depending on
-// GOMAXPROCS.
-//
-// - The number of iterations per benchmark is controlled by the benchmarking
-// framework.
-//
-// - Care is taken to ensure that all goroutines participating in the benchmark
-// have been created before the benchmark begins.
-func BenchmarkTmutex(b *testing.B) {
- for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
- b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
- var m Mutex
- m.Init()
-
- var ready sync.WaitGroup
- begin := make(chan struct{})
- var end sync.WaitGroup
- for i := 0; i < n; i++ {
- ready.Add(1)
- end.Add(1)
- go func() {
- ready.Done()
- <-begin
- for j := 0; j < b.N; j++ {
- m.Lock()
- m.Unlock()
- }
- end.Done()
- }()
- }
-
- ready.Wait()
- b.ResetTimer()
- close(begin)
- end.Wait()
- })
- }
-}
-
-// BenchmarkSyncMutex is equivalent to BenchmarkTmutex, but uses sync.Mutex as
-// a comparison point.
-func BenchmarkSyncMutex(b *testing.B) {
- for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
- b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
- var m sync.Mutex
-
- var ready sync.WaitGroup
- begin := make(chan struct{})
- var end sync.WaitGroup
- for i := 0; i < n; i++ {
- ready.Add(1)
- end.Add(1)
- go func() {
- ready.Done()
- <-begin
- for j := 0; j < b.N; j++ {
- m.Lock()
- m.Unlock()
- }
- end.Done()
- }()
- }
-
- ready.Wait()
- b.ResetTimer()
- close(begin)
- end.Wait()
- })
- }
-}
diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go
index 707eb085b..67a950444 100644
--- a/pkg/waiter/waiter.go
+++ b/pkg/waiter/waiter.go
@@ -128,13 +128,6 @@ type EntryCallback interface {
//
// +stateify savable
type Entry struct {
- // Context stores any state the waiter may wish to store in the entry
- // itself, which may be used at wake up time.
- //
- // Note that use of this field is optional and state may alternatively be
- // stored in the callback itself.
- Context interface{}
-
Callback EntryCallback
// The following fields are protected by the queue lock.
@@ -142,13 +135,14 @@ type Entry struct {
waiterEntry
}
-type channelCallback struct{}
+type channelCallback struct {
+ ch chan struct{}
+}
// Callback implements EntryCallback.Callback.
-func (*channelCallback) Callback(e *Entry) {
- ch := e.Context.(chan struct{})
+func (c *channelCallback) Callback(*Entry) {
select {
- case ch <- struct{}{}:
+ case c.ch <- struct{}{}:
default:
}
}
@@ -164,7 +158,7 @@ func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) {
c = make(chan struct{}, 1)
}
- return Entry{Context: c, Callback: &channelCallback{}}, c
+ return Entry{Callback: &channelCallback{ch: c}}, c
}
// Queue represents the wait queue where waiters can be added and