summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.travis.yml4
-rw-r--r--BUILD23
-rw-r--r--CONTRIBUTING.md3
-rw-r--r--Dockerfile2
-rw-r--r--WORKSPACE188
-rw-r--r--benchmarks/harness/machine_producers/gcloud_producer.py28
-rw-r--r--benchmarks/runner/__init__.py13
-rw-r--r--benchmarks/runner/commands.py15
-rw-r--r--benchmarks/workloads/ruby/Gemfile.lock22
-rw-r--r--benchmarks/workloads/ruby_template/Gemfile.lock6
-rw-r--r--go.mod33
-rw-r--r--go.sum33
-rw-r--r--kokoro/benchmark_tests.cfg8
-rw-r--r--kokoro/iptables_tests.cfg2
-rw-r--r--kokoro/packetimpact_tests.cfg9
-rw-r--r--pkg/abi/linux/elf.go3
-rw-r--r--pkg/abi/linux/file.go8
-rw-r--r--pkg/abi/linux/netfilter.go109
-rw-r--r--pkg/abi/linux/netfilter_test.go1
-rw-r--r--pkg/bits/bits_template.go8
-rw-r--r--pkg/bits/uint64_test.go18
-rw-r--r--pkg/buffer/BUILD43
-rw-r--r--pkg/buffer/buffer.go94
-rw-r--r--pkg/buffer/safemem.go129
-rw-r--r--pkg/buffer/safemem_test.go170
-rw-r--r--pkg/buffer/view.go390
-rw-r--r--pkg/buffer/view_test.go467
-rw-r--r--pkg/buffer/view_unsafe.go25
-rw-r--r--pkg/context/context.go4
-rw-r--r--pkg/cpuid/cpuid_parse_x86_test.go2
-rw-r--r--pkg/cpuid/cpuid_x86.go9
-rw-r--r--pkg/cpuid/cpuid_x86_test.go2
-rw-r--r--pkg/eventchannel/event_test.go4
-rw-r--r--pkg/gate/gate_test.go3
-rw-r--r--pkg/ilist/list.go54
-rw-r--r--pkg/log/glog.go6
-rw-r--r--pkg/log/json.go2
-rw-r--r--pkg/log/json_k8s.go4
-rw-r--r--pkg/log/log.go2
-rw-r--r--pkg/log/log_test.go6
-rw-r--r--pkg/p9/client.go45
-rw-r--r--pkg/p9/client_test.go7
-rw-r--r--pkg/p9/file.go8
-rw-r--r--pkg/p9/handlers.go41
-rw-r--r--pkg/p9/messages.go14
-rw-r--r--pkg/p9/messages_test.go2
-rw-r--r--pkg/p9/transport_flipcall.go2
-rw-r--r--pkg/rand/rand_linux.go17
-rw-r--r--pkg/safecopy/memcpy_amd64.s111
-rw-r--r--pkg/safemem/seq_test.go21
-rw-r--r--pkg/safemem/seq_unsafe.go3
-rw-r--r--pkg/seccomp/seccomp_test.go2
-rw-r--r--pkg/segment/test/segment_test.go2
-rw-r--r--pkg/sentry/arch/arch.go3
-rw-r--r--pkg/sentry/arch/arch_aarch64.go39
-rw-r--r--pkg/sentry/arch/arch_arm64.go30
-rw-r--r--pkg/sentry/arch/arch_state_x86.go2
-rw-r--r--pkg/sentry/arch/arch_x86.go2
-rw-r--r--pkg/sentry/arch/arch_x86_impl.go2
-rw-r--r--pkg/sentry/arch/signal_arm64.go21
-rw-r--r--pkg/sentry/arch/signal_stack.go2
-rw-r--r--pkg/sentry/arch/stack.go3
-rw-r--r--pkg/sentry/arch/syscalls_amd64.go7
-rw-r--r--pkg/sentry/arch/syscalls_arm64.go23
-rw-r--r--pkg/sentry/contexttest/contexttest.go4
-rw-r--r--pkg/sentry/control/pprof.go34
-rw-r--r--pkg/sentry/control/proc.go6
-rw-r--r--pkg/sentry/fs/copy_up.go28
-rw-r--r--pkg/sentry/fs/dev/BUILD1
-rw-r--r--pkg/sentry/fs/dev/dev.go7
-rw-r--r--pkg/sentry/fs/dev/net_tun.go7
-rw-r--r--pkg/sentry/fs/dirent.go139
-rw-r--r--pkg/sentry/fs/dirent_cache.go2
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_test.go4
-rw-r--r--pkg/sentry/fs/file_overlay_test.go84
-rw-r--r--pkg/sentry/fs/fsutil/inode.go4
-rw-r--r--pkg/sentry/fs/gofer/file_state.go1
-rw-r--r--pkg/sentry/fs/gofer/handles.go1
-rw-r--r--pkg/sentry/fs/gofer/inode.go5
-rw-r--r--pkg/sentry/fs/gofer/inode_state.go1
-rw-r--r--pkg/sentry/fs/gofer/session_state.go1
-rw-r--r--pkg/sentry/fs/gofer/util.go16
-rw-r--r--pkg/sentry/fs/host/BUILD5
-rw-r--r--pkg/sentry/fs/host/control.go4
-rw-r--r--pkg/sentry/fs/host/descriptor.go37
-rw-r--r--pkg/sentry/fs/host/descriptor_state.go2
-rw-r--r--pkg/sentry/fs/host/descriptor_test.go4
-rw-r--r--pkg/sentry/fs/host/file.go14
-rw-r--r--pkg/sentry/fs/host/fs.go339
-rw-r--r--pkg/sentry/fs/host/fs_test.go380
-rw-r--r--pkg/sentry/fs/host/host.go59
-rw-r--r--pkg/sentry/fs/host/inode.go144
-rw-r--r--pkg/sentry/fs/host/inode_state.go32
-rw-r--r--pkg/sentry/fs/host/inode_test.go69
-rw-r--r--pkg/sentry/fs/host/ioctl_unsafe.go4
-rw-r--r--pkg/sentry/fs/host/socket_test.go6
-rw-r--r--pkg/sentry/fs/host/tty.go5
-rw-r--r--pkg/sentry/fs/host/util.go88
-rw-r--r--pkg/sentry/fs/host/util_unsafe.go41
-rw-r--r--pkg/sentry/fs/host/wait_test.go3
-rw-r--r--pkg/sentry/fs/inode.go3
-rw-r--r--pkg/sentry/fs/inode_overlay.go3
-rw-r--r--pkg/sentry/fs/inotify.go5
-rw-r--r--pkg/sentry/fs/mounts.go13
-rw-r--r--pkg/sentry/fs/proc/mounts.go3
-rw-r--r--pkg/sentry/fs/proc/net.go53
-rw-r--r--pkg/sentry/fs/proc/proc.go2
-rw-r--r--pkg/sentry/fs/proc/sys_net.go4
-rw-r--r--pkg/sentry/fs/proc/task.go158
-rw-r--r--pkg/sentry/fs/tmpfs/fs.go3
-rw-r--r--pkg/sentry/fsbridge/vfs.go30
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs.go5
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD1
-rw-r--r--pkg/sentry/fsimpl/ext/ext.go12
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go26
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go2
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD13
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go7
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go87
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go175
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer_test.go64
-rw-r--r--pkg/sentry/fsimpl/gofer/p9file.go14
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go21
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go12
-rw-r--r--pkg/sentry/fsimpl/gofer/symlink.go2
-rw-r--r--pkg/sentry/fsimpl/gofer/time.go39
-rw-r--r--pkg/sentry/fsimpl/host/BUILD34
-rw-r--r--pkg/sentry/fsimpl/host/host.go667
-rw-r--r--pkg/sentry/fsimpl/host/ioctl_unsafe.go56
-rw-r--r--pkg/sentry/fsimpl/host/tty.go379
-rw-r--r--pkg/sentry/fsimpl/host/util.go66
-rw-r--r--pkg/sentry/fsimpl/host/util_unsafe.go (renamed from pkg/sentry/kernel/pipe/buffer_test.go)26
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD1
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go9
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go23
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go86
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go92
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go38
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go10
-rw-r--r--pkg/sentry/fsimpl/kernfs/symlink.go12
-rw-r--r--pkg/sentry/fsimpl/pipefs/BUILD20
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go148
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD7
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go11
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go25
-rw-r--r--pkg/sentry/fsimpl/proc/task.go95
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go302
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go236
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go (renamed from pkg/sentry/fsimpl/proc/tasks_net.go)107
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go63
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go95
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go122
-rw-r--r--pkg/sentry/fsimpl/sockfs/BUILD17
-rw-r--r--pkg/sentry/fsimpl/sockfs/sockfs.go102
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go17
-rw-r--r--pkg/sentry/fsimpl/testutil/BUILD2
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go27
-rw-r--r--pkg/sentry/fsimpl/testutil/testutil.go3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/device_file.go10
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go9
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go115
-rw-r--r--pkg/sentry/fsimpl/tmpfs/named_pipe.go25
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go17
-rw-r--r--pkg/sentry/fsimpl/tmpfs/socket_file.go (renamed from pkg/tcpip/packet_buffer_state.go)27
-rw-r--r--pkg/sentry/fsimpl/tmpfs/stat_test.go12
-rw-r--r--pkg/sentry/fsimpl/tmpfs/symlink.go3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go234
-rw-r--r--pkg/sentry/inet/namespace.go5
-rw-r--r--pkg/sentry/kernel/BUILD3
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go2
-rw-r--r--pkg/sentry/kernel/epoll/epoll_state.go13
-rw-r--r--pkg/sentry/kernel/fd_table.go79
-rw-r--r--pkg/sentry/kernel/kernel.go136
-rw-r--r--pkg/sentry/kernel/pipe/BUILD18
-rw-r--r--pkg/sentry/kernel/pipe/buffer.go115
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go127
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go25
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go162
-rw-r--r--pkg/sentry/kernel/ptrace.go1
-rw-r--r--pkg/sentry/kernel/rseq.go2
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go1
-rw-r--r--pkg/sentry/kernel/sessions.go2
-rw-r--r--pkg/sentry/kernel/shm/shm.go2
-rw-r--r--pkg/sentry/kernel/syscalls.go43
-rw-r--r--pkg/sentry/kernel/syscalls_state.go36
-rw-r--r--pkg/sentry/kernel/task.go37
-rw-r--r--pkg/sentry/kernel/task_clone.go3
-rw-r--r--pkg/sentry/kernel/task_context.go3
-rw-r--r--pkg/sentry/kernel/task_identity.go2
-rw-r--r--pkg/sentry/kernel/task_run.go43
-rw-r--r--pkg/sentry/kernel/task_signals.go4
-rw-r--r--pkg/sentry/kernel/task_syscall.go26
-rw-r--r--pkg/sentry/kernel/thread_group.go7
-rw-r--r--pkg/sentry/kernel/time/time.go10
-rw-r--r--pkg/sentry/mm/address_space.go6
-rw-r--r--pkg/sentry/mm/aio_context.go101
-rw-r--r--pkg/sentry/mm/aio_context_state.go2
-rw-r--r--pkg/sentry/mm/lifecycle.go2
-rw-r--r--pkg/sentry/platform/kvm/bluepill.go6
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go6
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.s8
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go25
-rw-r--r--pkg/sentry/platform/kvm/kvm.go32
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64.go32
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64.go36
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go87
-rw-r--r--pkg/sentry/platform/ptrace/BUILD1
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_amd64.go14
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go62
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go12
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go80
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_arm64.go11
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go65
-rw-r--r--pkg/sentry/platform/ring0/BUILD2
-rw-r--r--pkg/sentry/platform/ring0/aarch64.go27
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s18
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.go7
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64_unsafe.go108
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD2
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_x86.go2
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids.go (renamed from pkg/sentry/platform/ring0/pagetables/pcids_x86.go)2
-rw-r--r--pkg/sentry/platform/ring0/x86.go2
-rw-r--r--pkg/sentry/sighandling/sighandling.go5
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/netfilter/BUILD2
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go14
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go208
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go128
-rw-r--r--pkg/sentry/socket/netfilter/targets.go11
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go11
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go13
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go239
-rw-r--r--pkg/sentry/socket/netstack/provider.go14
-rw-r--r--pkg/sentry/socket/netstack/stack.go76
-rw-r--r--pkg/sentry/socket/socket.go89
-rw-r--r--pkg/sentry/socket/unix/BUILD4
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go48
-rw-r--r--pkg/sentry/socket/unix/unix.go89
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go348
-rw-r--r--pkg/sentry/strace/strace.go5
-rw-r--r--pkg/sentry/syscalls/linux/BUILD2
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go36
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_pipe.go14
-rw-r--r--pkg/sentry/syscalls/linux/sys_prctl.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go29
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat_amd64.go45
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat_arm64.go45
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD13
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go17
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/getdents.go24
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go50
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/pipe.go63
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/poll.go14
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/read_write.go16
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go17
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go1139
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat.go118
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat_amd64.go46
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat_arm64.go46
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sys_timerfd.go123
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/xattr.go13
-rw-r--r--pkg/sentry/vfs/BUILD8
-rw-r--r--pkg/sentry/vfs/anonfs.go51
-rw-r--r--pkg/sentry/vfs/epoll.go2
-rw-r--r--pkg/sentry/vfs/file_description.go65
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go13
-rw-r--r--pkg/sentry/vfs/filesystem.go51
-rw-r--r--pkg/sentry/vfs/filesystem_type.go3
-rw-r--r--pkg/sentry/vfs/memxattr/BUILD15
-rw-r--r--pkg/sentry/vfs/memxattr/xattr.go102
-rw-r--r--pkg/sentry/vfs/mount.go254
-rw-r--r--pkg/sentry/vfs/mount_test.go2
-rw-r--r--pkg/sentry/vfs/options.go35
-rw-r--r--pkg/sentry/vfs/pathname.go43
-rw-r--r--pkg/sentry/vfs/permissions.go54
-rw-r--r--pkg/sentry/vfs/resolving_path.go16
-rw-r--r--pkg/sentry/vfs/timerfd.go142
-rw-r--r--pkg/sentry/vfs/vfs.go85
-rw-r--r--pkg/sentry/watchdog/watchdog.go6
-rw-r--r--pkg/state/state.go5
-rw-r--r--pkg/sync/BUILD10
-rw-r--r--pkg/sync/aliases.go5
-rw-r--r--pkg/sync/mutex_test.go (renamed from pkg/sync/tmutex_test.go)0
-rw-r--r--pkg/sync/mutex_unsafe.go (renamed from pkg/sync/tmutex_unsafe.go)0
-rw-r--r--pkg/sync/rwmutex_test.go (renamed from pkg/sync/downgradable_rwmutex_test.go)0
-rw-r--r--pkg/sync/rwmutex_unsafe.go (renamed from pkg/sync/downgradable_rwmutex_unsafe.go)0
-rw-r--r--pkg/sync/sync.go (renamed from pkg/sync/syncutil.go)0
-rw-r--r--pkg/syserror/syserror.go1
-rw-r--r--pkg/tcpip/BUILD2
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD1
-rw-r--r--pkg/tcpip/buffer/view.go73
-rw-r--r--pkg/tcpip/buffer/view_test.go137
-rw-r--r--pkg/tcpip/checker/checker.go198
-rw-r--r--pkg/tcpip/header/BUILD4
-rw-r--r--pkg/tcpip/header/eth_test.go2
-rw-r--r--pkg/tcpip/header/ipv4.go5
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go544
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers_test.go992
-rw-r--r--pkg/tcpip/header/ndp_options.go196
-rw-r--r--pkg/tcpip/header/ndp_test.go120
-rw-r--r--pkg/tcpip/header/ndpoptionidentifier_string.go50
-rw-r--r--pkg/tcpip/header/tcp.go42
-rw-r--r--pkg/tcpip/header/udp.go5
-rw-r--r--pkg/tcpip/iptables/BUILD18
-rw-r--r--pkg/tcpip/iptables/targets.go65
-rw-r--r--pkg/tcpip/link/channel/channel.go58
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go178
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go86
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_unsafe.go9
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go3
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go4
-rw-r--r--pkg/tcpip/link/loopback/loopback.go8
-rw-r--r--pkg/tcpip/link/muxed/injectable.go6
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go4
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go6
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go26
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go170
-rw-r--r--pkg/tcpip/link/tun/device.go12
-rw-r--r--pkg/tcpip/link/waitable/waitable.go8
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go22
-rw-r--r--pkg/tcpip/network/arp/arp.go12
-rw-r--r--pkg/tcpip/network/arp/arp_test.go5
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go8
-rw-r--r--pkg/tcpip/network/hash/hash.go4
-rw-r--r--pkg/tcpip/network/ip_test.go24
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go9
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go67
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go18
-rw-r--r--pkg/tcpip/network/ipv6/BUILD4
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go253
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go17
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go261
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go1051
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go431
-rw-r--r--pkg/tcpip/stack/BUILD24
-rw-r--r--pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go39
-rw-r--r--pkg/tcpip/stack/forwarder.go131
-rw-r--r--pkg/tcpip/stack/forwarder_test.go635
-rw-r--r--pkg/tcpip/stack/iptables.go (renamed from pkg/tcpip/iptables/iptables.go)83
-rw-r--r--pkg/tcpip/stack/iptables_targets.go144
-rw-r--r--pkg/tcpip/stack/iptables_types.go (renamed from pkg/tcpip/iptables/types.go)18
-rw-r--r--pkg/tcpip/stack/ndp.go449
-rw-r--r--pkg/tcpip/stack/ndp_test.go814
-rw-r--r--pkg/tcpip/stack/nic.go253
-rw-r--r--pkg/tcpip/stack/nic_test.go3
-rw-r--r--pkg/tcpip/stack/packet_buffer.go (renamed from pkg/tcpip/packet_buffer.go)29
-rw-r--r--pkg/tcpip/stack/rand.go40
-rw-r--r--pkg/tcpip/stack/registration.go35
-rw-r--r--pkg/tcpip/stack/route.go23
-rw-r--r--pkg/tcpip/stack/stack.go72
-rw-r--r--pkg/tcpip/stack/stack_test.go580
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go438
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go243
-rw-r--r--pkg/tcpip/stack/transport_test.go29
-rw-r--r--pkg/tcpip/tcpip.go164
-rw-r--r--pkg/tcpip/tcpip_test.go2
-rw-r--r--pkg/tcpip/transport/icmp/BUILD1
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go89
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go2
-rw-r--r--pkg/tcpip/transport/packet/BUILD1
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go32
-rw-r--r--pkg/tcpip/transport/raw/BUILD1
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go36
-rw-r--r--pkg/tcpip/transport/tcp/BUILD8
-rw-r--r--pkg/tcpip/transport/tcp/accept.go273
-rw-r--r--pkg/tcpip/transport/tcp/connect.go299
-rw-r--r--pkg/tcpip/transport/tcp/connect_unsafe.go30
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go11
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go9
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go1039
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go8
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go6
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go150
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go10
-rw-r--r--pkg/tcpip/transport/tcp/segment.go15
-rw-r--r--pkg/tcpip/transport/tcp/segment_heap.go17
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go8
-rw-r--r--pkg/tcpip/transport/tcp/snd.go26
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go83
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go41
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go293
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go30
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go35
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go278
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go3
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/udp/protocol.go6
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go81
-rw-r--r--pkg/usermem/usermem.go3
-rw-r--r--pkg/usermem/usermem_x86.go2
-rw-r--r--runsc/boot/BUILD12
-rw-r--r--runsc/boot/compat.go2
-rw-r--r--runsc/boot/config.go9
-rw-r--r--runsc/boot/controller.go13
-rw-r--r--runsc/boot/fds.go38
-rw-r--r--runsc/boot/filter/config.go11
-rw-r--r--runsc/boot/fs.go24
-rw-r--r--runsc/boot/loader.go76
-rw-r--r--runsc/boot/loader_amd64.go5
-rw-r--r--runsc/boot/loader_arm64.go5
-rw-r--r--runsc/boot/loader_test.go49
-rw-r--r--runsc/boot/user.go64
-rw-r--r--runsc/boot/vfs.go310
-rw-r--r--runsc/cmd/boot.go75
-rw-r--r--runsc/cmd/chroot.go2
-rw-r--r--runsc/cmd/debug.go64
-rw-r--r--runsc/cmd/gofer.go7
-rw-r--r--runsc/container/console_test.go4
-rw-r--r--runsc/container/container.go43
-rw-r--r--runsc/container/container_test.go33
-rw-r--r--runsc/container/test_app/test_app.go2
-rw-r--r--runsc/fsgofer/fsgofer.go4
-rw-r--r--runsc/main.go46
-rw-r--r--runsc/sandbox/sandbox.go185
-rw-r--r--runsc/specutils/namespace.go3
-rw-r--r--runsc/specutils/specutils.go11
-rw-r--r--runsc/testutil/testutil.go23
-rwxr-xr-x[-rw-r--r--]scripts/benchmark.sh30
-rwxr-xr-xscripts/build.sh2
-rwxr-xr-xscripts/common.sh22
-rwxr-xr-xscripts/dev.sh1
-rwxr-xr-xscripts/iptables_tests.sh13
-rwxr-xr-xscripts/packetimpact_tests.sh20
-rwxr-xr-xscripts/release.sh13
-rw-r--r--test/e2e/exec_test.go22
-rw-r--r--test/e2e/integration_test.go6
-rw-r--r--test/iptables/filter_input.go86
-rw-r--r--test/iptables/filter_output.go209
-rw-r--r--test/iptables/iptables_test.go134
-rw-r--r--test/iptables/iptables_util.go38
-rw-r--r--test/iptables/nat.go253
-rw-r--r--test/packetdrill/BUILD44
-rw-r--r--test/packetdrill/Dockerfile8
-rw-r--r--test/packetdrill/accept_ack_drop.pkt27
-rw-r--r--test/packetdrill/defs.bzl10
-rw-r--r--test/packetdrill/fin_wait2_timeout.pkt2
-rw-r--r--test/packetdrill/linux/tcp_user_timeout.pkt39
-rw-r--r--test/packetdrill/listen_close_before_handshake_complete.pkt31
-rw-r--r--test/packetdrill/netstack/tcp_user_timeout.pkt38
-rw-r--r--test/packetdrill/no_rst_to_rst.pkt36
-rwxr-xr-xtest/packetdrill/packetdrill_test.sh4
-rw-r--r--test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt9
-rw-r--r--test/packetdrill/sanity_test.pkt7
-rw-r--r--test/packetdrill/tcp_defer_accept.pkt48
-rw-r--r--test/packetdrill/tcp_defer_accept_timeout.pkt48
-rw-r--r--test/packetimpact/README.md531
-rw-r--r--test/packetimpact/dut/BUILD18
-rw-r--r--test/packetimpact/dut/posix_server.cc260
-rw-r--r--test/packetimpact/proto/BUILD12
-rw-r--r--test/packetimpact/proto/posix_server.proto193
-rw-r--r--test/packetimpact/testbench/BUILD41
-rw-r--r--test/packetimpact/testbench/connections.go670
-rw-r--r--test/packetimpact/testbench/dut.go473
-rw-r--r--test/packetimpact/testbench/dut_client.go28
-rw-r--r--test/packetimpact/testbench/layers.go710
-rw-r--r--test/packetimpact/testbench/layers_test.go156
-rw-r--r--test/packetimpact/testbench/rawsockets.go183
-rw-r--r--test/packetimpact/tests/BUILD94
-rw-r--r--test/packetimpact/tests/Dockerfile17
-rw-r--r--test/packetimpact/tests/defs.bzl118
-rw-r--r--test/packetimpact/tests/fin_wait2_timeout_test.go70
-rw-r--r--test/packetimpact/tests/tcp_close_wait_ack_test.go102
-rw-r--r--test/packetimpact/tests/tcp_noaccept_close_rst_test.go37
-rw-r--r--test/packetimpact/tests/tcp_outside_the_window_test.go88
-rw-r--r--test/packetimpact/tests/tcp_should_piggyback_test.go59
-rw-r--r--test/packetimpact/tests/tcp_window_shrink_test.go68
-rwxr-xr-xtest/packetimpact/tests/test_runner.sh268
-rw-r--r--test/packetimpact/tests/udp_recv_multicast_test.go37
-rw-r--r--test/perf/BUILD3
-rw-r--r--test/perf/linux/futex_benchmark.cc144
-rw-r--r--test/perf/linux/getdents_benchmark.cc2
-rw-r--r--test/perf/linux/signal_benchmark.cc10
-rw-r--r--test/root/BUILD3
-rw-r--r--test/root/cgroup_test.go4
-rw-r--r--test/root/runsc_test.go151
-rw-r--r--test/runner/gtest/gtest.go7
-rw-r--r--test/runner/runner.go27
-rw-r--r--test/runtimes/blacklist_test.go2
-rw-r--r--test/runtimes/runner.go2
-rw-r--r--test/syscalls/BUILD13
-rw-r--r--test/syscalls/linux/BUILD57
-rw-r--r--test/syscalls/linux/aio.cc12
-rw-r--r--test/syscalls/linux/bad.cc2
-rw-r--r--test/syscalls/linux/dev.cc7
-rw-r--r--test/syscalls/linux/epoll.cc4
-rw-r--r--test/syscalls/linux/exec.cc10
-rw-r--r--test/syscalls/linux/exec_binary.cc164
-rw-r--r--test/syscalls/linux/file_base.h19
-rw-r--r--test/syscalls/linux/fork.cc17
-rw-r--r--test/syscalls/linux/getrandom.cc2
-rw-r--r--test/syscalls/linux/ip_socket_test_util.cc22
-rw-r--r--test/syscalls/linux/ip_socket_test_util.h6
-rw-r--r--test/syscalls/linux/itimer.cc2
-rw-r--r--test/syscalls/linux/lseek.cc2
-rw-r--r--test/syscalls/linux/memfd.cc1
-rw-r--r--test/syscalls/linux/mkdir.cc22
-rw-r--r--test/syscalls/linux/mlock.cc4
-rw-r--r--test/syscalls/linux/mmap.cc10
-rw-r--r--test/syscalls/linux/network_namespace.cc87
-rw-r--r--test/syscalls/linux/open.cc22
-rw-r--r--test/syscalls/linux/open_create.cc2
-rw-r--r--test/syscalls/linux/packet_socket.cc138
-rw-r--r--test/syscalls/linux/pipe.cc2
-rw-r--r--test/syscalls/linux/poll.cc2
-rw-r--r--test/syscalls/linux/pread64.cc16
-rw-r--r--test/syscalls/linux/proc.cc54
-rw-r--r--test/syscalls/linux/proc_net.cc80
-rw-r--r--test/syscalls/linux/proc_net_unix.cc6
-rw-r--r--test/syscalls/linux/proc_pid_oomscore.cc72
-rw-r--r--test/syscalls/linux/proc_pid_smaps.cc4
-rw-r--r--test/syscalls/linux/ptrace.cc33
-rw-r--r--test/syscalls/linux/pty.cc2
-rw-r--r--test/syscalls/linux/pwrite64.cc21
-rw-r--r--test/syscalls/linux/seccomp.cc42
-rw-r--r--test/syscalls/linux/sendfile.cc51
-rw-r--r--test/syscalls/linux/sendfile_socket.cc107
-rw-r--r--test/syscalls/linux/socket_generic.cc96
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc219
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc49
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h6
-rw-r--r--test/syscalls/linux/socket_netlink_route.cc77
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.cc7
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.h4
-rw-r--r--test/syscalls/linux/socket_test_util.cc5
-rw-r--r--test/syscalls/linux/socket_unix.cc2
-rw-r--r--test/syscalls/linux/splice.cc1
-rw-r--r--test/syscalls/linux/stat.cc69
-rw-r--r--test/syscalls/linux/sticky.cc16
-rw-r--r--test/syscalls/linux/sysret.cc35
-rw-r--r--test/syscalls/linux/tcp_socket.cc29
-rw-r--r--test/syscalls/linux/time.cc2
-rw-r--r--test/syscalls/linux/tuntap.cc118
-rw-r--r--test/syscalls/linux/tuntap_hostinet.cc38
-rw-r--r--test/syscalls/linux/utimes.cc18
-rw-r--r--test/syscalls/linux/vsyscall.cc2
-rw-r--r--test/syscalls/linux/write.cc10
-rw-r--r--test/syscalls/linux/xattr.cc8
-rw-r--r--test/util/BUILD6
-rw-r--r--test/util/capability_util.cc4
-rw-r--r--test/util/temp_umask.h (renamed from test/syscalls/linux/temp_umask.h)6
-rw-r--r--tools/BUILD2
-rw-r--r--tools/bazeldefs/defs.bzl88
-rw-r--r--tools/bazeldefs/platforms.bzl7
-rw-r--r--tools/bigquery/BUILD10
-rw-r--r--tools/bigquery/bigquery.go121
-rw-r--r--tools/defs.bzl58
-rw-r--r--tools/go_generics/defs.bzl1
-rw-r--r--tools/go_marshal/analysis/analysis_unsafe.go4
-rw-r--r--tools/go_marshal/defs.bzl3
-rw-r--r--tools/go_marshal/gomarshal/BUILD3
-rw-r--r--tools/go_marshal/gomarshal/generator.go143
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go748
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go141
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go282
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_struct.go614
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go54
-rw-r--r--tools/go_marshal/gomarshal/util.go176
-rw-r--r--tools/go_marshal/marshal/marshal.go103
-rw-r--r--tools/go_marshal/primitive/BUILD18
-rw-r--r--tools/go_marshal/primitive/primitive.go175
-rw-r--r--tools/go_marshal/test/BUILD14
-rw-r--r--tools/go_marshal/test/benchmark_test.go42
-rw-r--r--tools/go_marshal/test/external/external.go8
-rw-r--r--tools/go_marshal/test/marshal_test.go515
-rw-r--r--tools/go_marshal/test/test.go69
-rwxr-xr-xtools/go_mod.sh29
-rw-r--r--tools/go_stateify/main.go2
-rwxr-xr-xtools/image_build.sh98
-rw-r--r--tools/images/BUILD9
-rw-r--r--tools/images/README.md42
-rwxr-xr-xtools/images/build.sh36
-rw-r--r--tools/images/defs.bzl136
-rwxr-xr-xtools/images/ubuntu1604/10_core.sh15
-rwxr-xr-xtools/images/ubuntu1604/20_bazel.sh12
-rwxr-xr-xtools/images/ubuntu1604/25_docker.sh33
-rwxr-xr-xtools/images/ubuntu1604/30_containerd.sh12
-rwxr-xr-xtools/images/ubuntu1604/40_kokoro.sh17
-rwxr-xr-xtools/images/zone.sh17
-rwxr-xr-xtools/installers/master.sh7
-rw-r--r--tools/nogo.js7
-rw-r--r--tools/nogo.json39
-rwxr-xr-xtools/tag_release.sh12
-rw-r--r--vdso/vdso.cc4
595 files changed, 30478 insertions, 8618 deletions
diff --git a/.travis.yml b/.travis.yml
index a2a260538..acbd3d61b 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -17,3 +17,7 @@ matrix:
script:
- uname -a
- make DOCKER_RUN_OPTIONS="" BAZEL_OPTIONS="build runsc:runsc" bazel && $RUNSC_PATH --alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do ls
+branches:
+ except:
+ # Skip copybara branches.
+ - /^test\/cl.*$/
diff --git a/BUILD b/BUILD
index 5fd929378..a709a9816 100644
--- a/BUILD
+++ b/BUILD
@@ -49,10 +49,31 @@ gazelle(name = "gazelle")
# live in the tools subdirectory (unless they are standard).
nogo(
name = "nogo",
- config = "//tools:nogo.js",
+ config = "//tools:nogo.json",
visibility = ["//visibility:public"],
deps = [
"//tools/checkunsafe",
+ "@org_golang_x_tools//go/analysis/passes/asmdecl:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/assign:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/atomic:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/atomicalign:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/bools:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/buildtag:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/cgocall:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/copylock:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/deepequalerrors:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/loopclosure:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/lostcancel:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/nilfunc:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/nilness:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/printf:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/shift:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/stdmethods:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/structtag:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/tests:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/unmarshal:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/unsafeptr:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/unusedresult:go_tool_library",
],
)
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 71650a4b8..ad8e710da 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -32,6 +32,9 @@ will need to be added to the appropriate `BUILD` files, and the `:gopath` target
will need to be re-run to generate appropriate symlinks in the `GOPATH`
directory tree.
+Dependencies can be added by using `go mod get`. In order to keep the
+`WORKSPACE` file in sync, run `tools/go_mod.sh` in place of `go mod`.
+
### Coding Guidelines
All Go code should conform to the [Go style guidelines][gostyle]. C++ code
diff --git a/Dockerfile b/Dockerfile
index 2bfdfec6c..0fac71710 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -2,7 +2,7 @@ FROM fedora:31
RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel
-RUN dnf install -y bazel2 git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static
+RUN dnf install -y bazel2 git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static libstdc++-static patch
RUN pip install pycparser
diff --git a/WORKSPACE b/WORKSPACE
index a15238a2e..c40e03ad2 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -4,10 +4,10 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
# Load go bazel rules and gazelle.
http_archive(
name = "io_bazel_rules_go",
- sha256 = "f99a9d76e972e0c8f935b2fe6d0d9d778f67c760c6d2400e23fc2e469016e2bd",
+ sha256 = "db2b2d35293f405430f553bc7a865a8749a8ef60c30287e90d2b278c32771afe",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.21.2/rules_go-v0.21.2.tar.gz",
- "https://github.com/bazelbuild/rules_go/releases/download/v0.21.2/rules_go-v0.21.2.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.22.3/rules_go-v0.22.3.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.22.3/rules_go-v0.22.3.tar.gz",
],
)
@@ -20,12 +20,12 @@ http_archive(
],
)
-load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_toolchains")
+load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies")
go_rules_dependencies()
go_register_toolchains(
- go_version = "1.13.7",
+ go_version = "1.14.2",
nogo = "@//:nogo",
)
@@ -43,8 +43,8 @@ gazelle_dependencies()
go_repository(
name = "org_golang_x_sys",
importpath = "golang.org/x/sys",
- sum = "h1:72l8qCJ1nGxMGH26QVBVIxKd/D34cfGt0OvrPtpemyY=",
- version = "v0.0.0-20191220220014-0732a990476f",
+ sum = "h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=",
+ version = "v0.0.0-20200302150141-5c8b2ff67527",
)
# Load C++ rules.
@@ -68,16 +68,19 @@ http_archive(
"https://github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz",
],
)
+
load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains")
+
rules_proto_dependencies()
+
rules_proto_toolchains()
# Load python dependencies.
git_repository(
name = "rules_python",
- commit = "94677401bc56ed5d756f50b441a6a5c7f735a6d4",
+ commit = "abc4869e02fe9b3866942e89f07b7341f830e805",
remote = "https://github.com/bazelbuild/rules_python.git",
- shallow_since = "1573842889 -0500",
+ shallow_since = "1583341286 -0500",
)
load("@rules_python//python:pip.bzl", "pip_import")
@@ -96,11 +99,11 @@ pip_install()
# See releases at https://releases.bazel.build/bazel-toolchains.html
http_archive(
name = "bazel_toolchains",
- sha256 = "a653c9d318e42b14c0ccd7ac50c4a2a276c0db1e39743ab88b5aa2f0bc9cf607",
- strip_prefix = "bazel-toolchains-2.0.2",
+ sha256 = "239a1a673861eabf988e9804f45da3b94da28d1aff05c373b013193c315d9d9e",
+ strip_prefix = "bazel-toolchains-3.0.1",
urls = [
- "https://github.com/bazelbuild/bazel-toolchains/releases/download/2.0.2/bazel-toolchains-2.0.2.tar.gz",
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/2.0.2.tar.gz",
+ "https://github.com/bazelbuild/bazel-toolchains/releases/download/3.0.1/bazel-toolchains-3.0.1.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/releases/download/3.0.1/bazel-toolchains-3.0.1.tar.gz",
],
)
@@ -146,9 +149,9 @@ load(
# This container is built from the Dockerfile in test/iptables/runner.
container_pull(
name = "iptables-test",
+ digest = "sha256:a137d692a2eb9fc7bf95c5f4a568da090e2c31098e93634421ed88f3a3f1db65",
registry = "gcr.io",
repository = "gvisor-presubmit/iptables-test",
- digest = "sha256:a137d692a2eb9fc7bf95c5f4a568da090e2c31098e93634421ed88f3a3f1db65",
)
load(
@@ -158,6 +161,20 @@ load(
_go_image_repos()
+# Load C++ grpc rules.
+http_archive(
+ name = "com_github_grpc_grpc",
+ sha256 = "2fcb7f1ab160d6fd3aaade64520be3e5446fc4c6fa7ba6581afdc4e26094bd81",
+ strip_prefix = "grpc-1.26.0",
+ urls = [
+ "https://github.com/grpc/grpc/archive/v1.26.0.tar.gz",
+ ],
+)
+load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
+grpc_deps()
+load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps")
+grpc_extra_deps()
+
# External repositories, in sorted order.
go_repository(
name = "com_github_cenkalti_backoff",
@@ -202,6 +219,20 @@ go_repository(
)
go_repository(
+ name = "com_github_imdario_mergo",
+ importpath = "github.com/imdario/mergo",
+ version = "v0.3.8",
+ sum = "h1:CGgOkSJeqMRmt0D9XLWExdT4m4F1vd3FV3VPt+0VxkQ=",
+)
+
+go_repository(
+ name = "com_github_kr_pretty",
+ importpath = "github.com/kr/pretty",
+ sum = "h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=",
+ version = "v0.2.0",
+)
+
+go_repository(
name = "com_github_kr_pty",
importpath = "github.com/kr/pty",
sum = "h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=",
@@ -209,6 +240,19 @@ go_repository(
)
go_repository(
+ name = "com_github_kr_text",
+ importpath = "github.com/kr/text",
+ sum = "h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=",
+ version = "v0.1.0",
+)
+
+go_repository(
+ name = "com_github_mohae_deepcopy",
+ importpath = "github.com/mohae/deepcopy",
+ commit = "c48cc78d482608239f6c4c92a4abd87eb8761c90",
+)
+
+go_repository(
name = "com_github_opencontainers_runtime-spec",
importpath = "github.com/opencontainers/runtime-spec",
sum = "h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI=",
@@ -237,30 +281,39 @@ go_repository(
)
go_repository(
- name = "org_golang_x_crypto",
- importpath = "golang.org/x/crypto",
- sum = "h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=",
- version = "v0.0.0-20190308221718-c2843e01d9a2",
+ name = "org_golang_google_grpc",
+ build_file_proto_mode = "disable",
+ importpath = "google.golang.org/grpc",
+ sum = "h1:zvIju4sqAGvwKspUQOhwnpcqSbzi7/H6QomNNjTL4sk=",
+ version = "v1.27.1",
)
go_repository(
- name = "org_golang_x_net",
- importpath = "golang.org/x/net",
- sum = "h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628=",
- version = "v0.0.0-20190311183353-d8887717615a",
+ name = "in_gopkg_check_v1",
+ importpath = "gopkg.in/check.v1",
+ sum = "h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=",
+ version = "v1.0.0-20190902080502-41f04d3bba15",
)
go_repository(
- name = "org_golang_x_text",
- importpath = "golang.org/x/text",
- sum = "h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=",
- version = "v0.3.0",
+ name = "org_golang_x_crypto",
+ importpath = "golang.org/x/crypto",
+ sum = "h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8=",
+ version = "v0.0.0-20191011191535-87dc89f01550",
)
go_repository(
- name = "org_golang_x_tools",
- commit = "36563e24a262",
- importpath = "golang.org/x/tools",
+ name = "org_golang_x_mod",
+ importpath = "golang.org/x/mod",
+ sum = "h1:p1YOIz9H/mGN8k1XkaV5VFAq9+zhN9Obefv439UwRhI=",
+ version = "v0.2.1-0.20200224194123-e5e73c1b9c72",
+)
+
+go_repository(
+ name = "org_golang_x_net",
+ importpath = "golang.org/x/net",
+ sum = "h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI=",
+ version = "v0.0.0-20190620200207-3b0461eec859",
)
go_repository(
@@ -271,15 +324,31 @@ go_repository(
)
go_repository(
+ name = "org_golang_x_text",
+ importpath = "golang.org/x/text",
+ sum = "h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=",
+ version = "v0.3.0",
+)
+
+go_repository(
name = "org_golang_x_time",
- commit = "c4c64cad1fd0a1a8dab2523e04e61d35308e131e",
importpath = "golang.org/x/time",
+ sum = "h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=",
+ version = "v0.0.0-20191024005414-555d28b269f0",
)
go_repository(
name = "org_golang_x_tools",
- commit = "aa82965741a9fecd12b026fbb3d3c6ed3231b8f8",
importpath = "golang.org/x/tools",
+ sum = "h1:aZzprAO9/8oim3qStq3wc1Xuxx4QmAGriC4VU4ojemQ=",
+ version = "v0.0.0-20191119224855-298f0cb1881e",
+)
+
+go_repository(
+ name = "org_golang_x_xerrors",
+ importpath = "golang.org/x/xerrors",
+ sum = "h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=",
+ version = "v0.0.0-20191204190536-9bdfabe68543",
)
go_repository(
@@ -317,6 +386,63 @@ go_repository(
version = "v1.0.0",
)
+go_repository(
+ name = "com_google_cloud_go_bigquery",
+ importpath = "cloud.google.com/go/bigquery",
+ sum = "h1:K2NyuHRuv15ku6eUpe0DQk5ZykPMnSOnvuVf6IHcjaE=",
+ version = "v1.5.0",
+)
+
+go_repository(
+ name = "org_golang_google_api",
+ importpath = "google.golang.org/api",
+ sum = "h1:jz2KixHX7EcCPiQrySzPdnYT7DbINAypCqKZ1Z7GM40=",
+ version = "v0.20.0",
+)
+
+go_repository(
+ name = "org_uber_go_atomic",
+ importpath = "go.uber.org/atomic",
+ version = "v1.6.0",
+ sum = "h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=",
+)
+
+go_repository(
+ name = "org_uber_go_multierr",
+ importpath = "go.uber.org/multierr",
+ version = "v1.5.0",
+ sum = "h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A=",
+)
+
+# BigQuery Dependencies for Benchmarks
+go_repository(
+ name = "com_google_cloud_go",
+ importpath = "cloud.google.com/go",
+ sum = "h1:eoz/lYxKSL4CNAiaUJ0ZfD1J3bfMYbU5B3rwM1C1EIU=",
+ version = "v0.55.0",
+)
+
+go_repository(
+ name = "com_github_googleapis_gax_go_v2",
+ importpath = "github.com/googleapis/gax-go/v2",
+ sum = "h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM=",
+ version = "v2.0.5",
+)
+
+go_repository(
+ name = "io_opencensus_go",
+ importpath = "go.opencensus.io",
+ sum = "h1:8sGtKOrtQqkN1bp2AtX+misvLIlOmsEsNd+9NIcPEm8=",
+ version = "v0.22.3",
+)
+
+go_repository(
+ name = "com_github_golang_groupcache",
+ importpath = "github.com/golang/groupcache",
+ sum = "h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=",
+ version = "v0.0.0-20200121045136-8c9f03a8e57e",
+)
+
# System Call test dependencies.
http_archive(
name = "com_google_absl",
diff --git a/benchmarks/harness/machine_producers/gcloud_producer.py b/benchmarks/harness/machine_producers/gcloud_producer.py
index 513d16e4f..44d72f575 100644
--- a/benchmarks/harness/machine_producers/gcloud_producer.py
+++ b/benchmarks/harness/machine_producers/gcloud_producer.py
@@ -53,6 +53,8 @@ class GCloudProducer(machine_producer.MachineProducer):
ssh_key_file: path to a valid ssh private key. See README on vaild ssh keys.
ssh_user: string of user name for ssh_key
ssh_password: string of password for ssh key
+ internal: if true, use internal IPs of instances. Used if bm-tools is
+ running on a GCP vm when a firewall is set for external IPs.
mock: a mock printer which will print mock data if required. Mock data is
recorded output from subprocess calls (returncode, stdout, args).
condition: mutex for this class around machine creation and deleteion.
@@ -66,6 +68,7 @@ class GCloudProducer(machine_producer.MachineProducer):
ssh_key_file: str,
ssh_user: str,
ssh_password: str,
+ internal: bool,
mock: gcloud_mock_recorder.MockPrinter = None):
self.image = image
self.zone = zone
@@ -74,6 +77,7 @@ class GCloudProducer(machine_producer.MachineProducer):
self.ssh_key_file = ssh_key_file
self.ssh_user = ssh_user
self.ssh_password = ssh_password
+ self.internal = internal
self.mock = mock
self.condition = threading.Condition()
@@ -129,15 +133,13 @@ class GCloudProducer(machine_producer.MachineProducer):
machines = []
for instance in instances:
name = instance["name"]
+ external = instance["networkInterfaces"][0]["accessConfigs"][0]["natIP"]
+ internal = instance["networkInterfaces"][0]["networkIP"]
kwargs = {
- "hostname":
- instance["networkInterfaces"][0]["accessConfigs"][0]["natIP"],
- "key_path":
- self.ssh_key_file,
- "username":
- self.ssh_user,
- "key_password":
- self.ssh_password
+ "hostname": internal if self.internal else external,
+ "key_path": self.ssh_key_file,
+ "username": self.ssh_user,
+ "key_password": self.ssh_password
}
machines.append(machine.RemoteMachine(name=name, **kwargs))
return machines
@@ -168,7 +170,9 @@ class GCloudProducer(machine_producer.MachineProducer):
cmd.append("--zone=" + self.zone)
cmd.append("--machine-type=" + self.machine_type)
res = self._run_command(cmd)
- return json.loads(res.stdout)
+ data = res.stdout
+ data = str(data, "utf-8") if isinstance(data, (bytes, bytearray)) else data
+ return json.loads(data)
def _add_ssh_key_to_instances(self, names: List[str]) -> None:
"""Adds ssh key to instances by calling gcloud ssh command.
@@ -186,11 +190,13 @@ class GCloudProducer(machine_producer.MachineProducer):
TimeoutError: when 3 unsuccessful tries to ssh into the host return 255.
"""
for name in names:
- cmd = "gcloud compute ssh {name}".format(name=name).split(" ")
+ cmd = "gcloud compute ssh {user}@{name}".format(
+ user=self.ssh_user, name=name).split(" ")
+ if self.internal:
+ cmd.append("--internal-ip")
cmd.append("--ssh-key-file={key}".format(key=self.ssh_key_file))
cmd.append("--zone={zone}".format(zone=self.zone))
cmd.append("--command=uname")
- cmd.append("--ssh-key-expire-after=60m")
timeout = datetime.timedelta(seconds=5 * 60)
start = datetime.datetime.now()
while datetime.datetime.now() <= timeout + start:
diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py
index ba27dc69f..fc59cf505 100644
--- a/benchmarks/runner/__init__.py
+++ b/benchmarks/runner/__init__.py
@@ -19,6 +19,7 @@ import logging
import pkgutil
import pydoc
import re
+import subprocess
import sys
import types
from typing import List
@@ -120,14 +121,13 @@ def run_mock(ctx, **kwargs):
@runner.command("run-gcp", commands.GCPCommand)
@click.pass_context
-def run_gcp(ctx, image_file: str, zone_file: str, machine_type: str,
- installers: List[str], **kwargs):
+def run_gcp(ctx, image_file: str, zone_file: str, internal: bool,
+ machine_type: str, installers: List[str], **kwargs):
"""Runs all benchmarks on GCP instances."""
# Resolve all files.
- image = open(image_file).read().rstrip()
- zone = open(zone_file).read().rstrip()
-
+ image = subprocess.check_output([image_file]).rstrip()
+ zone = subprocess.check_output([zone_file]).rstrip()
key_file = harness.make_key()
producer = gcloud_producer.GCloudProducer(
@@ -137,7 +137,8 @@ def run_gcp(ctx, image_file: str, zone_file: str, machine_type: str,
installers,
ssh_key_file=key_file,
ssh_user=harness.DEFAULT_USER,
- ssh_password="")
+ ssh_password="",
+ internal=internal)
try:
run(ctx, producer, **kwargs)
diff --git a/benchmarks/runner/commands.py b/benchmarks/runner/commands.py
index 0fccb2fad..e8289f6c5 100644
--- a/benchmarks/runner/commands.py
+++ b/benchmarks/runner/commands.py
@@ -101,15 +101,21 @@ class GCPCommand(RunCommand):
image_file = click.core.Option(
("--image_file",),
- help="The file containing the image for VMs.",
+ help="The binary that emits the GCP image.",
default=os.path.join(
- os.path.dirname(__file__), "../../tools/images/ubuntu1604.txt"),
+ os.path.dirname(__file__), "../../tools/images/ubuntu1604"),
)
zone_file = click.core.Option(
("--zone_file",),
- help="The file containing the GCP zone.",
+ help="The binary that emits the GCP zone.",
default=os.path.join(
- os.path.dirname(__file__), "../../tools/images/zone.txt"),
+ os.path.dirname(__file__), "../../tools/images/zone"),
+ )
+ internal = click.core.Option(
+ ("--internal/--no-internal",),
+ help="""Use instance internal IPs. Used if bm-tools runner is running on
+ GCP instance with firewall rules blocking external IPs.""",
+ default=False,
)
installers = click.core.Option(
("--installers",),
@@ -124,6 +130,7 @@ class GCPCommand(RunCommand):
self.params.extend([
image_file,
zone_file,
+ internal,
machine_type,
installers,
])
diff --git a/benchmarks/workloads/ruby/Gemfile.lock b/benchmarks/workloads/ruby/Gemfile.lock
index b44817bd3..ea9f0ea85 100644
--- a/benchmarks/workloads/ruby/Gemfile.lock
+++ b/benchmarks/workloads/ruby/Gemfile.lock
@@ -1,28 +1,41 @@
GEM
remote: https://rubygems.org/
specs:
+ activemerchant (1.105.0)
+ activesupport (>= 4.2)
+ builder (>= 2.1.2, < 4.0.0)
+ i18n (>= 0.6.9)
+ nokogiri (~> 1.4)
activesupport (5.2.3)
concurrent-ruby (~> 1.0, >= 1.0.2)
i18n (>= 0.7, < 2)
minitest (~> 5.1)
tzinfo (~> 1.1)
+ bcrypt (3.1.13)
+ builder (3.2.4)
cassandra-driver (3.2.3)
ione (~> 1.2)
concurrent-ruby (1.1.5)
+ ffi (1.12.2)
i18n (1.6.0)
concurrent-ruby (~> 1.0)
ione (1.2.4)
+ mini_portile2 (2.4.0)
minitest (5.11.3)
mustermann (1.0.3)
+ nokogiri (1.10.8)
+ mini_portile2 (~> 2.4.0)
pdf-core (0.7.0)
prawn (2.2.2)
pdf-core (~> 0.7.0)
ttfunk (~> 1.5)
- puma (3.12.1)
- rack (2.0.7)
+ puma (3.12.4)
+ rack (2.2.2)
rack-protection (2.0.5)
rack
- rake (12.3.2)
+ rake (12.3.3)
+ rbnacl (7.1.1)
+ ffi
redis (4.1.1)
ruby-fann (1.2.6)
sinatra (2.0.5)
@@ -43,9 +56,12 @@ PLATFORMS
ruby
DEPENDENCIES
+ activemerchant
+ bcrypt
cassandra-driver
puma
rake
+ rbnacl
redis
ruby-fann
sinatra
diff --git a/benchmarks/workloads/ruby_template/Gemfile.lock b/benchmarks/workloads/ruby_template/Gemfile.lock
index dd8d56fb7..f637b6081 100644
--- a/benchmarks/workloads/ruby_template/Gemfile.lock
+++ b/benchmarks/workloads/ruby_template/Gemfile.lock
@@ -2,25 +2,25 @@ GEM
remote: https://rubygems.org/
specs:
mustermann (1.0.3)
- puma (3.12.0)
+ puma (3.12.4)
rack (2.0.6)
rack-protection (2.0.5)
rack
+ redis (4.1.0)
sinatra (2.0.5)
mustermann (~> 1.0)
rack (~> 2.0)
rack-protection (= 2.0.5)
tilt (~> 2.0)
tilt (2.0.9)
- redis (4.1.0)
PLATFORMS
ruby
DEPENDENCIES
puma
- sinatra
redis
+ sinatra
BUNDLED WITH
1.17.1 \ No newline at end of file
diff --git a/go.mod b/go.mod
index c4687ed02..434fa713f 100644
--- a/go.mod
+++ b/go.mod
@@ -1,23 +1,20 @@
module gvisor.dev/gvisor
-go 1.13
+go 1.14
require (
- github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422
- github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079
- github.com/golang/mock v1.3.1
- github.com/golang/protobuf v1.3.1
- github.com/google/btree v1.0.0
- github.com/google/go-cmp v0.2.0
- github.com/google/go-github/v28 v28.1.1
- github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8
- github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d
- github.com/kr/pty v1.1.1
- github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78
- github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2
- github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e
- github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936
- golang.org/x/net v0.0.0-20190311183353-d8887717615a
- golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6
- golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a
+ github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422
+ github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079
+ github.com/golang/protobuf v1.3.1
+ github.com/google/btree v1.0.0
+ github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8
+ github.com/kr/pretty v0.2.0 // indirect
+ github.com/kr/pty v1.1.1
+ github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78
+ github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2
+ github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e
+ github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 // indirect
+ golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527
+ golang.org/x/time v0.0.0-20191024005414-555d28b269f0
+ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
)
diff --git a/go.sum b/go.sum
index 434770beb..c44a17c71 100644
--- a/go.sum
+++ b/go.sum
@@ -1,21 +1,32 @@
+github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 h1:+FKjzBIdfBHYDvxCv+djmDJdes/AoDtg8gpcxowBlF8=
github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM=
+github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs=
github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
-github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
+github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
+github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
-github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
-github.com/google/go-github/v28 v28.1.1/go.mod h1:bsqJWQX05omyWVmc00nEUql9mhQyv38lDZ8kPZcQVoM=
+github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 h1:GZGUPQiZfYrd9uOqyqwbQcHPkz/EZJVkZB1MkaO9UBI=
github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
-github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=
+github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
+github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
+github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
+github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI=
github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
+github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8=
github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
+github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e h1:/Tdc23Arz1OtdIsBY2utWepGRQ9fEAJlhkdoLzWMK8Q=
github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
+github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 h1:J9gO8RJCAFlln1jsvRba/CWVUnMHwObklfxxjErl1uk=
github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
-golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=
+golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
+golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
+gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
diff --git a/kokoro/benchmark_tests.cfg b/kokoro/benchmark_tests.cfg
index c48518a05..f85cc9681 100644
--- a/kokoro/benchmark_tests.cfg
+++ b/kokoro/benchmark_tests.cfg
@@ -5,14 +5,14 @@ before_action {
fetch_keystore {
keystore_resource {
keystore_config_id : 73898
- keyname : 'kokoro-rbe-service-account'
+ keyname : 'gvisor-benchmarks-service-account'
},
}
}
env_vars {
key : 'PROJECT'
- value : 'gvisor-kokoro-testing'
+ value : 'gvisor-benchmarks'
}
env_vars {
@@ -21,6 +21,6 @@ env_vars {
}
env_vars {
- key : 'KOKORO_SERVICE_ACCOUNT'
- value : '73898_kokoro-rbe-service-account'
+ key : 'GCLOUD_CREDENTIALS'
+ value : '73898_gvisor-benchmarks-service-account'
}
diff --git a/kokoro/iptables_tests.cfg b/kokoro/iptables_tests.cfg
index 7af20629a..a30d82591 100644
--- a/kokoro/iptables_tests.cfg
+++ b/kokoro/iptables_tests.cfg
@@ -1,4 +1,4 @@
-build_file: "repo/scripts/iptables_test.sh"
+build_file: "repo/scripts/iptables_tests.sh"
action {
define_artifacts {
diff --git a/kokoro/packetimpact_tests.cfg b/kokoro/packetimpact_tests.cfg
new file mode 100644
index 000000000..db86b52d5
--- /dev/null
+++ b/kokoro/packetimpact_tests.cfg
@@ -0,0 +1,9 @@
+build_file: "repo/scripts/packetimpact_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ }
+}
diff --git a/pkg/abi/linux/elf.go b/pkg/abi/linux/elf.go
index 40f0459a0..7c9a02f20 100644
--- a/pkg/abi/linux/elf.go
+++ b/pkg/abi/linux/elf.go
@@ -102,4 +102,7 @@ const (
// NT_X86_XSTATE is for x86 extended state using xsave.
NT_X86_XSTATE = 0x202
+
+ // NT_ARM_TLS is for ARM TLS register.
+ NT_ARM_TLS = 0x401
)
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index e229ac21c..055ac1d7c 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -266,6 +266,9 @@ type Statx struct {
DevMinor uint32
}
+// SizeOfStatx is the size of a Statx struct.
+var SizeOfStatx = binary.Size(Statx{})
+
// FileMode represents a mode_t.
type FileMode uint16
@@ -284,6 +287,11 @@ func (m FileMode) ExtraBits() FileMode {
return m &^ (PermissionsMask | FileTypeMask)
}
+// IsDir returns true if file type represents a directory.
+func (m FileMode) IsDir() bool {
+ return m.FileType() == S_IFDIR
+}
+
// String returns a string representation of m.
func (m FileMode) String() string {
var s []string
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index bd2e13ba1..a8d4f9d69 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -158,10 +158,32 @@ type IPTIP struct {
// Flags define matching behavior for the IP header.
Flags uint8
- // InverseFlags invert the meaning of fields in struct IPTIP.
+ // InverseFlags invert the meaning of fields in struct IPTIP. See the
+ // IPT_INV_* flags.
InverseFlags uint8
}
+// Flags in IPTIP.InverseFlags. Corresponding constants are in
+// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+const (
+ // Invert the meaning of InputInterface.
+ IPT_INV_VIA_IN = 0x01
+ // Invert the meaning of OutputInterface.
+ IPT_INV_VIA_OUT = 0x02
+ // Unclear what this is, as no references to it exist in the kernel.
+ IPT_INV_TOS = 0x04
+ // Invert the meaning of Src.
+ IPT_INV_SRCIP = 0x08
+ // Invert the meaning of Dst.
+ IPT_INV_DSTIP = 0x10
+ // Invert the meaning of the IPT_F_FRAG flag.
+ IPT_INV_FRAG = 0x20
+ // Invert the meaning of the Protocol field.
+ IPT_INV_PROTO = 0x40
+ // Enable all flags.
+ IPT_INV_MASK = 0x7F
+)
+
// SizeOfIPTIP is the size of an IPTIP.
const SizeOfIPTIP = 84
@@ -253,6 +275,50 @@ type XTErrorTarget struct {
// SizeOfXTErrorTarget is the size of an XTErrorTarget.
const SizeOfXTErrorTarget = 64
+// Flag values for NfNATIPV4Range. The values indicate whether to map
+// protocol specific part(ports) or IPs. It corresponds to values in
+// include/uapi/linux/netfilter/nf_nat.h.
+const (
+ NF_NAT_RANGE_MAP_IPS = 1 << 0
+ NF_NAT_RANGE_PROTO_SPECIFIED = 1 << 1
+ NF_NAT_RANGE_PROTO_RANDOM = 1 << 2
+ NF_NAT_RANGE_PERSISTENT = 1 << 3
+ NF_NAT_RANGE_PROTO_RANDOM_FULLY = 1 << 4
+ NF_NAT_RANGE_PROTO_RANDOM_ALL = (NF_NAT_RANGE_PROTO_RANDOM | NF_NAT_RANGE_PROTO_RANDOM_FULLY)
+ NF_NAT_RANGE_MASK = (NF_NAT_RANGE_MAP_IPS |
+ NF_NAT_RANGE_PROTO_SPECIFIED | NF_NAT_RANGE_PROTO_RANDOM |
+ NF_NAT_RANGE_PERSISTENT | NF_NAT_RANGE_PROTO_RANDOM_FULLY)
+)
+
+// NfNATIPV4Range corresponds to struct nf_nat_ipv4_range
+// in include/uapi/linux/netfilter/nf_nat.h. The fields are in
+// network byte order.
+type NfNATIPV4Range struct {
+ Flags uint32
+ MinIP [4]byte
+ MaxIP [4]byte
+ MinPort uint16
+ MaxPort uint16
+}
+
+// NfNATIPV4MultiRangeCompat corresponds to struct
+// nf_nat_ipv4_multi_range_compat in include/uapi/linux/netfilter/nf_nat.h.
+type NfNATIPV4MultiRangeCompat struct {
+ RangeSize uint32
+ RangeIPV4 NfNATIPV4Range
+}
+
+// XTRedirectTarget triggers a redirect when reached.
+// Adding 4 bytes of padding to make the struct 8 byte aligned.
+type XTRedirectTarget struct {
+ Target XTEntryTarget
+ NfRange NfNATIPV4MultiRangeCompat
+ _ [4]byte
+}
+
+// SizeOfXTRedirectTarget is the size of an XTRedirectTarget.
+const SizeOfXTRedirectTarget = 56
+
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
type IPTGetinfo struct {
@@ -443,3 +509,44 @@ const (
// Enable all flags.
XT_UDP_INV_MASK = 0x03
)
+
+// IPTOwnerInfo holds data for matching packets with owner. It corresponds
+// to struct ipt_owner_info in libxt_owner.c of iptables binary.
+type IPTOwnerInfo struct {
+ // UID is user id which created the packet.
+ UID uint32
+
+ // GID is group id which created the packet.
+ GID uint32
+
+ // PID is process id of the process which created the packet.
+ PID uint32
+
+ // SID is session id which created the packet.
+ SID uint32
+
+ // Comm is the command name which created the packet.
+ Comm [16]byte
+
+ // Match is used to match UID/GID of the socket. See the
+ // XT_OWNER_* flags below.
+ Match uint8
+
+ // Invert flips the meaning of Match field.
+ Invert uint8
+}
+
+// SizeOfIPTOwnerInfo is the size of an XTOwnerMatchInfo.
+const SizeOfIPTOwnerInfo = 34
+
+// Flags in IPTOwnerInfo.Match. Corresponding constants are in
+// include/uapi/linux/netfilter/xt_owner.h.
+const (
+ // Match the UID of the packet.
+ XT_OWNER_UID = 1 << 0
+ // Match the GID of the packet.
+ XT_OWNER_GID = 1 << 1
+ // Match if the socket exists for the packet. Forwarded
+ // packets do not have an associated socket.
+ XT_OWNER_SOCKET = 1 << 2
+)
diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go
index 21e237f92..565dd550e 100644
--- a/pkg/abi/linux/netfilter_test.go
+++ b/pkg/abi/linux/netfilter_test.go
@@ -29,6 +29,7 @@ func TestSizes(t *testing.T) {
{IPTGetEntries{}, SizeOfIPTGetEntries},
{IPTGetinfo{}, SizeOfIPTGetinfo},
{IPTIP{}, SizeOfIPTIP},
+ {IPTOwnerInfo{}, SizeOfIPTOwnerInfo},
{IPTReplace{}, SizeOfIPTReplace},
{XTCounters{}, SizeOfXTCounters},
{XTEntryMatch{}, SizeOfXTEntryMatch},
diff --git a/pkg/bits/bits_template.go b/pkg/bits/bits_template.go
index 93a435b80..998645388 100644
--- a/pkg/bits/bits_template.go
+++ b/pkg/bits/bits_template.go
@@ -42,3 +42,11 @@ func Mask(is ...int) T {
func MaskOf(i int) T {
return T(1) << T(i)
}
+
+// IsPowerOfTwo returns true if v is power of 2.
+func IsPowerOfTwo(v T) bool {
+ if v == 0 {
+ return false
+ }
+ return v&(v-1) == 0
+}
diff --git a/pkg/bits/uint64_test.go b/pkg/bits/uint64_test.go
index 1b018d808..193d1ebcd 100644
--- a/pkg/bits/uint64_test.go
+++ b/pkg/bits/uint64_test.go
@@ -114,3 +114,21 @@ func TestIsOn(t *testing.T) {
}
}
}
+
+func TestIsPowerOfTwo(t *testing.T) {
+ for _, tc := range []struct {
+ v uint64
+ want bool
+ }{
+ {v: 0, want: false},
+ {v: 1, want: true},
+ {v: 2, want: true},
+ {v: 3, want: false},
+ {v: 4, want: true},
+ {v: 5, want: false},
+ } {
+ if got := IsPowerOfTwo64(tc.v); got != tc.want {
+ t.Errorf("IsPowerOfTwo(%d) = %t, want: %t", tc.v, got, tc.want)
+ }
+ }
+}
diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD
new file mode 100644
index 000000000..dcd086298
--- /dev/null
+++ b/pkg/buffer/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "buffer_list",
+ out = "buffer_list.go",
+ package = "buffer",
+ prefix = "buffer",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*buffer",
+ "Linker": "*buffer",
+ },
+)
+
+go_library(
+ name = "buffer",
+ srcs = [
+ "buffer.go",
+ "buffer_list.go",
+ "safemem.go",
+ "view.go",
+ "view_unsafe.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/safemem",
+ ],
+)
+
+go_test(
+ name = "buffer_test",
+ size = "small",
+ srcs = [
+ "safemem_test.go",
+ "view_test.go",
+ ],
+ library = ":buffer",
+ deps = ["//pkg/safemem"],
+)
diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go
new file mode 100644
index 000000000..c6d089fd9
--- /dev/null
+++ b/pkg/buffer/buffer.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package buffer provides the implementation of a buffer view.
+//
+// A view is an flexible buffer, backed by a pool, supporting the safecopy
+// operations natively as well as the ability to grow via either prepend or
+// append, as well as shrink.
+package buffer
+
+import (
+ "sync"
+)
+
+const bufferSize = 8144 // See below.
+
+// buffer encapsulates a queueable byte buffer.
+//
+// Note that the total size is slightly less than two pages. This is done
+// intentionally to ensure that the buffer object aligns with runtime
+// internals. We have no hard size or alignment requirements. This two page
+// size will effectively minimize internal fragmentation, but still have a
+// large enough chunk to limit excessive segmentation.
+//
+// +stateify savable
+type buffer struct {
+ data [bufferSize]byte
+ read int
+ write int
+ bufferEntry
+}
+
+// reset resets internal data.
+//
+// This must be called before returning the buffer to the pool.
+func (b *buffer) Reset() {
+ b.read = 0
+ b.write = 0
+}
+
+// Full indicates the buffer is full.
+//
+// This indicates there is no capacity left to write.
+func (b *buffer) Full() bool {
+ return b.write == len(b.data)
+}
+
+// ReadSize returns the number of bytes available for reading.
+func (b *buffer) ReadSize() int {
+ return b.write - b.read
+}
+
+// ReadMove advances the read index by the given amount.
+func (b *buffer) ReadMove(n int) {
+ b.read += n
+}
+
+// ReadSlice returns the read slice for this buffer.
+func (b *buffer) ReadSlice() []byte {
+ return b.data[b.read:b.write]
+}
+
+// WriteSize returns the number of bytes available for writing.
+func (b *buffer) WriteSize() int {
+ return len(b.data) - b.write
+}
+
+// WriteMove advances the write index by the given amount.
+func (b *buffer) WriteMove(n int) {
+ b.write += n
+}
+
+// WriteSlice returns the write slice for this buffer.
+func (b *buffer) WriteSlice() []byte {
+ return b.data[b.write:]
+}
+
+// bufferPool is a pool for buffers.
+var bufferPool = sync.Pool{
+ New: func() interface{} {
+ return new(buffer)
+ },
+}
diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go
new file mode 100644
index 000000000..0e5b86344
--- /dev/null
+++ b/pkg/buffer/safemem.go
@@ -0,0 +1,129 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+// WriteBlock returns this buffer as a write Block.
+func (b *buffer) WriteBlock() safemem.Block {
+ return safemem.BlockFromSafeSlice(b.WriteSlice())
+}
+
+// ReadBlock returns this buffer as a read Block.
+func (b *buffer) ReadBlock() safemem.Block {
+ return safemem.BlockFromSafeSlice(b.ReadSlice())
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// This will advance the write index.
+func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ need := int(srcs.NumBytes())
+ if need == 0 {
+ return 0, nil
+ }
+
+ var (
+ dst safemem.BlockSeq
+ blocks []safemem.Block
+ )
+
+ // Need at least one buffer.
+ firstBuf := v.data.Back()
+ if firstBuf == nil {
+ firstBuf = bufferPool.Get().(*buffer)
+ v.data.PushBack(firstBuf)
+ }
+
+ // Does the last block have sufficient capacity alone?
+ if l := firstBuf.WriteSize(); l >= need {
+ dst = safemem.BlockSeqOf(firstBuf.WriteBlock())
+ } else {
+ // Append blocks until sufficient.
+ need -= l
+ blocks = append(blocks, firstBuf.WriteBlock())
+ for need > 0 {
+ emptyBuf := bufferPool.Get().(*buffer)
+ v.data.PushBack(emptyBuf)
+ need -= emptyBuf.WriteSize()
+ blocks = append(blocks, emptyBuf.WriteBlock())
+ }
+ dst = safemem.BlockSeqFromSlice(blocks)
+ }
+
+ // Perform the copy.
+ n, err := safemem.CopySeq(dst, srcs)
+ v.size += int64(n)
+
+ // Update all indices.
+ for left := int(n); left > 0; firstBuf = firstBuf.Next() {
+ if l := firstBuf.WriteSize(); left >= l {
+ firstBuf.WriteMove(l) // Whole block.
+ left -= l
+ } else {
+ firstBuf.WriteMove(left) // Partial block.
+ left = 0
+ }
+ }
+
+ return n, err
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+//
+// This will not advance the read index; the caller should follow
+// this call with a call to TrimFront in order to remove the read
+// data from the buffer. This is done to support pipe sematics.
+func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ need := int(dsts.NumBytes())
+ if need == 0 {
+ return 0, nil
+ }
+
+ var (
+ src safemem.BlockSeq
+ blocks []safemem.Block
+ )
+
+ firstBuf := v.data.Front()
+ if firstBuf == nil {
+ return 0, nil // No EOF.
+ }
+
+ // Is all the data in a single block?
+ if l := firstBuf.ReadSize(); l >= need {
+ src = safemem.BlockSeqOf(firstBuf.ReadBlock())
+ } else {
+ // Build a list of all the buffers.
+ need -= l
+ blocks = append(blocks, firstBuf.ReadBlock())
+ for buf := firstBuf.Next(); buf != nil && need > 0; buf = buf.Next() {
+ need -= buf.ReadSize()
+ blocks = append(blocks, buf.ReadBlock())
+ }
+ src = safemem.BlockSeqFromSlice(blocks)
+ }
+
+ // Perform the copy.
+ n, err := safemem.CopySeq(dsts, src)
+
+ // See above: we would normally advance the read index here, but we
+ // don't do that in order to support pipe semantics. We rely on a
+ // separate call to TrimFront() in this case.
+
+ return n, err
+}
diff --git a/pkg/buffer/safemem_test.go b/pkg/buffer/safemem_test.go
new file mode 100644
index 000000000..47f357e0c
--- /dev/null
+++ b/pkg/buffer/safemem_test.go
@@ -0,0 +1,170 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "bytes"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+func TestSafemem(t *testing.T) {
+ testCases := []struct {
+ name string
+ input string
+ output string
+ readLen int
+ op func(*View)
+ }{
+ // Basic coverage.
+ {
+ name: "short",
+ input: "010",
+ output: "010",
+ },
+ {
+ name: "long",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "0" + strings.Repeat("1", bufferSize) + "0",
+ },
+ {
+ name: "short-read",
+ input: "0",
+ readLen: 100, // > size.
+ output: "0",
+ },
+ {
+ name: "zero-read",
+ input: "0",
+ output: "",
+ },
+ {
+ name: "read-empty",
+ input: "",
+ readLen: 1, // > size.
+ output: "",
+ },
+
+ // Ensure offsets work.
+ {
+ name: "offsets-short",
+ input: "012",
+ output: "2",
+ op: func(v *View) {
+ v.TrimFront(2)
+ },
+ },
+ {
+ name: "offsets-long0",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: strings.Repeat("1", bufferSize) + "0",
+ op: func(v *View) {
+ v.TrimFront(1)
+ },
+ },
+ {
+ name: "offsets-long1",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: strings.Repeat("1", bufferSize-1) + "0",
+ op: func(v *View) {
+ v.TrimFront(2)
+ },
+ },
+ {
+ name: "offsets-long2",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "10",
+ op: func(v *View) {
+ v.TrimFront(bufferSize)
+ },
+ },
+
+ // Ensure truncation works.
+ {
+ name: "truncate-short",
+ input: "012",
+ output: "01",
+ op: func(v *View) {
+ v.Truncate(2)
+ },
+ },
+ {
+ name: "truncate-long0",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "0" + strings.Repeat("1", bufferSize),
+ op: func(v *View) {
+ v.Truncate(bufferSize + 1)
+ },
+ },
+ {
+ name: "truncate-long1",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "0" + strings.Repeat("1", bufferSize-1),
+ op: func(v *View) {
+ v.Truncate(bufferSize)
+ },
+ },
+ {
+ name: "truncate-long2",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "01",
+ op: func(v *View) {
+ v.Truncate(2)
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ // Construct the new view.
+ var view View
+ bs := safemem.BlockSeqOf(safemem.BlockFromSafeSlice([]byte(tc.input)))
+ n, err := view.WriteFromBlocks(bs)
+ if err != nil {
+ t.Errorf("expected err nil, got %v", err)
+ }
+ if n != uint64(len(tc.input)) {
+ t.Errorf("expected %d bytes, got %d", len(tc.input), n)
+ }
+
+ // Run the operation.
+ if tc.op != nil {
+ tc.op(&view)
+ }
+
+ // Read and validate.
+ readLen := tc.readLen
+ if readLen == 0 {
+ readLen = len(tc.output) // Default.
+ }
+ out := make([]byte, readLen)
+ bs = safemem.BlockSeqOf(safemem.BlockFromSafeSlice(out))
+ n, err = view.ReadToBlocks(bs)
+ if err != nil {
+ t.Errorf("expected nil, got %v", err)
+ }
+ if n != uint64(len(tc.output)) {
+ t.Errorf("expected %d bytes, got %d", len(tc.output), n)
+ }
+
+ // Ensure the contents are correct.
+ if !bytes.Equal(out[:n], []byte(tc.output[:n])) {
+ t.Errorf("contents are wrong: expected %q, got %q", tc.output, string(out))
+ }
+ })
+ }
+}
diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go
new file mode 100644
index 000000000..e6901eadb
--- /dev/null
+++ b/pkg/buffer/view.go
@@ -0,0 +1,390 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "fmt"
+ "io"
+)
+
+// View is a non-linear buffer.
+//
+// All methods are thread compatible.
+//
+// +stateify savable
+type View struct {
+ data bufferList
+ size int64
+}
+
+// TrimFront removes the first count bytes from the buffer.
+func (v *View) TrimFront(count int64) {
+ if count >= v.size {
+ v.advanceRead(v.size)
+ } else {
+ v.advanceRead(count)
+ }
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (v *View) ReadAt(p []byte, offset int64) (int, error) {
+ var (
+ skipped int64
+ done int64
+ )
+ for buf := v.data.Front(); buf != nil && done < int64(len(p)); buf = buf.Next() {
+ needToSkip := int(offset - skipped)
+ if sz := buf.ReadSize(); sz <= needToSkip {
+ skipped += int64(sz)
+ continue
+ }
+
+ // Actually read data.
+ n := copy(p[done:], buf.ReadSlice()[needToSkip:])
+ skipped += int64(needToSkip)
+ done += int64(n)
+ }
+ if int(done) < len(p) || offset+done == v.size {
+ return int(done), io.EOF
+ }
+ return int(done), nil
+}
+
+// advanceRead advances the view's read index.
+//
+// Precondition: there must be sufficient bytes in the buffer.
+func (v *View) advanceRead(count int64) {
+ for buf := v.data.Front(); buf != nil && count > 0; {
+ sz := int64(buf.ReadSize())
+ if sz > count {
+ // There is still data for reading.
+ buf.ReadMove(int(count))
+ v.size -= count
+ count = 0
+ break
+ }
+
+ // Consume the whole buffer.
+ oldBuf := buf
+ buf = buf.Next() // Iterate.
+ v.data.Remove(oldBuf)
+ oldBuf.Reset()
+ bufferPool.Put(oldBuf)
+
+ // Update counts.
+ count -= sz
+ v.size -= sz
+ }
+ if count > 0 {
+ panic(fmt.Sprintf("advanceRead still has %d bytes remaining", count))
+ }
+}
+
+// Truncate truncates the view to the given bytes.
+//
+// This will not grow the view, only shrink it. If a length is passed that is
+// greater than the current size of the view, then nothing will happen.
+//
+// Precondition: length must be >= 0.
+func (v *View) Truncate(length int64) {
+ if length < 0 {
+ panic("negative length provided")
+ }
+ if length >= v.size {
+ return // Nothing to do.
+ }
+ for buf := v.data.Back(); buf != nil && v.size > length; buf = v.data.Back() {
+ sz := int64(buf.ReadSize())
+ if after := v.size - sz; after < length {
+ // Truncate the buffer locally.
+ left := (length - after)
+ buf.write = buf.read + int(left)
+ v.size = length
+ break
+ }
+
+ // Drop the buffer completely; see above.
+ v.data.Remove(buf)
+ buf.Reset()
+ bufferPool.Put(buf)
+ v.size -= sz
+ }
+}
+
+// Grow grows the given view to the number of bytes, which will be appended. If
+// zero is true, all these bytes will be zero. If zero is false, then this is
+// the caller's responsibility.
+//
+// Precondition: length must be >= 0.
+func (v *View) Grow(length int64, zero bool) {
+ if length < 0 {
+ panic("negative length provided")
+ }
+ for v.size < length {
+ buf := v.data.Back()
+
+ // Is there some space in the last buffer?
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Write up to length bytes.
+ sz := buf.WriteSize()
+ if int64(sz) > length-v.size {
+ sz = int(length - v.size)
+ }
+
+ // Zero the written section; note that this pattern is
+ // specifically recognized and optimized by the compiler.
+ if zero {
+ for i := buf.write; i < buf.write+sz; i++ {
+ buf.data[i] = 0
+ }
+ }
+
+ // Advance the index.
+ buf.WriteMove(sz)
+ v.size += int64(sz)
+ }
+}
+
+// Prepend prepends the given data.
+func (v *View) Prepend(data []byte) {
+ // Is there any space in the first buffer?
+ if buf := v.data.Front(); buf != nil && buf.read > 0 {
+ // Fill up before the first write.
+ avail := buf.read
+ bStart := 0
+ dStart := len(data) - avail
+ if avail > len(data) {
+ bStart = avail - len(data)
+ dStart = 0
+ }
+ n := copy(buf.data[bStart:], data[dStart:])
+ data = data[:dStart]
+ v.size += int64(n)
+ buf.read -= n
+ }
+
+ for len(data) > 0 {
+ // Do we need an empty buffer?
+ buf := bufferPool.Get().(*buffer)
+ v.data.PushFront(buf)
+
+ // The buffer is empty; copy last chunk.
+ avail := len(buf.data)
+ bStart := 0
+ dStart := len(data) - avail
+ if avail > len(data) {
+ bStart = avail - len(data)
+ dStart = 0
+ }
+
+ // We have to put the data at the end of the current
+ // buffer in order to ensure that the next prepend will
+ // correctly fill up the beginning of this buffer.
+ n := copy(buf.data[bStart:], data[dStart:])
+ data = data[:dStart]
+ v.size += int64(n)
+ buf.read = len(buf.data) - n
+ buf.write = len(buf.data)
+ }
+}
+
+// Append appends the given data.
+func (v *View) Append(data []byte) {
+ for done := 0; done < len(data); {
+ buf := v.data.Back()
+
+ // Ensure there's a buffer with space.
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Copy in to the given buffer.
+ n := copy(buf.WriteSlice(), data[done:])
+ done += n
+ buf.WriteMove(n)
+ v.size += int64(n)
+ }
+}
+
+// Flatten returns a flattened copy of this data.
+//
+// This method should not be used in any performance-sensitive paths. It may
+// allocate a fresh byte slice sufficiently large to contain all the data in
+// the buffer. This is principally for debugging.
+//
+// N.B. Tee data still belongs to this view, as if there is a single buffer
+// present, then it will be returned directly. This should be used for
+// temporary use only, and a reference to the given slice should not be held.
+func (v *View) Flatten() []byte {
+ if buf := v.data.Front(); buf == nil {
+ return nil // No data at all.
+ } else if buf.Next() == nil {
+ return buf.ReadSlice() // Only one buffer.
+ }
+ data := make([]byte, 0, v.size) // Need to flatten.
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ // Copy to the allocated slice.
+ data = append(data, buf.ReadSlice()...)
+ }
+ return data
+}
+
+// Size indicates the total amount of data available in this view.
+func (v *View) Size() int64 {
+ return v.size
+}
+
+// Copy makes a strict copy of this view.
+func (v *View) Copy() (other View) {
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ other.Append(buf.ReadSlice())
+ }
+ return
+}
+
+// Apply applies the given function across all valid data.
+func (v *View) Apply(fn func([]byte)) {
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ fn(buf.ReadSlice())
+ }
+}
+
+// Merge merges the provided View with this one.
+//
+// The other view will be appended to v, and other will be empty after this
+// operation completes.
+func (v *View) Merge(other *View) {
+ // Copy over all buffers.
+ for buf := other.data.Front(); buf != nil; buf = other.data.Front() {
+ other.data.Remove(buf)
+ v.data.PushBack(buf)
+ }
+
+ // Adjust sizes.
+ v.size += other.size
+ other.size = 0
+}
+
+// WriteFromReader writes to the buffer from an io.Reader.
+//
+// A minimum read size equal to unsafe.Sizeof(unintptr) is enforced,
+// provided that count is greater than or equal to unsafe.Sizeof(uintptr).
+func (v *View) WriteFromReader(r io.Reader, count int64) (int64, error) {
+ var (
+ done int64
+ n int
+ err error
+ )
+ for done < count {
+ buf := v.data.Back()
+
+ // Ensure we have an empty buffer.
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Is this less than the minimum batch?
+ if buf.WriteSize() < minBatch && (count-done) >= int64(minBatch) {
+ tmp := make([]byte, minBatch)
+ n, err = r.Read(tmp)
+ v.Append(tmp[:n])
+ done += int64(n)
+ if err != nil {
+ break
+ }
+ continue
+ }
+
+ // Limit the read, if necessary.
+ sz := buf.WriteSize()
+ if left := count - done; int64(sz) > left {
+ sz = int(left)
+ }
+
+ // Pass the relevant portion of the buffer.
+ n, err = r.Read(buf.WriteSlice()[:sz])
+ buf.WriteMove(n)
+ done += int64(n)
+ v.size += int64(n)
+ if err == io.EOF {
+ err = nil // Short write allowed.
+ break
+ } else if err != nil {
+ break
+ }
+ }
+ return done, err
+}
+
+// ReadToWriter reads from the buffer into an io.Writer.
+//
+// N.B. This does not consume the bytes read. TrimFront should
+// be called appropriately after this call in order to do so.
+//
+// A minimum write size equal to unsafe.Sizeof(unintptr) is enforced,
+// provided that count is greater than or equal to unsafe.Sizeof(uintptr).
+func (v *View) ReadToWriter(w io.Writer, count int64) (int64, error) {
+ var (
+ done int64
+ n int
+ err error
+ )
+ offset := 0 // Spill-over for batching.
+ for buf := v.data.Front(); buf != nil && done < count; buf = buf.Next() {
+ // Has this been consumed? Skip it.
+ sz := buf.ReadSize()
+ if sz <= offset {
+ offset -= sz
+ continue
+ }
+ sz -= offset
+
+ // Is this less than the minimum batch?
+ left := count - done
+ if sz < minBatch && left >= int64(minBatch) && (v.size-done) >= int64(minBatch) {
+ tmp := make([]byte, minBatch)
+ n, err = v.ReadAt(tmp, done)
+ w.Write(tmp[:n])
+ done += int64(n)
+ offset = n - sz // Reset below.
+ if err != nil {
+ break
+ }
+ continue
+ }
+
+ // Limit the write if necessary.
+ if int64(sz) >= left {
+ sz = int(left)
+ }
+
+ // Perform the actual write.
+ n, err = w.Write(buf.ReadSlice()[offset : offset+sz])
+ done += int64(n)
+ if err != nil {
+ break
+ }
+
+ // Reset spill-over.
+ offset = 0
+ }
+ return done, err
+}
diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go
new file mode 100644
index 000000000..3db1bc6ee
--- /dev/null
+++ b/pkg/buffer/view_test.go
@@ -0,0 +1,467 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "bytes"
+ "io"
+ "strings"
+ "testing"
+)
+
+func fillAppend(v *View, data []byte) {
+ v.Append(data)
+}
+
+func fillAppendEnd(v *View, data []byte) {
+ v.Grow(bufferSize-1, false)
+ v.Append(data)
+ v.TrimFront(bufferSize - 1)
+}
+
+func fillWriteFromReader(v *View, data []byte) {
+ b := bytes.NewBuffer(data)
+ v.WriteFromReader(b, int64(len(data)))
+}
+
+func fillWriteFromReaderEnd(v *View, data []byte) {
+ v.Grow(bufferSize-1, false)
+ b := bytes.NewBuffer(data)
+ v.WriteFromReader(b, int64(len(data)))
+ v.TrimFront(bufferSize - 1)
+}
+
+var fillFuncs = map[string]func(*View, []byte){
+ "append": fillAppend,
+ "appendEnd": fillAppendEnd,
+ "writeFromReader": fillWriteFromReader,
+ "writeFromReaderEnd": fillWriteFromReaderEnd,
+}
+
+func testReadAt(t *testing.T, v *View, offset int64, n int, wantStr string, wantErr error) {
+ t.Helper()
+ d := make([]byte, n)
+ n, err := v.ReadAt(d, offset)
+ if n != len(wantStr) {
+ t.Errorf("got %d, want %d", n, len(wantStr))
+ }
+ if err != wantErr {
+ t.Errorf("got err %v, want %v", err, wantErr)
+ }
+ if !bytes.Equal(d[:n], []byte(wantStr)) {
+ t.Errorf("got %q, want %q", string(d[:n]), wantStr)
+ }
+}
+
+func TestView(t *testing.T) {
+ testCases := []struct {
+ name string
+ input string
+ output string
+ op func(*testing.T, *View)
+ }{
+ // Preconditions.
+ {
+ name: "truncate-check",
+ input: "hello",
+ output: "hello", // Not touched.
+ op: func(t *testing.T, v *View) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Truncate(-1) did not panic")
+ }
+ }()
+ v.Truncate(-1)
+ },
+ },
+ {
+ name: "grow-check",
+ input: "hello",
+ output: "hello", // Not touched.
+ op: func(t *testing.T, v *View) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Grow(-1) did not panic")
+ }
+ }()
+ v.Grow(-1, false)
+ },
+ },
+ {
+ name: "advance-check",
+ input: "hello",
+ output: "", // Consumed.
+ op: func(t *testing.T, v *View) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("advanceRead(Size()+1) did not panic")
+ }
+ }()
+ v.advanceRead(v.Size() + 1)
+ },
+ },
+
+ // Prepend.
+ {
+ name: "prepend",
+ input: "world",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte("hello "))
+ },
+ },
+ {
+ name: "prepend-backfill-full",
+ input: "hello world",
+ output: "jello world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(1)
+ v.Prepend([]byte("j"))
+ },
+ },
+ {
+ name: "prepend-backfill-under",
+ input: "hello world",
+ output: "hola world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(5)
+ v.Prepend([]byte("hola"))
+ },
+ },
+ {
+ name: "prepend-backfill-over",
+ input: "hello world",
+ output: "smello world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(1)
+ v.Prepend([]byte("sm"))
+ },
+ },
+ {
+ name: "prepend-fill",
+ input: strings.Repeat("1", bufferSize-1),
+ output: "0" + strings.Repeat("1", bufferSize-1),
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte("0"))
+ },
+ },
+ {
+ name: "prepend-overflow",
+ input: strings.Repeat("1", bufferSize),
+ output: "0" + strings.Repeat("1", bufferSize),
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte("0"))
+ },
+ },
+ {
+ name: "prepend-multiple-buffers",
+ input: strings.Repeat("1", bufferSize-1),
+ output: strings.Repeat("0", bufferSize*3) + strings.Repeat("1", bufferSize-1),
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte(strings.Repeat("0", bufferSize*3)))
+ },
+ },
+
+ // Append and write.
+ {
+ name: "append",
+ input: "hello",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte(" world"))
+ },
+ },
+ {
+ name: "append-fill",
+ input: strings.Repeat("1", bufferSize-1),
+ output: strings.Repeat("1", bufferSize-1) + "0",
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte("0"))
+ },
+ },
+ {
+ name: "append-overflow",
+ input: strings.Repeat("1", bufferSize),
+ output: strings.Repeat("1", bufferSize) + "0",
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte("0"))
+ },
+ },
+ {
+ name: "append-multiple-buffers",
+ input: strings.Repeat("1", bufferSize-1),
+ output: strings.Repeat("1", bufferSize-1) + strings.Repeat("0", bufferSize*3),
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte(strings.Repeat("0", bufferSize*3)))
+ },
+ },
+
+ // Truncate.
+ {
+ name: "truncate",
+ input: "hello world",
+ output: "hello",
+ op: func(t *testing.T, v *View) {
+ v.Truncate(5)
+ },
+ },
+ {
+ name: "truncate-noop",
+ input: "hello world",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Truncate(v.Size() + 1)
+ },
+ },
+ {
+ name: "truncate-multiple-buffers",
+ input: strings.Repeat("1", bufferSize*2),
+ output: strings.Repeat("1", bufferSize*2-1),
+ op: func(t *testing.T, v *View) {
+ v.Truncate(bufferSize*2 - 1)
+ },
+ },
+ {
+ name: "truncate-multiple-buffers-to-one",
+ input: strings.Repeat("1", bufferSize*2),
+ output: "11111",
+ op: func(t *testing.T, v *View) {
+ v.Truncate(5)
+ },
+ },
+
+ // TrimFront.
+ {
+ name: "trim",
+ input: "hello world",
+ output: "world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(6)
+ },
+ },
+ {
+ name: "trim-too-large",
+ input: "hello world",
+ output: "",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(v.Size() + 1)
+ },
+ },
+ {
+ name: "trim-multiple-buffers",
+ input: strings.Repeat("1", bufferSize*2),
+ output: strings.Repeat("1", bufferSize*2-1),
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(1)
+ },
+ },
+ {
+ name: "trim-multiple-buffers-to-one-buffer",
+ input: strings.Repeat("1", bufferSize*2),
+ output: "1",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(bufferSize*2 - 1)
+ },
+ },
+
+ // Grow.
+ {
+ name: "grow",
+ input: "hello world",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Grow(1, true)
+ },
+ },
+ {
+ name: "grow-from-zero",
+ output: strings.Repeat("\x00", 1024),
+ op: func(t *testing.T, v *View) {
+ v.Grow(1024, true)
+ },
+ },
+ {
+ name: "grow-from-non-zero",
+ input: strings.Repeat("1", bufferSize),
+ output: strings.Repeat("1", bufferSize) + strings.Repeat("\x00", bufferSize),
+ op: func(t *testing.T, v *View) {
+ v.Grow(bufferSize*2, true)
+ },
+ },
+
+ // Copy.
+ {
+ name: "copy",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) {
+ other := v.Copy()
+ bs := other.Flatten()
+ want := []byte("hello")
+ if !bytes.Equal(bs, want) {
+ t.Errorf("expected %v, got %v", want, bs)
+ }
+ },
+ },
+ {
+ name: "copy-large",
+ input: strings.Repeat("1", bufferSize+1),
+ output: strings.Repeat("1", bufferSize+1),
+ op: func(t *testing.T, v *View) {
+ other := v.Copy()
+ bs := other.Flatten()
+ want := []byte(strings.Repeat("1", bufferSize+1))
+ if !bytes.Equal(bs, want) {
+ t.Errorf("expected %v, got %v", want, bs)
+ }
+ },
+ },
+
+ // Merge.
+ {
+ name: "merge",
+ input: "hello",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ var other View
+ other.Append([]byte(" world"))
+ v.Merge(&other)
+ if sz := other.Size(); sz != 0 {
+ t.Errorf("expected 0, got %d", sz)
+ }
+ },
+ },
+ {
+ name: "merge-large",
+ input: strings.Repeat("1", bufferSize+1),
+ output: strings.Repeat("1", bufferSize+1) + strings.Repeat("0", bufferSize+1),
+ op: func(t *testing.T, v *View) {
+ var other View
+ other.Append([]byte(strings.Repeat("0", bufferSize+1)))
+ v.Merge(&other)
+ if sz := other.Size(); sz != 0 {
+ t.Errorf("expected 0, got %d", sz)
+ }
+ },
+ },
+
+ // ReadAt.
+ {
+ name: "readat",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 6, "hello", io.EOF) },
+ },
+ {
+ name: "readat-long",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 8, "hello", io.EOF) },
+ },
+ {
+ name: "readat-short",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 3, "hel", nil) },
+ },
+ {
+ name: "readat-offset",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 3, "llo", io.EOF) },
+ },
+ {
+ name: "readat-long-offset",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 8, "llo", io.EOF) },
+ },
+ {
+ name: "readat-short-offset",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 2, "ll", nil) },
+ },
+ {
+ name: "readat-skip-all",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 1, "", io.EOF) },
+ },
+ {
+ name: "readat-second-buffer",
+ input: strings.Repeat("0", bufferSize+1) + "12",
+ output: strings.Repeat("0", bufferSize+1) + "12",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 1, "1", nil) },
+ },
+ {
+ name: "readat-second-buffer-end",
+ input: strings.Repeat("0", bufferSize+1) + "12",
+ output: strings.Repeat("0", bufferSize+1) + "12",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 2, "12", io.EOF) },
+ },
+ }
+
+ for _, tc := range testCases {
+ for fillName, fn := range fillFuncs {
+ t.Run(fillName+"/"+tc.name, func(t *testing.T) {
+ // Construct & fill the view.
+ var view View
+ fn(&view, []byte(tc.input))
+
+ // Run the operation.
+ if tc.op != nil {
+ tc.op(t, &view)
+ }
+
+ // Flatten and validate.
+ out := view.Flatten()
+ if !bytes.Equal([]byte(tc.output), out) {
+ t.Errorf("expected %q, got %q", tc.output, string(out))
+ }
+
+ // Ensure the size is correct.
+ if len(out) != int(view.Size()) {
+ t.Errorf("size is wrong: expected %d, got %d", len(out), view.Size())
+ }
+
+ // Calculate contents via apply.
+ var appliedOut []byte
+ view.Apply(func(b []byte) {
+ appliedOut = append(appliedOut, b...)
+ })
+ if len(appliedOut) != len(out) {
+ t.Errorf("expected %d, got %d", len(out), len(appliedOut))
+ }
+ if !bytes.Equal(appliedOut, out) {
+ t.Errorf("expected %v, got %v", out, appliedOut)
+ }
+
+ // Calculate contents via ReadToWriter.
+ var b bytes.Buffer
+ n, err := view.ReadToWriter(&b, int64(len(out)))
+ if n != int64(len(out)) {
+ t.Errorf("expected %d, got %d", len(out), n)
+ }
+ if err != nil {
+ t.Errorf("expected nil, got %v", err)
+ }
+ if !bytes.Equal(b.Bytes(), out) {
+ t.Errorf("expected %v, got %v", out, b.Bytes())
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/buffer/view_unsafe.go b/pkg/buffer/view_unsafe.go
new file mode 100644
index 000000000..d1ef39b26
--- /dev/null
+++ b/pkg/buffer/view_unsafe.go
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "unsafe"
+)
+
+// minBatch is the smallest Read or Write operation that the
+// WriteFromReader and ReadToWriter functions will use.
+//
+// This is defined as the size of a native pointer.
+const minBatch = int(unsafe.Sizeof(uintptr(0)))
diff --git a/pkg/context/context.go b/pkg/context/context.go
index 23e009ef3..5319b6d8d 100644
--- a/pkg/context/context.go
+++ b/pkg/context/context.go
@@ -127,10 +127,6 @@ func (logContext) Value(key interface{}) interface{} {
var bgContext = &logContext{Logger: log.Log()}
// Background returns an empty context using the default logger.
-//
-// Users should be wary of using a Background context. Please tag any use with
-// FIXME(b/38173783) and a note to remove this use.
-//
// Generally, one should use the Task as their context when available, or avoid
// having to use a context in places where a Task is unavailable.
//
diff --git a/pkg/cpuid/cpuid_parse_x86_test.go b/pkg/cpuid/cpuid_parse_x86_test.go
index d48418e69..c9bd40e1b 100644
--- a/pkg/cpuid/cpuid_parse_x86_test.go
+++ b/pkg/cpuid/cpuid_parse_x86_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build 386 amd64
package cpuid
diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go
index a0bc55ea1..562f8f405 100644
--- a/pkg/cpuid/cpuid_x86.go
+++ b/pkg/cpuid/cpuid_x86.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build 386 amd64
package cpuid
@@ -235,7 +235,9 @@ const (
X86FeaturePERFCTR_TSC
X86FeaturePERFCTR_LLC
X86FeatureMWAITX
- // ECX[31:30] are reserved.
+ // TODO(b/152776797): Some CPUs set this but it is not documented anywhere.
+ X86FeatureBlock5Bit30
+ _ // ecx bit 31 is reserved.
)
// Block 6 constants are the extended feature bits in
@@ -438,6 +440,9 @@ var x86FeatureParseOnlyStrings = map[Feature]string{
// Block 3.
X86FeaturePREFETCHWT1: "prefetchwt1",
+
+ // Block 5.
+ X86FeatureBlock5Bit30: "block5_bit30",
}
// intelCacheDescriptors describe the caches and TLBs on the system. They are
diff --git a/pkg/cpuid/cpuid_x86_test.go b/pkg/cpuid/cpuid_x86_test.go
index 0fe20c213..bacf345c8 100644
--- a/pkg/cpuid/cpuid_x86_test.go
+++ b/pkg/cpuid/cpuid_x86_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build 386 amd64
package cpuid
diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go
index 7f41b4a27..43750360b 100644
--- a/pkg/eventchannel/event_test.go
+++ b/pkg/eventchannel/event_test.go
@@ -78,7 +78,7 @@ func TestMultiEmitter(t *testing.T) {
for _, name := range names {
m := testMessage{name: name}
if _, err := me.Emit(m); err != nil {
- t.Fatal("me.Emit(%v) failed: %v", m, err)
+ t.Fatalf("me.Emit(%v) failed: %v", m, err)
}
}
@@ -96,7 +96,7 @@ func TestMultiEmitter(t *testing.T) {
// Close multiEmitter.
if err := me.Close(); err != nil {
- t.Fatal("me.Close() failed: %v", err)
+ t.Fatalf("me.Close() failed: %v", err)
}
// All testEmitters should be closed.
diff --git a/pkg/gate/gate_test.go b/pkg/gate/gate_test.go
index 850693df8..316015e06 100644
--- a/pkg/gate/gate_test.go
+++ b/pkg/gate/gate_test.go
@@ -15,6 +15,7 @@
package gate_test
import (
+ "runtime"
"testing"
"time"
@@ -165,6 +166,8 @@ func worker(g *gate.Gate, done *sync.WaitGroup) {
if !g.Enter() {
break
}
+ // Golang before v1.14 doesn't preempt busyloops.
+ runtime.Gosched()
g.Leave()
}
done.Done()
diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go
index 019caadca..0d07da3b1 100644
--- a/pkg/ilist/list.go
+++ b/pkg/ilist/list.go
@@ -86,11 +86,21 @@ func (l *List) Back() Element {
return l.tail
}
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+func (l *List) Len() (count int) {
+ for e := l.Front(); e != nil; e = e.Next() {
+ count++
+ }
+ return count
+}
+
// PushFront inserts the element e at the front of list l.
func (l *List) PushFront(e Element) {
- ElementMapper{}.linkerFor(e).SetNext(l.head)
- ElementMapper{}.linkerFor(e).SetPrev(nil)
-
+ linker := ElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
if l.head != nil {
ElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
@@ -102,9 +112,9 @@ func (l *List) PushFront(e Element) {
// PushBack inserts the element e at the back of list l.
func (l *List) PushBack(e Element) {
- ElementMapper{}.linkerFor(e).SetNext(nil)
- ElementMapper{}.linkerFor(e).SetPrev(l.tail)
-
+ linker := ElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
if l.tail != nil {
ElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
@@ -125,17 +135,20 @@ func (l *List) PushBackList(m *List) {
l.tail = m.tail
}
-
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
func (l *List) InsertAfter(b, e Element) {
- a := ElementMapper{}.linkerFor(b).Next()
- ElementMapper{}.linkerFor(e).SetNext(a)
- ElementMapper{}.linkerFor(e).SetPrev(b)
- ElementMapper{}.linkerFor(b).SetNext(e)
+ bLinker := ElementMapper{}.linkerFor(b)
+ eLinker := ElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
if a != nil {
ElementMapper{}.linkerFor(a).SetPrev(e)
@@ -146,10 +159,13 @@ func (l *List) InsertAfter(b, e Element) {
// InsertBefore inserts e before a.
func (l *List) InsertBefore(a, e Element) {
- b := ElementMapper{}.linkerFor(a).Prev()
- ElementMapper{}.linkerFor(e).SetNext(a)
- ElementMapper{}.linkerFor(e).SetPrev(b)
- ElementMapper{}.linkerFor(a).SetPrev(e)
+ aLinker := ElementMapper{}.linkerFor(a)
+ eLinker := ElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
if b != nil {
ElementMapper{}.linkerFor(b).SetNext(e)
@@ -160,8 +176,9 @@ func (l *List) InsertBefore(a, e Element) {
// Remove removes e from l.
func (l *List) Remove(e Element) {
- prev := ElementMapper{}.linkerFor(e).Prev()
- next := ElementMapper{}.linkerFor(e).Next()
+ linker := ElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
if prev != nil {
ElementMapper{}.linkerFor(prev).SetNext(next)
@@ -174,6 +191,9 @@ func (l *List) Remove(e Element) {
} else {
l.tail = prev
}
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
}
// Entry is a default implementation of Linker. Users can add anonymous fields
diff --git a/pkg/log/glog.go b/pkg/log/glog.go
index b4f7bb5a4..f57c4427b 100644
--- a/pkg/log/glog.go
+++ b/pkg/log/glog.go
@@ -25,7 +25,7 @@ import (
// GoogleEmitter is a wrapper that emits logs in a format compatible with
// package github.com/golang/glog.
type GoogleEmitter struct {
- Writer
+ *Writer
}
// pid is used for the threadid component of the header.
@@ -46,7 +46,7 @@ var pid = os.Getpid()
// line The line number
// msg The user-supplied message
//
-func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) {
+func (g GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) {
// Log level.
prefix := byte('?')
switch level {
@@ -81,5 +81,5 @@ func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format
message := fmt.Sprintf(format, args...)
// Emit the formatted result.
- fmt.Fprintf(&g.Writer, "%c%02d%02d %02d:%02d:%02d.%06d % 7d %s:%d] %s\n", prefix, int(month), day, hour, minute, second, microsecond, pid, file, line, message)
+ fmt.Fprintf(g.Writer, "%c%02d%02d %02d:%02d:%02d.%06d % 7d %s:%d] %s\n", prefix, int(month), day, hour, minute, second, microsecond, pid, file, line, message)
}
diff --git a/pkg/log/json.go b/pkg/log/json.go
index 0943db1cc..bdf9d691e 100644
--- a/pkg/log/json.go
+++ b/pkg/log/json.go
@@ -58,7 +58,7 @@ func (lv *Level) UnmarshalJSON(b []byte) error {
// JSONEmitter logs messages in json format.
type JSONEmitter struct {
- Writer
+ *Writer
}
// Emit implements Emitter.Emit.
diff --git a/pkg/log/json_k8s.go b/pkg/log/json_k8s.go
index 6c6fc8b6f..5883e95e1 100644
--- a/pkg/log/json_k8s.go
+++ b/pkg/log/json_k8s.go
@@ -29,11 +29,11 @@ type k8sJSONLog struct {
// K8sJSONEmitter logs messages in json format that is compatible with
// Kubernetes fluent configuration.
type K8sJSONEmitter struct {
- Writer
+ *Writer
}
// Emit implements Emitter.Emit.
-func (e *K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
+func (e K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
j := k8sJSONLog{
Log: fmt.Sprintf(format, v...),
Level: level,
diff --git a/pkg/log/log.go b/pkg/log/log.go
index a794da1aa..37e0605ad 100644
--- a/pkg/log/log.go
+++ b/pkg/log/log.go
@@ -374,5 +374,5 @@ func CopyStandardLogTo(l Level) error {
func init() {
// Store the initial value for the log.
- log.Store(&BasicLogger{Level: Info, Emitter: &GoogleEmitter{Writer{Next: os.Stderr}}})
+ log.Store(&BasicLogger{Level: Info, Emitter: GoogleEmitter{&Writer{Next: os.Stderr}}})
}
diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go
index 402cc29ae..9ff18559b 100644
--- a/pkg/log/log_test.go
+++ b/pkg/log/log_test.go
@@ -52,7 +52,7 @@ func TestDropMessages(t *testing.T) {
t.Fatalf("Write should have failed")
}
- fmt.Printf("writer: %+v\n", w)
+ fmt.Printf("writer: %#v\n", &w)
tw.fail = false
if _, err := w.Write([]byte("line 2\n")); err != nil {
@@ -76,7 +76,7 @@ func TestDropMessages(t *testing.T) {
func TestCaller(t *testing.T) {
tw := &testWriter{}
- e := &GoogleEmitter{Writer: Writer{Next: tw}}
+ e := GoogleEmitter{Writer: &Writer{Next: tw}}
bl := &BasicLogger{
Emitter: e,
Level: Debug,
@@ -94,7 +94,7 @@ func BenchmarkGoogleLogging(b *testing.B) {
tw := &testWriter{
limit: 1, // Only record one message.
}
- e := &GoogleEmitter{Writer: Writer{Next: tw}}
+ e := GoogleEmitter{Writer: &Writer{Next: tw}}
bl := &BasicLogger{
Emitter: e,
Level: Debug,
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index a6f493b82..71e944c30 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -174,7 +174,7 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
// our sendRecv function to use that functionality. Otherwise,
// we stick to sendRecvLegacy.
rversion := Rversion{}
- err := c.sendRecvLegacy(&Tversion{
+ _, err := c.sendRecvLegacy(&Tversion{
Version: versionString(requested),
MSize: messageSize,
}, &rversion)
@@ -219,11 +219,11 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
c.sendRecv = c.sendRecvChannel
} else {
// Channel setup failed; fallback.
- c.sendRecv = c.sendRecvLegacy
+ c.sendRecv = c.sendRecvLegacySyscallErr
}
} else {
// No channels available: use the legacy mechanism.
- c.sendRecv = c.sendRecvLegacy
+ c.sendRecv = c.sendRecvLegacySyscallErr
}
// Ensure that the socket and channels are closed when the socket is shut
@@ -305,7 +305,7 @@ func (c *Client) openChannel(id int) error {
)
// Open the data channel.
- if err := c.sendRecvLegacy(&Tchannel{
+ if _, err := c.sendRecvLegacy(&Tchannel{
ID: uint32(id),
Control: 0,
}, &rchannel0); err != nil {
@@ -319,7 +319,7 @@ func (c *Client) openChannel(id int) error {
defer rchannel0.FilePayload().Close()
// Open the channel for file descriptors.
- if err := c.sendRecvLegacy(&Tchannel{
+ if _, err := c.sendRecvLegacy(&Tchannel{
ID: uint32(id),
Control: 1,
}, &rchannel1); err != nil {
@@ -431,13 +431,28 @@ func (c *Client) waitAndRecv(done chan error) error {
}
}
+// sendRecvLegacySyscallErr is a wrapper for sendRecvLegacy that converts all
+// non-syscall errors to EIO.
+func (c *Client) sendRecvLegacySyscallErr(t message, r message) error {
+ received, err := c.sendRecvLegacy(t, r)
+ if !received {
+ log.Warningf("p9.Client.sendRecvChannel: %v", err)
+ return syscall.EIO
+ }
+ return err
+}
+
// sendRecvLegacy performs a roundtrip message exchange.
//
+// sendRecvLegacy returns true if a message was received. This allows us to
+// differentiate between failed receives and successful receives where the
+// response was an error message.
+//
// This is called by internal functions.
-func (c *Client) sendRecvLegacy(t message, r message) error {
+func (c *Client) sendRecvLegacy(t message, r message) (bool, error) {
tag, ok := c.tagPool.Get()
if !ok {
- return ErrOutOfTags
+ return false, ErrOutOfTags
}
defer c.tagPool.Put(tag)
@@ -457,12 +472,12 @@ func (c *Client) sendRecvLegacy(t message, r message) error {
err := send(c.socket, Tag(tag), t)
c.sendMu.Unlock()
if err != nil {
- return err
+ return false, err
}
// Co-ordinate with other receivers.
if err := c.waitAndRecv(resp.done); err != nil {
- return err
+ return false, err
}
// Is it an error message?
@@ -470,14 +485,14 @@ func (c *Client) sendRecvLegacy(t message, r message) error {
// For convenience, we transform these directly
// into errors. Handlers need not handle this case.
if rlerr, ok := resp.r.(*Rlerror); ok {
- return syscall.Errno(rlerr.Error)
+ return true, syscall.Errno(rlerr.Error)
}
// At this point, we know it matches.
//
// Per recv call above, we will only allow a type
// match (and give our r) or an instance of Rlerror.
- return nil
+ return true, nil
}
// sendRecvChannel uses channels to send a message.
@@ -486,7 +501,7 @@ func (c *Client) sendRecvChannel(t message, r message) error {
c.channelsMu.Lock()
if len(c.availableChannels) == 0 {
c.channelsMu.Unlock()
- return c.sendRecvLegacy(t, r)
+ return c.sendRecvLegacySyscallErr(t, r)
}
idx := len(c.availableChannels) - 1
ch := c.availableChannels[idx]
@@ -526,7 +541,11 @@ func (c *Client) sendRecvChannel(t message, r message) error {
}
// Parse the server's response.
- _, retErr := ch.recv(r, rsz)
+ resp, retErr := ch.recv(r, rsz)
+ if resp == nil {
+ log.Warningf("p9.Client.sendRecvChannel: p9.channel.recv: %v", retErr)
+ retErr = syscall.EIO
+ }
// Release the channel.
c.channelsMu.Lock()
diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go
index 29a0afadf..c757583e0 100644
--- a/pkg/p9/client_test.go
+++ b/pkg/p9/client_test.go
@@ -96,7 +96,12 @@ func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) e
}
func BenchmarkSendRecvLegacy(b *testing.B) {
- benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy })
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error {
+ return func(t message, r message) error {
+ _, err := c.sendRecvLegacy(t, r)
+ return err
+ }
+ })
}
func BenchmarkSendRecvChannel(b *testing.B) {
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
index d4ffbc8e3..cab35896f 100644
--- a/pkg/p9/file.go
+++ b/pkg/p9/file.go
@@ -97,12 +97,12 @@ type File interface {
// free to ignore the hint entirely (i.e. the value returned may be larger
// than size). All size checking is done independently at the syscall layer.
//
- // TODO(b/127675828): Determine concurrency guarantees once implemented.
+ // On the server, GetXattr has a read concurrency guarantee.
GetXattr(name string, size uint64) (string, error)
// SetXattr sets extended attributes on this node.
//
- // TODO(b/127675828): Determine concurrency guarantees once implemented.
+ // On the server, SetXattr has a write concurrency guarantee.
SetXattr(name, value string, flags uint32) error
// ListXattr lists the names of the extended attributes on this node.
@@ -113,12 +113,12 @@ type File interface {
// free to ignore the hint entirely (i.e. the value returned may be larger
// than size). All size checking is done independently at the syscall layer.
//
- // TODO(b/148303075): Determine concurrency guarantees once implemented.
+ // On the server, ListXattr has a read concurrency guarantee.
ListXattr(size uint64) (map[string]struct{}, error)
// RemoveXattr removes extended attributes on this node.
//
- // TODO(b/148303075): Determine concurrency guarantees once implemented.
+ // On the server, RemoveXattr has a write concurrency guarantee.
RemoveXattr(name string) error
// Allocate allows the caller to directly manipulate the allocated disk space
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index 2ac45eb80..1db5797dd 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -48,6 +48,8 @@ func ExtractErrno(err error) syscall.Errno {
return ExtractErrno(e.Err)
case *os.SyscallError:
return ExtractErrno(e.Err)
+ case *os.LinkError:
+ return ExtractErrno(e.Err)
}
// Default case.
@@ -920,8 +922,15 @@ func (t *Tgetxattr) handle(cs *connState) message {
}
defer ref.DecRef()
- val, err := ref.file.GetXattr(t.Name, t.Size)
- if err != nil {
+ var val string
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow getxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ val, err = ref.file.GetXattr(t.Name, t.Size)
+ return err
+ }); err != nil {
return newErr(err)
}
return &Rgetxattr{Value: val}
@@ -935,7 +944,13 @@ func (t *Tsetxattr) handle(cs *connState) message {
}
defer ref.DecRef()
- if err := ref.file.SetXattr(t.Name, t.Value, t.Flags); err != nil {
+ if err := ref.safelyWrite(func() error {
+ // Don't allow setxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ return ref.file.SetXattr(t.Name, t.Value, t.Flags)
+ }); err != nil {
return newErr(err)
}
return &Rsetxattr{}
@@ -949,10 +964,18 @@ func (t *Tlistxattr) handle(cs *connState) message {
}
defer ref.DecRef()
- xattrs, err := ref.file.ListXattr(t.Size)
- if err != nil {
+ var xattrs map[string]struct{}
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow listxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ xattrs, err = ref.file.ListXattr(t.Size)
+ return err
+ }); err != nil {
return newErr(err)
}
+
xattrList := make([]string, 0, len(xattrs))
for x := range xattrs {
xattrList = append(xattrList, x)
@@ -968,7 +991,13 @@ func (t *Tremovexattr) handle(cs *connState) message {
}
defer ref.DecRef()
- if err := ref.file.RemoveXattr(t.Name); err != nil {
+ if err := ref.safelyWrite(func() error {
+ // Don't allow removexattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ return ref.file.RemoveXattr(t.Name)
+ }); err != nil {
return newErr(err)
}
return &Rremovexattr{}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index 3863ad1f5..57b89ad7d 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -1926,19 +1926,17 @@ func (r *Rreaddir) decode(b *buffer) {
// encode implements encoder.encode.
func (r *Rreaddir) encode(b *buffer) {
entriesBuf := buffer{}
+ payloadSize := 0
for _, d := range r.Entries {
d.encode(&entriesBuf)
- if len(entriesBuf.data) >= int(r.Count) {
+ if len(entriesBuf.data) > int(r.Count) {
break
}
+ payloadSize = len(entriesBuf.data)
}
- if len(entriesBuf.data) < int(r.Count) {
- r.Count = uint32(len(entriesBuf.data))
- r.payload = entriesBuf.data
- } else {
- r.payload = entriesBuf.data[:r.Count]
- }
- b.Write32(uint32(r.Count))
+ r.Count = uint32(payloadSize)
+ r.payload = entriesBuf.data[:payloadSize]
+ b.Write32(r.Count)
}
// Type implements message.Type.
diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go
index c20324404..7facc9f5e 100644
--- a/pkg/p9/messages_test.go
+++ b/pkg/p9/messages_test.go
@@ -216,7 +216,7 @@ func TestEncodeDecode(t *testing.T) {
},
&Rreaddir{
// Count must be sufficient to encode a dirent.
- Count: 0x18,
+ Count: 0x1a,
Entries: []Dirent{{QID: QID{Type: 2}}},
},
&Tfsync{
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
index a0d274f3b..38038abdf 100644
--- a/pkg/p9/transport_flipcall.go
+++ b/pkg/p9/transport_flipcall.go
@@ -236,7 +236,7 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) {
// Convert errors appropriately; see above.
if rlerr, ok := r.(*Rlerror); ok {
- return nil, syscall.Errno(rlerr.Error)
+ return r, syscall.Errno(rlerr.Error)
}
return r, nil
diff --git a/pkg/rand/rand_linux.go b/pkg/rand/rand_linux.go
index 0bdad5fad..fa6a21026 100644
--- a/pkg/rand/rand_linux.go
+++ b/pkg/rand/rand_linux.go
@@ -17,6 +17,7 @@
package rand
import (
+ "bufio"
"crypto/rand"
"io"
@@ -45,8 +46,22 @@ func (r *reader) Read(p []byte) (int, error) {
return rand.Read(p)
}
+// bufferedReader implements a threadsafe buffered io.Reader.
+type bufferedReader struct {
+ mu sync.Mutex
+ r *bufio.Reader
+}
+
+// Read implements io.Reader.Read.
+func (b *bufferedReader) Read(p []byte) (int, error) {
+ b.mu.Lock()
+ n, err := b.r.Read(p)
+ b.mu.Unlock()
+ return n, err
+}
+
// Reader is the default reader.
-var Reader io.Reader = &reader{}
+var Reader io.Reader = &bufferedReader{r: bufio.NewReader(&reader{})}
// Read reads from the default reader.
func Read(b []byte) (int, error) {
diff --git a/pkg/safecopy/memcpy_amd64.s b/pkg/safecopy/memcpy_amd64.s
index 129691d68..00b46c18f 100644
--- a/pkg/safecopy/memcpy_amd64.s
+++ b/pkg/safecopy/memcpy_amd64.s
@@ -55,15 +55,9 @@ TEXT ·memcpy(SB), NOSPLIT, $0-36
MOVQ from+8(FP), SI
MOVQ n+16(FP), BX
- // REP instructions have a high startup cost, so we handle small sizes
- // with some straightline code. The REP MOVSQ instruction is really fast
- // for large sizes. The cutover is approximately 2K.
tail:
- // move_129through256 or smaller work whether or not the source and the
- // destination memory regions overlap because they load all data into
- // registers before writing it back. move_256through2048 on the other
- // hand can be used only when the memory regions don't overlap or the copy
- // direction is forward.
+ // BSR+branch table make almost all memmove/memclr benchmarks worse. Not
+ // worth doing.
TESTQ BX, BX
JEQ move_0
CMPQ BX, $2
@@ -83,31 +77,45 @@ tail:
JBE move_65through128
CMPQ BX, $256
JBE move_129through256
- // TODO: use branch table and BSR to make this just a single dispatch
-/*
- * forward copy loop
- */
- CMPQ BX, $2048
- JLS move_256through2048
-
- // Check alignment
- MOVL SI, AX
- ORL DI, AX
- TESTL $7, AX
- JEQ fwdBy8
-
- // Do 1 byte at a time
- MOVQ BX, CX
- REP; MOVSB
- RET
-
-fwdBy8:
- // Do 8 bytes at a time
- MOVQ BX, CX
- SHRQ $3, CX
- ANDQ $7, BX
- REP; MOVSQ
+move_257plus:
+ SUBQ $256, BX
+ MOVOU (SI), X0
+ MOVOU X0, (DI)
+ MOVOU 16(SI), X1
+ MOVOU X1, 16(DI)
+ MOVOU 32(SI), X2
+ MOVOU X2, 32(DI)
+ MOVOU 48(SI), X3
+ MOVOU X3, 48(DI)
+ MOVOU 64(SI), X4
+ MOVOU X4, 64(DI)
+ MOVOU 80(SI), X5
+ MOVOU X5, 80(DI)
+ MOVOU 96(SI), X6
+ MOVOU X6, 96(DI)
+ MOVOU 112(SI), X7
+ MOVOU X7, 112(DI)
+ MOVOU 128(SI), X8
+ MOVOU X8, 128(DI)
+ MOVOU 144(SI), X9
+ MOVOU X9, 144(DI)
+ MOVOU 160(SI), X10
+ MOVOU X10, 160(DI)
+ MOVOU 176(SI), X11
+ MOVOU X11, 176(DI)
+ MOVOU 192(SI), X12
+ MOVOU X12, 192(DI)
+ MOVOU 208(SI), X13
+ MOVOU X13, 208(DI)
+ MOVOU 224(SI), X14
+ MOVOU X14, 224(DI)
+ MOVOU 240(SI), X15
+ MOVOU X15, 240(DI)
+ CMPQ BX, $256
+ LEAQ 256(SI), SI
+ LEAQ 256(DI), DI
+ JGE move_257plus
JMP tail
move_1or2:
@@ -209,42 +217,3 @@ move_129through256:
MOVOU -16(SI)(BX*1), X15
MOVOU X15, -16(DI)(BX*1)
RET
-move_256through2048:
- SUBQ $256, BX
- MOVOU (SI), X0
- MOVOU X0, (DI)
- MOVOU 16(SI), X1
- MOVOU X1, 16(DI)
- MOVOU 32(SI), X2
- MOVOU X2, 32(DI)
- MOVOU 48(SI), X3
- MOVOU X3, 48(DI)
- MOVOU 64(SI), X4
- MOVOU X4, 64(DI)
- MOVOU 80(SI), X5
- MOVOU X5, 80(DI)
- MOVOU 96(SI), X6
- MOVOU X6, 96(DI)
- MOVOU 112(SI), X7
- MOVOU X7, 112(DI)
- MOVOU 128(SI), X8
- MOVOU X8, 128(DI)
- MOVOU 144(SI), X9
- MOVOU X9, 144(DI)
- MOVOU 160(SI), X10
- MOVOU X10, 160(DI)
- MOVOU 176(SI), X11
- MOVOU X11, 176(DI)
- MOVOU 192(SI), X12
- MOVOU X12, 192(DI)
- MOVOU 208(SI), X13
- MOVOU X13, 208(DI)
- MOVOU 224(SI), X14
- MOVOU X14, 224(DI)
- MOVOU 240(SI), X15
- MOVOU X15, 240(DI)
- CMPQ BX, $256
- LEAQ 256(SI), SI
- LEAQ 256(DI), DI
- JGE move_256through2048
- JMP tail
diff --git a/pkg/safemem/seq_test.go b/pkg/safemem/seq_test.go
index eba4bb535..de34005e9 100644
--- a/pkg/safemem/seq_test.go
+++ b/pkg/safemem/seq_test.go
@@ -20,6 +20,27 @@ import (
"testing"
)
+func TestBlockSeqOfEmptyBlock(t *testing.T) {
+ bs := BlockSeqOf(Block{})
+ if !bs.IsEmpty() {
+ t.Errorf("BlockSeqOf(Block{}).IsEmpty(): got false, wanted true; BlockSeq is %v", bs)
+ }
+}
+
+func TestBlockSeqOfNonemptyBlock(t *testing.T) {
+ b := BlockFromSafeSlice(make([]byte, 1))
+ bs := BlockSeqOf(b)
+ if bs.IsEmpty() {
+ t.Fatalf("BlockSeqOf(non-empty Block).IsEmpty(): got true, wanted false; BlockSeq is %v", bs)
+ }
+ if head := bs.Head(); head != b {
+ t.Fatalf("BlockSeqOf(non-empty Block).Head(): got %v, wanted %v", head, b)
+ }
+ if tail := bs.Tail(); !tail.IsEmpty() {
+ t.Fatalf("BlockSeqOf(non-empty Block).Tail().IsEmpty(): got false, wanted true: tail is %v", tail)
+ }
+}
+
type blockSeqTest struct {
desc string
diff --git a/pkg/safemem/seq_unsafe.go b/pkg/safemem/seq_unsafe.go
index dcdfc9600..f5f0574f8 100644
--- a/pkg/safemem/seq_unsafe.go
+++ b/pkg/safemem/seq_unsafe.go
@@ -56,6 +56,9 @@ type BlockSeq struct {
// BlockSeqOf returns a BlockSeq representing the single Block b.
func BlockSeqOf(b Block) BlockSeq {
+ if b.length == 0 {
+ return BlockSeq{}
+ }
bs := BlockSeq{
data: b.start,
length: -1,
diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go
index da5a5e4b2..88766f33b 100644
--- a/pkg/seccomp/seccomp_test.go
+++ b/pkg/seccomp/seccomp_test.go
@@ -451,7 +451,7 @@ func TestRandom(t *testing.T) {
}
}
- fmt.Printf("Testing filters: %v", syscallRules)
+ t.Logf("Testing filters: %v", syscallRules)
instrs, err := BuildProgram([]RuleSet{
RuleSet{
Rules: syscallRules,
diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go
index f19a005f3..97b16c158 100644
--- a/pkg/segment/test/segment_test.go
+++ b/pkg/segment/test/segment_test.go
@@ -63,7 +63,7 @@ func checkSet(s *Set, expectedSegments int) error {
return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments)
}
if got, want := seg.Value(), seg.Start()+valueOffset; got != want {
- return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start, got, want)
+ return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start(), got, want)
}
prev = next
havePrev = true
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index 1d11cc472..a903d031c 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -88,6 +88,9 @@ type Context interface {
// SyscallNo returns the syscall number.
SyscallNo() uintptr
+ // SyscallSaveOrig save orignal register value.
+ SyscallSaveOrig()
+
// SyscallArgs returns the syscall arguments in an array.
SyscallArgs() SyscallArguments
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index d7794db63..c29e1b841 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -32,6 +32,12 @@ import (
const (
// SyscallWidth is the width of insturctions.
SyscallWidth = 4
+
+ // fpsimdMagic is the magic number which is used in fpsimd_context.
+ fpsimdMagic = 0x46508001
+
+ // fpsimdContextSize is the size of fpsimd_context.
+ fpsimdContextSize = 0x210
)
// ARMTrapFlag is the mask for the trap flag.
@@ -40,24 +46,24 @@ const ARMTrapFlag = uint64(1) << 21
// aarch64FPState is aarch64 floating point state.
type aarch64FPState []byte
-// initAarch64FPState (defined in asm files) sets up initial state.
-func initAarch64FPState(data *FloatingPointData) {
- // TODO(gvisor.dev/issue/1238): floating-point is not supported.
+// initAarch64FPState sets up initial state.
+func initAarch64FPState(data aarch64FPState) {
+ binary.LittleEndian.PutUint32(data, fpsimdMagic)
+ binary.LittleEndian.PutUint32(data[4:], fpsimdContextSize)
}
func newAarch64FPStateSlice() []byte {
- return alignedBytes(4096, 32)[:4096]
+ return alignedBytes(4096, 16)[:fpsimdContextSize]
}
// newAarch64FPState returns an initialized floating point state.
//
// The returned state is large enough to store all floating point state
// supported by host, even if the app won't use much of it due to a restricted
-// FeatureSet. Since they may still be able to see state not advertised by
-// CPUID we must ensure it does not contain any sentry state.
+// FeatureSet.
func newAarch64FPState() aarch64FPState {
f := aarch64FPState(newAarch64FPStateSlice())
- initAarch64FPState(f.FloatingPointData())
+ initAarch64FPState(f)
return f
}
@@ -89,8 +95,14 @@ type State struct {
// Our floating point state.
aarch64FPState `state:"wait"`
+ // TLS pointer
+ TPValue uint64
+
// FeatureSet is a pointer to the currently active feature set.
FeatureSet *cpuid.FeatureSet
+
+ // OrigR0 stores the value of register R0.
+ OrigR0 uint64
}
// Proto returns a protobuf representation of the system registers in State.
@@ -136,10 +148,12 @@ func (s State) Proto() *rpb.Registers {
// Fork creates and returns an identical copy of the state.
func (s *State) Fork() State {
- // TODO(gvisor.dev/issue/1238): floating-point is not supported.
return State{
- Regs: s.Regs,
- FeatureSet: s.FeatureSet,
+ Regs: s.Regs,
+ aarch64FPState: s.aarch64FPState.fork(),
+ TPValue: s.TPValue,
+ FeatureSet: s.FeatureSet,
+ OrigR0: s.OrigR0,
}
}
@@ -249,6 +263,7 @@ func (s *State) PtraceSetFPRegs(src io.Reader) (int, error) {
const (
_NT_PRSTATUS = 1
_NT_PRFPREG = 2
+ _NT_ARM_TLS = 0x401
)
// PtraceGetRegSet implements Context.PtraceGetRegSet.
@@ -288,8 +303,10 @@ func New(arch Arch, fs *cpuid.FeatureSet) Context {
case ARM64:
return &context64{
State{
- FeatureSet: fs,
+ aarch64FPState: newAarch64FPState(),
+ FeatureSet: fs,
},
+ []aarch64FPState(nil),
}
}
panic(fmt.Sprintf("unknown architecture %v", arch))
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
index ac98897b5..db99c5acb 100644
--- a/pkg/sentry/arch/arch_arm64.go
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -53,6 +53,11 @@ const (
preferredPIELoadAddr usermem.Addr = maxAddr64 / 6 * 5
)
+var (
+ // CPUIDInstruction doesn't exist on ARM64.
+ CPUIDInstruction = []byte{}
+)
+
// These constants are selected as heuristics to help make the Platform's
// potentially limited address space conform as closely to Linux as possible.
const (
@@ -68,6 +73,7 @@ const (
// context64 represents an ARM64 context.
type context64 struct {
State
+ sigFPState []aarch64FPState // fpstate to be restored on sigreturn.
}
// Arch implements Context.Arch.
@@ -75,10 +81,19 @@ func (c *context64) Arch() Arch {
return ARM64
}
+func (c *context64) copySigFPState() []aarch64FPState {
+ var sigfps []aarch64FPState
+ for _, s := range c.sigFPState {
+ sigfps = append(sigfps, s.fork())
+ }
+ return sigfps
+}
+
// Fork returns an exact copy of this context.
func (c *context64) Fork() Context {
return &context64{
- State: c.State.Fork(),
+ State: c.State.Fork(),
+ sigFPState: c.copySigFPState(),
}
}
@@ -125,16 +140,17 @@ func (c *context64) SetStack(value uintptr) {
// TLS returns the current TLS pointer.
func (c *context64) TLS() uintptr {
- // TODO(gvisor.dev/issue/1238): TLS is not supported.
- // MRS_TPIDR_EL0
- return 0
+ return uintptr(c.TPValue)
}
// SetTLS sets the current TLS pointer. Returns false if value is invalid.
func (c *context64) SetTLS(value uintptr) bool {
- // TODO(gvisor.dev/issue/1238): TLS is not supported.
- // MSR_TPIDR_EL0
- return false
+ if value >= uintptr(maxAddr64) {
+ return false
+ }
+
+ c.TPValue = uint64(value)
+ return true
}
// SetOldRSeqInterruptedIP implements Context.SetOldRSeqInterruptedIP.
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
index e35c9214a..aa31169e0 100644
--- a/pkg/sentry/arch/arch_state_x86.go
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 i386
+// +build amd64 386
package arch
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index 88b40a9d1..7fc4c0473 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 i386
+// +build amd64 386
package arch
diff --git a/pkg/sentry/arch/arch_x86_impl.go b/pkg/sentry/arch/arch_x86_impl.go
index 04ac283c6..3edf40764 100644
--- a/pkg/sentry/arch/arch_x86_impl.go
+++ b/pkg/sentry/arch/arch_x86_impl.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 i386
+// +build amd64 386
package arch
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
index 4f4cc46a8..0c1db4b13 100644
--- a/pkg/sentry/arch/signal_arm64.go
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -30,14 +30,29 @@ type SignalContext64 struct {
Sp uint64
Pc uint64
Pstate uint64
- _pad [8]byte // __attribute__((__aligned__(16)))
- Reserved [4096]uint8
+ _pad [8]byte // __attribute__((__aligned__(16)))
+ Fpsimd64 FpsimdContext // size = 528
+ Reserved [3568]uint8
+}
+
+type aarch64Ctx struct {
+ Magic uint32
+ Size uint32
+}
+
+// FpsimdContext is equivalent to struct fpsimd_context on arm64
+// (arch/arm64/include/uapi/asm/sigcontext.h).
+type FpsimdContext struct {
+ Head aarch64Ctx
+ Fpsr uint32
+ Fpcr uint32
+ Vregs [64]uint64 // actually [32]uint128
}
// UContext64 is equivalent to ucontext on arm64(arch/arm64/include/uapi/asm/ucontext.h).
type UContext64 struct {
Flags uint64
- Link *UContext64
+ Link uint64
Stack SignalStack
Sigset linux.SignalSet
// glibc uses a 1024-bit sigset_t
diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go
index 1a6056171..e58f055c7 100644
--- a/pkg/sentry/arch/signal_stack.go
+++ b/pkg/sentry/arch/signal_stack.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64 arm64
+// +build 386 amd64 arm64
package arch
diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go
index 09bceabc9..1108fa0bd 100644
--- a/pkg/sentry/arch/stack.go
+++ b/pkg/sentry/arch/stack.go
@@ -97,7 +97,6 @@ func (s *Stack) Push(vals ...interface{}) (usermem.Addr, error) {
if c < 0 {
return 0, fmt.Errorf("bad binary.Size for %T", v)
}
- // TODO(b/38173783): Use a real context.Context.
n, err := usermem.CopyObjectOut(context.Background(), s.IO, s.Bottom-usermem.Addr(c), norm, usermem.IOOpts{})
if err != nil || c != n {
return 0, err
@@ -121,11 +120,9 @@ func (s *Stack) Pop(vals ...interface{}) (usermem.Addr, error) {
var err error
if isVaddr {
value := s.Arch.Native(uintptr(0))
- // TODO(b/38173783): Use a real context.Context.
n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, value, usermem.IOOpts{})
*vaddr = usermem.Addr(s.Arch.Value(value))
} else {
- // TODO(b/38173783): Use a real context.Context.
n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, v, usermem.IOOpts{})
}
if err != nil {
diff --git a/pkg/sentry/arch/syscalls_amd64.go b/pkg/sentry/arch/syscalls_amd64.go
index 8b4f23007..3859f41ee 100644
--- a/pkg/sentry/arch/syscalls_amd64.go
+++ b/pkg/sentry/arch/syscalls_amd64.go
@@ -18,6 +18,13 @@ package arch
const restartSyscallNr = uintptr(219)
+// SyscallSaveOrig save the value of the register which is clobbered in
+// syscall handler(doSyscall()).
+//
+// Noop on x86.
+func (c *context64) SyscallSaveOrig() {
+}
+
// SyscallNo returns the syscall number according to the 64-bit convention.
func (c *context64) SyscallNo() uintptr {
return uintptr(c.Regs.Orig_rax)
diff --git a/pkg/sentry/arch/syscalls_arm64.go b/pkg/sentry/arch/syscalls_arm64.go
index 00d5ef461..92d062513 100644
--- a/pkg/sentry/arch/syscalls_arm64.go
+++ b/pkg/sentry/arch/syscalls_arm64.go
@@ -18,6 +18,17 @@ package arch
const restartSyscallNr = uintptr(128)
+// SyscallSaveOrig save the value of the register R0 which is clobbered in
+// syscall handler(doSyscall()).
+//
+// In linux, at the entry of the syscall handler(el0_svc_common()), value of R0
+// is saved to the pt_regs.orig_x0 in kernel code. But currently, the orig_x0
+// was not accessible to the user space application, so we have to do the same
+// operation in the sentry code to save the R0 value into the App context.
+func (c *context64) SyscallSaveOrig() {
+ c.OrigR0 = c.Regs.Regs[0]
+}
+
// SyscallNo returns the syscall number according to the 64-bit convention.
func (c *context64) SyscallNo() uintptr {
return uintptr(c.Regs.Regs[8])
@@ -40,7 +51,7 @@ func (c *context64) SyscallNo() uintptr {
// R30: the link register.
func (c *context64) SyscallArgs() SyscallArguments {
return SyscallArguments{
- SyscallArgument{Value: uintptr(c.Regs.Regs[0])},
+ SyscallArgument{Value: uintptr(c.OrigR0)},
SyscallArgument{Value: uintptr(c.Regs.Regs[1])},
SyscallArgument{Value: uintptr(c.Regs.Regs[2])},
SyscallArgument{Value: uintptr(c.Regs.Regs[3])},
@@ -50,13 +61,21 @@ func (c *context64) SyscallArgs() SyscallArguments {
}
// RestartSyscall implements Context.RestartSyscall.
+// Prepare for system call restart, OrigR0 will be restored to R0.
+// Please see the linux code as reference:
+// arch/arm64/kernel/signal.c:do_signal()
func (c *context64) RestartSyscall() {
c.Regs.Pc -= SyscallWidth
- c.Regs.Regs[8] = uint64(restartSyscallNr)
+ // R0 will be backed up into OrigR0 when entering doSyscall().
+ // Please see the linux code as reference:
+ // arch/arm64/kernel/syscall.c:el0_svc_common().
+ // Here we restore it back.
+ c.Regs.Regs[0] = uint64(c.OrigR0)
}
// RestartSyscallWithRestartBlock implements Context.RestartSyscallWithRestartBlock.
func (c *context64) RestartSyscallWithRestartBlock() {
c.Regs.Pc -= SyscallWidth
+ c.Regs.Regs[0] = uint64(c.OrigR0)
c.Regs.Regs[8] = uint64(restartSyscallNr)
}
diff --git a/pkg/sentry/contexttest/contexttest.go b/pkg/sentry/contexttest/contexttest.go
index 031fc64ec..8e5658c7a 100644
--- a/pkg/sentry/contexttest/contexttest.go
+++ b/pkg/sentry/contexttest/contexttest.go
@@ -97,7 +97,7 @@ type hostClock struct {
}
// Now implements ktime.Clock.Now.
-func (hostClock) Now() ktime.Time {
+func (*hostClock) Now() ktime.Time {
return ktime.FromNanoseconds(time.Now().UnixNano())
}
@@ -127,7 +127,7 @@ func (t *TestContext) Value(key interface{}) interface{} {
case uniqueid.CtxInotifyCookie:
return atomic.AddUint32(&lastInotifyCookie, 1)
case ktime.CtxRealtimeClock:
- return hostClock{}
+ return &hostClock{}
default:
if val, ok := t.otherValues[key]; ok {
return val
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
index 151808911..663e51989 100644
--- a/pkg/sentry/control/pprof.go
+++ b/pkg/sentry/control/pprof.go
@@ -117,9 +117,9 @@ func (p *Profile) HeapProfile(o *ProfileOpts, _ *struct{}) error {
return nil
}
-// Goroutine is an RPC stub which dumps out the stack trace for all running
-// goroutines.
-func (p *Profile) Goroutine(o *ProfileOpts, _ *struct{}) error {
+// GoroutineProfile is an RPC stub which dumps out the stack trace for all
+// running goroutines.
+func (p *Profile) GoroutineProfile(o *ProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
return errNoOutput
}
@@ -131,6 +131,34 @@ func (p *Profile) Goroutine(o *ProfileOpts, _ *struct{}) error {
return nil
}
+// BlockProfile is an RPC stub which dumps out the stack trace that led to
+// blocking on synchronization primitives.
+func (p *Profile) BlockProfile(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+ output := o.FilePayload.Files[0]
+ defer output.Close()
+ if err := pprof.Lookup("block").WriteTo(output, 0); err != nil {
+ return err
+ }
+ return nil
+}
+
+// MutexProfile is an RPC stub which dumps out the stack trace of holders of
+// contended mutexes.
+func (p *Profile) MutexProfile(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+ output := o.FilePayload.Files[0]
+ defer output.Close()
+ if err := pprof.Lookup("mutex").WriteTo(output, 0); err != nil {
+ return err
+ }
+ return nil
+}
+
// StartTrace is an RPC stub which starts collection of an execution trace.
func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index 5457ba5e7..b51fb3959 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -224,8 +224,6 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
}
}
- mounter := fs.FileOwnerFromContext(ctx)
-
// TODO(gvisor.dev/issue/1623): Use host FD when supported in VFS2.
var ttyFile *fs.File
for appFD, hostFile := range args.FilePayload.Files {
@@ -235,7 +233,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
// Import the file as a host TTY file.
if ttyFile == nil {
var err error
- appFile, err = host.ImportFile(ctx, int(hostFile.Fd()), mounter, true /* isTTY */)
+ appFile, err = host.ImportFile(ctx, int(hostFile.Fd()), true /* isTTY */)
if err != nil {
return nil, 0, nil, err
}
@@ -254,7 +252,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
} else {
// Import the file as a regular host file.
var err error
- appFile, err = host.ImportFile(ctx, int(hostFile.Fd()), mounter, false /* isTTY */)
+ appFile, err = host.ImportFile(ctx, int(hostFile.Fd()), false /* isTTY */)
if err != nil {
return nil, 0, nil, err
}
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
index b060a12ff..ab1424c95 100644
--- a/pkg/sentry/fs/copy_up.go
+++ b/pkg/sentry/fs/copy_up.go
@@ -222,8 +222,8 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
}
childUpper, err := parentUpper.Lookup(ctx, next.name)
if err != nil {
- log.Warningf("copy up failed to lookup directory: %v", err)
- cleanupUpper(ctx, parentUpper, next.name)
+ werr := fmt.Errorf("copy up failed to lookup directory: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name, werr)
return syserror.EIO
}
defer childUpper.DecRef()
@@ -242,8 +242,8 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
}
childUpper, err := parentUpper.Lookup(ctx, next.name)
if err != nil {
- log.Warningf("copy up failed to lookup symlink: %v", err)
- cleanupUpper(ctx, parentUpper, next.name)
+ werr := fmt.Errorf("copy up failed to lookup symlink: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name, werr)
return syserror.EIO
}
defer childUpper.DecRef()
@@ -256,23 +256,23 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
// Bring file attributes up to date. This does not include size, which will be
// brought up to date with copyContentsLocked.
if err := copyAttributesLocked(ctx, childUpperInode, next.Inode.overlay.lower); err != nil {
- log.Warningf("copy up failed to copy up attributes: %v", err)
- cleanupUpper(ctx, parentUpper, next.name)
+ werr := fmt.Errorf("copy up failed to copy up attributes: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name, werr)
return syserror.EIO
}
// Copy the entire file.
if err := copyContentsLocked(ctx, childUpperInode, next.Inode.overlay.lower, attrs.Size); err != nil {
- log.Warningf("copy up failed to copy up contents: %v", err)
- cleanupUpper(ctx, parentUpper, next.name)
+ werr := fmt.Errorf("copy up failed to copy up contents: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name, werr)
return syserror.EIO
}
lowerMappable := next.Inode.overlay.lower.Mappable()
upperMappable := childUpperInode.Mappable()
if lowerMappable != nil && upperMappable == nil {
- log.Warningf("copy up failed: cannot ensure memory mapping coherence")
- cleanupUpper(ctx, parentUpper, next.name)
+ werr := fmt.Errorf("copy up failed: cannot ensure memory mapping coherence")
+ cleanupUpper(ctx, parentUpper, next.name, werr)
return syserror.EIO
}
@@ -324,12 +324,14 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
return nil
}
-// cleanupUpper removes name from parent, and panics if it is unsuccessful.
-func cleanupUpper(ctx context.Context, parent *Inode, name string) {
+// cleanupUpper is called when copy-up fails. It logs the copy-up error and
+// attempts to remove name from parent. If that fails, then it panics.
+func cleanupUpper(ctx context.Context, parent *Inode, name string, copyUpErr error) {
+ log.Warningf(copyUpErr.Error())
if err := parent.InodeOperations.Remove(ctx, parent, name); err != nil {
// Unfortunately we don't have much choice. We shouldn't
// willingly give the caller access to a nonsense filesystem.
- panic(fmt.Sprintf("overlay filesystem is in an inconsistent state: failed to remove %q from upper filesystem: %v", name, err))
+ panic(fmt.Sprintf("overlay filesystem is in an inconsistent state: copyUp got error: %v; then cleanup failed to remove %q from upper filesystem: %v.", copyUpErr, name, err))
}
}
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 9b6bb26d0..9379a4d7b 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -26,6 +26,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
"//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/memmap",
"//pkg/sentry/mm",
diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go
index 7e66c29b0..acbd401a0 100644
--- a/pkg/sentry/fs/dev/dev.go
+++ b/pkg/sentry/fs/dev/dev.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -124,10 +125,12 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
"ptmx": newSymlink(ctx, "pts/ptmx", msrc),
"tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor),
+ }
- "net": newDirectory(ctx, map[string]*fs.Inode{
+ if isNetTunSupported(inet.StackFromContext(ctx)) {
+ contents["net"] = newDirectory(ctx, map[string]*fs.Inode{
"tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor),
- }, msrc),
+ }, msrc)
}
iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
index 755644488..dc7ad075a 100644
--- a/pkg/sentry/fs/dev/net_tun.go
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/syserror"
@@ -168,3 +169,9 @@ func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.Eve
func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) {
fops.device.EventUnregister(e)
}
+
+// isNetTunSupported returns whether /dev/net/tun device is supported for s.
+func isNetTunSupported(s inet.Stack) bool {
+ _, ok := s.(*netstack.Stack)
+ return ok
+}
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index e0b32e1c1..65be12175 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -17,7 +17,6 @@ package fs
import (
"fmt"
"path"
- "sort"
"sync/atomic"
"syscall"
@@ -121,9 +120,6 @@ type Dirent struct {
// deleted may be set atomically when removed.
deleted int32
- // frozen indicates this entry can't walk to unknown nodes.
- frozen bool
-
// mounted is true if Dirent is a mount point, similar to include/linux/dcache.h:DCACHE_MOUNTED.
mounted bool
@@ -253,8 +249,7 @@ func (d *Dirent) IsNegative() bool {
return d.Inode == nil
}
-// hashChild will hash child into the children list of its new parent d, carrying over
-// any "frozen" state from d.
+// hashChild will hash child into the children list of its new parent d.
//
// Returns (*WeakRef, true) if hashing child caused a Dirent to be unhashed. The caller must
// validate the returned unhashed weak reference. Common cases:
@@ -282,9 +277,6 @@ func (d *Dirent) hashChild(child *Dirent) (*refs.WeakRef, bool) {
d.IncRef()
}
- // Carry over parent's frozen state.
- child.frozen = d.frozen
-
return d.hashChildParentSet(child)
}
@@ -320,9 +312,9 @@ func (d *Dirent) SyncAll(ctx context.Context) {
// There is nothing to sync for a read-only filesystem.
if !d.Inode.MountSource.Flags.ReadOnly {
- // FIXME(b/34856369): This should be a mount traversal, not a
- // Dirent traversal, because some Inodes that need to be synced
- // may no longer be reachable by name (after sys_unlink).
+ // NOTE(b/34856369): This should be a mount traversal, not a Dirent
+ // traversal, because some Inodes that need to be synced may no longer
+ // be reachable by name (after sys_unlink).
//
// Write out metadata, dirty page cached pages, and sync disk/remote
// caches.
@@ -400,38 +392,6 @@ func (d *Dirent) MountRoot() *Dirent {
return mountRoot
}
-// Freeze prevents this dirent from walking to more nodes. Freeze is applied
-// recursively to all children.
-//
-// If this particular Dirent represents a Virtual node, then Walks and Creates
-// may proceed as before.
-//
-// Freeze can only be called before the application starts running, otherwise
-// the root it might be out of sync with the application root if modified by
-// sys_chroot.
-func (d *Dirent) Freeze() {
- d.mu.Lock()
- defer d.mu.Unlock()
- if d.frozen {
- // Already frozen.
- return
- }
- d.frozen = true
-
- // Take a reference when freezing.
- for _, w := range d.children {
- if child := w.Get(); child != nil {
- // NOTE: We would normally drop the reference here. But
- // instead we're hanging on to it.
- ch := child.(*Dirent)
- ch.Freeze()
- }
- }
-
- // Drop all expired weak references.
- d.flush()
-}
-
// descendantOf returns true if the receiver dirent is equal to, or a
// descendant of, the argument dirent.
//
@@ -524,11 +484,6 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl
w.Drop()
}
- // Are we allowed to do the lookup?
- if d.frozen && !d.Inode.IsVirtual() {
- return nil, syscall.ENOENT
- }
-
// Slow path: load the InodeOperations into memory. Since this is a hot path and the lookup may be
// expensive, if possible release the lock and re-acquire it.
if walkMayUnlock {
@@ -659,11 +614,6 @@ func (d *Dirent) Create(ctx context.Context, root *Dirent, name string, flags Fi
return nil, syscall.EEXIST
}
- // Are we frozen?
- if d.frozen && !d.Inode.IsVirtual() {
- return nil, syscall.ENOENT
- }
-
// Try the create. We need to trust the file system to return EEXIST (or something
// that will translate to EEXIST) if name already exists.
file, err := d.Inode.Create(ctx, d, name, flags, perms)
@@ -727,11 +677,6 @@ func (d *Dirent) genericCreate(ctx context.Context, root *Dirent, name string, c
return syscall.EEXIST
}
- // Are we frozen?
- if d.frozen && !d.Inode.IsVirtual() {
- return syscall.ENOENT
- }
-
// Remove any negative Dirent. We've already asserted above with d.exists
// that the only thing remaining here can be a negative Dirent.
if w, ok := d.children[name]; ok {
@@ -862,49 +807,6 @@ func (d *Dirent) GetDotAttrs(root *Dirent) (DentAttr, DentAttr) {
return dot, dot
}
-// readdirFrozen returns readdir results based solely on the frozen children.
-func (d *Dirent) readdirFrozen(root *Dirent, offset int64, dirCtx *DirCtx) (int64, error) {
- // Collect attrs for "." and "..".
- attrs := make(map[string]DentAttr)
- names := []string{".", ".."}
- attrs["."], attrs[".."] = d.GetDotAttrs(root)
-
- // Get info from all children.
- d.mu.Lock()
- defer d.mu.Unlock()
- for name, w := range d.children {
- if child := w.Get(); child != nil {
- defer child.DecRef()
-
- // Skip negative children.
- if child.(*Dirent).IsNegative() {
- continue
- }
-
- sattr := child.(*Dirent).Inode.StableAttr
- attrs[name] = DentAttr{
- Type: sattr.Type,
- InodeID: sattr.InodeID,
- }
- names = append(names, name)
- }
- }
-
- sort.Strings(names)
-
- if int(offset) >= len(names) {
- return offset, nil
- }
- names = names[int(offset):]
- for _, name := range names {
- if err := dirCtx.DirEmit(name, attrs[name]); err != nil {
- return offset, err
- }
- offset++
- }
- return offset, nil
-}
-
// DirIterator is an open directory containing directory entries that can be read.
type DirIterator interface {
// IterateDir emits directory entries by calling dirCtx.EmitDir, beginning
@@ -964,10 +866,6 @@ func direntReaddir(ctx context.Context, d *Dirent, it DirIterator, root *Dirent,
return offset, nil
}
- if d.frozen {
- return d.readdirFrozen(root, offset, dirCtx)
- }
-
// Collect attrs for "." and "..".
dot, dotdot := d.GetDotAttrs(root)
@@ -1068,11 +966,6 @@ func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err
return nil, syserror.EINVAL
}
- // Are we frozen?
- if d.parent.frozen && !d.parent.Inode.IsVirtual() {
- return nil, syserror.ENOENT
- }
-
// Dirent that'll replace d.
//
// Note that NewDirent returns with one reference taken; the reference
@@ -1101,11 +994,6 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error {
return syserror.ENOENT
}
- // Are we frozen?
- if d.parent.frozen && !d.parent.Inode.IsVirtual() {
- return syserror.ENOENT
- }
-
// Remount our former child in its place.
//
// As replacement used to be our child, it must already have the right
@@ -1135,11 +1023,6 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath
unlock := d.lockDirectory()
defer unlock()
- // Are we frozen?
- if d.frozen && !d.Inode.IsVirtual() {
- return syscall.ENOENT
- }
-
// Try to walk to the node.
child, err := d.walk(ctx, root, name, false /* may unlock */)
if err != nil {
@@ -1201,11 +1084,6 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string)
unlock := d.lockDirectory()
defer unlock()
- // Are we frozen?
- if d.frozen && !d.Inode.IsVirtual() {
- return syscall.ENOENT
- }
-
// Check for dots.
if name == "." {
// Rejected as the last component by rmdir(2).
@@ -1519,15 +1397,6 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
return err
}
- // Are we frozen?
- // TODO(jamieliu): Is this the right errno?
- if oldParent.frozen && !oldParent.Inode.IsVirtual() {
- return syscall.ENOENT
- }
- if newParent.frozen && !newParent.Inode.IsVirtual() {
- return syscall.ENOENT
- }
-
// Do we have general permission to remove from oldParent and
// create/replace in newParent?
if err := oldParent.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go
index 25514ace4..33de32c69 100644
--- a/pkg/sentry/fs/dirent_cache.go
+++ b/pkg/sentry/fs/dirent_cache.go
@@ -101,8 +101,6 @@ func (c *DirentCache) remove(d *Dirent) {
panic(fmt.Sprintf("trying to remove %v, which is not in the dirent cache", d))
}
c.list.Remove(d)
- d.SetPrev(nil)
- d.SetNext(nil)
d.DecRef()
c.currentSize--
if c.limit != nil {
diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go
index 5aff0cc95..a0082ecca 100644
--- a/pkg/sentry/fs/fdpipe/pipe_test.go
+++ b/pkg/sentry/fs/fdpipe/pipe_test.go
@@ -119,7 +119,7 @@ func TestNewPipe(t *testing.T) {
continue
}
if flags := p.flags; test.flags != flags {
- t.Errorf("%s: got file flags %s, want %s", test.desc, flags, test.flags)
+ t.Errorf("%s: got file flags %v, want %v", test.desc, flags, test.flags)
continue
}
if len(test.readAheadBuffer) != len(p.readAheadBuffer) {
@@ -136,7 +136,7 @@ func TestNewPipe(t *testing.T) {
continue
}
if !fdnotifier.HasFD(int32(f.FD())) {
- t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD)
+ t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD())
}
}
}
diff --git a/pkg/sentry/fs/file_overlay_test.go b/pkg/sentry/fs/file_overlay_test.go
index a76d87e3a..1971cc680 100644
--- a/pkg/sentry/fs/file_overlay_test.go
+++ b/pkg/sentry/fs/file_overlay_test.go
@@ -175,90 +175,6 @@ func TestReaddirRevalidation(t *testing.T) {
}
}
-// TestReaddirOverlayFrozen tests that calling Readdir on an overlay file with
-// a frozen dirent tree does not make Readdir calls to the underlying files.
-// This is a regression test for b/114808269.
-func TestReaddirOverlayFrozen(t *testing.T) {
- ctx := contexttest.Context(t)
-
- // Create an overlay with two directories, each with two files.
- upper := newTestRamfsDir(ctx, []dirContent{{name: "upper-file1"}, {name: "upper-file2"}}, nil)
- lower := newTestRamfsDir(ctx, []dirContent{{name: "lower-file1"}, {name: "lower-file2"}}, nil)
- overlayInode := fs.NewTestOverlayDir(ctx, upper, lower, false)
-
- // Set that overlay as the root.
- root := fs.NewDirent(ctx, overlayInode, "root")
- ctx = &rootContext{
- Context: ctx,
- root: root,
- }
-
- // Check that calling Readdir on the root now returns all 4 files (2
- // from each layer in the overlay).
- rootFile, err := root.Inode.GetFile(ctx, root, fs.FileFlags{Read: true})
- if err != nil {
- t.Fatalf("root.Inode.GetFile failed: %v", err)
- }
- defer rootFile.DecRef()
- ser := &fs.CollectEntriesSerializer{}
- if err := rootFile.Readdir(ctx, ser); err != nil {
- t.Fatalf("rootFile.Readdir failed: %v", err)
- }
- if got, want := ser.Order, []string{".", "..", "lower-file1", "lower-file2", "upper-file1", "upper-file2"}; !reflect.DeepEqual(got, want) {
- t.Errorf("Readdir got names %v, want %v", got, want)
- }
-
- // Readdir should have been called on upper and lower.
- upperDir := upper.InodeOperations.(*dir)
- lowerDir := lower.InodeOperations.(*dir)
- if !upperDir.ReaddirCalled {
- t.Errorf("upperDir.ReaddirCalled got %v, want true", upperDir.ReaddirCalled)
- }
- if !lowerDir.ReaddirCalled {
- t.Errorf("lowerDir.ReaddirCalled got %v, want true", lowerDir.ReaddirCalled)
- }
-
- // Reset.
- upperDir.ReaddirCalled = false
- lowerDir.ReaddirCalled = false
-
- // Take references on "upper-file1" and "lower-file1", pinning them in
- // the dirent tree.
- for _, name := range []string{"upper-file1", "lower-file1"} {
- if _, err := root.Walk(ctx, root, name); err != nil {
- t.Fatalf("root.Walk(%q) failed: %v", name, err)
- }
- // Don't drop a reference on the returned dirent so that it
- // will stay in the tree.
- }
-
- // Freeze the dirent tree.
- root.Freeze()
-
- // Seek back to the beginning of the file.
- if _, err := rootFile.Seek(ctx, fs.SeekSet, 0); err != nil {
- t.Fatalf("error seeking to beginning of directory: %v", err)
- }
-
- // Calling Readdir on the root now will return only the pinned
- // children.
- ser = &fs.CollectEntriesSerializer{}
- if err := rootFile.Readdir(ctx, ser); err != nil {
- t.Fatalf("rootFile.Readdir failed: %v", err)
- }
- if got, want := ser.Order, []string{".", "..", "lower-file1", "upper-file1"}; !reflect.DeepEqual(got, want) {
- t.Errorf("Readdir got names %v, want %v", got, want)
- }
-
- // Readdir should NOT have been called on upper or lower.
- if upperDir.ReaddirCalled {
- t.Errorf("upperDir.ReaddirCalled got %v, want false", upperDir.ReaddirCalled)
- }
- if lowerDir.ReaddirCalled {
- t.Errorf("lowerDir.ReaddirCalled got %v, want false", lowerDir.ReaddirCalled)
- }
-}
-
type rootContext struct {
context.Context
root *fs.Dirent
diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go
index daecc4ffe..1922ff08c 100644
--- a/pkg/sentry/fs/fsutil/inode.go
+++ b/pkg/sentry/fs/fsutil/inode.go
@@ -259,8 +259,8 @@ func (i *InodeSimpleExtendedAttributes) ListXattr(context.Context, *fs.Inode, ui
// RemoveXattr implements fs.InodeOperations.RemoveXattr.
func (i *InodeSimpleExtendedAttributes) RemoveXattr(_ context.Context, _ *fs.Inode, name string) error {
- i.mu.RLock()
- defer i.mu.RUnlock()
+ i.mu.Lock()
+ defer i.mu.Unlock()
if _, ok := i.xattrs[name]; ok {
delete(i.xattrs, name)
return nil
diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go
index ff96b28ba..edd6576aa 100644
--- a/pkg/sentry/fs/gofer/file_state.go
+++ b/pkg/sentry/fs/gofer/file_state.go
@@ -34,7 +34,6 @@ func (f *fileOperations) afterLoad() {
flags := f.flags
flags.Truncate = false
- // TODO(b/38173783): Context is not plumbed to save/restore.
f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), flags, f.inodeOperations.cachingInodeOps)
if err != nil {
return fmt.Errorf("failed to re-open handle: %v", err)
diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go
index 9f7c3e89f..fc14249be 100644
--- a/pkg/sentry/fs/gofer/handles.go
+++ b/pkg/sentry/fs/gofer/handles.go
@@ -57,7 +57,6 @@ func (h *handles) DecRef() {
}
}
}
- // FIXME(b/38173783): Context is not plumbed here.
if err := h.File.close(context.Background()); err != nil {
log.Warningf("error closing p9 file: %v", err)
}
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 1c934981b..a016c896e 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -273,7 +273,7 @@ func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handle
// operations on the old will see the new data. Then, make the new handle take
// ownereship of the old FD and mark the old readHandle to not close the FD
// when done.
- if err := syscall.Dup3(h.Host.FD(), i.readHandles.Host.FD(), 0); err != nil {
+ if err := syscall.Dup3(h.Host.FD(), i.readHandles.Host.FD(), syscall.O_CLOEXEC); err != nil {
return err
}
@@ -710,13 +710,10 @@ func init() {
}
// AddLink implements InodeOperations.AddLink, but is currently a noop.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (*inodeOperations) AddLink() {}
// DropLink implements InodeOperations.DropLink, but is currently a noop.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (*inodeOperations) DropLink() {}
// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {}
diff --git a/pkg/sentry/fs/gofer/inode_state.go b/pkg/sentry/fs/gofer/inode_state.go
index 238f7804c..a3402e343 100644
--- a/pkg/sentry/fs/gofer/inode_state.go
+++ b/pkg/sentry/fs/gofer/inode_state.go
@@ -123,7 +123,6 @@ func (i *inodeFileState) afterLoad() {
// beforeSave.
return fmt.Errorf("failed to find path for inode number %d. Device %s contains %s", i.sattr.InodeID, i.s.connID, fs.InodeMappings(i.s.inodeMappings))
}
- // TODO(b/38173783): Context is not plumbed to save/restore.
ctx := &dummyClockContext{context.Background()}
_, i.file, err = i.s.attach.walk(ctx, splitAbsolutePath(name))
diff --git a/pkg/sentry/fs/gofer/session_state.go b/pkg/sentry/fs/gofer/session_state.go
index 111da59f9..2d398b753 100644
--- a/pkg/sentry/fs/gofer/session_state.go
+++ b/pkg/sentry/fs/gofer/session_state.go
@@ -104,7 +104,6 @@ func (s *session) afterLoad() {
// If private unix sockets are enabled, create and fill the session's endpoint
// maps.
if opts.privateunixsocket {
- // TODO(b/38173783): Context is not plumbed to save/restore.
ctx := &dummyClockContext{context.Background()}
if err = s.restoreEndpointMaps(ctx); err != nil {
diff --git a/pkg/sentry/fs/gofer/util.go b/pkg/sentry/fs/gofer/util.go
index 2d8d3a2ea..47a6c69bf 100644
--- a/pkg/sentry/fs/gofer/util.go
+++ b/pkg/sentry/fs/gofer/util.go
@@ -20,17 +20,29 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
)
func utimes(ctx context.Context, file contextFile, ts fs.TimeSpec) error {
if ts.ATimeOmit && ts.MTimeOmit {
return nil
}
+
+ // Replace requests to use the "system time" with the current time to
+ // ensure that timestamps remain consistent with the remote
+ // filesystem.
+ now := ktime.NowFromContext(ctx)
+ if ts.ATimeSetSystemTime {
+ ts.ATime = now
+ }
+ if ts.MTimeSetSystemTime {
+ ts.MTime = now
+ }
mask := p9.SetAttrMask{
ATime: !ts.ATimeOmit,
- ATimeNotSystemTime: !ts.ATimeSetSystemTime,
+ ATimeNotSystemTime: true,
MTime: !ts.MTimeOmit,
- MTimeNotSystemTime: !ts.MTimeSetSystemTime,
+ MTimeNotSystemTime: true,
}
as, ans := ts.ATime.Unix()
ms, mns := ts.MTime.Unix()
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index 21003ea45..aabce6cc9 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -10,7 +10,7 @@ go_library(
"descriptor_state.go",
"device.go",
"file.go",
- "fs.go",
+ "host.go",
"inode.go",
"inode_state.go",
"ioctl_unsafe.go",
@@ -62,18 +62,15 @@ go_test(
size = "small",
srcs = [
"descriptor_test.go",
- "fs_test.go",
"inode_test.go",
"socket_test.go",
"wait_test.go",
],
library = ":host",
deps = [
- "//pkg/context",
"//pkg/fd",
"//pkg/fdnotifier",
"//pkg/sentry/contexttest",
- "//pkg/sentry/fs",
"//pkg/sentry/kernel/time",
"//pkg/sentry/socket",
"//pkg/sentry/socket/unix/transport",
diff --git a/pkg/sentry/fs/host/control.go b/pkg/sentry/fs/host/control.go
index 1658979fc..52c0504b6 100644
--- a/pkg/sentry/fs/host/control.go
+++ b/pkg/sentry/fs/host/control.go
@@ -32,6 +32,8 @@ func newSCMRights(fds []int) control.SCMRights {
}
// Files implements control.SCMRights.Files.
+//
+// TODO(gvisor.dev/issue/2017): Port to VFS2.
func (c *scmRights) Files(ctx context.Context, max int) (control.RightsFiles, bool) {
n := max
var trunc bool
@@ -76,7 +78,7 @@ func fdsToFiles(ctx context.Context, fds []int) []*fs.File {
}
// Create the file backed by hostFD.
- file, err := NewFile(ctx, fd, fs.FileOwnerFromContext(ctx))
+ file, err := NewFile(ctx, fd)
if err != nil {
ctx.Warningf("Error creating file from host FD: %v", err)
break
diff --git a/pkg/sentry/fs/host/descriptor.go b/pkg/sentry/fs/host/descriptor.go
index 2a4d1b291..cfdce6a74 100644
--- a/pkg/sentry/fs/host/descriptor.go
+++ b/pkg/sentry/fs/host/descriptor.go
@@ -16,7 +16,6 @@ package host
import (
"fmt"
- "path"
"syscall"
"gvisor.dev/gvisor/pkg/fdnotifier"
@@ -28,12 +27,9 @@ import (
//
// +stateify savable
type descriptor struct {
- // donated is true if the host fd was donated by another process.
- donated bool
-
// If origFD >= 0, it is the host fd that this file was originally created
// from, which must be available at time of restore. The FD can be closed
- // after descriptor is created. Only set if donated is true.
+ // after descriptor is created.
origFD int
// wouldBlock is true if value (below) points to a file that can
@@ -41,15 +37,13 @@ type descriptor struct {
wouldBlock bool
// value is the wrapped host fd. It is never saved or restored
- // directly. How it is restored depends on whether it was
- // donated and the fs.MountSource it was originally
- // opened/created from.
+ // directly.
value int `state:"nosave"`
}
// newDescriptor returns a wrapped host file descriptor. On success,
// the descriptor is registered for event notifications with queue.
-func newDescriptor(fd int, donated bool, saveable bool, wouldBlock bool, queue *waiter.Queue) (*descriptor, error) {
+func newDescriptor(fd int, saveable bool, wouldBlock bool, queue *waiter.Queue) (*descriptor, error) {
ownedFD := fd
origFD := -1
if saveable {
@@ -69,7 +63,6 @@ func newDescriptor(fd int, donated bool, saveable bool, wouldBlock bool, queue *
}
}
return &descriptor{
- donated: donated,
origFD: origFD,
wouldBlock: wouldBlock,
value: ownedFD,
@@ -77,25 +70,11 @@ func newDescriptor(fd int, donated bool, saveable bool, wouldBlock bool, queue *
}
// initAfterLoad initializes the value of the descriptor after Load.
-func (d *descriptor) initAfterLoad(mo *superOperations, id uint64, queue *waiter.Queue) error {
- if d.donated {
- var err error
- d.value, err = syscall.Dup(d.origFD)
- if err != nil {
- return fmt.Errorf("failed to dup restored fd %d: %v", d.origFD, err)
- }
- } else {
- name, ok := mo.inodeMappings[id]
- if !ok {
- return fmt.Errorf("failed to find path for inode number %d", id)
- }
- fullpath := path.Join(mo.root, name)
-
- var err error
- d.value, err = open(nil, fullpath)
- if err != nil {
- return fmt.Errorf("failed to open %q: %v", fullpath, err)
- }
+func (d *descriptor) initAfterLoad(id uint64, queue *waiter.Queue) error {
+ var err error
+ d.value, err = syscall.Dup(d.origFD)
+ if err != nil {
+ return fmt.Errorf("failed to dup restored fd %d: %v", d.origFD, err)
}
if d.wouldBlock {
if err := syscall.SetNonblock(d.value, true); err != nil {
diff --git a/pkg/sentry/fs/host/descriptor_state.go b/pkg/sentry/fs/host/descriptor_state.go
index 8167390a9..e880582ab 100644
--- a/pkg/sentry/fs/host/descriptor_state.go
+++ b/pkg/sentry/fs/host/descriptor_state.go
@@ -16,7 +16,7 @@ package host
// beforeSave is invoked by stateify.
func (d *descriptor) beforeSave() {
- if d.donated && d.origFD < 0 {
+ if d.origFD < 0 {
panic("donated file descriptor cannot be saved")
}
}
diff --git a/pkg/sentry/fs/host/descriptor_test.go b/pkg/sentry/fs/host/descriptor_test.go
index 4205981f5..d8e4605b6 100644
--- a/pkg/sentry/fs/host/descriptor_test.go
+++ b/pkg/sentry/fs/host/descriptor_test.go
@@ -47,10 +47,10 @@ func TestDescriptorRelease(t *testing.T) {
// FD ownership is transferred to the descritor.
queue := &waiter.Queue{}
- d, err := newDescriptor(fd, false /* donated*/, tc.saveable, tc.wouldBlock, queue)
+ d, err := newDescriptor(fd, tc.saveable, tc.wouldBlock, queue)
if err != nil {
syscall.Close(fd)
- t.Fatalf("newDescriptor(%d, %t, false, %t, queue) failed, err: %v", fd, tc.saveable, tc.wouldBlock, err)
+ t.Fatalf("newDescriptor(%d, %t, %t, queue) failed, err: %v", fd, tc.saveable, tc.wouldBlock, err)
}
if tc.saveable {
if d.origFD < 0 {
diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go
index e08f56d04..3e48b8b2c 100644
--- a/pkg/sentry/fs/host/file.go
+++ b/pkg/sentry/fs/host/file.go
@@ -60,8 +60,8 @@ var _ fs.FileOperations = (*fileOperations)(nil)
// The returned File cannot be saved, since there is no guarantee that the same
// FD will exist or represent the same file at time of restore. If such a
// guarantee does exist, use ImportFile instead.
-func NewFile(ctx context.Context, fd int, mounter fs.FileOwner) (*fs.File, error) {
- return newFileFromDonatedFD(ctx, fd, mounter, false, false)
+func NewFile(ctx context.Context, fd int) (*fs.File, error) {
+ return newFileFromDonatedFD(ctx, fd, false, false)
}
// ImportFile creates a new File backed by the provided host file descriptor.
@@ -71,13 +71,13 @@ func NewFile(ctx context.Context, fd int, mounter fs.FileOwner) (*fs.File, error
// If the returned file is saved, it will be restored by re-importing the FD
// originally passed to ImportFile. It is the restorer's responsibility to
// ensure that the FD represents the same file.
-func ImportFile(ctx context.Context, fd int, mounter fs.FileOwner, isTTY bool) (*fs.File, error) {
- return newFileFromDonatedFD(ctx, fd, mounter, true, isTTY)
+func ImportFile(ctx context.Context, fd int, isTTY bool) (*fs.File, error) {
+ return newFileFromDonatedFD(ctx, fd, true, isTTY)
}
// newFileFromDonatedFD returns an fs.File from a donated FD. If the FD is
// saveable, then saveable is true.
-func newFileFromDonatedFD(ctx context.Context, donated int, mounter fs.FileOwner, saveable, isTTY bool) (*fs.File, error) {
+func newFileFromDonatedFD(ctx context.Context, donated int, saveable, isTTY bool) (*fs.File, error) {
var s syscall.Stat_t
if err := syscall.Fstat(donated, &s); err != nil {
return nil, err
@@ -101,8 +101,8 @@ func newFileFromDonatedFD(ctx context.Context, donated int, mounter fs.FileOwner
})
return s, nil
default:
- msrc := newMountSource(ctx, "/", mounter, &Filesystem{}, fs.MountSourceFlags{}, false /* dontTranslateOwnership */)
- inode, err := newInode(ctx, msrc, donated, saveable, true /* donated */)
+ msrc := fs.NewNonCachingMountSource(ctx, &filesystem{}, fs.MountSourceFlags{})
+ inode, err := newInode(ctx, msrc, donated, saveable)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fs/host/fs.go b/pkg/sentry/fs/host/fs.go
deleted file mode 100644
index d3e8e3a36..000000000
--- a/pkg/sentry/fs/host/fs.go
+++ /dev/null
@@ -1,339 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package host implements an fs.Filesystem for files backed by host
-// file descriptors.
-package host
-
-import (
- "fmt"
- "path"
- "path/filepath"
- "strconv"
- "strings"
-
- "gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/fs"
-)
-
-// FilesystemName is the name under which Filesystem is registered.
-const FilesystemName = "whitelistfs"
-
-const (
- // whitelistKey is the mount option containing a comma-separated list
- // of host paths to whitelist.
- whitelistKey = "whitelist"
-
- // rootPathKey is the mount option containing the root path of the
- // mount.
- rootPathKey = "root"
-
- // dontTranslateOwnershipKey is the key to superOperations.dontTranslateOwnership.
- dontTranslateOwnershipKey = "dont_translate_ownership"
-)
-
-// maxTraversals determines link traversals in building the whitelist.
-const maxTraversals = 10
-
-// Filesystem is a pseudo file system that is only available during the setup
-// to lock down the configurations. This filesystem should only be mounted at root.
-//
-// Think twice before exposing this to applications.
-//
-// +stateify savable
-type Filesystem struct {
- // whitelist is a set of host paths to whitelist.
- paths []string
-}
-
-var _ fs.Filesystem = (*Filesystem)(nil)
-
-// Name is the identifier of this file system.
-func (*Filesystem) Name() string {
- return FilesystemName
-}
-
-// AllowUserMount prohibits users from using mount(2) with this file system.
-func (*Filesystem) AllowUserMount() bool {
- return false
-}
-
-// AllowUserList allows this filesystem to be listed in /proc/filesystems.
-func (*Filesystem) AllowUserList() bool {
- return true
-}
-
-// Flags returns that there is nothing special about this file system.
-func (*Filesystem) Flags() fs.FilesystemFlags {
- return 0
-}
-
-// Mount returns an fs.Inode exposing the host file system. It is intended to be locked
-// down in PreExec below.
-func (f *Filesystem) Mount(ctx context.Context, _ string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) {
- // Parse generic comma-separated key=value options.
- options := fs.GenericMountSourceOptions(data)
-
- // Grab the whitelist if one was specified.
- // TODO(edahlgren/mpratt/hzy): require another option "testonly" in order to allow
- // no whitelist.
- if wl, ok := options[whitelistKey]; ok {
- f.paths = strings.Split(wl, "|")
- delete(options, whitelistKey)
- }
-
- // If the rootPath was set, use it. Othewise default to the root of the
- // host fs.
- rootPath := "/"
- if rp, ok := options[rootPathKey]; ok {
- rootPath = rp
- delete(options, rootPathKey)
-
- // We must relativize the whitelisted paths to the new root.
- for i, p := range f.paths {
- rel, err := filepath.Rel(rootPath, p)
- if err != nil {
- return nil, fmt.Errorf("whitelist path %q must be a child of root path %q", p, rootPath)
- }
- f.paths[i] = path.Join("/", rel)
- }
- }
- fd, err := open(nil, rootPath)
- if err != nil {
- return nil, fmt.Errorf("failed to find root: %v", err)
- }
-
- var dontTranslateOwnership bool
- if v, ok := options[dontTranslateOwnershipKey]; ok {
- b, err := strconv.ParseBool(v)
- if err != nil {
- return nil, fmt.Errorf("invalid value for %q: %v", dontTranslateOwnershipKey, err)
- }
- dontTranslateOwnership = b
- delete(options, dontTranslateOwnershipKey)
- }
-
- // Fail if the caller passed us more options than we know about.
- if len(options) > 0 {
- return nil, fmt.Errorf("unsupported mount options: %v", options)
- }
-
- // The mounting EUID/EGID will be cached by this file system. This will
- // be used to assign ownership to files that we own.
- owner := fs.FileOwnerFromContext(ctx)
-
- // Construct the host file system mount and inode.
- msrc := newMountSource(ctx, rootPath, owner, f, flags, dontTranslateOwnership)
- return newInode(ctx, msrc, fd, false /* saveable */, false /* donated */)
-}
-
-// InstallWhitelist locks down the MountNamespace to only the currently installed
-// Dirents and the given paths.
-func (f *Filesystem) InstallWhitelist(ctx context.Context, m *fs.MountNamespace) error {
- return installWhitelist(ctx, m, f.paths)
-}
-
-func installWhitelist(ctx context.Context, m *fs.MountNamespace, paths []string) error {
- if len(paths) == 0 || (len(paths) == 1 && paths[0] == "") {
- // Warning will be logged during filter installation if the empty
- // whitelist matters (allows for host file access).
- return nil
- }
-
- // Done tracks entries already added.
- done := make(map[string]bool)
- root := m.Root()
- defer root.DecRef()
-
- for i := 0; i < len(paths); i++ {
- // Make sure the path is absolute. This is a sanity check.
- if !path.IsAbs(paths[i]) {
- return fmt.Errorf("path %q is not absolute", paths[i])
- }
-
- // We need to add all the intermediate paths, in case one of
- // them is a symlink that needs to be resolved.
- for j := 1; j <= len(paths[i]); j++ {
- if j < len(paths[i]) && paths[i][j] != '/' {
- continue
- }
- current := paths[i][:j]
-
- // Lookup the given component in the tree.
- remainingTraversals := uint(maxTraversals)
- d, err := m.FindLink(ctx, root, nil, current, &remainingTraversals)
- if err != nil {
- log.Warningf("populate failed for %q: %v", current, err)
- continue
- }
-
- // It's critical that this DecRef happens after the
- // freeze below. This ensures that the dentry is in
- // place to be frozen. Otherwise, we freeze without
- // these entries.
- defer d.DecRef()
-
- // Expand the last component if necessary.
- if current == paths[i] {
- // Is it a directory or symlink?
- sattr := d.Inode.StableAttr
- if fs.IsDir(sattr) {
- for name := range childDentAttrs(ctx, d) {
- paths = append(paths, path.Join(current, name))
- }
- }
- if fs.IsSymlink(sattr) {
- // Only expand symlinks once. The
- // folder structure may contain
- // recursive symlinks and we don't want
- // to end up infinitely expanding this
- // symlink. This is safe because this
- // is the last component. If a later
- // path wants to symlink something
- // beneath this symlink that will still
- // be handled by the FindLink above.
- if done[current] {
- continue
- }
-
- s, err := d.Inode.Readlink(ctx)
- if err != nil {
- log.Warningf("readlink failed for %q: %v", current, err)
- continue
- }
- if path.IsAbs(s) {
- paths = append(paths, s)
- } else {
- target := path.Join(path.Dir(current), s)
- paths = append(paths, target)
- }
- }
- }
-
- // Only report this one once even though we may look
- // it up more than once. If we whitelist /a/b,/a then
- // /a will be "done" when it is looked up for /a/b,
- // however we still need to expand all of its contents
- // when whitelisting /a.
- if !done[current] {
- log.Debugf("whitelisted: %s", current)
- }
- done[current] = true
- }
- }
-
- // Freeze the mount tree in place. This prevents any new paths from
- // being opened and any old ones from being removed. If we do provide
- // tmpfs mounts, we'll want to freeze/thaw those separately.
- m.Freeze()
- return nil
-}
-
-func childDentAttrs(ctx context.Context, d *fs.Dirent) map[string]fs.DentAttr {
- dirname, _ := d.FullName(nil /* root */)
- dir, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
- if err != nil {
- log.Warningf("failed to open directory %q: %v", dirname, err)
- return nil
- }
- dir.DecRef()
- var stubSerializer fs.CollectEntriesSerializer
- if err := dir.Readdir(ctx, &stubSerializer); err != nil {
- log.Warningf("failed to iterate on host directory %q: %v", dirname, err)
- return nil
- }
- delete(stubSerializer.Entries, ".")
- delete(stubSerializer.Entries, "..")
- return stubSerializer.Entries
-}
-
-// newMountSource constructs a new host fs.MountSource
-// relative to a root path. The root should match the mount point.
-func newMountSource(ctx context.Context, root string, mounter fs.FileOwner, filesystem fs.Filesystem, flags fs.MountSourceFlags, dontTranslateOwnership bool) *fs.MountSource {
- return fs.NewMountSource(ctx, &superOperations{
- root: root,
- inodeMappings: make(map[uint64]string),
- mounter: mounter,
- dontTranslateOwnership: dontTranslateOwnership,
- }, filesystem, flags)
-}
-
-// superOperations implements fs.MountSourceOperations.
-//
-// +stateify savable
-type superOperations struct {
- fs.SimpleMountSourceOperations
-
- // root is the path of the mount point. All inode mappings
- // are relative to this root.
- root string
-
- // inodeMappings contains mappings of fs.Inodes associated
- // with this MountSource to paths under root.
- inodeMappings map[uint64]string
-
- // mounter is the cached EUID/EGID that mounted this file system.
- mounter fs.FileOwner
-
- // dontTranslateOwnership indicates whether to not translate file
- // ownership.
- //
- // By default, files/directories owned by the sandbox uses UID/GID
- // of the mounter. For files/directories that are not owned by the
- // sandbox, file UID/GID is translated to a UID/GID which cannot
- // be mapped in the sandboxed application's user namespace. The
- // UID/GID will look like the nobody UID/GID (65534) but is not
- // strictly owned by the user "nobody".
- //
- // If whitelistfs is a lower filesystem in an overlay, set
- // dont_translate_ownership=true in mount options.
- dontTranslateOwnership bool
-}
-
-var _ fs.MountSourceOperations = (*superOperations)(nil)
-
-// ResetInodeMappings implements fs.MountSourceOperations.ResetInodeMappings.
-func (m *superOperations) ResetInodeMappings() {
- m.inodeMappings = make(map[uint64]string)
-}
-
-// SaveInodeMapping implements fs.MountSourceOperations.SaveInodeMapping.
-func (m *superOperations) SaveInodeMapping(inode *fs.Inode, path string) {
- // This is very unintuitive. We *CANNOT* trust the inode's StableAttrs,
- // because overlay copyUp may have changed them out from under us.
- // So much for "immutable".
- sattr := inode.InodeOperations.(*inodeOperations).fileState.sattr
- m.inodeMappings[sattr.InodeID] = path
-}
-
-// Keep implements fs.MountSourceOperations.Keep.
-//
-// TODO(b/72455313,b/77596690): It is possible to change the permissions on a
-// host file while it is in the dirent cache (say from RO to RW), but it is not
-// possible to re-open the file with more relaxed permissions, since the host
-// FD is already open and stored in the inode.
-//
-// Using the dirent LRU cache increases the odds that this bug is encountered.
-// Since host file access is relatively fast anyways, we disable the LRU cache
-// for host fs files. Once we can properly deal with permissions changes and
-// re-opening host files, we should revisit whether or not to make use of the
-// LRU cache.
-func (*superOperations) Keep(*fs.Dirent) bool {
- return false
-}
-
-func init() {
- fs.RegisterFilesystem(&Filesystem{})
-}
diff --git a/pkg/sentry/fs/host/fs_test.go b/pkg/sentry/fs/host/fs_test.go
deleted file mode 100644
index 3111d2df9..000000000
--- a/pkg/sentry/fs/host/fs_test.go
+++ /dev/null
@@ -1,380 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package host
-
-import (
- "fmt"
- "io/ioutil"
- "os"
- "path"
- "reflect"
- "sort"
- "testing"
-
- "gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/sentry/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/fs"
-)
-
-// newTestMountNamespace creates a MountNamespace with a ramfs root.
-// It returns the host folder created, which should be removed when done.
-func newTestMountNamespace(t *testing.T) (*fs.MountNamespace, string, error) {
- p, err := ioutil.TempDir("", "root")
- if err != nil {
- return nil, "", err
- }
-
- fd, err := open(nil, p)
- if err != nil {
- os.RemoveAll(p)
- return nil, "", err
- }
- ctx := contexttest.Context(t)
- root, err := newInode(ctx, newMountSource(ctx, p, fs.RootOwner, &Filesystem{}, fs.MountSourceFlags{}, false), fd, false, false)
- if err != nil {
- os.RemoveAll(p)
- return nil, "", err
- }
- mm, err := fs.NewMountNamespace(ctx, root)
- if err != nil {
- os.RemoveAll(p)
- return nil, "", err
- }
- return mm, p, nil
-}
-
-// createTestDirs populates the root with some test files and directories.
-// /a/a1.txt
-// /a/a2.txt
-// /b/b1.txt
-// /b/c/c1.txt
-// /symlinks/normal.txt
-// /symlinks/to_normal.txt -> /symlinks/normal.txt
-// /symlinks/recursive -> /symlinks
-func createTestDirs(ctx context.Context, t *testing.T, m *fs.MountNamespace) error {
- r := m.Root()
- defer r.DecRef()
-
- if err := r.CreateDirectory(ctx, r, "a", fs.FilePermsFromMode(0777)); err != nil {
- return err
- }
-
- a, err := r.Walk(ctx, r, "a")
- if err != nil {
- return err
- }
- defer a.DecRef()
-
- a1, err := a.Create(ctx, r, "a1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- a1.DecRef()
-
- a2, err := a.Create(ctx, r, "a2.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- a2.DecRef()
-
- if err := r.CreateDirectory(ctx, r, "b", fs.FilePermsFromMode(0777)); err != nil {
- return err
- }
-
- b, err := r.Walk(ctx, r, "b")
- if err != nil {
- return err
- }
- defer b.DecRef()
-
- b1, err := b.Create(ctx, r, "b1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- b1.DecRef()
-
- if err := b.CreateDirectory(ctx, r, "c", fs.FilePermsFromMode(0777)); err != nil {
- return err
- }
-
- c, err := b.Walk(ctx, r, "c")
- if err != nil {
- return err
- }
- defer c.DecRef()
-
- c1, err := c.Create(ctx, r, "c1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- c1.DecRef()
-
- if err := r.CreateDirectory(ctx, r, "symlinks", fs.FilePermsFromMode(0777)); err != nil {
- return err
- }
-
- symlinks, err := r.Walk(ctx, r, "symlinks")
- if err != nil {
- return err
- }
- defer symlinks.DecRef()
-
- normal, err := symlinks.Create(ctx, r, "normal.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- normal.DecRef()
-
- if err := symlinks.CreateLink(ctx, r, "/symlinks/normal.txt", "to_normal.txt"); err != nil {
- return err
- }
-
- return symlinks.CreateLink(ctx, r, "/symlinks", "recursive")
-}
-
-// allPaths returns a slice of all paths of entries visible in the rootfs.
-func allPaths(ctx context.Context, t *testing.T, m *fs.MountNamespace, base string) ([]string, error) {
- var paths []string
- root := m.Root()
- defer root.DecRef()
-
- maxTraversals := uint(1)
- d, err := m.FindLink(ctx, root, nil, base, &maxTraversals)
- if err != nil {
- t.Logf("FindLink failed for %q", base)
- return paths, err
- }
- defer d.DecRef()
-
- if fs.IsDir(d.Inode.StableAttr) {
- dir, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
- if err != nil {
- return nil, fmt.Errorf("failed to open directory %q: %v", base, err)
- }
- iter, ok := dir.FileOperations.(fs.DirIterator)
- if !ok {
- return nil, fmt.Errorf("cannot directly iterate on host directory %q", base)
- }
- dirCtx := &fs.DirCtx{
- Serializer: noopDentrySerializer{},
- }
- if _, err := fs.DirentReaddir(ctx, d, iter, root, dirCtx, 0); err != nil {
- return nil, err
- }
- for name := range dirCtx.DentAttrs() {
- if name == "." || name == ".." {
- continue
- }
-
- fullName := path.Join(base, name)
- paths = append(paths, fullName)
-
- // Recurse.
- subpaths, err := allPaths(ctx, t, m, fullName)
- if err != nil {
- return paths, err
- }
- paths = append(paths, subpaths...)
- }
- }
-
- return paths, nil
-}
-
-type noopDentrySerializer struct{}
-
-func (noopDentrySerializer) CopyOut(string, fs.DentAttr) error {
- return nil
-}
-func (noopDentrySerializer) Written() int {
- return 4096
-}
-
-// pathsEqual returns true if the two string slices contain the same entries.
-func pathsEqual(got, want []string) bool {
- sort.Strings(got)
- sort.Strings(want)
-
- if len(got) != len(want) {
- return false
- }
-
- for i := range got {
- if got[i] != want[i] {
- return false
- }
- }
-
- return true
-}
-
-func TestWhitelist(t *testing.T) {
- for _, test := range []struct {
- // description of the test.
- desc string
- // paths are the paths to whitelist
- paths []string
- // want are all of the directory entries that should be
- // visible (nothing beyond this set should be visible).
- want []string
- }{
- {
- desc: "root",
- paths: []string{"/"},
- want: []string{"/a", "/a/a1.txt", "/a/a2.txt", "/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt", "/symlinks", "/symlinks/normal.txt", "/symlinks/to_normal.txt", "/symlinks/recursive"},
- },
- {
- desc: "top-level directories",
- paths: []string{"/a", "/b"},
- want: []string{"/a", "/a/a1.txt", "/a/a2.txt", "/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "nested directories (1/2)",
- paths: []string{"/b", "/b/c"},
- want: []string{"/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "nested directories (2/2)",
- paths: []string{"/b/c", "/b"},
- want: []string{"/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "single file",
- paths: []string{"/b/c/c1.txt"},
- want: []string{"/b", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "single file and directory",
- paths: []string{"/a/a1.txt", "/b/c"},
- want: []string{"/a", "/a/a1.txt", "/b", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "symlink",
- paths: []string{"/symlinks/to_normal.txt"},
- want: []string{"/symlinks", "/symlinks/normal.txt", "/symlinks/to_normal.txt"},
- },
- {
- desc: "recursive symlink",
- paths: []string{"/symlinks/recursive/normal.txt"},
- want: []string{"/symlinks", "/symlinks/normal.txt", "/symlinks/recursive"},
- },
- } {
- t.Run(test.desc, func(t *testing.T) {
- m, p, err := newTestMountNamespace(t)
- if err != nil {
- t.Errorf("Failed to create MountNamespace: %v", err)
- }
- defer os.RemoveAll(p)
-
- ctx := withRoot(contexttest.RootContext(t), m.Root())
- if err := createTestDirs(ctx, t, m); err != nil {
- t.Errorf("Failed to create test dirs: %v", err)
- }
-
- if err := installWhitelist(ctx, m, test.paths); err != nil {
- t.Errorf("installWhitelist(%v) err got %v want nil", test.paths, err)
- }
-
- got, err := allPaths(ctx, t, m, "/")
- if err != nil {
- t.Fatalf("Failed to lookup paths (whitelisted: %v): %v", test.paths, err)
- }
-
- if !pathsEqual(got, test.want) {
- t.Errorf("For paths %v got %v want %v", test.paths, got, test.want)
- }
- })
- }
-}
-
-func TestRootPath(t *testing.T) {
- // Create a temp dir, which will be the root of our mounted fs.
- rootPath, err := ioutil.TempDir(os.TempDir(), "root")
- if err != nil {
- t.Fatalf("TempDir failed: %v", err)
- }
- defer os.RemoveAll(rootPath)
-
- // Create two files inside the new root, one which will be whitelisted
- // and one not.
- whitelisted, err := ioutil.TempFile(rootPath, "white")
- if err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
- if _, err := ioutil.TempFile(rootPath, "black"); err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
-
- // Create a mount with a root path and single whitelisted file.
- hostFS := &Filesystem{}
- ctx := contexttest.Context(t)
- data := fmt.Sprintf("%s=%s,%s=%s", rootPathKey, rootPath, whitelistKey, whitelisted.Name())
- inode, err := hostFS.Mount(ctx, "", fs.MountSourceFlags{}, data, nil)
- if err != nil {
- t.Fatalf("Mount failed: %v", err)
- }
- mm, err := fs.NewMountNamespace(ctx, inode)
- if err != nil {
- t.Fatalf("NewMountNamespace failed: %v", err)
- }
- if err := hostFS.InstallWhitelist(ctx, mm); err != nil {
- t.Fatalf("InstallWhitelist failed: %v", err)
- }
-
- // Get the contents of the root directory.
- rootDir := mm.Root()
- rctx := withRoot(ctx, rootDir)
- f, err := rootDir.Inode.GetFile(rctx, rootDir, fs.FileFlags{})
- if err != nil {
- t.Fatalf("GetFile failed: %v", err)
- }
- c := &fs.CollectEntriesSerializer{}
- if err := f.Readdir(rctx, c); err != nil {
- t.Fatalf("Readdir failed: %v", err)
- }
-
- // We should have only our whitelisted file, plus the dots.
- want := []string{path.Base(whitelisted.Name()), ".", ".."}
- got := c.Order
- sort.Strings(want)
- sort.Strings(got)
- if !reflect.DeepEqual(got, want) {
- t.Errorf("Readdir got %v, wanted %v", got, want)
- }
-}
-
-type rootContext struct {
- context.Context
- root *fs.Dirent
-}
-
-// withRoot returns a copy of ctx with the given root.
-func withRoot(ctx context.Context, root *fs.Dirent) context.Context {
- return &rootContext{
- Context: ctx,
- root: root,
- }
-}
-
-// Value implements Context.Value.
-func (rc rootContext) Value(key interface{}) interface{} {
- switch key {
- case fs.CtxRoot:
- rc.root.IncRef()
- return rc.root
- default:
- return rc.Context.Value(key)
- }
-}
diff --git a/pkg/sentry/fs/host/host.go b/pkg/sentry/fs/host/host.go
new file mode 100644
index 000000000..081ba1dd8
--- /dev/null
+++ b/pkg/sentry/fs/host/host.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package host supports file descriptors imported directly.
+package host
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// filesystem is a host filesystem.
+//
+// +stateify savable
+type filesystem struct{}
+
+func init() {
+ fs.RegisterFilesystem(&filesystem{})
+}
+
+// FilesystemName is the name under which the filesystem is registered.
+const FilesystemName = "host"
+
+// Name is the name of the filesystem.
+func (*filesystem) Name() string {
+ return FilesystemName
+}
+
+// Mount returns an error. Mounting hostfs is not allowed.
+func (*filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, dataObj interface{}) (*fs.Inode, error) {
+ return nil, syserror.EPERM
+}
+
+// AllowUserMount prohibits users from using mount(2) with this file system.
+func (*filesystem) AllowUserMount() bool {
+ return false
+}
+
+// AllowUserList prohibits this filesystem to be listed in /proc/filesystems.
+func (*filesystem) AllowUserList() bool {
+ return false
+}
+
+// Flags returns that there is nothing special about this file system.
+func (*filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 6fa39caab..62f1246aa 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -17,12 +17,10 @@ package host
import (
"syscall"
- "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/secio"
- "gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -69,9 +67,6 @@ type inodeOperations struct {
//
// +stateify savable
type inodeFileState struct {
- // Common file system state.
- mops *superOperations `state:"wait"`
-
// descriptor is the backing host FD.
descriptor *descriptor `state:"wait"`
@@ -160,7 +155,7 @@ func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, err
if err := syscall.Fstat(i.FD(), &s); err != nil {
return fs.UnstableAttr{}, err
}
- return unstableAttr(i.mops, &s), nil
+ return unstableAttr(&s), nil
}
// Allocate implements fsutil.CachedFileObject.Allocate.
@@ -172,7 +167,7 @@ func (i *inodeFileState) Allocate(_ context.Context, offset, length int64) error
var _ fs.InodeOperations = (*inodeOperations)(nil)
// newInode returns a new fs.Inode backed by the host FD.
-func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool, donated bool) (*fs.Inode, error) {
+func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool) (*fs.Inode, error) {
// Retrieve metadata.
var s syscall.Stat_t
err := syscall.Fstat(fd, &s)
@@ -181,24 +176,17 @@ func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool,
}
fileState := &inodeFileState{
- mops: msrc.MountSourceOperations.(*superOperations),
sattr: stableAttr(&s),
}
// Initialize the wrapped host file descriptor.
- fileState.descriptor, err = newDescriptor(
- fd,
- donated,
- saveable,
- wouldBlock(&s),
- &fileState.queue,
- )
+ fileState.descriptor, err = newDescriptor(fd, saveable, wouldBlock(&s), &fileState.queue)
if err != nil {
return nil, err
}
// Build the fs.InodeOperations.
- uattr := unstableAttr(msrc.MountSourceOperations.(*superOperations), &s)
+ uattr := unstableAttr(&s)
iops := &inodeOperations{
fileState: fileState,
cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{
@@ -232,54 +220,23 @@ func (i *inodeOperations) Release(context.Context) {
// Lookup implements fs.InodeOperations.Lookup.
func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) {
- // Get a new FD relative to i at name.
- fd, err := open(i, name)
- if err != nil {
- if err == syserror.ENOENT {
- return nil, syserror.ENOENT
- }
- return nil, err
- }
-
- inode, err := newInode(ctx, dir.MountSource, fd, false /* saveable */, false /* donated */)
- if err != nil {
- return nil, err
- }
-
- // Return the fs.Dirent.
- return fs.NewDirent(ctx, inode, name), nil
+ return nil, syserror.ENOENT
}
// Create implements fs.InodeOperations.Create.
func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.FileFlags, perm fs.FilePermissions) (*fs.File, error) {
- // Create a file relative to i at name.
- //
- // N.B. We always open this file O_RDWR regardless of flags because a
- // future GetFile might want more access. Open allows this regardless
- // of perm.
- fd, err := openAt(i, name, syscall.O_RDWR|syscall.O_CREAT|syscall.O_EXCL, perm.LinuxMode())
- if err != nil {
- return nil, err
- }
-
- inode, err := newInode(ctx, dir.MountSource, fd, false /* saveable */, false /* donated */)
- if err != nil {
- return nil, err
- }
+ return nil, syserror.EPERM
- d := fs.NewDirent(ctx, inode, name)
- defer d.DecRef()
- return inode.GetFile(ctx, d, flags)
}
// CreateDirectory implements fs.InodeOperations.CreateDirectory.
func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
- return syscall.Mkdirat(i.fileState.FD(), name, uint32(perm.LinuxMode()))
+ return syserror.EPERM
}
// CreateLink implements fs.InodeOperations.CreateLink.
func (i *inodeOperations) CreateLink(ctx context.Context, dir *fs.Inode, oldname string, newname string) error {
- return createLink(i.fileState.FD(), oldname, newname)
+ return syserror.EPERM
}
// CreateHardLink implements fs.InodeOperations.CreateHardLink.
@@ -294,25 +251,17 @@ func (*inodeOperations) CreateFifo(context.Context, *fs.Inode, string, fs.FilePe
// Remove implements fs.InodeOperations.Remove.
func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string) error {
- return unlinkAt(i.fileState.FD(), name, false /* dir */)
+ return syserror.EPERM
}
// RemoveDirectory implements fs.InodeOperations.RemoveDirectory.
func (i *inodeOperations) RemoveDirectory(ctx context.Context, dir *fs.Inode, name string) error {
- return unlinkAt(i.fileState.FD(), name, true /* dir */)
+ return syserror.EPERM
}
// Rename implements fs.InodeOperations.Rename.
func (i *inodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
- op, ok := oldParent.InodeOperations.(*inodeOperations)
- if !ok {
- return syscall.EXDEV
- }
- np, ok := newParent.InodeOperations.(*inodeOperations)
- if !ok {
- return syscall.EXDEV
- }
- return syscall.Renameat(op.fileState.FD(), oldName, np.fileState.FD(), newName)
+ return syserror.EPERM
}
// Bind implements fs.InodeOperations.Bind.
@@ -448,82 +397,17 @@ func (i *inodeOperations) StatFS(context.Context) (fs.Info, error) {
}
// AddLink implements fs.InodeOperations.AddLink.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (i *inodeOperations) AddLink() {}
// DropLink implements fs.InodeOperations.DropLink.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (i *inodeOperations) DropLink() {}
// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {}
// readdirAll returns all of the directory entries in i.
func (i *inodeOperations) readdirAll(d *dirInfo) (map[string]fs.DentAttr, error) {
- i.readdirMu.Lock()
- defer i.readdirMu.Unlock()
-
- fd := i.fileState.FD()
-
- // syscall.ReadDirent will use getdents, which will seek the file past
- // the last directory entry. To read the directory entries a second
- // time, we need to seek back to the beginning.
- if _, err := syscall.Seek(fd, 0, 0); err != nil {
- if err == syscall.ESPIPE {
- // All directories should be seekable. If this file
- // isn't seekable, it is not a directory and we should
- // return that more sane error.
- err = syscall.ENOTDIR
- }
- return nil, err
- }
-
- names := make([]string, 0, 100)
- for {
- // Refill the buffer if necessary
- if d.bufp >= d.nbuf {
- d.bufp = 0
- // ReadDirent will just do a sys_getdents64 to the kernel.
- n, err := syscall.ReadDirent(fd, d.buf)
- if err != nil {
- return nil, err
- }
- if n == 0 {
- break // EOF
- }
- d.nbuf = n
- }
-
- var nb int
- // Parse the dirent buffer we just get and return the directory names along
- // with the number of bytes consumed in the buffer.
- nb, _, names = syscall.ParseDirent(d.buf[d.bufp:d.nbuf], -1, names)
- d.bufp += nb
- }
-
- entries := make(map[string]fs.DentAttr)
- for _, filename := range names {
- // Lookup the type and host device and inode.
- stat, lerr := fstatat(fd, filename, linux.AT_SYMLINK_NOFOLLOW)
- if lerr == syscall.ENOENT {
- // File disappeared between readdir and lstat.
- // Just treat it as if it didn't exist.
- continue
- }
-
- // There was a serious problem, we should probably report it.
- if lerr != nil {
- return nil, lerr
- }
-
- entries[filename] = fs.DentAttr{
- Type: nodeType(&stat),
- InodeID: hostFileDevice.Map(device.MultiDeviceKey{
- Device: stat.Dev,
- Inode: stat.Ino,
- }),
- }
- }
- return entries, nil
+ // We only support non-directory file descriptors that have been
+ // imported, so just claim that this isn't a directory, even if it is.
+ return nil, syscall.ENOTDIR
}
diff --git a/pkg/sentry/fs/host/inode_state.go b/pkg/sentry/fs/host/inode_state.go
index 299e0e0b0..1adbd4562 100644
--- a/pkg/sentry/fs/host/inode_state.go
+++ b/pkg/sentry/fs/host/inode_state.go
@@ -18,29 +18,14 @@ import (
"fmt"
"syscall"
- "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
-// beforeSave is invoked by stateify.
-func (i *inodeFileState) beforeSave() {
- if !i.queue.IsEmpty() {
- panic("event queue must be empty")
- }
- if !i.descriptor.donated && i.sattr.Type == fs.RegularFile {
- uattr, err := i.unstableAttr(context.Background())
- if err != nil {
- panic(fs.ErrSaveRejection{fmt.Errorf("failed to get unstable atttribute of %s: %v", i.mops.inodeMappings[i.sattr.InodeID], err)})
- }
- i.savedUAttr = &uattr
- }
-}
-
// afterLoad is invoked by stateify.
func (i *inodeFileState) afterLoad() {
// Initialize the descriptor value.
- if err := i.descriptor.initAfterLoad(i.mops, i.sattr.InodeID, &i.queue); err != nil {
+ if err := i.descriptor.initAfterLoad(i.sattr.InodeID, &i.queue); err != nil {
panic(fmt.Sprintf("failed to load value of descriptor: %v", err))
}
@@ -61,19 +46,4 @@ func (i *inodeFileState) afterLoad() {
// change across save and restore, error out.
panic(fs.ErrCorruption{fmt.Errorf("host %s conflict in host device mappings: %s", key, hostFileDevice)})
}
-
- if !i.descriptor.donated && i.sattr.Type == fs.RegularFile {
- env, ok := fs.CurrentRestoreEnvironment()
- if !ok {
- panic("missing restore environment")
- }
- uattr := unstableAttr(i.mops, &s)
- if env.ValidateFileSize && uattr.Size != i.savedUAttr.Size {
- panic(fs.ErrCorruption{fmt.Errorf("file size has changed for %s: previously %d, now %d", i.mops.inodeMappings[i.sattr.InodeID], i.savedUAttr.Size, uattr.Size)})
- }
- if env.ValidateFileTimestamp && uattr.ModificationTime != i.savedUAttr.ModificationTime {
- panic(fs.ErrCorruption{fmt.Errorf("file modification time has changed for %s: previously %v, now %v", i.mops.inodeMappings[i.sattr.InodeID], i.savedUAttr.ModificationTime, uattr.ModificationTime)})
- }
- i.savedUAttr = nil
- }
}
diff --git a/pkg/sentry/fs/host/inode_test.go b/pkg/sentry/fs/host/inode_test.go
index 7221bc825..c507f57eb 100644
--- a/pkg/sentry/fs/host/inode_test.go
+++ b/pkg/sentry/fs/host/inode_test.go
@@ -15,79 +15,12 @@
package host
import (
- "io/ioutil"
- "os"
- "path"
"syscall"
"testing"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/fs"
)
-// TestMultipleReaddir verifies that multiple Readdir calls return the same
-// thing if they use different dir contexts.
-func TestMultipleReaddir(t *testing.T) {
- p, err := ioutil.TempDir("", "readdir")
- if err != nil {
- t.Fatalf("Failed to create test dir: %v", err)
- }
- defer os.RemoveAll(p)
-
- f, err := os.Create(path.Join(p, "a.txt"))
- if err != nil {
- t.Fatalf("Failed to create a.txt: %v", err)
- }
- f.Close()
-
- f, err = os.Create(path.Join(p, "b.txt"))
- if err != nil {
- t.Fatalf("Failed to create b.txt: %v", err)
- }
- f.Close()
-
- fd, err := open(nil, p)
- if err != nil {
- t.Fatalf("Failed to open %q: %v", p, err)
- }
- ctx := contexttest.Context(t)
- n, err := newInode(ctx, newMountSource(ctx, p, fs.RootOwner, &Filesystem{}, fs.MountSourceFlags{}, false), fd, false, false)
- if err != nil {
- t.Fatalf("Failed to create inode: %v", err)
- }
-
- dirent := fs.NewDirent(ctx, n, "readdir")
- openFile, err := n.GetFile(ctx, dirent, fs.FileFlags{Read: true})
- if err != nil {
- t.Fatalf("Failed to get file: %v", err)
- }
- defer openFile.DecRef()
-
- c1 := &fs.DirCtx{DirCursor: new(string)}
- if _, err := openFile.FileOperations.(*fileOperations).IterateDir(ctx, dirent, c1, 0); err != nil {
- t.Fatalf("First Readdir failed: %v", err)
- }
-
- c2 := &fs.DirCtx{DirCursor: new(string)}
- if _, err := openFile.FileOperations.(*fileOperations).IterateDir(ctx, dirent, c2, 0); err != nil {
- t.Errorf("Second Readdir failed: %v", err)
- }
-
- if _, ok := c1.DentAttrs()["a.txt"]; !ok {
- t.Errorf("want a.txt in first Readdir, got %v", c1.DentAttrs())
- }
- if _, ok := c1.DentAttrs()["b.txt"]; !ok {
- t.Errorf("want b.txt in first Readdir, got %v", c1.DentAttrs())
- }
-
- if _, ok := c2.DentAttrs()["a.txt"]; !ok {
- t.Errorf("want a.txt in second Readdir, got %v", c2.DentAttrs())
- }
- if _, ok := c2.DentAttrs()["b.txt"]; !ok {
- t.Errorf("want b.txt in second Readdir, got %v", c2.DentAttrs())
- }
-}
-
// TestCloseFD verifies fds will be closed.
func TestCloseFD(t *testing.T) {
var p [2]int
@@ -99,7 +32,7 @@ func TestCloseFD(t *testing.T) {
// Use the write-end because we will detect if it's closed on the read end.
ctx := contexttest.Context(t)
- file, err := NewFile(ctx, p[1], fs.RootOwner)
+ file, err := NewFile(ctx, p[1])
if err != nil {
t.Fatalf("Failed to create File: %v", err)
}
diff --git a/pkg/sentry/fs/host/ioctl_unsafe.go b/pkg/sentry/fs/host/ioctl_unsafe.go
index 271582e54..150ac8e19 100644
--- a/pkg/sentry/fs/host/ioctl_unsafe.go
+++ b/pkg/sentry/fs/host/ioctl_unsafe.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
)
+// LINT.IfChange
+
func ioctlGetTermios(fd int) (*linux.Termios, error) {
var t linux.Termios
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TCGETS, uintptr(unsafe.Pointer(&t)))
@@ -54,3 +56,5 @@ func ioctlSetWinsize(fd int, w *linux.Winsize) error {
}
return nil
}
+
+// LINT.ThenChange(../../fsimpl/host/ioctl_unsafe.go)
diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go
index eb4afe520..affdbcacb 100644
--- a/pkg/sentry/fs/host/socket_test.go
+++ b/pkg/sentry/fs/host/socket_test.go
@@ -199,14 +199,14 @@ func TestListen(t *testing.T) {
}
func TestPasscred(t *testing.T) {
- e := ConnectedEndpoint{}
+ e := &ConnectedEndpoint{}
if got, want := e.Passcred(), false; got != want {
t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want)
}
}
func TestGetLocalAddress(t *testing.T) {
- e := ConnectedEndpoint{path: "foo"}
+ e := &ConnectedEndpoint{path: "foo"}
want := tcpip.FullAddress{Addr: tcpip.Address("foo")}
if got, err := e.GetLocalAddress(); err != nil || got != want {
t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil)
@@ -214,7 +214,7 @@ func TestGetLocalAddress(t *testing.T) {
}
func TestQueuedSize(t *testing.T) {
- e := ConnectedEndpoint{}
+ e := &ConnectedEndpoint{}
tests := []struct {
name string
f func() int64
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 3f218b4a7..cb91355ab 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -26,6 +26,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// TTYFileOperations implements fs.FileOperations for a host file descriptor
// that wraps a TTY FD.
//
@@ -43,6 +45,7 @@ type TTYFileOperations struct {
// connected to this TTY.
fgProcessGroup *kernel.ProcessGroup
+ // termios contains the terminal attributes for this TTY.
termios linux.KernelTermios
}
@@ -357,3 +360,5 @@ func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) e
_ = pg.SendSignal(kernel.SignalInfoPriv(sig))
return kernel.ERESTARTSYS
}
+
+// LINT.ThenChange(../../fsimpl/host/tty.go)
diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go
index e37e687c6..1b0356930 100644
--- a/pkg/sentry/fs/host/util.go
+++ b/pkg/sentry/fs/host/util.go
@@ -16,7 +16,6 @@ package host
import (
"os"
- "path"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -28,45 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-func open(parent *inodeOperations, name string) (int, error) {
- if parent == nil && !path.IsAbs(name) {
- return -1, syserror.EINVAL
- }
- name = path.Clean(name)
-
- // Don't follow through symlinks.
- flags := syscall.O_NOFOLLOW
-
- if fd, err := openAt(parent, name, flags|syscall.O_RDWR, 0); err == nil {
- return fd, nil
- }
- // Retry as read-only.
- if fd, err := openAt(parent, name, flags|syscall.O_RDONLY, 0); err == nil {
- return fd, nil
- }
-
- // Retry as write-only.
- if fd, err := openAt(parent, name, flags|syscall.O_WRONLY, 0); err == nil {
- return fd, nil
- }
-
- // Retry as a symlink, by including O_PATH as an option.
- fd, err := openAt(parent, name, linux.O_PATH|flags, 0)
- if err == nil {
- return fd, nil
- }
-
- // Everything failed.
- return -1, err
-}
-
-func openAt(parent *inodeOperations, name string, flags int, perm linux.FileMode) (int, error) {
- if parent == nil {
- return syscall.Open(name, flags, uint32(perm))
- }
- return syscall.Openat(parent.fileState.FD(), name, flags, uint32(perm))
-}
-
func nodeType(s *syscall.Stat_t) fs.InodeType {
switch x := (s.Mode & syscall.S_IFMT); x {
case syscall.S_IFLNK:
@@ -107,51 +67,19 @@ func stableAttr(s *syscall.Stat_t) fs.StableAttr {
}
}
-func owner(mo *superOperations, s *syscall.Stat_t) fs.FileOwner {
- // User requested no translation, just return actual owner.
- if mo.dontTranslateOwnership {
- return fs.FileOwner{auth.KUID(s.Uid), auth.KGID(s.Gid)}
+func owner(s *syscall.Stat_t) fs.FileOwner {
+ return fs.FileOwner{
+ UID: auth.KUID(s.Uid),
+ GID: auth.KGID(s.Gid),
}
-
- // Show only IDs relevant to the sandboxed task. I.e. if we not own the
- // file, no sandboxed task can own the file. In that case, we
- // use OverflowID for UID, implying that the IDs are not mapped in the
- // "root" user namespace.
- //
- // E.g.
- // sandbox's host EUID/EGID is 1/1.
- // some_dir's host UID/GID is 2/1.
- // Task that mounted this fs has virtualized EUID/EGID 5/5.
- //
- // If you executed `ls -n` in the sandboxed task, it would show:
- // drwxwrxwrx [...] 65534 5 [...] some_dir
-
- // Files are owned by OverflowID by default.
- owner := fs.FileOwner{auth.KUID(auth.OverflowUID), auth.KGID(auth.OverflowGID)}
-
- // If we own file on host, let mounting task's initial EUID own
- // the file.
- if s.Uid == hostUID {
- owner.UID = mo.mounter.UID
- }
-
- // If our group matches file's group, make file's group match
- // the mounting task's initial EGID.
- for _, gid := range hostGIDs {
- if s.Gid == gid {
- owner.GID = mo.mounter.GID
- break
- }
- }
- return owner
}
-func unstableAttr(mo *superOperations, s *syscall.Stat_t) fs.UnstableAttr {
+func unstableAttr(s *syscall.Stat_t) fs.UnstableAttr {
return fs.UnstableAttr{
Size: s.Size,
Usage: s.Blocks * 512,
Perms: fs.FilePermsFromMode(linux.FileMode(s.Mode)),
- Owner: owner(mo, s),
+ Owner: owner(s),
AccessTime: ktime.FromUnix(s.Atim.Sec, s.Atim.Nsec),
ModificationTime: ktime.FromUnix(s.Mtim.Sec, s.Mtim.Nsec),
StatusChangeTime: ktime.FromUnix(s.Ctim.Sec, s.Ctim.Nsec),
@@ -165,6 +93,8 @@ type dirInfo struct {
bufp int // location of next record in buf.
}
+// LINT.IfChange
+
// isBlockError unwraps os errors and checks if they are caused by EAGAIN or
// EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock.
func isBlockError(err error) bool {
@@ -177,6 +107,8 @@ func isBlockError(err error) bool {
return false
}
+// LINT.ThenChange(../../fsimpl/host/util.go)
+
func hostEffectiveKIDs() (uint32, []uint32, error) {
gids, err := os.Getgroups()
if err != nil {
diff --git a/pkg/sentry/fs/host/util_unsafe.go b/pkg/sentry/fs/host/util_unsafe.go
index 3ab36b088..23bd35d64 100644
--- a/pkg/sentry/fs/host/util_unsafe.go
+++ b/pkg/sentry/fs/host/util_unsafe.go
@@ -26,26 +26,6 @@ import (
// NulByte is a single NUL byte. It is passed to readlinkat as an empty string.
var NulByte byte = '\x00'
-func createLink(fd int, name string, linkName string) error {
- namePtr, err := syscall.BytePtrFromString(name)
- if err != nil {
- return err
- }
- linkNamePtr, err := syscall.BytePtrFromString(linkName)
- if err != nil {
- return err
- }
- _, _, errno := syscall.Syscall(
- syscall.SYS_SYMLINKAT,
- uintptr(unsafe.Pointer(namePtr)),
- uintptr(fd),
- uintptr(unsafe.Pointer(linkNamePtr)))
- if errno != 0 {
- return errno
- }
- return nil
-}
-
func readLink(fd int) (string, error) {
// Buffer sizing copied from os.Readlink.
for l := 128; ; l *= 2 {
@@ -66,27 +46,6 @@ func readLink(fd int) (string, error) {
}
}
-func unlinkAt(fd int, name string, dir bool) error {
- namePtr, err := syscall.BytePtrFromString(name)
- if err != nil {
- return err
- }
- var flags uintptr
- if dir {
- flags = linux.AT_REMOVEDIR
- }
- _, _, errno := syscall.Syscall(
- syscall.SYS_UNLINKAT,
- uintptr(fd),
- uintptr(unsafe.Pointer(namePtr)),
- flags,
- )
- if errno != 0 {
- return errno
- }
- return nil
-}
-
func timespecFromTimestamp(t ktime.Time, omit, setSysTime bool) syscall.Timespec {
if omit {
return syscall.Timespec{0, linux.UTIME_OMIT}
diff --git a/pkg/sentry/fs/host/wait_test.go b/pkg/sentry/fs/host/wait_test.go
index d49c3a635..ce397a5e3 100644
--- a/pkg/sentry/fs/host/wait_test.go
+++ b/pkg/sentry/fs/host/wait_test.go
@@ -20,7 +20,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -34,7 +33,7 @@ func TestWait(t *testing.T) {
defer syscall.Close(fds[1])
ctx := contexttest.Context(t)
- file, err := NewFile(ctx, fds[0], fs.RootOwner)
+ file, err := NewFile(ctx, fds[0])
if err != nil {
syscall.Close(fds[0])
t.Fatalf("NewFile failed: %v", err)
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
index 55fb71c16..a34fbc946 100644
--- a/pkg/sentry/fs/inode.go
+++ b/pkg/sentry/fs/inode.go
@@ -102,7 +102,6 @@ func (i *Inode) DecRef() {
// destroy releases the Inode and releases the msrc reference taken.
func (i *Inode) destroy() {
- // FIXME(b/38173783): Context is not plumbed here.
ctx := context.Background()
if err := i.WriteOut(ctx); err != nil {
// FIXME(b/65209558): Mark as warning again once noatime is
@@ -397,8 +396,6 @@ func (i *Inode) Getlink(ctx context.Context) (*Dirent, error) {
// AddLink calls i.InodeOperations.AddLink.
func (i *Inode) AddLink() {
if i.overlay != nil {
- // FIXME(b/63117438): Remove this from InodeOperations altogether.
- //
// This interface is only used by ramfs to update metadata of
// children. These filesystems should _never_ have overlay
// Inodes cached as children. So explicitly disallow this
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
index 5ada33a32..537c8d257 100644
--- a/pkg/sentry/fs/inode_overlay.go
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -231,7 +231,8 @@ func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name st
upperFile.Dirent.Inode.IncRef()
entry, err := newOverlayEntry(ctx, upperFile.Dirent.Inode, nil, false)
if err != nil {
- cleanupUpper(ctx, o.upper, name)
+ werr := fmt.Errorf("newOverlayEntry failed: %v", err)
+ cleanupUpper(ctx, o.upper, name, werr)
return nil, err
}
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index 928c90aa0..e3a715c1f 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -143,7 +143,10 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i
}
var writeLen int64
- for event := i.events.Front(); event != nil; event = event.Next() {
+ for it := i.events.Front(); it != nil; {
+ event := it
+ it = it.Next()
+
// Does the buffer have enough remaining space to hold the event we're
// about to write out?
if dst.NumBytes() < int64(event.sizeOf()) {
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index c7981f66e..b414ddaee 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -273,19 +273,6 @@ func (mns *MountNamespace) DecRef() {
mns.DecRefWithDestructor(mns.destroy)
}
-// Freeze freezes the entire mount tree.
-func (mns *MountNamespace) Freeze() {
- mns.mu.Lock()
- defer mns.mu.Unlock()
-
- // We only want to freeze Dirents with active references, not Dirents referenced
- // by a mount's MountSource.
- mns.flushMountSourceRefsLocked()
-
- // Freeze the entire shebang.
- mns.root.Freeze()
-}
-
// withMountLocked prevents further walks to `node`, because `node` is about to
// be a mount point.
func (mns *MountNamespace) withMountLocked(node *Dirent, fn func() error) error {
diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go
index 94deb553b..1fc9c703c 100644
--- a/pkg/sentry/fs/proc/mounts.go
+++ b/pkg/sentry/fs/proc/mounts.go
@@ -170,7 +170,8 @@ func superBlockOpts(mountPath string, msrc *fs.MountSource) string {
// NOTE(b/147673608): If the mount is a cgroup, we also need to include
// the cgroup name in the options. For now we just read that from the
// path.
- // TODO(gvisor.dev/issues/190): Once gVisor has full cgroup support, we
+ //
+ // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we
// should get this value from the cgroup itself, and not rely on the
// path.
if msrc.FilesystemType == "cgroup" {
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 95d5817ff..bd18177d4 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -40,47 +40,48 @@ import (
// LINT.IfChange
-// newNet creates a new proc net entry.
-func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode {
+// newNetDir creates a new proc net entry.
+func newNetDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ k := t.Kernel()
+
var contents map[string]*fs.Inode
- // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
- // network namespace of the calling process. We should make this per-process,
- // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net.
- if s := p.k.RootNetworkNamespace().Stack(); s != nil {
+ if s := t.NetworkNamespace().Stack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task
+ // network namespace.
contents = map[string]*fs.Inode{
- "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc),
- "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc),
+ "dev": seqfile.NewSeqFileInode(t, &netDev{s: s}, msrc),
+ "snmp": seqfile.NewSeqFileInode(t, &netSnmp{s: s}, msrc),
// The following files are simple stubs until they are
// implemented in netstack, if the file contains a
// header the stub is just the header otherwise it is
// an empty file.
- "arp": newStaticProcInode(ctx, msrc, []byte("IP address HW type Flags HW address Mask Device\n")),
+ "arp": newStaticProcInode(t, msrc, []byte("IP address HW type Flags HW address Mask Device\n")),
- "netlink": newStaticProcInode(ctx, msrc, []byte("sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n")),
- "netstat": newStaticProcInode(ctx, msrc, []byte("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n")),
- "packet": newStaticProcInode(ctx, msrc, []byte("sk RefCnt Type Proto Iface R Rmem User Inode\n")),
- "protocols": newStaticProcInode(ctx, msrc, []byte("protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n")),
+ "netlink": newStaticProcInode(t, msrc, []byte("sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n")),
+ "netstat": newStaticProcInode(t, msrc, []byte("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n")),
+ "packet": newStaticProcInode(t, msrc, []byte("sk RefCnt Type Proto Iface R Rmem User Inode\n")),
+ "protocols": newStaticProcInode(t, msrc, []byte("protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n")),
// Linux sets psched values to: nsec per usec, psched
// tick in ns, 1000000, high res timer ticks per sec
// (ClockGetres returns 1ns resolution).
- "psched": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))),
- "ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function\n")),
- "route": seqfile.NewSeqFileInode(ctx, &netRoute{s: s}, msrc),
- "tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc),
- "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc),
- "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc),
+ "psched": newStaticProcInode(t, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))),
+ "ptype": newStaticProcInode(t, msrc, []byte("Type Device Function\n")),
+ "route": seqfile.NewSeqFileInode(t, &netRoute{s: s}, msrc),
+ "tcp": seqfile.NewSeqFileInode(t, &netTCP{k: k}, msrc),
+ "udp": seqfile.NewSeqFileInode(t, &netUDP{k: k}, msrc),
+ "unix": seqfile.NewSeqFileInode(t, &netUnix{k: k}, msrc),
}
if s.SupportsIPv6() {
- contents["if_inet6"] = seqfile.NewSeqFileInode(ctx, &ifinet6{s: s}, msrc)
- contents["ipv6_route"] = newStaticProcInode(ctx, msrc, []byte(""))
- contents["tcp6"] = seqfile.NewSeqFileInode(ctx, &netTCP6{k: k}, msrc)
- contents["udp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n"))
+ contents["if_inet6"] = seqfile.NewSeqFileInode(t, &ifinet6{s: s}, msrc)
+ contents["ipv6_route"] = newStaticProcInode(t, msrc, []byte(""))
+ contents["tcp6"] = seqfile.NewSeqFileInode(t, &netTCP6{k: k}, msrc)
+ contents["udp6"] = newStaticProcInode(t, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n"))
}
}
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
+ d := ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(t, d, msrc, fs.SpecialDirectory, t)
}
// ifinet6 implements seqfile.SeqSource for /proc/net/if_inet6.
@@ -837,4 +838,4 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
return data, 0
}
-// LINT.ThenChange(../../fsimpl/proc/tasks_net.go)
+// LINT.ThenChange(../../fsimpl/proc/task_net.go)
diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go
index c8abb5052..c659224a7 100644
--- a/pkg/sentry/fs/proc/proc.go
+++ b/pkg/sentry/fs/proc/proc.go
@@ -70,6 +70,7 @@ func New(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string
"loadavg": seqfile.NewSeqFileInode(ctx, &loadavgData{}, msrc),
"meminfo": seqfile.NewSeqFileInode(ctx, &meminfoData{k}, msrc),
"mounts": newProcInode(ctx, ramfs.NewSymlink(ctx, fs.RootOwner, "self/mounts"), msrc, fs.Symlink, nil),
+ "net": newProcInode(ctx, ramfs.NewSymlink(ctx, fs.RootOwner, "self/net"), msrc, fs.Symlink, nil),
"self": newSelf(ctx, pidns, msrc),
"stat": seqfile.NewSeqFileInode(ctx, &statData{k}, msrc),
"thread-self": newThreadSelf(ctx, pidns, msrc),
@@ -86,7 +87,6 @@ func New(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string
}
// Add more contents that need proc to be initialized.
- p.AddChild(ctx, "net", p.newNetDir(ctx, k, msrc))
p.AddChild(ctx, "sys", p.newSysDir(ctx, msrc))
return newProcInode(ctx, p, msrc, fs.SpecialDirectory, nil), nil
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index d4c4b533d..702fdd392 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -80,7 +80,7 @@ func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir
}
// Truncate implements fs.InodeOperations.Truncate.
-func (tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error {
+func (*tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error {
return nil
}
@@ -196,7 +196,7 @@ func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *f
}
// Truncate implements fs.InodeOperations.Truncate.
-func (tcpSack) Truncate(context.Context, *fs.Inode, int64) error {
+func (*tcpSack) Truncate(context.Context, *fs.Inode, int64) error {
return nil
}
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 8ab8d8a02..4d42eac83 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -57,6 +57,16 @@ func getTaskMM(t *kernel.Task) (*mm.MemoryManager, error) {
return m, nil
}
+func checkTaskState(t *kernel.Task) error {
+ switch t.ExitState() {
+ case kernel.TaskExitZombie:
+ return syserror.EACCES
+ case kernel.TaskExitDead:
+ return syserror.ESRCH
+ }
+ return nil
+}
+
// taskDir represents a task-level directory.
//
// +stateify savable
@@ -72,24 +82,27 @@ var _ fs.InodeOperations = (*taskDir)(nil)
// newTaskDir creates a new proc task entry.
func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode {
contents := map[string]*fs.Inode{
- "auxv": newAuxvec(t, msrc),
- "cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
- "comm": newComm(t, msrc),
- "environ": newExecArgInode(t, msrc, environExecArg),
- "exe": newExe(t, msrc),
- "fd": newFdDir(t, msrc),
- "fdinfo": newFdInfoDir(t, msrc),
- "gid_map": newGIDMap(t, msrc),
- "io": newIO(t, msrc, isThreadGroup),
- "maps": newMaps(t, msrc),
- "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
- "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
- "ns": newNamespaceDir(t, msrc),
- "smaps": newSmaps(t, msrc),
- "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns),
- "statm": newStatm(t, msrc),
- "status": newStatus(t, msrc, p.pidns),
- "uid_map": newUIDMap(t, msrc),
+ "auxv": newAuxvec(t, msrc),
+ "cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
+ "comm": newComm(t, msrc),
+ "environ": newExecArgInode(t, msrc, environExecArg),
+ "exe": newExe(t, msrc),
+ "fd": newFdDir(t, msrc),
+ "fdinfo": newFdInfoDir(t, msrc),
+ "gid_map": newGIDMap(t, msrc),
+ "io": newIO(t, msrc, isThreadGroup),
+ "maps": newMaps(t, msrc),
+ "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
+ "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
+ "net": newNetDir(t, msrc),
+ "ns": newNamespaceDir(t, msrc),
+ "oom_score": newOOMScore(t, msrc),
+ "oom_score_adj": newOOMScoreAdj(t, msrc),
+ "smaps": newSmaps(t, msrc),
+ "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns),
+ "statm": newStatm(t, msrc),
+ "status": newStatus(t, msrc, p.pidns),
+ "uid_map": newUIDMap(t, msrc),
}
if isThreadGroup {
contents["task"] = p.newSubtasks(t, msrc)
@@ -251,11 +264,12 @@ func newExe(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
}
func (e *exe) executable() (file fsbridge.File, err error) {
+ if err := checkTaskState(e.t); err != nil {
+ return nil, err
+ }
e.t.WithMuLocked(func(t *kernel.Task) {
mm := t.MemoryManager()
if mm == nil {
- // TODO(b/34851096): Check shouldn't allow Readlink once the
- // Task is zombied.
err = syserror.EACCES
return
}
@@ -265,7 +279,7 @@ func (e *exe) executable() (file fsbridge.File, err error) {
// (with locks held).
file = mm.Executable()
if file == nil {
- err = syserror.ENOENT
+ err = syserror.ESRCH
}
})
return
@@ -310,11 +324,22 @@ func newNamespaceSymlink(t *kernel.Task, msrc *fs.MountSource, name string) *fs.
return newProcInode(t, n, msrc, fs.Symlink, t)
}
+// Readlink reads the symlink value.
+func (n *namespaceSymlink) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ if err := checkTaskState(n.t); err != nil {
+ return "", err
+ }
+ return n.Symlink.Readlink(ctx, inode)
+}
+
// Getlink implements fs.InodeOperations.Getlink.
func (n *namespaceSymlink) Getlink(ctx context.Context, inode *fs.Inode) (*fs.Dirent, error) {
if !kernel.ContextCanTrace(ctx, n.t, false) {
return nil, syserror.EACCES
}
+ if err := checkTaskState(n.t); err != nil {
+ return nil, err
+ }
// Create a new regular file to fake the namespace file.
iops := fsutil.NewNoReadWriteFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0777), linux.PROC_SUPER_MAGIC)
@@ -796,4 +821,95 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc
return int64(n), err
}
+// newOOMScore returns a oom_score file. It is a stub that always returns 0.
+// TODO(gvisor.dev/issue/1967)
+func newOOMScore(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newStaticProcInode(t, msrc, []byte("0\n"))
+}
+
+// oomScoreAdj is a file containing the oom_score adjustment for a task.
+//
+// +stateify savable
+type oomScoreAdj struct {
+ fsutil.SimpleFileInode
+
+ t *kernel.Task
+}
+
+// +stateify savable
+type oomScoreAdjFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ t *kernel.Task
+}
+
+// newOOMScoreAdj returns a oom_score_adj file.
+func newOOMScoreAdj(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ i := &oomScoreAdj{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
+ t: t,
+ }
+ return newProcInode(t, i, msrc, fs.SpecialFile, t)
+}
+
+// Truncate implements fs.InodeOperations.Truncate. Truncate is called when
+// O_TRUNC is specified for any kind of existing Dirent but is not called via
+// (f)truncate for proc files.
+func (*oomScoreAdj) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (o *oomScoreAdj) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &oomScoreAdjFile{t: o.t}), nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *oomScoreAdjFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if f.t.ExitState() == kernel.TaskExitDead {
+ return 0, syserror.ESRCH
+ }
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "%d\n", f.t.OOMScoreAdj())
+ if offset >= int64(buf.Len()) {
+ return 0, io.EOF
+ }
+ n, err := dst.CopyOut(ctx, buf.Bytes()[offset:])
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (f *oomScoreAdjFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit input size so as not to impact performance if input size is large.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+
+ if f.t.ExitState() == kernel.TaskExitDead {
+ return 0, syserror.ESRCH
+ }
+ if err := f.t.SetOOMScoreAdj(v); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
// LINT.ThenChange(../../fsimpl/proc/task.go|../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/tmpfs/fs.go b/pkg/sentry/fs/tmpfs/fs.go
index d5be56c3f..bc117ca6a 100644
--- a/pkg/sentry/fs/tmpfs/fs.go
+++ b/pkg/sentry/fs/tmpfs/fs.go
@@ -44,9 +44,6 @@ const (
// lookup.
cacheRevalidate = "revalidate"
- // TODO(edahlgren/mpratt): support a tmpfs size limit.
- // size = "size"
-
// Permissions that exceed modeMask will be rejected.
modeMask = 01777
diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go
index 6aa17bfc1..89168220a 100644
--- a/pkg/sentry/fsbridge/vfs.go
+++ b/pkg/sentry/fsbridge/vfs.go
@@ -26,22 +26,22 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// fsFile implements File interface over vfs.FileDescription.
+// VFSFile implements File interface over vfs.FileDescription.
//
// +stateify savable
-type vfsFile struct {
+type VFSFile struct {
file *vfs.FileDescription
}
-var _ File = (*vfsFile)(nil)
+var _ File = (*VFSFile)(nil)
// NewVFSFile creates a new File over fs.File.
func NewVFSFile(file *vfs.FileDescription) File {
- return &vfsFile{file: file}
+ return &VFSFile{file: file}
}
// PathnameWithDeleted implements File.
-func (f *vfsFile) PathnameWithDeleted(ctx context.Context) string {
+func (f *VFSFile) PathnameWithDeleted(ctx context.Context) string {
root := vfs.RootFromContext(ctx)
defer root.DecRef()
@@ -51,7 +51,7 @@ func (f *vfsFile) PathnameWithDeleted(ctx context.Context) string {
}
// ReadFull implements File.
-func (f *vfsFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
+func (f *VFSFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
var total int64
for dst.NumBytes() > 0 {
n, err := f.file.PRead(ctx, dst, offset+total, vfs.ReadOptions{})
@@ -67,12 +67,12 @@ func (f *vfsFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset i
}
// ConfigureMMap implements File.
-func (f *vfsFile) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+func (f *VFSFile) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
return f.file.ConfigureMMap(ctx, opts)
}
// Type implements File.
-func (f *vfsFile) Type(ctx context.Context) (linux.FileMode, error) {
+func (f *VFSFile) Type(ctx context.Context) (linux.FileMode, error) {
stat, err := f.file.Stat(ctx, vfs.StatOptions{})
if err != nil {
return 0, err
@@ -81,15 +81,21 @@ func (f *vfsFile) Type(ctx context.Context) (linux.FileMode, error) {
}
// IncRef implements File.
-func (f *vfsFile) IncRef() {
+func (f *VFSFile) IncRef() {
f.file.IncRef()
}
// DecRef implements File.
-func (f *vfsFile) DecRef() {
+func (f *VFSFile) DecRef() {
f.file.DecRef()
}
+// FileDescription returns the FileDescription represented by f. It does not
+// take an additional reference on the returned FileDescription.
+func (f *VFSFile) FileDescription() *vfs.FileDescription {
+ return f.file
+}
+
// fsLookup implements Lookup interface using fs.File.
//
// +stateify savable
@@ -115,8 +121,6 @@ func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry)
//
// remainingTraversals is not configurable in VFS2, all callers are using the
// default anyways.
-//
-// TODO(gvisor.dev/issue/1623): Check mount has read and exec permission.
func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem()
creds := auth.CredentialsFromContext(ctx)
@@ -134,5 +138,5 @@ func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.Open
if err != nil {
return nil, err
}
- return &vfsFile{file: fd}, nil
+ return &VFSFile{file: fd}, nil
}
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
index abd4f24e7..64f1b142c 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
@@ -42,6 +42,11 @@ type FilesystemType struct {
root *vfs.Dentry
}
+// Name implements vfs.FilesystemType.Name.
+func (*FilesystemType) Name() string {
+ return Name
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fst *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
fst.initOnce.Do(func() {
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index 6f78f478f..d83d75b3d 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -45,6 +45,7 @@ go_library(
"//pkg/sentry/fsimpl/ext/disklayout",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
"//pkg/sentry/syscalls/linux",
"//pkg/sentry/vfs",
"//pkg/sync",
diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go
index 373d23b74..7176af6d1 100644
--- a/pkg/sentry/fsimpl/ext/ext.go
+++ b/pkg/sentry/fsimpl/ext/ext.go
@@ -30,6 +30,9 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// Name is the name of this filesystem.
+const Name = "ext"
+
// FilesystemType implements vfs.FilesystemType.
type FilesystemType struct{}
@@ -91,8 +94,13 @@ func isCompatible(sb disklayout.SuperBlock) bool {
return true
}
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
-func (FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
// TODO(b/134676337): Ensure that the user is mounting readonly. If not,
// EACCESS should be returned according to mount(2). Filesystem independent
// flags (like readonly) are currently not available in pkg/sentry/vfs.
@@ -103,7 +111,7 @@ func (FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFile
}
fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)}
- fs.vfsfs.Init(vfsObj, &fs)
+ fs.vfsfs.Init(vfsObj, &fsType, &fs)
fs.sb, err = readSuperBlock(dev)
if err != nil {
return nil, nil, err
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index e05429d41..afea58f65 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -22,6 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -255,6 +257,15 @@ func (fs *filesystem) statTo(stat *linux.Statfs) {
// TODO(b/134676337): Set Statfs.Flags and Statfs.FSID.
}
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ _, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+ return inode.checkPermissions(rp.Credentials(), ats)
+}
+
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
vfsd, inode, err := fs.walk(rp, false)
@@ -453,8 +464,19 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return syserror.EROFS
}
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return nil, err
+ }
+
+ // TODO(b/134676337): Support sockets.
+ return nil, syserror.ECONNREFUSED
+}
+
// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
-func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
_, _, err := fs.walk(rp, false)
if err != nil {
return nil, err
@@ -463,7 +485,7 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([
}
// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
-func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
_, _, err := fs.walk(rp, false)
if err != nil {
return "", err
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index 6962083f5..a39a37318 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -186,7 +186,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt
}
func (in *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
- return vfs.GenericCheckPermissions(creds, ats, in.isDir(), uint16(in.diskInode.Mode()), in.diskInode.UID(), in.diskInode.GID())
+ return vfs.GenericCheckPermissions(creds, ats, in.diskInode.Mode(), in.diskInode.UID(), in.diskInode.GID())
}
// statTo writes the statx fields to the output parameter.
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 4ba76a1e8..99d1e3f8f 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
licenses(["notice"])
@@ -46,6 +46,7 @@ go_library(
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
+ "//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
"//pkg/syserror",
@@ -53,3 +54,13 @@ go_library(
"//pkg/usermem",
],
)
+
+go_test(
+ name = "gofer_test",
+ srcs = ["gofer_test.go"],
+ library = ":gofer",
+ deps = [
+ "//pkg/p9",
+ "//pkg/sentry/contexttest",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 5dbfc6250..49d9f859b 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -56,14 +56,19 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
fd.mu.Lock()
defer fd.mu.Unlock()
+ d := fd.dentry()
if fd.dirents == nil {
- ds, err := fd.dentry().getDirents(ctx)
+ ds, err := d.getDirents(ctx)
if err != nil {
return err
}
fd.dirents = ds
}
+ if d.fs.opts.interop != InteropModeShared {
+ d.touchAtime(fd.vfsfd.Mount())
+ }
+
for fd.off < int64(len(fd.dirents)) {
if err := cb.Handle(fd.dirents[fd.off]); err != nil {
return err
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 5cfb0dc4c..cd744bf5e 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -118,7 +120,7 @@ func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *
if !d.isDir() {
return nil, syserror.ENOTDIR
}
- if err := d.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
afterSymlink:
@@ -313,7 +315,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if err != nil {
return err
}
- if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
if parent.isDeleted() {
@@ -354,7 +356,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if err := create(parent, name); err != nil {
return err
}
- parent.touchCMtime(ctx)
+ if fs.opts.interop != InteropModeShared {
+ parent.touchCMtime()
+ }
delete(parent.negativeChildren, name)
parent.dirents = nil
return nil
@@ -377,7 +381,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
if err != nil {
return err
}
- if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
if err := rp.Mount().CheckBeginWrite(); err != nil {
@@ -433,14 +437,19 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
flags := uint32(0)
if dir {
if child != nil && !child.isDir() {
+ vfsObj.AbortDeleteDentry(childVFSD)
return syserror.ENOTDIR
}
flags = linux.AT_REMOVEDIR
} else {
if child != nil && child.isDir() {
+ vfsObj.AbortDeleteDentry(childVFSD)
return syserror.EISDIR
}
if rp.MustBeDir() {
+ if childVFSD != nil {
+ vfsObj.AbortDeleteDentry(childVFSD)
+ }
return syserror.ENOTDIR
}
}
@@ -452,7 +461,10 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
return err
}
if fs.opts.interop != InteropModeShared {
- parent.touchCMtime(ctx)
+ parent.touchCMtime()
+ if dir {
+ parent.decLinks()
+ }
parent.cacheNegativeChildLocked(name)
parent.dirents = nil
}
@@ -499,6 +511,18 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ds **[]*dentry) {
putDentrySlice(*ds)
}
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return d.checkPermissions(creds, ats)
+}
+
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
var ds *[]*dentry
@@ -512,7 +536,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
if !d.isDir() {
return nil, syserror.ENOTDIR
}
- if err := d.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
}
@@ -556,8 +580,13 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string) error {
creds := rp.Credentials()
- _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
- return err
+ if _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)); err != nil {
+ return err
+ }
+ if fs.opts.interop != InteropModeShared {
+ parent.incLinks()
+ }
+ return nil
})
}
@@ -603,7 +632,7 @@ afterTrailingSymlink:
return nil, err
}
// Check for search permission in the parent directory.
- if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
// Determine whether or not we need to create a file.
@@ -640,7 +669,7 @@ afterTrailingSymlink:
// Preconditions: fs.renameMu must be locked.
func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
ats := vfs.AccessTypesForOpenFlags(opts)
- if err := d.checkPermissions(rp.Credentials(), ats, d.isDir()); err != nil {
+ if err := d.checkPermissions(rp.Credentials(), ats); err != nil {
return nil, err
}
mnt := rp.Mount()
@@ -701,7 +730,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked.
func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
- if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil {
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
return nil, err
}
if d.isDeleted() {
@@ -780,7 +809,6 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
d.IncRef() // reference held by child on its parent d
d.vfsd.InsertChild(&child.vfsd, name)
if d.fs.opts.interop != InteropModeShared {
- d.touchCMtime(ctx)
delete(d.negativeChildren, name)
d.dirents = nil
}
@@ -812,6 +840,9 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
}
childVFSFD = &fd.vfsfd
}
+ if d.fs.opts.interop != InteropModeShared {
+ d.touchCMtime()
+ }
return childVFSFD, nil
}
@@ -863,7 +894,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
return err
}
}
- if err := oldParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ if err := oldParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
vfsObj := rp.VirtualFilesystem()
@@ -883,7 +914,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
return syserror.EINVAL
}
if oldParent != newParent {
- if err := renamed.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil {
+ if err := renamed.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
return err
}
}
@@ -894,7 +925,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
if oldParent != newParent {
- if err := newParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ if err := newParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
newParent.dirMu.Lock()
@@ -949,6 +980,13 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
oldParent.dirents = nil
delete(newParent.negativeChildren, newName)
newParent.dirents = nil
+ if renamed.isDir() {
+ oldParent.decLinks()
+ newParent.incLinks()
+ }
+ oldParent.touchCMtime()
+ newParent.touchCMtime()
+ renamed.touchCtime()
}
vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, &newParent.vfsd, newName, replacedVFSD)
return nil
@@ -1034,8 +1072,15 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return fs.unlinkAt(ctx, rp, false /* dir */)
}
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+//
+// TODO(gvisor.dev/issue/1476): Implement BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) {
+ return nil, syserror.ECONNREFUSED
+}
+
// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
-func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(&ds)
@@ -1043,11 +1088,11 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([
if err != nil {
return nil, err
}
- return d.listxattr(ctx)
+ return d.listxattr(ctx, rp.Credentials(), size)
}
// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
-func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(&ds)
@@ -1055,7 +1100,7 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, nam
if err != nil {
return "", err
}
- return d.getxattr(ctx, name)
+ return d.getxattr(ctx, rp.Credentials(), &opts)
}
// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
@@ -1067,7 +1112,7 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
if err != nil {
return err
}
- return d.setxattr(ctx, &opts)
+ return d.setxattr(ctx, rp.Credentials(), &opts)
}
// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
@@ -1079,7 +1124,7 @@ func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath,
if err != nil {
return err
}
- return d.removexattr(ctx, name)
+ return d.removexattr(ctx, rp.Credentials(), name)
}
// PrependPath implements vfs.FilesystemImpl.PrependPath.
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index d00850e25..2485cdb53 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -34,6 +34,7 @@ package gofer
import (
"fmt"
"strconv"
+ "strings"
"sync"
"sync/atomic"
"syscall"
@@ -44,6 +45,7 @@ import (
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -72,6 +74,9 @@ type filesystem struct {
// client is the client used by this filesystem. client is immutable.
client *p9.Client
+ // clock is a realtime clock used to set timestamps in file operations.
+ clock ktime.Clock
+
// uid and gid are the effective KUID and KGID of the filesystem's creator,
// and are used as the owner and group for files that don't specify one.
// uid and gid are immutable.
@@ -199,6 +204,11 @@ const (
InteropModeShared
)
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
@@ -371,10 +381,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
uid: creds.EffectiveKUID,
gid: creds.EffectiveKGID,
client: client,
+ clock: ktime.RealtimeClockFromContext(ctx),
dentries: make(map[*dentry]struct{}),
specialFileFDs: make(map[*specialFileFD]struct{}),
}
- fs.vfsfs.Init(vfsObj, fs)
+ fs.vfsfs.Init(vfsObj, &fstype, fs)
// Construct the root dentry.
root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr)
@@ -434,7 +445,8 @@ type dentry struct {
// refs is the reference count. Each dentry holds a reference on its
// parent, even if disowned. refs is accessed using atomic memory
- // operations.
+ // operations. When refs reaches 0, the dentry may be added to the cache or
+ // destroyed. If refs==-1 the dentry has already been destroyed.
refs int64
// fs is the owning filesystem. fs is immutable.
@@ -485,6 +497,11 @@ type dentry struct {
// locked to mutate it).
size uint64
+ // nlink counts the number of hard links to this dentry. It's updated and
+ // accessed using atomic operations. It's not protected by metadataMu like the
+ // other metadata fields.
+ nlink uint32
+
mapsMu sync.Mutex
// If this dentry represents a regular file, mappings tracks mappings of
@@ -604,6 +621,9 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
if mask.BTime {
d.btime = dentryTimestampFromP9(attr.BTimeSeconds, attr.BTimeNanoSeconds)
}
+ if mask.NLink {
+ d.nlink = uint32(attr.NLink)
+ }
d.vfsd.Init(d)
fs.syncMu.Lock()
@@ -645,6 +665,9 @@ func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) {
if mask.BTime {
atomic.StoreInt64(&d.btime, dentryTimestampFromP9(attr.BTimeSeconds, attr.BTimeNanoSeconds))
}
+ if mask.NLink {
+ atomic.StoreUint32(&d.nlink, uint32(attr.NLink))
+ }
if mask.Size {
d.dataMu.Lock()
atomic.StoreUint64(&d.size, attr.Size)
@@ -687,10 +710,7 @@ func (d *dentry) fileType() uint32 {
func (d *dentry) statTo(stat *linux.Statx) {
stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME
stat.Blksize = atomic.LoadUint32(&d.blockSize)
- stat.Nlink = 1
- if d.isDir() {
- stat.Nlink = 2
- }
+ stat.Nlink = atomic.LoadUint32(&d.nlink)
stat.UID = atomic.LoadUint32(&d.uid)
stat.GID = atomic.LoadUint32(&d.gid)
stat.Mode = uint16(atomic.LoadUint32(&d.mode))
@@ -703,7 +723,7 @@ func (d *dentry) statTo(stat *linux.Statx) {
stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime))
stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime))
stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime))
- // TODO(jamieliu): device number
+ // TODO(gvisor.dev/issue/1198): device number
}
func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mnt *vfs.Mount) error {
@@ -713,7 +733,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 {
return syserror.EPERM
}
- if err := vfs.CheckSetStat(creds, stat, uint16(atomic.LoadUint32(&d.mode))&^linux.S_IFMT, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ mode := linux.FileMode(atomic.LoadUint32(&d.mode))
+ if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
return err
}
if err := mnt.CheckBeginWrite(); err != nil {
@@ -765,10 +786,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
// data, so there's no cache to truncate either.)
return nil
}
- now, haveNow := nowFromContext(ctx)
- if !haveNow {
- ctx.Warningf("gofer.dentry.setStat: current time not available")
- }
+ now := d.fs.clock.Now().Nanoseconds()
if stat.Mask&linux.STATX_MODE != 0 {
atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode))
}
@@ -780,25 +798,19 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
}
if setLocalAtime {
if stat.Atime.Nsec == linux.UTIME_NOW {
- if haveNow {
- atomic.StoreInt64(&d.atime, now)
- }
+ atomic.StoreInt64(&d.atime, now)
} else {
atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime))
}
}
if setLocalMtime {
if stat.Mtime.Nsec == linux.UTIME_NOW {
- if haveNow {
- atomic.StoreInt64(&d.mtime, now)
- }
+ atomic.StoreInt64(&d.mtime, now)
} else {
atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime))
}
}
- if haveNow {
- atomic.StoreInt64(&d.ctime, now)
- }
+ atomic.StoreInt64(&d.ctime, now)
if stat.Mask&linux.STATX_SIZE != 0 {
d.dataMu.Lock()
oldSize := d.size
@@ -835,8 +847,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
return nil
}
-func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error {
- return vfs.GenericCheckPermissions(creds, ats, isDir, uint16(atomic.LoadUint32(&d.mode))&0777, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
+func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
}
// IncRef implements vfs.DentryImpl.IncRef.
@@ -850,7 +862,7 @@ func (d *dentry) IncRef() {
func (d *dentry) TryIncRef() bool {
for {
refs := atomic.LoadInt64(&d.refs)
- if refs == 0 {
+ if refs <= 0 {
return false
}
if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
@@ -873,13 +885,20 @@ func (d *dentry) DecRef() {
// checkCachingLocked should be called after d's reference count becomes 0 or it
// becomes disowned.
//
+// It may be called on a destroyed dentry. For example,
+// renameMu[R]UnlockAndCheckCaching may call checkCachingLocked multiple times
+// for the same dentry when the dentry is visited more than once in the same
+// operation. One of the calls may destroy the dentry, so subsequent calls will
+// do nothing.
+//
// Preconditions: d.fs.renameMu must be locked for writing.
func (d *dentry) checkCachingLocked() {
// Dentries with a non-zero reference count must be retained. (The only way
// to obtain a reference on a dentry with zero references is via path
// resolution, which requires renameMu, so if d.refs is zero then it will
// remain zero while we hold renameMu for writing.)
- if atomic.LoadInt64(&d.refs) != 0 {
+ refs := atomic.LoadInt64(&d.refs)
+ if refs > 0 {
if d.cached {
d.fs.cachedDentries.Remove(d)
d.fs.cachedDentriesLen--
@@ -887,6 +906,10 @@ func (d *dentry) checkCachingLocked() {
}
return
}
+ if refs == -1 {
+ // Dentry has already been destroyed.
+ return
+ }
// Non-child dentries with zero references are no longer reachable by path
// resolution and should be dropped immediately.
if d.vfsd.Parent() == nil || d.vfsd.IsDisowned() {
@@ -939,9 +962,22 @@ func (d *dentry) checkCachingLocked() {
}
}
+// destroyLocked destroys the dentry. It may flushes dirty pages from cache,
+// close p9 file and remove reference on parent dentry.
+//
// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. d is
// not a child dentry.
func (d *dentry) destroyLocked() {
+ switch atomic.LoadInt64(&d.refs) {
+ case 0:
+ // Mark the dentry destroyed.
+ atomic.StoreInt64(&d.refs, -1)
+ case -1:
+ panic("dentry.destroyLocked() called on already destroyed dentry")
+ default:
+ panic("dentry.destroyLocked() called with references on the dentry")
+ }
+
ctx := context.Background()
d.handleMu.Lock()
if !d.handle.file.isNil() {
@@ -961,7 +997,10 @@ func (d *dentry) destroyLocked() {
d.handle.close(ctx)
}
d.handleMu.Unlock()
- d.file.close(ctx)
+ if !d.file.isNil() {
+ d.file.close(ctx)
+ d.file = p9file{}
+ }
// Remove d from the set of all dentries.
d.fs.syncMu.Lock()
delete(d.fs.dentries, d)
@@ -986,21 +1025,50 @@ func (d *dentry) setDeleted() {
atomic.StoreUint32(&d.deleted, 1)
}
-func (d *dentry) listxattr(ctx context.Context) ([]string, error) {
- return nil, syserror.ENOTSUP
+// We only support xattrs prefixed with "user." (see b/148380782). Currently,
+// there is no need to expose any other xattrs through a gofer.
+func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) {
+ xattrMap, err := d.file.listXattr(ctx, size)
+ if err != nil {
+ return nil, err
+ }
+ xattrs := make([]string, 0, len(xattrMap))
+ for x := range xattrMap {
+ if strings.HasPrefix(x, linux.XATTR_USER_PREFIX) {
+ xattrs = append(xattrs, x)
+ }
+ }
+ return xattrs, nil
}
-func (d *dentry) getxattr(ctx context.Context, name string) (string, error) {
- // TODO(jamieliu): add vfs.GetxattrOptions.Size
- return d.file.getXattr(ctx, name, linux.XATTR_SIZE_MAX)
+func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) {
+ if err := d.checkPermissions(creds, vfs.MayRead); err != nil {
+ return "", err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return "", syserror.EOPNOTSUPP
+ }
+ return d.file.getXattr(ctx, opts.Name, opts.Size)
}
-func (d *dentry) setxattr(ctx context.Context, opts *vfs.SetxattrOptions) error {
+func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetxattrOptions) error {
+ if err := d.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags)
}
-func (d *dentry) removexattr(ctx context.Context, name string) error {
- return syserror.ENOTSUP
+func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name string) error {
+ if err := d.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ return d.file.removeXattr(ctx, name)
}
// Preconditions: d.isRegularFile() || d.isDirectory().
@@ -1045,13 +1113,13 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// using the old file descriptor, preventing us from safely
// closing it. We could handle this by invalidating existing
// memmap.Translations, but this is expensive. Instead, use
- // dup2() to make the old file descriptor refer to the new file
+ // dup3 to make the old file descriptor refer to the new file
// description, then close the new file descriptor (which is no
// longer needed). Racing callers may use the old or new file
// description, but this doesn't matter since they refer to the
// same file (unless d.fs.opts.overlayfsStaleRead is true,
// which we handle separately).
- if err := syscall.Dup2(int(h.fd), int(d.handle.fd)); err != nil {
+ if err := syscall.Dup3(int(h.fd), int(d.handle.fd), syscall.O_CLOEXEC); err != nil {
d.handleMu.Unlock()
ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, d.handle.fd, err)
h.close(ctx)
@@ -1094,6 +1162,26 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
return nil
}
+// incLinks increments link count.
+//
+// Preconditions: d.nlink != 0 && d.nlink < math.MaxUint32.
+func (d *dentry) incLinks() {
+ v := atomic.AddUint32(&d.nlink, 1)
+ if v < 2 {
+ panic(fmt.Sprintf("dentry.nlink is invalid (was 0 or overflowed): %d", v))
+ }
+}
+
+// decLinks decrements link count.
+//
+// Preconditions: d.nlink > 1.
+func (d *dentry) decLinks() {
+ v := atomic.AddUint32(&d.nlink, ^uint32(0))
+ if v == 0 {
+ panic(fmt.Sprintf("dentry.nlink must be greater than 0: %d", v))
+ }
+}
+
// fileDescription is embedded by gofer implementations of
// vfs.FileDescriptionImpl.
type fileDescription struct {
@@ -1112,7 +1200,8 @@ func (fd *fileDescription) dentry() *dentry {
// Stat implements vfs.FileDescriptionImpl.Stat.
func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
d := fd.dentry()
- if d.fs.opts.interop == InteropModeShared && opts.Mask&(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME|linux.STATX_SIZE|linux.STATX_BLOCKS|linux.STATX_BTIME) != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC {
+ const validMask = uint32(linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME)
+ if d.fs.opts.interop == InteropModeShared && opts.Mask&(validMask) != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC {
// TODO(jamieliu): Use specialFileFD.handle.file for the getattr if
// available?
if err := d.updateFromGetattr(ctx); err != nil {
@@ -1130,21 +1219,21 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
}
// Listxattr implements vfs.FileDescriptionImpl.Listxattr.
-func (fd *fileDescription) Listxattr(ctx context.Context) ([]string, error) {
- return fd.dentry().listxattr(ctx)
+func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
+ return fd.dentry().listxattr(ctx, auth.CredentialsFromContext(ctx), size)
}
// Getxattr implements vfs.FileDescriptionImpl.Getxattr.
-func (fd *fileDescription) Getxattr(ctx context.Context, name string) (string, error) {
- return fd.dentry().getxattr(ctx, name)
+func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) {
+ return fd.dentry().getxattr(ctx, auth.CredentialsFromContext(ctx), &opts)
}
// Setxattr implements vfs.FileDescriptionImpl.Setxattr.
func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error {
- return fd.dentry().setxattr(ctx, &opts)
+ return fd.dentry().setxattr(ctx, auth.CredentialsFromContext(ctx), &opts)
}
// Removexattr implements vfs.FileDescriptionImpl.Removexattr.
func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
- return fd.dentry().removexattr(ctx, name)
+ return fd.dentry().removexattr(ctx, auth.CredentialsFromContext(ctx), name)
}
diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go
new file mode 100644
index 000000000..82bc239db
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/gofer_test.go
@@ -0,0 +1,64 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "sync/atomic"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+)
+
+func TestDestroyIdempotent(t *testing.T) {
+ fs := filesystem{
+ dentries: make(map[*dentry]struct{}),
+ opts: filesystemOptions{
+ // Test relies on no dentry being held in the cache.
+ maxCachedDentries: 0,
+ },
+ }
+
+ ctx := contexttest.Context(t)
+ attr := &p9.Attr{
+ Mode: p9.ModeRegular,
+ }
+ mask := p9.AttrMask{
+ Mode: true,
+ Size: true,
+ }
+ parent, err := fs.newDentry(ctx, p9file{}, p9.QID{}, mask, attr)
+ if err != nil {
+ t.Fatalf("fs.newDentry(): %v", err)
+ }
+
+ child, err := fs.newDentry(ctx, p9file{}, p9.QID{}, mask, attr)
+ if err != nil {
+ t.Fatalf("fs.newDentry(): %v", err)
+ }
+ parent.IncRef() // reference held by child on its parent.
+ parent.vfsd.InsertChild(&child.vfsd, "child")
+
+ child.checkCachingLocked()
+ if got := atomic.LoadInt64(&child.refs); got != -1 {
+ t.Fatalf("child.refs=%d, want: -1", got)
+ }
+ // Parent will also be destroyed when child reference is removed.
+ if got := atomic.LoadInt64(&parent.refs); got != -1 {
+ t.Fatalf("parent.refs=%d, want: -1", got)
+ }
+ child.checkCachingLocked()
+ child.checkCachingLocked()
+}
diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go
index 755ac2985..87f0b877f 100644
--- a/pkg/sentry/fsimpl/gofer/p9file.go
+++ b/pkg/sentry/fsimpl/gofer/p9file.go
@@ -85,6 +85,13 @@ func (f p9file) setAttr(ctx context.Context, valid p9.SetAttrMask, attr p9.SetAt
return err
}
+func (f p9file) listXattr(ctx context.Context, size uint64) (map[string]struct{}, error) {
+ ctx.UninterruptibleSleepStart(false)
+ xattrs, err := f.file.ListXattr(size)
+ ctx.UninterruptibleSleepFinish(false)
+ return xattrs, err
+}
+
func (f p9file) getXattr(ctx context.Context, name string, size uint64) (string, error) {
ctx.UninterruptibleSleepStart(false)
val, err := f.file.GetXattr(name, size)
@@ -99,6 +106,13 @@ func (f p9file) setXattr(ctx context.Context, name, value string, flags uint32)
return err
}
+func (f p9file) removeXattr(ctx context.Context, name string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.RemoveXattr(name)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
func (f p9file) allocate(ctx context.Context, mode p9.AllocateMode, offset, length uint64) error {
ctx.UninterruptibleSleepStart(false)
err := f.file.Allocate(mode, offset, length)
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 54c1031a7..857f7c74e 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -104,7 +104,7 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
putDentryReadWriter(rw)
if d.fs.opts.interop != InteropModeShared {
// Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed().
- d.touchAtime(ctx, fd.vfsfd.Mount())
+ d.touchAtime(fd.vfsfd.Mount())
}
return n, err
}
@@ -126,6 +126,11 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
if opts.Flags != 0 {
return 0, syserror.EOPNOTSUPP
}
+ limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
+ if err != nil {
+ return 0, err
+ }
+ src = src.TakeFirst64(limit)
d := fd.dentry()
d.metadataMu.Lock()
@@ -134,10 +139,7 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// Compare Linux's mm/filemap.c:__generic_file_write_iter() =>
// file_update_time(). This is d.touchCMtime(), but without locking
// d.metadataMu (recursively).
- if now, ok := nowFromContext(ctx); ok {
- atomic.StoreInt64(&d.mtime, now)
- atomic.StoreInt64(&d.ctime, now)
- }
+ d.touchCMtimeLocked()
}
if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
// Write dirty cached pages that will be touched by the write back to
@@ -361,8 +363,15 @@ func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, erro
rw.d.handleMu.RLock()
if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct {
n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, srcs, rw.off)
- rw.d.handleMu.RUnlock()
rw.off += n
+ rw.d.dataMu.Lock()
+ if rw.off > rw.d.size {
+ atomic.StoreUint64(&rw.d.size, rw.off)
+ // The remote file's size will implicitly be extended to the correct
+ // value when we write back to it.
+ }
+ rw.d.dataMu.Unlock()
+ rw.d.handleMu.RUnlock()
return n, err
}
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index 08c691c47..507e0e276 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -76,7 +76,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
// hold here since specialFileFD doesn't client-cache data. Just buffer the
// read instead.
if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
- d.touchAtime(ctx, fd.vfsfd.Mount())
+ d.touchAtime(fd.vfsfd.Mount())
}
buf := make([]byte, dst.NumBytes())
n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
@@ -107,9 +107,17 @@ func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
return 0, syserror.EOPNOTSUPP
}
+ if fd.dentry().fileType() == linux.S_IFREG {
+ limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
+ if err != nil {
+ return 0, err
+ }
+ src = src.TakeFirst64(limit)
+ }
+
// Do a buffered write. See rationale in PRead.
if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
- d.touchCMtime(ctx)
+ d.touchCMtime()
}
buf := make([]byte, src.NumBytes())
// Don't do partial writes if we get a partial read from src.
diff --git a/pkg/sentry/fsimpl/gofer/symlink.go b/pkg/sentry/fsimpl/gofer/symlink.go
index adf43be60..2ec819f86 100644
--- a/pkg/sentry/fsimpl/gofer/symlink.go
+++ b/pkg/sentry/fsimpl/gofer/symlink.go
@@ -27,7 +27,7 @@ func (d *dentry) isSymlink() bool {
// Precondition: d.isSymlink().
func (d *dentry) readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
if d.fs.opts.interop != InteropModeShared {
- d.touchAtime(ctx, mnt)
+ d.touchAtime(mnt)
d.dataMu.Lock()
if d.haveTarget {
target := d.target
diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go
index 7598ec6a8..2608e7e1d 100644
--- a/pkg/sentry/fsimpl/gofer/time.go
+++ b/pkg/sentry/fsimpl/gofer/time.go
@@ -18,8 +18,6 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/context"
- ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
@@ -38,23 +36,12 @@ func statxTimestampFromDentry(ns int64) linux.StatxTimestamp {
}
}
-func nowFromContext(ctx context.Context) (int64, bool) {
- if clock := ktime.RealtimeClockFromContext(ctx); clock != nil {
- return clock.Now().Nanoseconds(), true
- }
- return 0, false
-}
-
// Preconditions: fs.interop != InteropModeShared.
-func (d *dentry) touchAtime(ctx context.Context, mnt *vfs.Mount) {
+func (d *dentry) touchAtime(mnt *vfs.Mount) {
if err := mnt.CheckBeginWrite(); err != nil {
return
}
- now, ok := nowFromContext(ctx)
- if !ok {
- mnt.EndWrite()
- return
- }
+ now := d.fs.clock.Now().Nanoseconds()
d.metadataMu.Lock()
atomic.StoreInt64(&d.atime, now)
d.metadataMu.Unlock()
@@ -63,13 +50,25 @@ func (d *dentry) touchAtime(ctx context.Context, mnt *vfs.Mount) {
// Preconditions: fs.interop != InteropModeShared. The caller has successfully
// called vfs.Mount.CheckBeginWrite().
-func (d *dentry) touchCMtime(ctx context.Context) {
- now, ok := nowFromContext(ctx)
- if !ok {
- return
- }
+func (d *dentry) touchCtime() {
+ now := d.fs.clock.Now().Nanoseconds()
+ d.metadataMu.Lock()
+ atomic.StoreInt64(&d.ctime, now)
+ d.metadataMu.Unlock()
+}
+
+// Preconditions: fs.interop != InteropModeShared. The caller has successfully
+// called vfs.Mount.CheckBeginWrite().
+func (d *dentry) touchCMtime() {
+ now := d.fs.clock.Now().Nanoseconds()
d.metadataMu.Lock()
atomic.StoreInt64(&d.mtime, now)
atomic.StoreInt64(&d.ctime, now)
d.metadataMu.Unlock()
}
+
+func (d *dentry) touchCMtimeLocked() {
+ now := d.fs.clock.Now().Nanoseconds()
+ atomic.StoreInt64(&d.mtime, now)
+ atomic.StoreInt64(&d.ctime, now)
+}
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
new file mode 100644
index 000000000..82e1fb74b
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -0,0 +1,34 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "host",
+ srcs = [
+ "host.go",
+ "ioctl_unsafe.go",
+ "tty.go",
+ "util.go",
+ "util_unsafe.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fd",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/unimpl",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
new file mode 100644
index 000000000..97fa7f7ab
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -0,0 +1,667 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package host provides a filesystem implementation for host files imported as
+// file descriptors.
+package host
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// filesystemType implements vfs.FilesystemType.
+type filesystemType struct{}
+
+// GetFilesystem implements FilesystemType.GetFilesystem.
+func (filesystemType) GetFilesystem(context.Context, *vfs.VirtualFilesystem, *auth.Credentials, string, vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ panic("cannot instaniate a host filesystem")
+}
+
+// Name implements FilesystemType.Name.
+func (filesystemType) Name() string {
+ return "none"
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+}
+
+// NewMount returns a new disconnected mount in vfsObj that may be passed to ImportFD.
+func NewMount(vfsObj *vfs.VirtualFilesystem) (*vfs.Mount, error) {
+ fs := &filesystem{}
+ fs.Init(vfsObj, &filesystemType{})
+ vfsfs := fs.VFSFilesystem()
+ // NewDisconnectedMount will take an additional reference on vfsfs.
+ defer vfsfs.DecRef()
+ return vfsObj.NewDisconnectedMount(vfsfs, nil, &vfs.MountOptions{})
+}
+
+// ImportFD sets up and returns a vfs.FileDescription from a donated fd.
+func ImportFD(mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs.FileDescription, error) {
+ fs, ok := mnt.Filesystem().Impl().(*kernfs.Filesystem)
+ if !ok {
+ return nil, fmt.Errorf("can't import host FDs into filesystems of type %T", mnt.Filesystem().Impl())
+ }
+
+ // Retrieve metadata.
+ var s unix.Stat_t
+ if err := unix.Fstat(hostFD, &s); err != nil {
+ return nil, err
+ }
+
+ fileMode := linux.FileMode(s.Mode)
+ fileType := fileMode.FileType()
+
+ // Determine if hostFD is seekable. If not, this syscall will return ESPIPE
+ // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character
+ // devices.
+ _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR)
+ seekable := err != syserror.ESPIPE
+
+ i := &inode{
+ hostFD: hostFD,
+ seekable: seekable,
+ isTTY: isTTY,
+ canMap: canMap(uint32(fileType)),
+ ino: fs.NextIno(),
+ mode: fileMode,
+ // For simplicity, set offset to 0. Technically, we should use the existing
+ // offset on the host if the file is seekable.
+ offset: 0,
+ }
+
+ // Non-seekable files can't be memory mapped, assert this.
+ if !i.seekable && i.canMap {
+ panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped")
+ }
+
+ d := &kernfs.Dentry{}
+ d.Init(i)
+ // i.open will take a reference on d.
+ defer d.DecRef()
+
+ return i.open(d.VFSDentry(), mnt)
+}
+
+// inode implements kernfs.Inode.
+type inode struct {
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ // When the reference count reaches zero, the host fd is closed.
+ refs.AtomicRefCount
+
+ // hostFD contains the host fd that this file was originally created from,
+ // which must be available at time of restore.
+ //
+ // This field is initialized at creation time and is immutable.
+ hostFD int
+
+ // seekable is false if the host fd points to a file representing a stream,
+ // e.g. a socket or a pipe. Such files are not seekable and can return
+ // EWOULDBLOCK for I/O operations.
+ //
+ // This field is initialized at creation time and is immutable.
+ seekable bool
+
+ // isTTY is true if this file represents a TTY.
+ //
+ // This field is initialized at creation time and is immutable.
+ isTTY bool
+
+ // canMap specifies whether we allow the file to be memory mapped.
+ //
+ // This field is initialized at creation time and is immutable.
+ canMap bool
+
+ // ino is an inode number unique within this filesystem.
+ //
+ // This field is initialized at creation time and is immutable.
+ ino uint64
+
+ // modeMu protects mode.
+ modeMu sync.Mutex
+
+ // mode is a cached version of the file mode on the host. Note that it may
+ // become out of date if the mode is changed on the host, e.g. with chmod.
+ //
+ // Generally, it is better to retrieve the mode from the host through an
+ // fstat syscall. We only use this value in inode.Mode(), which cannot
+ // return an error, if the syscall to host fails.
+ //
+ // FIXME(b/152294168): Plumb error into Inode.Mode() return value so we
+ // can get rid of this.
+ mode linux.FileMode
+
+ // offsetMu protects offset.
+ offsetMu sync.Mutex
+
+ // offset specifies the current file offset.
+ offset int64
+}
+
+// Note that these flags may become out of date, since they can be modified
+// on the host, e.g. with fcntl.
+func fileFlagsFromHostFD(fd int) (int, error) {
+ flags, err := unix.FcntlInt(uintptr(fd), syscall.F_GETFL, 0)
+ if err != nil {
+ log.Warningf("Failed to get file flags for donated FD %d: %v", fd, err)
+ return 0, err
+ }
+ // TODO(gvisor.dev/issue/1672): implement behavior corresponding to these allowed flags.
+ flags &= syscall.O_ACCMODE | syscall.O_DIRECT | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND
+ return flags, nil
+}
+
+// CheckPermissions implements kernfs.Inode.
+func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ mode, uid, gid, err := i.getPermissions()
+ if err != nil {
+ return err
+ }
+ return vfs.GenericCheckPermissions(creds, ats, mode, uid, gid)
+}
+
+// Mode implements kernfs.Inode.
+func (i *inode) Mode() linux.FileMode {
+ mode, _, _, err := i.getPermissions()
+ if err != nil {
+ return i.mode
+ }
+
+ return linux.FileMode(mode)
+}
+
+func (i *inode) getPermissions() (linux.FileMode, auth.KUID, auth.KGID, error) {
+ // Retrieve metadata.
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ return 0, 0, 0, err
+ }
+
+ // Update cached mode.
+ i.modeMu.Lock()
+ i.mode = linux.FileMode(s.Mode)
+ i.modeMu.Unlock()
+ return linux.FileMode(s.Mode), auth.KUID(s.Uid), auth.KGID(s.Gid), nil
+}
+
+// Stat implements kernfs.Inode.
+func (i *inode) Stat(_ *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ if opts.Mask&linux.STATX__RESERVED != 0 {
+ return linux.Statx{}, syserror.EINVAL
+ }
+ if opts.Sync&linux.AT_STATX_SYNC_TYPE == linux.AT_STATX_SYNC_TYPE {
+ return linux.Statx{}, syserror.EINVAL
+ }
+
+ // Limit our host call only to known flags.
+ mask := opts.Mask & linux.STATX_ALL
+ var s unix.Statx_t
+ err := unix.Statx(i.hostFD, "", int(unix.AT_EMPTY_PATH|opts.Sync), int(mask), &s)
+ // Fallback to fstat(2), if statx(2) is not supported on the host.
+ //
+ // TODO(b/151263641): Remove fallback.
+ if err == syserror.ENOSYS {
+ return i.fstat(opts)
+ } else if err != nil {
+ return linux.Statx{}, err
+ }
+
+ ls := linux.Statx{Mask: mask}
+ // Unconditionally fill blksize, attributes, and device numbers, as indicated
+ // by /include/uapi/linux/stat.h.
+ //
+ // RdevMajor/RdevMinor are left as zero, so as not to expose host device
+ // numbers.
+ //
+ // TODO(gvisor.dev/issue/1672): Use kernfs-specific, internally defined
+ // device numbers. If we use the device number from the host, it may collide
+ // with another sentry-internal device number. We handle device/inode
+ // numbers without relying on the host to prevent collisions.
+ ls.Blksize = s.Blksize
+ ls.Attributes = s.Attributes
+ ls.AttributesMask = s.Attributes_mask
+
+ if mask&linux.STATX_TYPE != 0 {
+ ls.Mode |= s.Mode & linux.S_IFMT
+ }
+ if mask&linux.STATX_MODE != 0 {
+ ls.Mode |= s.Mode &^ linux.S_IFMT
+ }
+ if mask&linux.STATX_NLINK != 0 {
+ ls.Nlink = s.Nlink
+ }
+ if mask&linux.STATX_UID != 0 {
+ ls.UID = s.Uid
+ }
+ if mask&linux.STATX_GID != 0 {
+ ls.GID = s.Gid
+ }
+ if mask&linux.STATX_ATIME != 0 {
+ ls.Atime = unixToLinuxStatxTimestamp(s.Atime)
+ }
+ if mask&linux.STATX_BTIME != 0 {
+ ls.Btime = unixToLinuxStatxTimestamp(s.Btime)
+ }
+ if mask&linux.STATX_CTIME != 0 {
+ ls.Ctime = unixToLinuxStatxTimestamp(s.Ctime)
+ }
+ if mask&linux.STATX_MTIME != 0 {
+ ls.Mtime = unixToLinuxStatxTimestamp(s.Mtime)
+ }
+ if mask&linux.STATX_SIZE != 0 {
+ ls.Size = s.Size
+ }
+ if mask&linux.STATX_BLOCKS != 0 {
+ ls.Blocks = s.Blocks
+ }
+
+ // Use our own internal inode number.
+ if mask&linux.STATX_INO != 0 {
+ ls.Ino = i.ino
+ }
+
+ // Update cached mode.
+ if (mask&linux.STATX_TYPE != 0) && (mask&linux.STATX_MODE != 0) {
+ i.modeMu.Lock()
+ i.mode = linux.FileMode(s.Mode)
+ i.modeMu.Unlock()
+ }
+ return ls, nil
+}
+
+// fstat is a best-effort fallback for inode.Stat() if the host does not
+// support statx(2).
+//
+// We ignore the mask and sync flags in opts and simply supply
+// STATX_BASIC_STATS, as fstat(2) itself does not allow the specification
+// of a mask or sync flags. fstat(2) does not provide any metadata
+// equivalent to Statx.Attributes, Statx.AttributesMask, or Statx.Btime, so
+// those fields remain empty.
+func (i *inode) fstat(opts vfs.StatOptions) (linux.Statx, error) {
+ var s unix.Stat_t
+ if err := unix.Fstat(i.hostFD, &s); err != nil {
+ return linux.Statx{}, err
+ }
+
+ // Note that rdev numbers are left as 0; do not expose host device numbers.
+ ls := linux.Statx{
+ Mask: linux.STATX_BASIC_STATS,
+ Blksize: uint32(s.Blksize),
+ Nlink: uint32(s.Nlink),
+ UID: s.Uid,
+ GID: s.Gid,
+ Mode: uint16(s.Mode),
+ Size: uint64(s.Size),
+ Blocks: uint64(s.Blocks),
+ Atime: timespecToStatxTimestamp(s.Atim),
+ Ctime: timespecToStatxTimestamp(s.Ctim),
+ Mtime: timespecToStatxTimestamp(s.Mtim),
+ }
+
+ // Use our own internal inode number.
+ //
+ // TODO(gvisor.dev/issue/1672): Use a kernfs-specific device number as well.
+ // If we use the device number from the host, it may collide with another
+ // sentry-internal device number. We handle device/inode numbers without
+ // relying on the host to prevent collisions.
+ ls.Ino = i.ino
+
+ return ls, nil
+}
+
+// SetStat implements kernfs.Inode.
+func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ s := opts.Stat
+
+ m := s.Mask
+ if m == 0 {
+ return nil
+ }
+ if m&^(linux.STATX_MODE|linux.STATX_SIZE|linux.STATX_ATIME|linux.STATX_MTIME) != 0 {
+ return syserror.EPERM
+ }
+ mode, uid, gid, err := i.getPermissions()
+ if err != nil {
+ return err
+ }
+ if err := vfs.CheckSetStat(ctx, creds, &s, mode.Permissions(), uid, gid); err != nil {
+ return err
+ }
+
+ if m&linux.STATX_MODE != 0 {
+ if err := syscall.Fchmod(i.hostFD, uint32(s.Mode)); err != nil {
+ return err
+ }
+ i.modeMu.Lock()
+ i.mode = linux.FileMode(s.Mode)
+ i.modeMu.Unlock()
+ }
+ if m&linux.STATX_SIZE != 0 {
+ if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil {
+ return err
+ }
+ }
+ if m&(linux.STATX_ATIME|linux.STATX_MTIME) != 0 {
+ ts := [2]syscall.Timespec{
+ toTimespec(s.Atime, m&linux.STATX_ATIME == 0),
+ toTimespec(s.Mtime, m&linux.STATX_MTIME == 0),
+ }
+ if err := setTimestamps(i.hostFD, &ts); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// DecRef implements kernfs.Inode.
+func (i *inode) DecRef() {
+ i.AtomicRefCount.DecRefWithDestructor(i.Destroy)
+}
+
+// Destroy implements kernfs.Inode.
+func (i *inode) Destroy() {
+ if err := unix.Close(i.hostFD); err != nil {
+ log.Warningf("failed to close host fd %d: %v", i.hostFD, err)
+ }
+}
+
+// Open implements kernfs.Inode.
+func (i *inode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ return i.open(vfsd, rp.Mount())
+}
+
+func (i *inode) open(d *vfs.Dentry, mnt *vfs.Mount) (*vfs.FileDescription, error) {
+ mode, _, _, err := i.getPermissions()
+ if err != nil {
+ return nil, err
+ }
+ fileType := mode.FileType()
+ if fileType == syscall.S_IFSOCK {
+ if i.isTTY {
+ return nil, errors.New("cannot use host socket as TTY")
+ }
+ // TODO(gvisor.dev/issue/1672): support importing sockets.
+ return nil, errors.New("importing host sockets not supported")
+ }
+
+ // TODO(gvisor.dev/issue/1672): Whitelist specific file types here, so that
+ // we don't allow importing arbitrary file types without proper support.
+ var (
+ vfsfd *vfs.FileDescription
+ fdImpl vfs.FileDescriptionImpl
+ )
+ if i.isTTY {
+ fd := &ttyFD{
+ fileDescription: fileDescription{inode: i},
+ termios: linux.DefaultSlaveTermios,
+ }
+ vfsfd = &fd.vfsfd
+ fdImpl = fd
+ } else {
+ // For simplicity, set offset to 0. Technically, we should
+ // only set to 0 on files that are not seekable (sockets, pipes, etc.),
+ // and use the offset from the host fd otherwise.
+ fd := &fileDescription{inode: i}
+ vfsfd = &fd.vfsfd
+ fdImpl = fd
+ }
+
+ flags, err := fileFlagsFromHostFD(i.hostFD)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := vfsfd.Init(fdImpl, uint32(flags), mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return vfsfd, nil
+}
+
+// fileDescription is embedded by host fd implementations of FileDescriptionImpl.
+//
+// TODO(gvisor.dev/issue/1672): Implement Waitable interface.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+
+ // inode is vfsfd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode), but
+ // cached to reduce indirections and casting. fileDescription does not hold
+ // a reference on the inode through the inode field (since one is already
+ // held via the Dentry).
+ //
+ // inode is immutable after fileDescription creation.
+ inode *inode
+}
+
+// SetStat implements vfs.FileDescriptionImpl.
+func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ return f.inode.SetStat(ctx, nil, creds, opts)
+}
+
+// Stat implements vfs.FileDescriptionImpl.
+func (f *fileDescription) Stat(_ context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ return f.inode.Stat(nil, opts)
+}
+
+// Release implements vfs.FileDescriptionImpl.
+func (f *fileDescription) Release() {
+ // noop
+}
+
+// PRead implements FileDescriptionImpl.
+func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ i := f.inode
+ if !i.seekable {
+ return 0, syserror.ESPIPE
+ }
+
+ return readFromHostFD(ctx, i.hostFD, dst, offset, opts.Flags)
+}
+
+// Read implements FileDescriptionImpl.
+func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ i := f.inode
+ if !i.seekable {
+ n, err := readFromHostFD(ctx, i.hostFD, dst, -1, opts.Flags)
+ if isBlockError(err) {
+ // If we got any data at all, return it as a "completed" partial read
+ // rather than retrying until complete.
+ if n != 0 {
+ err = nil
+ } else {
+ err = syserror.ErrWouldBlock
+ }
+ }
+ return n, err
+ }
+ // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so.
+ i.offsetMu.Lock()
+ n, err := readFromHostFD(ctx, i.hostFD, dst, i.offset, opts.Flags)
+ i.offset += n
+ i.offsetMu.Unlock()
+ return n, err
+}
+
+func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) {
+ // TODO(gvisor.dev/issue/1672): Support select preadv2 flags.
+ if flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ var reader safemem.Reader
+ if offset == -1 {
+ reader = safemem.FromIOReader{fd.NewReadWriter(hostFD)}
+ } else {
+ reader = safemem.FromVecReaderFunc{
+ func(srcs [][]byte) (int64, error) {
+ n, err := unix.Preadv(hostFD, srcs, offset)
+ return int64(n), err
+ },
+ }
+ }
+ n, err := dst.CopyOutFrom(ctx, reader)
+ return int64(n), err
+}
+
+// PWrite implements FileDescriptionImpl.
+func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ i := f.inode
+ if !i.seekable {
+ return 0, syserror.ESPIPE
+ }
+
+ return writeToHostFD(ctx, i.hostFD, src, offset, opts.Flags)
+}
+
+// Write implements FileDescriptionImpl.
+func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ i := f.inode
+ if !i.seekable {
+ n, err := writeToHostFD(ctx, i.hostFD, src, -1, opts.Flags)
+ if isBlockError(err) {
+ err = syserror.ErrWouldBlock
+ }
+ return n, err
+ }
+ // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so.
+ // TODO(gvisor.dev/issue/1672): Write to end of file and update offset if O_APPEND is set on this file.
+ i.offsetMu.Lock()
+ n, err := writeToHostFD(ctx, i.hostFD, src, i.offset, opts.Flags)
+ i.offset += n
+ i.offsetMu.Unlock()
+ return n, err
+}
+
+func writeToHostFD(ctx context.Context, hostFD int, src usermem.IOSequence, offset int64, flags uint32) (int64, error) {
+ // TODO(gvisor.dev/issue/1672): Support select pwritev2 flags.
+ if flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ var writer safemem.Writer
+ if offset == -1 {
+ writer = safemem.FromIOWriter{fd.NewReadWriter(hostFD)}
+ } else {
+ writer = safemem.FromVecWriterFunc{
+ func(srcs [][]byte) (int64, error) {
+ n, err := unix.Pwritev(hostFD, srcs, offset)
+ return int64(n), err
+ },
+ }
+ }
+ n, err := src.CopyInTo(ctx, writer)
+ return int64(n), err
+}
+
+// Seek implements FileDescriptionImpl.
+//
+// Note that we do not support seeking on directories, since we do not even
+// allow directory fds to be imported at all.
+func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (int64, error) {
+ i := f.inode
+ if !i.seekable {
+ return 0, syserror.ESPIPE
+ }
+
+ i.offsetMu.Lock()
+ defer i.offsetMu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ if offset < 0 {
+ return i.offset, syserror.EINVAL
+ }
+ i.offset = offset
+
+ case linux.SEEK_CUR:
+ // Check for overflow. Note that underflow cannot occur, since i.offset >= 0.
+ if offset > math.MaxInt64-i.offset {
+ return i.offset, syserror.EOVERFLOW
+ }
+ if i.offset+offset < 0 {
+ return i.offset, syserror.EINVAL
+ }
+ i.offset += offset
+
+ case linux.SEEK_END:
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ return i.offset, err
+ }
+ size := s.Size
+
+ // Check for overflow. Note that underflow cannot occur, since size >= 0.
+ if offset > math.MaxInt64-size {
+ return i.offset, syserror.EOVERFLOW
+ }
+ if size+offset < 0 {
+ return i.offset, syserror.EINVAL
+ }
+ i.offset = size + offset
+
+ case linux.SEEK_DATA, linux.SEEK_HOLE:
+ // Modifying the offset in the host file table should not matter, since
+ // this is the only place where we use it.
+ //
+ // For reading and writing, we always rely on our internal offset.
+ n, err := unix.Seek(i.hostFD, offset, int(whence))
+ if err != nil {
+ return i.offset, err
+ }
+ i.offset = n
+
+ default:
+ // Invalid whence.
+ return i.offset, syserror.EINVAL
+ }
+
+ return i.offset, nil
+}
+
+// Sync implements FileDescriptionImpl.
+func (f *fileDescription) Sync(context.Context) error {
+ // TODO(gvisor.dev/issue/1672): Currently we do not support the SyncData optimization, so we always sync everything.
+ return unix.Fsync(f.inode.hostFD)
+}
+
+// ConfigureMMap implements FileDescriptionImpl.
+func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error {
+ if !f.inode.canMap {
+ return syserror.ENODEV
+ }
+ // TODO(gvisor.dev/issue/1672): Implement ConfigureMMap and Mappable interface.
+ return syserror.ENODEV
+}
diff --git a/pkg/sentry/fsimpl/host/ioctl_unsafe.go b/pkg/sentry/fsimpl/host/ioctl_unsafe.go
new file mode 100644
index 000000000..0983bf7d8
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/ioctl_unsafe.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+func ioctlGetTermios(fd int) (*linux.Termios, error) {
+ var t linux.Termios
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TCGETS, uintptr(unsafe.Pointer(&t)))
+ if errno != 0 {
+ return nil, errno
+ }
+ return &t, nil
+}
+
+func ioctlSetTermios(fd int, req uint64, t *linux.Termios) error {
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(unsafe.Pointer(t)))
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+func ioctlGetWinsize(fd int) (*linux.Winsize, error) {
+ var w linux.Winsize
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TIOCGWINSZ, uintptr(unsafe.Pointer(&w)))
+ if errno != 0 {
+ return nil, errno
+ }
+ return &w, nil
+}
+
+func ioctlSetWinsize(fd int, w *linux.Winsize) error {
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TIOCSWINSZ, uintptr(unsafe.Pointer(w)))
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go
new file mode 100644
index 000000000..8936afb06
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/tty.go
@@ -0,0 +1,379 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/unimpl"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// ttyFD implements vfs.FileDescriptionImpl for a host file descriptor
+// that wraps a TTY FD.
+type ttyFD struct {
+ fileDescription
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // session is the session attached to this ttyFD.
+ session *kernel.Session
+
+ // fgProcessGroup is the foreground process group that is currently
+ // connected to this TTY.
+ fgProcessGroup *kernel.ProcessGroup
+
+ // termios contains the terminal attributes for this TTY.
+ termios linux.KernelTermios
+}
+
+// InitForegroundProcessGroup sets the foreground process group and session for
+// the TTY. This should only be called once, after the foreground process group
+// has been created, but before it has started running.
+func (t *ttyFD) InitForegroundProcessGroup(pg *kernel.ProcessGroup) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.fgProcessGroup != nil {
+ panic("foreground process group is already set")
+ }
+ t.fgProcessGroup = pg
+ t.session = pg.Session()
+}
+
+// ForegroundProcessGroup returns the foreground process for the TTY.
+func (t *ttyFD) ForegroundProcessGroup() *kernel.ProcessGroup {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.fgProcessGroup
+}
+
+// Release implements fs.FileOperations.Release.
+func (t *ttyFD) Release() {
+ t.mu.Lock()
+ t.fgProcessGroup = nil
+ t.mu.Unlock()
+
+ t.fileDescription.Release()
+}
+
+// PRead implements vfs.FileDescriptionImpl.
+//
+// Reading from a TTY is only allowed for foreground process groups. Background
+// process groups will either get EIO or a SIGTTIN.
+func (t *ttyFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Are we allowed to do the read?
+ // drivers/tty/n_tty.c:n_tty_read()=>job_control()=>tty_check_change().
+ if err := t.checkChange(ctx, linux.SIGTTIN); err != nil {
+ return 0, err
+ }
+
+ // Do the read.
+ return t.fileDescription.PRead(ctx, dst, offset, opts)
+}
+
+// Read implements vfs.FileDescriptionImpl.
+//
+// Reading from a TTY is only allowed for foreground process groups. Background
+// process groups will either get EIO or a SIGTTIN.
+func (t *ttyFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Are we allowed to do the read?
+ // drivers/tty/n_tty.c:n_tty_read()=>job_control()=>tty_check_change().
+ if err := t.checkChange(ctx, linux.SIGTTIN); err != nil {
+ return 0, err
+ }
+
+ // Do the read.
+ return t.fileDescription.Read(ctx, dst, opts)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (t *ttyFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check whether TOSTOP is enabled. This corresponds to the check in
+ // drivers/tty/n_tty.c:n_tty_write().
+ if t.termios.LEnabled(linux.TOSTOP) {
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+ }
+ return t.fileDescription.PWrite(ctx, src, offset, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (t *ttyFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check whether TOSTOP is enabled. This corresponds to the check in
+ // drivers/tty/n_tty.c:n_tty_write().
+ if t.termios.LEnabled(linux.TOSTOP) {
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+ }
+ return t.fileDescription.Write(ctx, src, opts)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (t *ttyFD) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // Ignore arg[0]. This is the real FD:
+ fd := t.inode.hostFD
+ ioctl := args[1].Uint64()
+ switch ioctl {
+ case linux.TCGETS:
+ termios, err := ioctlGetTermios(fd)
+ if err != nil {
+ return 0, err
+ }
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), termios, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TCSETS, linux.TCSETSW, linux.TCSETSF:
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+
+ var termios linux.Termios
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &termios, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ err := ioctlSetTermios(fd, ioctl, &termios)
+ if err == nil {
+ t.termios.FromTermios(termios)
+ }
+ return 0, err
+
+ case linux.TIOCGPGRP:
+ // Args: pid_t *argp
+ // When successful, equivalent to *argp = tcgetpgrp(fd).
+ // Get the process group ID of the foreground process group on this
+ // terminal.
+
+ pidns := kernel.PIDNamespaceFromContext(ctx)
+ if pidns == nil {
+ return 0, syserror.ENOTTY
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Map the ProcessGroup into a ProcessGroupID in the task's PID namespace.
+ pgID := pidns.IDOfProcessGroup(t.fgProcessGroup)
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCSPGRP:
+ // Args: const pid_t *argp
+ // Equivalent to tcsetpgrp(fd, *argp).
+ // Set the foreground process group ID of this terminal.
+
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ return 0, syserror.ENOTTY
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check that we are allowed to set the process group.
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ // drivers/tty/tty_io.c:tiocspgrp() converts -EIO from tty_check_change()
+ // to -ENOTTY.
+ if err == syserror.EIO {
+ return 0, syserror.ENOTTY
+ }
+ return 0, err
+ }
+
+ // Check that calling task's process group is in the TTY session.
+ if task.ThreadGroup().Session() != t.session {
+ return 0, syserror.ENOTTY
+ }
+
+ var pgID kernel.ProcessGroupID
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ // pgID must be non-negative.
+ if pgID < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Process group with pgID must exist in this PID namespace.
+ pidns := task.PIDNamespace()
+ pg := pidns.ProcessGroupWithID(pgID)
+ if pg == nil {
+ return 0, syserror.ESRCH
+ }
+
+ // Check that new process group is in the TTY session.
+ if pg.Session() != t.session {
+ return 0, syserror.EPERM
+ }
+
+ t.fgProcessGroup = pg
+ return 0, nil
+
+ case linux.TIOCGWINSZ:
+ // Args: struct winsize *argp
+ // Get window size.
+ winsize, err := ioctlGetWinsize(fd)
+ if err != nil {
+ return 0, err
+ }
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), winsize, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCSWINSZ:
+ // Args: const struct winsize *argp
+ // Set window size.
+
+ // Unlike setting the termios, any process group (even background ones) can
+ // set the winsize.
+
+ var winsize linux.Winsize
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &winsize, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ err := ioctlSetWinsize(fd, &winsize)
+ return 0, err
+
+ // Unimplemented commands.
+ case linux.TIOCSETD,
+ linux.TIOCSBRK,
+ linux.TIOCCBRK,
+ linux.TCSBRK,
+ linux.TCSBRKP,
+ linux.TIOCSTI,
+ linux.TIOCCONS,
+ linux.FIONBIO,
+ linux.TIOCEXCL,
+ linux.TIOCNXCL,
+ linux.TIOCGEXCL,
+ linux.TIOCNOTTY,
+ linux.TIOCSCTTY,
+ linux.TIOCGSID,
+ linux.TIOCGETD,
+ linux.TIOCVHANGUP,
+ linux.TIOCGDEV,
+ linux.TIOCMGET,
+ linux.TIOCMSET,
+ linux.TIOCMBIC,
+ linux.TIOCMBIS,
+ linux.TIOCGICOUNT,
+ linux.TCFLSH,
+ linux.TIOCSSERIAL,
+ linux.TIOCGPTPEER:
+
+ unimpl.EmitUnimplementedEvent(ctx)
+ fallthrough
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// checkChange checks that the process group is allowed to read, write, or
+// change the state of the TTY.
+//
+// This corresponds to Linux drivers/tty/tty_io.c:tty_check_change(). The logic
+// is a bit convoluted, but documented inline.
+//
+// Preconditions: t.mu must be held.
+func (t *ttyFD) checkChange(ctx context.Context, sig linux.Signal) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ // No task? Linux does not have an analog for this case, but
+ // tty_check_change is more of a blacklist of cases than a
+ // whitelist, and is surprisingly permissive. Allowing the
+ // change seems most appropriate.
+ return nil
+ }
+
+ tg := task.ThreadGroup()
+ pg := tg.ProcessGroup()
+
+ // If the session for the task is different than the session for the
+ // controlling TTY, then the change is allowed. Seems like a bad idea,
+ // but that's exactly what linux does.
+ if tg.Session() != t.fgProcessGroup.Session() {
+ return nil
+ }
+
+ // If we are the foreground process group, then the change is allowed.
+ if pg == t.fgProcessGroup {
+ return nil
+ }
+
+ // We are not the foreground process group.
+
+ // Is the provided signal blocked or ignored?
+ if (task.SignalMask()&linux.SignalSetOf(sig) != 0) || tg.SignalHandlers().IsIgnored(sig) {
+ // If the signal is SIGTTIN, then we are attempting to read
+ // from the TTY. Don't send the signal and return EIO.
+ if sig == linux.SIGTTIN {
+ return syserror.EIO
+ }
+
+ // Otherwise, we are writing or changing terminal state. This is allowed.
+ return nil
+ }
+
+ // If the process group is an orphan, return EIO.
+ if pg.IsOrphan() {
+ return syserror.EIO
+ }
+
+ // Otherwise, send the signal to the process group and return ERESTARTSYS.
+ //
+ // Note that Linux also unconditionally sets TIF_SIGPENDING on current,
+ // but this isn't necessary in gVisor because the rationale given in
+ // 040b6362d58f "tty: fix leakage of -ERESTARTSYS to userland" doesn't
+ // apply: the sentry will handle -ERESTARTSYS in
+ // kernel.runApp.execute() even if the kernel.Task isn't interrupted.
+ //
+ // Linux ignores the result of kill_pgrp().
+ _ = pg.SendSignal(kernel.SignalInfoPriv(sig))
+ return kernel.ERESTARTSYS
+}
diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go
new file mode 100644
index 000000000..2bc757b1a
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/util.go
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func toTimespec(ts linux.StatxTimestamp, omit bool) syscall.Timespec {
+ if omit {
+ return syscall.Timespec{
+ Sec: 0,
+ Nsec: unix.UTIME_OMIT,
+ }
+ }
+ return syscall.Timespec{
+ Sec: ts.Sec,
+ Nsec: int64(ts.Nsec),
+ }
+}
+
+func unixToLinuxStatxTimestamp(ts unix.StatxTimestamp) linux.StatxTimestamp {
+ return linux.StatxTimestamp{Sec: ts.Sec, Nsec: ts.Nsec}
+}
+
+func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp {
+ return linux.StatxTimestamp{Sec: int64(ts.Sec), Nsec: uint32(ts.Nsec)}
+}
+
+// wouldBlock returns true for file types that can return EWOULDBLOCK
+// for blocking operations, e.g. pipes, character devices, and sockets.
+func wouldBlock(fileType uint32) bool {
+ return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK
+}
+
+// canMap returns true if a file with fileType is allowed to be memory mapped.
+// This is ported over from VFS1, but it's probably not the best way for us
+// to check if a file can be memory mapped.
+func canMap(fileType uint32) bool {
+ // TODO(gvisor.dev/issue/1672): Also allow "special files" to be mapped (see fs/host:canMap()).
+ //
+ // TODO(b/38213152): Some obscure character devices can be mapped.
+ return fileType == syscall.S_IFREG
+}
+
+// isBlockError checks if an error is EAGAIN or EWOULDBLOCK.
+// If so, they can be transformed into syserror.ErrWouldBlock.
+func isBlockError(err error) bool {
+ return err == syserror.EAGAIN || err == syserror.EWOULDBLOCK
+}
diff --git a/pkg/sentry/kernel/pipe/buffer_test.go b/pkg/sentry/fsimpl/host/util_unsafe.go
index 4d54b8b8f..5136ac844 100644
--- a/pkg/sentry/kernel/pipe/buffer_test.go
+++ b/pkg/sentry/fsimpl/host/util_unsafe.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,21 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package pipe
+package host
import (
- "testing"
+ "syscall"
"unsafe"
-
- "gvisor.dev/gvisor/pkg/usermem"
)
-func TestBufferSize(t *testing.T) {
- bufferSize := unsafe.Sizeof(buffer{})
- if bufferSize < usermem.PageSize {
- t.Errorf("buffer is less than a page")
- }
- if bufferSize > (2 * usermem.PageSize) {
- t.Errorf("buffer is greater than two pages")
+func setTimestamps(fd int, ts *[2]syscall.Timespec) error {
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_UTIMENSAT,
+ uintptr(fd),
+ 0, /* path */
+ uintptr(unsafe.Pointer(ts)),
+ 0, /* flags */
+ 0, 0)
+ if errno != 0 {
+ return errno
}
+ return nil
}
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index e73f1f857..b3d6299d0 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -35,6 +35,7 @@ go_library(
"//pkg/refs",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/syserror",
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index d092ccb2a..d8bddbafa 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -61,9 +61,10 @@ func (f *DynamicBytesFile) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vf
return &fd.vfsfd, nil
}
-// SetStat implements Inode.SetStat.
-func (f *DynamicBytesFile) SetStat(*vfs.Filesystem, vfs.SetStatOptions) error {
- // DynamicBytesFiles are immutable.
+// SetStat implements Inode.SetStat. By default DynamicBytesFile doesn't allow
+// inode attributes to be changed. Override SetStat() making it call
+// f.InodeAttrs to allow it.
+func (*DynamicBytesFile) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
}
@@ -122,7 +123,7 @@ func (fd *DynamicBytesFD) Release() {}
// Stat implements vfs.FileDescriptionImpl.Stat.
func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
fs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
- return fd.inode.Stat(fs), nil
+ return fd.inode.Stat(fs, opts)
}
// SetStat implements vfs.FileDescriptionImpl.SetStat.
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index 5650512e0..bfa786c88 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -17,6 +17,7 @@ package kernfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -107,9 +108,13 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
fs.mu.Lock()
defer fs.mu.Unlock()
+ opts := vfs.StatOptions{Mask: linux.STATX_INO}
// Handle ".".
if fd.off == 0 {
- stat := fd.inode().Stat(vfsFS)
+ stat, err := fd.inode().Stat(vfsFS, opts)
+ if err != nil {
+ return err
+ }
dirent := vfs.Dirent{
Name: ".",
Type: linux.DT_DIR,
@@ -125,7 +130,10 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
// Handle "..".
if fd.off == 1 {
parentInode := vfsd.ParentOrSelf().Impl().(*Dentry).inode
- stat := parentInode.Stat(vfsFS)
+ stat, err := parentInode.Stat(vfsFS, opts)
+ if err != nil {
+ return err
+ }
dirent := vfs.Dirent{
Name: "..",
Type: linux.FileMode(stat.Mode).DirentType(),
@@ -146,7 +154,10 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
childIdx := fd.off - 2
for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() {
inode := it.Dentry.Impl().(*Dentry).inode
- stat := inode.Stat(vfsFS)
+ stat, err := inode.Stat(vfsFS, opts)
+ if err != nil {
+ return err
+ }
dirent := vfs.Dirent{
Name: it.Name,
Type: linux.FileMode(stat.Mode).DirentType(),
@@ -190,12 +201,12 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int
func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
fs := fd.filesystem()
inode := fd.inode()
- return inode.Stat(fs), nil
+ return inode.Stat(fs, opts)
}
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
- fs := fd.filesystem()
+ creds := auth.CredentialsFromContext(ctx)
inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
- return inode.SetStat(fs, opts)
+ return inode.SetStat(ctx, fd.filesystem(), creds, opts)
}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 292f58afd..baf81b4db 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -12,16 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// This file implements vfs.FilesystemImpl for kernfs.
-
package kernfs
+// This file implements vfs.FilesystemImpl for kernfs.
+
import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -61,6 +63,9 @@ afterSymlink:
rp.Advance()
return nextVFSD, nil
}
+ if len(name) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
d.dirMu.Lock()
nextVFSD, err := rp.ResolveChild(vfsd, name)
if err != nil {
@@ -74,16 +79,22 @@ afterSymlink:
}
// Resolve any symlink at current path component.
if rp.ShouldFollowSymlink() && next.isSymlink() {
- // TODO: VFS2 needs something extra for /proc/[pid]/fd/ "magic symlinks".
- target, err := next.inode.Readlink(ctx)
+ targetVD, targetPathname, err := next.inode.Getlink(ctx)
if err != nil {
return nil, err
}
- if err := rp.HandleSymlink(target); err != nil {
- return nil, err
+ if targetVD.Ok() {
+ err := rp.HandleJump(targetVD)
+ targetVD.DecRef()
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ if err := rp.HandleSymlink(targetPathname); err != nil {
+ return nil, err
+ }
}
goto afterSymlink
-
}
rp.Advance()
return &next.vfsd, nil
@@ -189,6 +200,9 @@ func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *v
if pc == "." || pc == ".." {
return "", syserror.EEXIST
}
+ if len(pc) > linux.NAME_MAX {
+ return "", syserror.ENAMETOOLONG
+ }
childVFSD, err := rp.ResolveChild(parentVFSD, pc)
if err != nil {
return "", err
@@ -229,6 +243,19 @@ func (fs *Filesystem) Sync(ctx context.Context) error {
return nil
}
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *Filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ defer fs.processDeferredDecRefs()
+
+ _, inode, err := fs.walkExistingLocked(ctx, rp)
+ if err != nil {
+ return err
+ }
+ return inode.CheckPermissions(ctx, creds, ats)
+}
+
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
fs.mu.RLock()
@@ -418,6 +445,9 @@ afterTrailingSymlink:
if pc == "." || pc == ".." {
return nil, syserror.EISDIR
}
+ if len(pc) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
// Determine whether or not we need to create a file.
childVFSD, err := rp.ResolveChild(parentVFSD, pc)
if err != nil {
@@ -446,19 +476,25 @@ afterTrailingSymlink:
}
childDentry := childVFSD.Impl().(*Dentry)
childInode := childDentry.inode
- if rp.ShouldFollowSymlink() {
- if childDentry.isSymlink() {
- target, err := childInode.Readlink(ctx)
+ if rp.ShouldFollowSymlink() && childDentry.isSymlink() {
+ targetVD, targetPathname, err := childInode.Getlink(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if targetVD.Ok() {
+ err := rp.HandleJump(targetVD)
+ targetVD.DecRef()
if err != nil {
return nil, err
}
- if err := rp.HandleSymlink(target); err != nil {
+ } else {
+ if err := rp.HandleSymlink(targetPathname); err != nil {
return nil, err
}
- // rp.Final() may no longer be true since we now need to resolve the
- // symlink target.
- goto afterTrailingSymlink
}
+ // rp.Final() may no longer be true since we now need to resolve the
+ // symlink target.
+ goto afterTrailingSymlink
}
if err := childInode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
@@ -622,7 +658,7 @@ func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
if opts.Stat.Mask == 0 {
return nil
}
- return inode.SetStat(fs.VFSFilesystem(), opts)
+ return inode.SetStat(ctx, fs.VFSFilesystem(), rp.Credentials(), opts)
}
// StatAt implements vfs.FilesystemImpl.StatAt.
@@ -634,7 +670,7 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return linux.Statx{}, err
}
- return inode.Stat(fs.VFSFilesystem()), nil
+ return inode.Stat(fs.VFSFilesystem(), opts)
}
// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
@@ -646,7 +682,7 @@ func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
if err != nil {
return linux.Statfs{}, err
}
- // TODO: actually implement statfs
+ // TODO(gvisor.dev/issue/1193): actually implement statfs.
return linux.Statfs{}, syserror.ENOSYS
}
@@ -714,8 +750,20 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return nil
}
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return nil, err
+ }
+ return nil, syserror.ECONNREFUSED
+}
+
// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
-func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
fs.mu.RLock()
_, _, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
@@ -728,7 +776,7 @@ func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([
}
// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
-func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
fs.mu.RLock()
_, _, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 099d70a16..65f09af5d 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -36,20 +36,20 @@ type InodeNoopRefCount struct {
}
// IncRef implements Inode.IncRef.
-func (n *InodeNoopRefCount) IncRef() {
+func (InodeNoopRefCount) IncRef() {
}
// DecRef implements Inode.DecRef.
-func (n *InodeNoopRefCount) DecRef() {
+func (InodeNoopRefCount) DecRef() {
}
// TryIncRef implements Inode.TryIncRef.
-func (n *InodeNoopRefCount) TryIncRef() bool {
+func (InodeNoopRefCount) TryIncRef() bool {
return true
}
// Destroy implements Inode.Destroy.
-func (n *InodeNoopRefCount) Destroy() {
+func (InodeNoopRefCount) Destroy() {
}
// InodeDirectoryNoNewChildren partially implements the Inode interface.
@@ -58,27 +58,27 @@ func (n *InodeNoopRefCount) Destroy() {
type InodeDirectoryNoNewChildren struct{}
// NewFile implements Inode.NewFile.
-func (*InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
return nil, syserror.EPERM
}
// NewDir implements Inode.NewDir.
-func (*InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
return nil, syserror.EPERM
}
// NewLink implements Inode.NewLink.
-func (*InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
return nil, syserror.EPERM
}
// NewSymlink implements Inode.NewSymlink.
-func (*InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
return nil, syserror.EPERM
}
// NewNode implements Inode.NewNode.
-func (*InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
return nil, syserror.EPERM
}
@@ -90,62 +90,62 @@ type InodeNotDirectory struct {
}
// HasChildren implements Inode.HasChildren.
-func (*InodeNotDirectory) HasChildren() bool {
+func (InodeNotDirectory) HasChildren() bool {
return false
}
// NewFile implements Inode.NewFile.
-func (*InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
panic("NewFile called on non-directory inode")
}
// NewDir implements Inode.NewDir.
-func (*InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
panic("NewDir called on non-directory inode")
}
// NewLink implements Inode.NewLinkink.
-func (*InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
panic("NewLink called on non-directory inode")
}
// NewSymlink implements Inode.NewSymlink.
-func (*InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
panic("NewSymlink called on non-directory inode")
}
// NewNode implements Inode.NewNode.
-func (*InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
panic("NewNode called on non-directory inode")
}
// Unlink implements Inode.Unlink.
-func (*InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error {
+func (InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error {
panic("Unlink called on non-directory inode")
}
// RmDir implements Inode.RmDir.
-func (*InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error {
+func (InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error {
panic("RmDir called on non-directory inode")
}
// Rename implements Inode.Rename.
-func (*InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) {
+func (InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) {
panic("Rename called on non-directory inode")
}
// Lookup implements Inode.Lookup.
-func (*InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+func (InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
panic("Lookup called on non-directory inode")
}
// IterDirents implements Inode.IterDirents.
-func (*InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
panic("IterDirents called on non-directory inode")
}
// Valid implements Inode.Valid.
-func (*InodeNotDirectory) Valid(context.Context) bool {
+func (InodeNotDirectory) Valid(context.Context) bool {
return true
}
@@ -157,17 +157,17 @@ func (*InodeNotDirectory) Valid(context.Context) bool {
type InodeNoDynamicLookup struct{}
// Lookup implements Inode.Lookup.
-func (*InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+func (InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
return nil, syserror.ENOENT
}
// IterDirents implements Inode.IterDirents.
-func (*InodeNoDynamicLookup) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+func (InodeNoDynamicLookup) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
return offset, nil
}
// Valid implements Inode.Valid.
-func (*InodeNoDynamicLookup) Valid(ctx context.Context) bool {
+func (InodeNoDynamicLookup) Valid(ctx context.Context) bool {
return true
}
@@ -177,10 +177,15 @@ func (*InodeNoDynamicLookup) Valid(ctx context.Context) bool {
type InodeNotSymlink struct{}
// Readlink implements Inode.Readlink.
-func (*InodeNotSymlink) Readlink(context.Context) (string, error) {
+func (InodeNotSymlink) Readlink(context.Context) (string, error) {
return "", syserror.EINVAL
}
+// Getlink implements Inode.Getlink.
+func (InodeNotSymlink) Getlink(context.Context) (vfs.VirtualDentry, string, error) {
+ return vfs.VirtualDentry{}, "", syserror.EINVAL
+}
+
// InodeAttrs partially implements the Inode interface, specifically the
// inodeMetadata sub interface. InodeAttrs provides functionality related to
// inode attributes.
@@ -219,7 +224,7 @@ func (a *InodeAttrs) Mode() linux.FileMode {
// Stat partially implements Inode.Stat. Note that this function doesn't provide
// all the stat fields, and the embedder should consider extending the result
// with filesystem-specific fields.
-func (a *InodeAttrs) Stat(*vfs.Filesystem) linux.Statx {
+func (a *InodeAttrs) Stat(*vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) {
var stat linux.Statx
stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK
stat.Ino = atomic.LoadUint64(&a.ino)
@@ -228,13 +233,23 @@ func (a *InodeAttrs) Stat(*vfs.Filesystem) linux.Statx {
stat.GID = atomic.LoadUint32(&a.gid)
stat.Nlink = atomic.LoadUint32(&a.nlink)
- // TODO: Implement other stat fields like timestamps.
+ // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
- return stat
+ return stat, nil
}
// SetStat implements Inode.SetStat.
-func (a *InodeAttrs) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error {
+func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask == 0 {
+ return nil
+ }
+ if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 {
+ return syserror.EPERM
+ }
+ if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
+ return err
+ }
+
stat := opts.Stat
if stat.Mask&linux.STATX_MODE != 0 {
for {
@@ -256,19 +271,17 @@ func (a *InodeAttrs) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error {
// Note that not all fields are modifiable. For example, the file type and
// inode numbers are immutable after node creation.
- // TODO: Implement other stat fields like timestamps.
+ // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
return nil
}
// CheckPermissions implements Inode.CheckPermissions.
func (a *InodeAttrs) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
- mode := a.Mode()
return vfs.GenericCheckPermissions(
creds,
ats,
- mode.FileType() == linux.ModeDirectory,
- uint16(mode),
+ a.Mode(),
auth.KUID(atomic.LoadUint32(&a.uid)),
auth.KGID(atomic.LoadUint32(&a.gid)),
)
@@ -554,3 +567,16 @@ func (s *StaticDirectory) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs
fd.Init(rp.Mount(), vfsd, &s.OrderedChildren, &opts)
return fd.VFSFileDescription(), nil
}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*StaticDirectory) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// AlwaysValid partially implements kernfs.inodeDynamicLookup.
+type AlwaysValid struct{}
+
+// Valid implements kernfs.inodeDynamicLookup.
+func (*AlwaysValid) Valid(context.Context) bool {
+ return true
+}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index c74fa999b..ad76b9f64 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -63,9 +63,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
)
-// FilesystemType implements vfs.FilesystemType.
-type FilesystemType struct{}
-
// Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory
// filesystem. Concrete implementations are expected to embed this in their own
// Filesystem type.
@@ -138,8 +135,8 @@ func (fs *Filesystem) processDeferredDecRefsLocked() {
// Init initializes a kernfs filesystem. This should be called from during
// vfs.FilesystemType.NewFilesystem for the concrete filesystem embedding
// kernfs.
-func (fs *Filesystem) Init(vfsObj *vfs.VirtualFilesystem) {
- fs.vfsfs.Init(vfsObj, fs)
+func (fs *Filesystem) Init(vfsObj *vfs.VirtualFilesystem, fsType vfs.FilesystemType) {
+ fs.vfsfs.Init(vfsObj, fsType, fs)
}
// VFSFilesystem returns the generic vfs filesystem object.
@@ -176,8 +173,6 @@ type Dentry struct {
vfsd vfs.Dentry
inode Inode
- refs uint64
-
// flags caches useful information about the dentry from the inode. See the
// dflags* consts above. Must be accessed by atomic ops.
flags uint32
@@ -302,7 +297,8 @@ type Inode interface {
// this inode. The returned file description should hold a reference on the
// inode for its lifetime.
//
- // Precondition: !rp.Done(). vfsd.Impl() must be a kernfs Dentry.
+ // Precondition: rp.Done(). vfsd.Impl() must be the kernfs Dentry containing
+ // the inode on which Open() is being called.
Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error)
}
@@ -320,7 +316,7 @@ type inodeMetadata interface {
// CheckPermissions checks that creds may access this inode for the
// requested access type, per the the rules of
// fs/namei.c:generic_permission().
- CheckPermissions(ctx context.Context, creds *auth.Credentials, atx vfs.AccessTypes) error
+ CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error
// Mode returns the (struct stat)::st_mode value for this inode. This is
// separated from Stat for performance.
@@ -328,11 +324,13 @@ type inodeMetadata interface {
// Stat returns the metadata for this inode. This corresponds to
// vfs.FilesystemImpl.StatAt.
- Stat(fs *vfs.Filesystem) linux.Statx
+ Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error)
// SetStat updates the metadata for this inode. This corresponds to
- // vfs.FilesystemImpl.SetStatAt.
- SetStat(fs *vfs.Filesystem, opts vfs.SetStatOptions) error
+ // vfs.FilesystemImpl.SetStatAt. Implementations are responsible for checking
+ // if the operation can be performed (see vfs.CheckSetStat() for common
+ // checks).
+ SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error
}
// Precondition: All methods in this interface may only be called on directory
@@ -416,7 +414,21 @@ type inodeDynamicLookup interface {
}
type inodeSymlink interface {
- // Readlink resolves the target of a symbolic link. If an inode is not a
+ // Readlink returns the target of a symbolic link. If an inode is not a
// symlink, the implementation should return EINVAL.
Readlink(ctx context.Context) (string, error)
+
+ // Getlink returns the target of a symbolic link, as used by path
+ // resolution:
+ //
+ // - If the inode is a "magic link" (a link whose target is most accurately
+ // represented as a VirtualDentry), Getlink returns (ok VirtualDentry, "",
+ // nil). A reference is taken on the returned VirtualDentry.
+ //
+ // - If the inode is an ordinary symlink, Getlink returns (zero-value
+ // VirtualDentry, symlink target, nil).
+ //
+ // - If the inode is not a symlink, Getlink returns (zero-value
+ // VirtualDentry, "", EINVAL).
+ Getlink(ctx context.Context) (vfs.VirtualDentry, string, error)
}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index 0459fb305..465451f35 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -91,7 +91,7 @@ type attrs struct {
kernfs.InodeAttrs
}
-func (a *attrs) SetStat(fs *vfs.Filesystem, opt vfs.SetStatOptions) error {
+func (*attrs) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
}
@@ -187,9 +187,13 @@ func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, err
return nil, syserror.EPERM
}
-func (fst *fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+func (fsType) Name() string {
+ return "kernfs"
+}
+
+func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
fs := &filesystem{}
- fs.Init(vfsObj)
+ fs.Init(vfsObj, &fst)
root := fst.rootFn(creds, fs)
return fs.VFSFilesystem(), root.VFSDentry(), nil
}
diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go
index 0ee7eb9b7..018aa503c 100644
--- a/pkg/sentry/fsimpl/kernfs/symlink.go
+++ b/pkg/sentry/fsimpl/kernfs/symlink.go
@@ -18,6 +18,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
)
// StaticSymlink provides an Inode implementation for symlinks that point to
@@ -52,3 +54,13 @@ func (s *StaticSymlink) Init(creds *auth.Credentials, ino uint64, target string)
func (s *StaticSymlink) Readlink(_ context.Context) (string, error) {
return s.target, nil
}
+
+// Getlink implements Inode.Getlink.
+func (s *StaticSymlink) Getlink(_ context.Context) (vfs.VirtualDentry, string, error) {
+ return vfs.VirtualDentry{}, s.target, nil
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*StaticSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
diff --git a/pkg/sentry/fsimpl/pipefs/BUILD b/pkg/sentry/fsimpl/pipefs/BUILD
new file mode 100644
index 000000000..0d411606f
--- /dev/null
+++ b/pkg/sentry/fsimpl/pipefs/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "pipefs",
+ srcs = ["pipefs.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
new file mode 100644
index 000000000..faf3179bc
--- /dev/null
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -0,0 +1,148 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pipefs provides the filesystem implementation backing
+// Kernel.PipeMount.
+package pipefs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type filesystemType struct{}
+
+// Name implements vfs.FilesystemType.Name.
+func (filesystemType) Name() string {
+ return "pipefs"
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (filesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ panic("pipefs.filesystemType.GetFilesystem should never be called")
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+
+ // TODO(gvisor.dev/issue/1193):
+ //
+ // - kernfs does not provide a way to implement statfs, from which we
+ // should indicate PIPEFS_MAGIC.
+ //
+ // - kernfs does not provide a way to override names for
+ // vfs.FilesystemImpl.PrependPath(); pipefs inodes should use synthetic
+ // name fmt.Sprintf("pipe:[%d]", inode.ino).
+}
+
+// NewFilesystem sets up and returns a new vfs.Filesystem implemented by
+// pipefs.
+func NewFilesystem(vfsObj *vfs.VirtualFilesystem) *vfs.Filesystem {
+ fs := &filesystem{}
+ fs.Init(vfsObj, filesystemType{})
+ return fs.VFSFilesystem()
+}
+
+// inode implements kernfs.Inode.
+type inode struct {
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+ kernfs.InodeNoopRefCount
+
+ pipe *pipe.VFSPipe
+
+ ino uint64
+ uid auth.KUID
+ gid auth.KGID
+ // We use the creation timestamp for all of atime, mtime, and ctime.
+ ctime ktime.Time
+}
+
+func newInode(ctx context.Context, fs *kernfs.Filesystem) *inode {
+ creds := auth.CredentialsFromContext(ctx)
+ return &inode{
+ pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize),
+ ino: fs.NextIno(),
+ uid: creds.EffectiveKUID,
+ gid: creds.EffectiveKGID,
+ ctime: ktime.NowFromContext(ctx),
+ }
+}
+
+const pipeMode = 0600 | linux.S_IFIFO
+
+// CheckPermissions implements kernfs.Inode.CheckPermissions.
+func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, pipeMode, i.uid, i.gid)
+}
+
+// Mode implements kernfs.Inode.Mode.
+func (i *inode) Mode() linux.FileMode {
+ return pipeMode
+}
+
+// Stat implements kernfs.Inode.Stat.
+func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ ts := linux.NsecToStatxTimestamp(i.ctime.Nanoseconds())
+ return linux.Statx{
+ Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS,
+ Blksize: usermem.PageSize,
+ Nlink: 1,
+ UID: uint32(i.uid),
+ GID: uint32(i.gid),
+ Mode: pipeMode,
+ Ino: i.ino,
+ Size: 0,
+ Blocks: 0,
+ Atime: ts,
+ Ctime: ts,
+ Mtime: ts,
+ // TODO(gvisor.dev/issue/1197): Device number.
+ }, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat.
+func (i *inode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask == 0 {
+ return nil
+ }
+ return syserror.EPERM
+}
+
+// Open implements kernfs.Inode.Open.
+func (i *inode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // FIXME(b/38173783): kernfs does not plumb Context here.
+ return i.pipe.Open(context.Background(), rp.Mount(), vfsd, opts.Flags)
+}
+
+// NewConnectedPipeFDs returns a pair of FileDescriptions representing the read
+// and write ends of a newly-created pipe, as for pipe(2) and pipe2(2).
+//
+// Preconditions: mnt.Filesystem() must have been returned by NewFilesystem().
+func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription) {
+ fs := mnt.Filesystem().Impl().(*kernfs.Filesystem)
+ inode := newInode(ctx, fs)
+ var d kernfs.Dentry
+ d.Init(inode)
+ defer d.DecRef()
+ return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags)
+}
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index a83245866..17c1342b5 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -8,10 +8,11 @@ go_library(
"filesystem.go",
"subtasks.go",
"task.go",
+ "task_fds.go",
"task_files.go",
+ "task_net.go",
"tasks.go",
"tasks_files.go",
- "tasks_net.go",
"tasks_sys.go",
],
visibility = ["//pkg/sentry:internal"],
@@ -19,8 +20,9 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/log",
+ "//pkg/refs",
"//pkg/safemem",
- "//pkg/sentry/fs",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
@@ -53,6 +55,7 @@ go_test(
"//pkg/fspath",
"//pkg/sentry/contexttest",
"//pkg/sentry/fsimpl/testutil",
+ "//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index 5c19d5522..104fc9030 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -36,8 +36,13 @@ type FilesystemType struct{}
var _ vfs.FilesystemType = (*FilesystemType)(nil)
-// GetFilesystem implements vfs.FilesystemType.
-func (ft *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
k := kernel.KernelFromContext(ctx)
if k == nil {
return nil, nil, fmt.Errorf("procfs requires a kernel")
@@ -48,7 +53,7 @@ func (ft *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virtual
}
procfs := &kernfs.Filesystem{}
- procfs.VFSFilesystem().Init(vfsObj, procfs)
+ procfs.VFSFilesystem().Init(vfsObj, &ft, procfs)
var cgroups map[string]string
if opts.InternalData != nil {
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index f3f4e49b4..a21313666 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -34,6 +35,7 @@ type subtasksInode struct {
kernfs.InodeDirectoryNoNewChildren
kernfs.InodeAttrs
kernfs.OrderedChildren
+ kernfs.AlwaysValid
task *kernel.Task
pidns *kernel.PIDNamespace
@@ -61,11 +63,6 @@ func newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, inoGen InoGenera
return dentry
}
-// Valid implements kernfs.inodeDynamicLookup.
-func (i *subtasksInode) Valid(ctx context.Context) bool {
- return true
-}
-
// Lookup implements kernfs.inodeDynamicLookup.
func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
tid, err := strconv.ParseUint(name, 10, 32)
@@ -121,8 +118,18 @@ func (i *subtasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.O
}
// Stat implements kernfs.Inode.
-func (i *subtasksInode) Stat(vsfs *vfs.Filesystem) linux.Statx {
- stat := i.InodeAttrs.Stat(vsfs)
- stat.Nlink += uint32(i.task.ThreadGroup().Count())
- return stat
+func (i *subtasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.InodeAttrs.Stat(vsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ if opts.Mask&linux.STATX_NLINK != 0 {
+ stat.Nlink += uint32(i.task.ThreadGroup().Count())
+ }
+ return stat, nil
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*subtasksInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index 2d814668a..888afc0fd 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -45,28 +45,31 @@ var _ kernfs.Inode = (*taskInode)(nil)
func newTaskInode(inoGen InoGenerator, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) *kernfs.Dentry {
contents := map[string]*kernfs.Dentry{
- "auxv": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &auxvData{task: task}),
- "cmdline": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}),
- "comm": newComm(task, inoGen.NextIno(), 0444),
- "environ": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}),
- //"exe": newExe(t, msrc),
- //"fd": newFdDir(t, msrc),
- //"fdinfo": newFdInfoDir(t, msrc),
- "gid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: true}),
- "io": newTaskOwnedFile(task, inoGen.NextIno(), 0400, newIO(task, isThreadGroup)),
- "maps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mapsData{task: task}),
- //"mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
- //"mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
+ "auxv": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &auxvData{task: task}),
+ "cmdline": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}),
+ "comm": newComm(task, inoGen.NextIno(), 0444),
+ "environ": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}),
+ "exe": newExeSymlink(task, inoGen.NextIno()),
+ "fd": newFDDirInode(task, inoGen),
+ "fdinfo": newFDInfoDirInode(task, inoGen),
+ "gid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: true}),
+ "io": newTaskOwnedFile(task, inoGen.NextIno(), 0400, newIO(task, isThreadGroup)),
+ "maps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mapsData{task: task}),
+ "mountinfo": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mountInfoData{task: task}),
+ "mounts": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mountsData{task: task}),
+ "net": newTaskNetDir(task, inoGen),
"ns": newTaskOwnedDir(task, inoGen.NextIno(), 0511, map[string]*kernfs.Dentry{
"net": newNamespaceSymlink(task, inoGen.NextIno(), "net"),
"pid": newNamespaceSymlink(task, inoGen.NextIno(), "pid"),
"user": newNamespaceSymlink(task, inoGen.NextIno(), "user"),
}),
- "smaps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &smapsData{task: task}),
- "stat": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}),
- "statm": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statmData{task: task}),
- "status": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statusData{task: task, pidns: pidns}),
- "uid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: false}),
+ "oom_score": newTaskOwnedFile(task, inoGen.NextIno(), 0444, newStaticFile("0\n")),
+ "oom_score_adj": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &oomScoreAdj{task: task}),
+ "smaps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &smapsData{task: task}),
+ "stat": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}),
+ "statm": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statmData{task: task}),
+ "status": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statusData{task: task, pidns: pidns}),
+ "uid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: false}),
}
if isThreadGroup {
contents["task"] = newSubtasks(task, pidns, inoGen, cgroupControllers)
@@ -104,13 +107,9 @@ func (i *taskInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenO
return fd.VFSFileDescription(), nil
}
-// SetStat implements kernfs.Inode.
-func (i *taskInode) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error {
- stat := opts.Stat
- if stat.Mask&linux.STATX_MODE != 0 {
- return syserror.EPERM
- }
- return nil
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*taskInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
}
// taskOwnedInode implements kernfs.Inode and overrides inode owner with task
@@ -152,26 +151,28 @@ func newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, childre
}
// Stat implements kernfs.Inode.
-func (i *taskOwnedInode) Stat(fs *vfs.Filesystem) linux.Statx {
- stat := i.Inode.Stat(fs)
- uid, gid := i.getOwner(linux.FileMode(stat.Mode))
- stat.UID = uint32(uid)
- stat.GID = uint32(gid)
- return stat
+func (i *taskOwnedInode) Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.Inode.Stat(fs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ if opts.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 {
+ uid, gid := i.getOwner(linux.FileMode(stat.Mode))
+ if opts.Mask&linux.STATX_UID != 0 {
+ stat.UID = uint32(uid)
+ }
+ if opts.Mask&linux.STATX_GID != 0 {
+ stat.GID = uint32(gid)
+ }
+ }
+ return stat, nil
}
// CheckPermissions implements kernfs.Inode.
func (i *taskOwnedInode) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
mode := i.Mode()
uid, gid := i.getOwner(mode)
- return vfs.GenericCheckPermissions(
- creds,
- ats,
- mode.FileType() == linux.ModeDirectory,
- uint16(mode),
- uid,
- gid,
- )
+ return vfs.GenericCheckPermissions(creds, ats, mode, uid, gid)
}
func (i *taskOwnedInode) getOwner(mode linux.FileMode) (auth.KUID, auth.KGID) {
@@ -213,28 +214,12 @@ func newIO(t *kernel.Task, isThreadGroup bool) *ioData {
return &ioData{ioUsage: t}
}
-func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry {
- // Namespace symlinks should contain the namespace name and the inode number
- // for the namespace instance, so for example user:[123456]. We currently fake
- // the inode number by sticking the symlink inode in its place.
- target := fmt.Sprintf("%s:[%d]", ns, ino)
-
- inode := &kernfs.StaticSymlink{}
- // Note: credentials are overridden by taskOwnedInode.
- inode.Init(task.Credentials(), ino, target)
-
- taskInode := &taskOwnedInode{Inode: inode, owner: task}
- d := &kernfs.Dentry{}
- d.Init(taskInode)
- return d
-}
-
// newCgroupData creates inode that shows cgroup information.
// From man 7 cgroups: "For each cgroup hierarchy of which the process is a
// member, there is one entry containing three colon-separated fields:
// hierarchy-ID:controller-list:cgroup-path"
func newCgroupData(controllers map[string]string) dynamicInode {
- buf := bytes.Buffer{}
+ var buf bytes.Buffer
// The hierarchy ids must be positive integers (for cgroup v1), but the
// exact number does not matter, so long as they are unique. We can
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
new file mode 100644
index 000000000..046265eca
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -0,0 +1,302 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func getTaskFD(t *kernel.Task, fd int32) (*vfs.FileDescription, kernel.FDFlags) {
+ var (
+ file *vfs.FileDescription
+ flags kernel.FDFlags
+ )
+ t.WithMuLocked(func(t *kernel.Task) {
+ if fdt := t.FDTable(); fdt != nil {
+ file, flags = fdt.GetVFS2(fd)
+ }
+ })
+ return file, flags
+}
+
+func taskFDExists(t *kernel.Task, fd int32) bool {
+ file, _ := getTaskFD(t, fd)
+ if file == nil {
+ return false
+ }
+ file.DecRef()
+ return true
+}
+
+type fdDir struct {
+ inoGen InoGenerator
+ task *kernel.Task
+
+ // When produceSymlinks is set, dirents produces for the FDs are reported
+ // as symlink. Otherwise, they are reported as regular files.
+ produceSymlink bool
+}
+
+// IterDirents implements kernfs.inodeDynamicLookup.
+func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, absOffset, relOffset int64) (int64, error) {
+ var fds []int32
+ i.task.WithMuLocked(func(t *kernel.Task) {
+ if fdTable := t.FDTable(); fdTable != nil {
+ fds = fdTable.GetFDs()
+ }
+ })
+
+ offset := absOffset + relOffset
+ typ := uint8(linux.DT_REG)
+ if i.produceSymlink {
+ typ = linux.DT_LNK
+ }
+
+ // Find the appropriate starting point.
+ idx := sort.Search(len(fds), func(i int) bool { return fds[i] >= int32(relOffset) })
+ if idx >= len(fds) {
+ return offset, nil
+ }
+ for _, fd := range fds[idx:] {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(fd), 10),
+ Type: typ,
+ Ino: i.inoGen.NextIno(),
+ NextOff: offset + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ return offset, nil
+}
+
+// fdDirInode represents the inode for /proc/[pid]/fd directory.
+//
+// +stateify savable
+type fdDirInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+ kernfs.AlwaysValid
+ fdDir
+}
+
+var _ kernfs.Inode = (*fdDirInode)(nil)
+
+func newFDDirInode(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry {
+ inode := &fdDirInode{
+ fdDir: fdDir{
+ inoGen: inoGen,
+ task: task,
+ produceSymlink: true,
+ },
+ }
+ inode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555)
+
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+ inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+
+ return dentry
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ fdInt, err := strconv.ParseInt(name, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+ fd := int32(fdInt)
+ if !taskFDExists(i.task, fd) {
+ return nil, syserror.ENOENT
+ }
+ taskDentry := newFDSymlink(i.task, fd, i.inoGen.NextIno())
+ return taskDentry.VFSDentry(), nil
+}
+
+// Open implements kernfs.Inode.
+func (i *fdDirInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &kernfs.GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ return fd.VFSFileDescription(), nil
+}
+
+// CheckPermissions implements kernfs.Inode.
+//
+// This is to match Linux, which uses a special permission handler to guarantee
+// that a process can still access /proc/self/fd after it has executed
+// setuid. See fs/proc/fd.c:proc_fd_permission.
+func (i *fdDirInode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ err := i.InodeAttrs.CheckPermissions(ctx, creds, ats)
+ if err == nil {
+ // Access granted, no extra check needed.
+ return nil
+ }
+ if t := kernel.TaskFromContext(ctx); t != nil {
+ // Allow access if the task trying to access it is in the thread group
+ // corresponding to this directory.
+ if i.task.ThreadGroup() == t.ThreadGroup() {
+ // Access granted (overridden).
+ return nil
+ }
+ }
+ return err
+}
+
+// fdSymlink is an symlink for the /proc/[pid]/fd/[fd] file.
+//
+// +stateify savable
+type fdSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ task *kernel.Task
+ fd int32
+}
+
+var _ kernfs.Inode = (*fdSymlink)(nil)
+
+func newFDSymlink(task *kernel.Task, fd int32, ino uint64) *kernfs.Dentry {
+ inode := &fdSymlink{
+ task: task,
+ fd: fd,
+ }
+ inode.Init(task.Credentials(), ino, linux.ModeSymlink|0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (s *fdSymlink) Readlink(ctx context.Context) (string, error) {
+ file, _ := getTaskFD(s.task, s.fd)
+ if file == nil {
+ return "", syserror.ENOENT
+ }
+ defer file.DecRef()
+ root := vfs.RootFromContext(ctx)
+ defer root.DecRef()
+ return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry())
+}
+
+func (s *fdSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) {
+ file, _ := getTaskFD(s.task, s.fd)
+ if file == nil {
+ return vfs.VirtualDentry{}, "", syserror.ENOENT
+ }
+ defer file.DecRef()
+ vd := file.VirtualDentry()
+ vd.IncRef()
+ return vd, "", nil
+}
+
+// fdInfoDirInode represents the inode for /proc/[pid]/fdinfo directory.
+//
+// +stateify savable
+type fdInfoDirInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+ kernfs.AlwaysValid
+ fdDir
+}
+
+var _ kernfs.Inode = (*fdInfoDirInode)(nil)
+
+func newFDInfoDirInode(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry {
+ inode := &fdInfoDirInode{
+ fdDir: fdDir{
+ inoGen: inoGen,
+ task: task,
+ },
+ }
+ inode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555)
+
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+ inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+
+ return dentry
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ fdInt, err := strconv.ParseInt(name, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+ fd := int32(fdInt)
+ if !taskFDExists(i.task, fd) {
+ return nil, syserror.ENOENT
+ }
+ data := &fdInfoData{
+ task: i.task,
+ fd: fd,
+ }
+ dentry := newTaskOwnedFile(i.task, i.inoGen.NextIno(), 0444, data)
+ return dentry.VFSDentry(), nil
+}
+
+// Open implements kernfs.Inode.
+func (i *fdInfoDirInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &kernfs.GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ return fd.VFSFileDescription(), nil
+}
+
+// fdInfoData implements vfs.DynamicBytesSource for /proc/[pid]/fdinfo/[fd].
+//
+// +stateify savable
+type fdInfoData struct {
+ kernfs.DynamicBytesFile
+ refs.AtomicRefCount
+
+ task *kernel.Task
+ fd int32
+}
+
+var _ dynamicInode = (*fdInfoData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ file, descriptorFlags := getTaskFD(d.task, d.fd)
+ if file == nil {
+ return syserror.ENOENT
+ }
+ defer file.DecRef()
+ // TODO(b/121266871): Include pos, locks, and other data. For now we only
+ // have flags.
+ // See https://www.kernel.org/doc/Documentation/filesystems/proc.txt
+ flags := uint(file.StatusFlags()) | descriptorFlags.ToLinuxFileFlags()
+ fmt.Fprintf(buf, "flags:\t0%o\n", flags)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index efd3b3453..2c6f8bdfc 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -63,6 +64,16 @@ func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) {
return m, nil
}
+func checkTaskState(t *kernel.Task) error {
+ switch t.ExitState() {
+ case kernel.TaskExitZombie:
+ return syserror.EACCES
+ case kernel.TaskExitDead:
+ return syserror.ESRCH
+ }
+ return nil
+}
+
type bufferWriter struct {
buf *bytes.Buffer
}
@@ -496,7 +507,7 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
return nil
}
-// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider.
+// ioUsage is the /proc/[pid]/io and /proc/[pid]/task/[tid]/io data provider.
type ioUsage interface {
// IOUsage returns the io usage data.
IOUsage() *usage.IO
@@ -525,3 +536,226 @@ func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled)
return nil
}
+
+// oomScoreAdj is a stub of the /proc/<pid>/oom_score_adj file.
+//
+// +stateify savable
+type oomScoreAdj struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ vfs.WritableDynamicBytesSource = (*oomScoreAdj)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (o *oomScoreAdj) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if o.task.ExitState() == kernel.TaskExitDead {
+ return syserror.ESRCH
+ }
+ fmt.Fprintf(buf, "%d\n", o.task.OOMScoreAdj())
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (o *oomScoreAdj) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit input size so as not to impact performance if input size is large.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+
+ if o.task.ExitState() == kernel.TaskExitDead {
+ return 0, syserror.ESRCH
+ }
+ if err := o.task.SetOOMScoreAdj(v); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+// exeSymlink is an symlink for the /proc/[pid]/exe file.
+//
+// +stateify savable
+type exeSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ task *kernel.Task
+}
+
+var _ kernfs.Inode = (*exeSymlink)(nil)
+
+func newExeSymlink(task *kernel.Task, ino uint64) *kernfs.Dentry {
+ inode := &exeSymlink{task: task}
+ inode.Init(task.Credentials(), ino, linux.ModeSymlink|0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+// Readlink implements kernfs.Inode.
+func (s *exeSymlink) Readlink(ctx context.Context) (string, error) {
+ if !kernel.ContextCanTrace(ctx, s.task, false) {
+ return "", syserror.EACCES
+ }
+
+ // Pull out the executable for /proc/[pid]/exe.
+ exec, err := s.executable()
+ if err != nil {
+ return "", err
+ }
+ defer exec.DecRef()
+
+ return exec.PathnameWithDeleted(ctx), nil
+}
+
+// Getlink implements kernfs.Inode.Getlink.
+func (s *exeSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) {
+ if !kernel.ContextCanTrace(ctx, s.task, false) {
+ return vfs.VirtualDentry{}, "", syserror.EACCES
+ }
+
+ exec, err := s.executable()
+ if err != nil {
+ return vfs.VirtualDentry{}, "", err
+ }
+ defer exec.DecRef()
+
+ vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry()
+ vd.IncRef()
+ return vd, "", nil
+}
+
+func (s *exeSymlink) executable() (file fsbridge.File, err error) {
+ if err := checkTaskState(s.task); err != nil {
+ return nil, err
+ }
+
+ s.task.WithMuLocked(func(t *kernel.Task) {
+ mm := t.MemoryManager()
+ if mm == nil {
+ err = syserror.EACCES
+ return
+ }
+
+ // The MemoryManager may be destroyed, in which case
+ // MemoryManager.destroy will simply set the executable to nil
+ // (with locks held).
+ file = mm.Executable()
+ if file == nil {
+ err = syserror.ESRCH
+ }
+ })
+ return
+}
+
+// mountInfoData is used to implement /proc/[pid]/mountinfo.
+//
+// +stateify savable
+type mountInfoData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*mountInfoData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (i *mountInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var fsctx *kernel.FSContext
+ i.task.WithMuLocked(func(t *kernel.Task) {
+ fsctx = t.FSContext()
+ })
+ if fsctx == nil {
+ // The task has been destroyed. Nothing to show here.
+ return nil
+ }
+ rootDir := fsctx.RootDirectoryVFS2()
+ if !rootDir.Ok() {
+ // Root has been destroyed. Don't try to read mounts.
+ return nil
+ }
+ defer rootDir.DecRef()
+ i.task.Kernel().VFS().GenerateProcMountInfo(ctx, rootDir, buf)
+ return nil
+}
+
+// mountsData is used to implement /proc/[pid]/mounts.
+//
+// +stateify savable
+type mountsData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*mountsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (i *mountsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var fsctx *kernel.FSContext
+ i.task.WithMuLocked(func(t *kernel.Task) {
+ fsctx = t.FSContext()
+ })
+ if fsctx == nil {
+ // The task has been destroyed. Nothing to show here.
+ return nil
+ }
+ rootDir := fsctx.RootDirectoryVFS2()
+ if !rootDir.Ok() {
+ // Root has been destroyed. Don't try to read mounts.
+ return nil
+ }
+ defer rootDir.DecRef()
+ i.task.Kernel().VFS().GenerateProcMounts(ctx, rootDir, buf)
+ return nil
+}
+
+type namespaceSymlink struct {
+ kernfs.StaticSymlink
+
+ task *kernel.Task
+}
+
+func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry {
+ // Namespace symlinks should contain the namespace name and the inode number
+ // for the namespace instance, so for example user:[123456]. We currently fake
+ // the inode number by sticking the symlink inode in its place.
+ target := fmt.Sprintf("%s:[%d]", ns, ino)
+
+ inode := &namespaceSymlink{task: task}
+ // Note: credentials are overridden by taskOwnedInode.
+ inode.Init(task.Credentials(), ino, target)
+
+ taskInode := &taskOwnedInode{Inode: inode, owner: task}
+ d := &kernfs.Dentry{}
+ d.Init(taskInode)
+ return d
+}
+
+// Readlink implements Inode.
+func (s *namespaceSymlink) Readlink(ctx context.Context) (string, error) {
+ if err := checkTaskState(s.task); err != nil {
+ return "", err
+ }
+ return s.StaticSymlink.Readlink(ctx)
+}
+
+// Getlink implements Inode.Getlink.
+func (s *namespaceSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) {
+ if err := checkTaskState(s.task); err != nil {
+ return vfs.VirtualDentry{}, "", err
+ }
+ return s.StaticSymlink.Getlink(ctx)
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_net.go b/pkg/sentry/fsimpl/proc/task_net.go
index d4e1812d8..6595fcee6 100644
--- a/pkg/sentry/fsimpl/proc/tasks_net.go
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -32,17 +31,19 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/usermem"
)
-func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry {
+func newTaskNetDir(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry {
+ k := task.Kernel()
+ pidns := task.PIDNamespace()
+ root := auth.NewRootCredentials(pidns.UserNamespace())
+
var contents map[string]*kernfs.Dentry
- // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
- // network namespace of the calling process. We should make this per-process,
- // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net.
- if stack := k.RootNetworkNamespace().Stack(); stack != nil {
+ if stack := task.NetworkNamespace().Stack(); stack != nil {
const (
arp = "IP address HW type Flags HW address Mask Device\n"
netlink = "sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n"
@@ -53,6 +54,8 @@ func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k
)
psched := fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond))
+ // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task
+ // network namespace.
contents = map[string]*kernfs.Dentry{
"dev": newDentry(root, inoGen.NextIno(), 0444, &netDevData{stack: stack}),
"snmp": newDentry(root, inoGen.NextIno(), 0444, &netSnmpData{stack: stack}),
@@ -84,7 +87,7 @@ func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k
}
}
- return kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, contents)
+ return newTaskOwnedDir(task, inoGen.NextIno(), 0555, contents)
}
// ifinet6 implements vfs.DynamicBytesSource for /proc/net/if_inet6.
@@ -203,22 +206,21 @@ var _ dynamicInode = (*netUnixData)(nil)
func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
buf.WriteString("Num RefCount Protocol Flags Type St Inode Path\n")
for _, se := range n.kernel.ListSockets() {
- s := se.Sock.Get()
- if s == nil {
- log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock)
+ s := se.SockVFS2
+ if !s.TryIncRef() {
+ log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
continue
}
- sfile := s.(*fs.File)
- if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX {
+ if family, _, _ := s.Impl().(socket.SocketVFS2).Type(); family != linux.AF_UNIX {
s.DecRef()
// Not a unix socket.
continue
}
- sops := sfile.FileOperations.(*unix.SocketOperations)
+ sops := s.Impl().(*unix.SocketVFS2)
addr, err := sops.Endpoint().GetLocalAddress()
if err != nil {
- log.Warningf("Failed to retrieve socket name from %+v: %v", sfile, err)
+ log.Warningf("Failed to retrieve socket name from %+v: %v", s, err)
addr.Addr = "<unknown>"
}
@@ -231,6 +233,15 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
}
+ // Get inode number.
+ var ino uint64
+ stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_INO})
+ if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
+ log.Warningf("Failed to retrieve ino for socket file: %v", statErr)
+ } else {
+ ino = stat.Ino
+ }
+
// In the socket entry below, the value for the 'Num' field requires
// some consideration. Linux prints the address to the struct
// unix_sock representing a socket in the kernel, but may redact the
@@ -249,14 +260,14 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// the definition of this struct changes over time.
//
// For now, we always redact this pointer.
- fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %5d",
+ fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %8d",
(*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct.
- sfile.ReadRefs()-1, // RefCount, don't count our own ref.
+ s.Refs()-1, // RefCount, don't count our own ref.
0, // Protocol, always 0 for UDS.
sockFlags, // Flags.
sops.Endpoint().Type(), // Type.
sops.State(), // State.
- sfile.InodeID(), // Inode.
+ ino, // Inode.
)
// Path
@@ -338,15 +349,14 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel,
t := kernel.TaskFromContext(ctx)
for _, se := range k.ListSockets() {
- s := se.Sock.Get()
- if s == nil {
- log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
+ s := se.SockVFS2
+ if !s.TryIncRef() {
+ log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
continue
}
- sfile := s.(*fs.File)
- sops, ok := sfile.FileOperations.(socket.Socket)
+ sops, ok := s.Impl().(socket.SocketVFS2)
if !ok {
- panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
+ panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s))
}
if fa, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) {
s.DecRef()
@@ -395,14 +405,15 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel,
// Unimplemented.
fmt.Fprintf(buf, "%08X ", 0)
+ stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_UID | linux.STATX_INO})
+
// Field: uid.
- uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx)
- if err != nil {
- log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
+ if statErr != nil || stat.Mask&linux.STATX_UID == 0 {
+ log.Warningf("Failed to retrieve uid for socket file: %v", statErr)
fmt.Fprintf(buf, "%5d ", 0)
} else {
creds := auth.CredentialsFromContext(ctx)
- fmt.Fprintf(buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow()))
+ fmt.Fprintf(buf, "%5d ", uint32(auth.KUID(stat.UID).In(creds.UserNamespace).OrOverflow()))
}
// Field: timeout; number of unanswered 0-window probes.
@@ -410,11 +421,16 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel,
fmt.Fprintf(buf, "%8d ", 0)
// Field: inode.
- fmt.Fprintf(buf, "%8d ", sfile.InodeID())
+ if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
+ log.Warningf("Failed to retrieve inode for socket file: %v", statErr)
+ fmt.Fprintf(buf, "%8d ", 0)
+ } else {
+ fmt.Fprintf(buf, "%8d ", stat.Ino)
+ }
// Field: refcount. Don't count the ref we obtain while deferencing
// the weakref to this socket.
- fmt.Fprintf(buf, "%d ", sfile.ReadRefs()-1)
+ fmt.Fprintf(buf, "%d ", s.Refs()-1)
// Field: Socket struct address. Redacted due to the same reason as
// the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
@@ -496,15 +512,14 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error {
t := kernel.TaskFromContext(ctx)
for _, se := range d.kernel.ListSockets() {
- s := se.Sock.Get()
- if s == nil {
- log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
+ s := se.SockVFS2
+ if !s.TryIncRef() {
+ log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
continue
}
- sfile := s.(*fs.File)
- sops, ok := sfile.FileOperations.(socket.Socket)
+ sops, ok := s.Impl().(socket.SocketVFS2)
if !ok {
- panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
+ panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s))
}
if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM {
s.DecRef()
@@ -548,25 +563,31 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// Field: retrnsmt. Always 0 for UDP.
fmt.Fprintf(buf, "%08X ", 0)
+ stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_UID | linux.STATX_INO})
+
// Field: uid.
- uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx)
- if err != nil {
- log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
+ if statErr != nil || stat.Mask&linux.STATX_UID == 0 {
+ log.Warningf("Failed to retrieve uid for socket file: %v", statErr)
fmt.Fprintf(buf, "%5d ", 0)
} else {
creds := auth.CredentialsFromContext(ctx)
- fmt.Fprintf(buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow()))
+ fmt.Fprintf(buf, "%5d ", uint32(auth.KUID(stat.UID).In(creds.UserNamespace).OrOverflow()))
}
// Field: timeout. Always 0 for UDP.
fmt.Fprintf(buf, "%8d ", 0)
// Field: inode.
- fmt.Fprintf(buf, "%8d ", sfile.InodeID())
+ if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
+ log.Warningf("Failed to retrieve inode for socket file: %v", statErr)
+ fmt.Fprintf(buf, "%8d ", 0)
+ } else {
+ fmt.Fprintf(buf, "%8d ", stat.Ino)
+ }
// Field: ref; reference count on the socket inode. Don't count the ref
// we obtain while deferencing the weakref to this socket.
- fmt.Fprintf(buf, "%d ", sfile.ReadRefs()-1)
+ fmt.Fprintf(buf, "%d ", s.Refs()-1)
// Field: Socket struct address. Redacted due to the same reason as
// the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
@@ -667,9 +688,9 @@ func (d *netSnmpData) Generate(ctx context.Context, buf *bytes.Buffer) error {
if line.prefix == "Tcp" {
tcp := stat.(*inet.StatSNMPTCP)
// "Tcp" needs special processing because MaxConn is signed. RFC 2012.
- fmt.Sprintf("%s: %s %d %s\n", line.prefix, sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:]))
+ fmt.Fprintf(buf, "%s: %s %d %s\n", line.prefix, sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:]))
} else {
- fmt.Sprintf("%s: %s\n", line.prefix, sprintSlice(toSlice(stat)))
+ fmt.Fprintf(buf, "%s: %s\n", line.prefix, sprintSlice(toSlice(stat)))
}
}
return nil
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index 10c08fa90..9f2ef8200 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -46,6 +46,7 @@ type tasksInode struct {
kernfs.InodeDirectoryNoNewChildren
kernfs.InodeAttrs
kernfs.OrderedChildren
+ kernfs.AlwaysValid
inoGen InoGenerator
pidns *kernel.PIDNamespace
@@ -66,23 +67,23 @@ var _ kernfs.Inode = (*tasksInode)(nil)
func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) (*tasksInode, *kernfs.Dentry) {
root := auth.NewRootCredentials(pidns.UserNamespace())
contents := map[string]*kernfs.Dentry{
- "cpuinfo": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(cpuInfoData(k))),
- //"filesystems": newDentry(root, inoGen.NextIno(), 0444, &filesystemsData{}),
- "loadavg": newDentry(root, inoGen.NextIno(), 0444, &loadavgData{}),
- "sys": newSysDir(root, inoGen, k),
- "meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}),
- "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"),
- "net": newNetDir(root, inoGen, k),
- "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{k: k}),
- "uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}),
- "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{k: k}),
+ "cpuinfo": newDentry(root, inoGen.NextIno(), 0444, newStaticFileSetStat(cpuInfoData(k))),
+ "filesystems": newDentry(root, inoGen.NextIno(), 0444, &filesystemsData{}),
+ "loadavg": newDentry(root, inoGen.NextIno(), 0444, &loadavgData{}),
+ "sys": newSysDir(root, inoGen, k),
+ "meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}),
+ "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"),
+ "net": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/net"),
+ "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{}),
+ "uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}),
+ "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{}),
}
inode := &tasksInode{
pidns: pidns,
inoGen: inoGen,
- selfSymlink: newSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(),
- threadSelfSymlink: newThreadSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(),
+ selfSymlink: newSelfSymlink(root, inoGen.NextIno(), pidns).VFSDentry(),
+ threadSelfSymlink: newThreadSelfSymlink(root, inoGen.NextIno(), pidns).VFSDentry(),
cgroupControllers: cgroupControllers,
}
inode.InodeAttrs.Init(root, inoGen.NextIno(), linux.ModeDirectory|0555)
@@ -121,11 +122,6 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro
return taskDentry.VFSDentry(), nil
}
-// Valid implements kernfs.inodeDynamicLookup.
-func (i *tasksInode) Valid(ctx context.Context) bool {
- return true
-}
-
// IterDirents implements kernfs.inodeDynamicLookup.
func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
// fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256
@@ -211,17 +207,36 @@ func (i *tasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.Open
return fd.VFSFileDescription(), nil
}
-func (i *tasksInode) Stat(vsfs *vfs.Filesystem) linux.Statx {
- stat := i.InodeAttrs.Stat(vsfs)
+func (i *tasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.InodeAttrs.Stat(vsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
- // Add dynamic children to link count.
- for _, tg := range i.pidns.ThreadGroups() {
- if leader := tg.Leader(); leader != nil {
- stat.Nlink++
+ if opts.Mask&linux.STATX_NLINK != 0 {
+ // Add dynamic children to link count.
+ for _, tg := range i.pidns.ThreadGroups() {
+ if leader := tg.Leader(); leader != nil {
+ stat.Nlink++
+ }
}
}
- return stat
+ return stat, nil
+}
+
+// staticFileSetStat implements a special static file that allows inode
+// attributes to be set. This is to support /proc files that are readonly, but
+// allow attributes to be set.
+type staticFileSetStat struct {
+ dynamicBytesFileSetAttr
+ vfs.StaticData
+}
+
+var _ dynamicInode = (*staticFileSetStat)(nil)
+
+func newStaticFileSetStat(data string) *staticFileSetStat {
+ return &staticFileSetStat{StaticData: vfs.StaticData{Data: data}}
}
func cpuInfoData(k *kernel.Kernel) string {
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index 434998910..4621e2de0 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -40,9 +41,9 @@ type selfSymlink struct {
var _ kernfs.Inode = (*selfSymlink)(nil)
-func newSelfSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, pidns *kernel.PIDNamespace) *kernfs.Dentry {
+func newSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry {
inode := &selfSymlink{pidns: pidns}
- inode.Init(creds, ino, linux.ModeSymlink|perm)
+ inode.Init(creds, ino, linux.ModeSymlink|0777)
d := &kernfs.Dentry{}
d.Init(inode)
@@ -62,6 +63,16 @@ func (s *selfSymlink) Readlink(ctx context.Context) (string, error) {
return strconv.FormatUint(uint64(tgid), 10), nil
}
+func (s *selfSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) {
+ target, err := s.Readlink(ctx)
+ return vfs.VirtualDentry{}, target, err
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*selfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
type threadSelfSymlink struct {
kernfs.InodeAttrs
kernfs.InodeNoopRefCount
@@ -72,9 +83,9 @@ type threadSelfSymlink struct {
var _ kernfs.Inode = (*threadSelfSymlink)(nil)
-func newThreadSelfSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, pidns *kernel.PIDNamespace) *kernfs.Dentry {
+func newThreadSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry {
inode := &threadSelfSymlink{pidns: pidns}
- inode.Init(creds, ino, linux.ModeSymlink|perm)
+ inode.Init(creds, ino, linux.ModeSymlink|0777)
d := &kernfs.Dentry{}
d.Init(inode)
@@ -95,6 +106,28 @@ func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) {
return fmt.Sprintf("%d/task/%d", tgid, tid), nil
}
+func (s *threadSelfSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) {
+ target, err := s.Readlink(ctx)
+ return vfs.VirtualDentry{}, target, err
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*threadSelfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// dynamicBytesFileSetAttr implements a special file that allows inode
+// attributes to be set. This is to support /proc files that are readonly, but
+// allow attributes to be set.
+type dynamicBytesFileSetAttr struct {
+ kernfs.DynamicBytesFile
+}
+
+// SetStat implements Inode.SetStat.
+func (d *dynamicBytesFileSetAttr) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ return d.DynamicBytesFile.InodeAttrs.SetStat(ctx, fs, creds, opts)
+}
+
// cpuStats contains the breakdown of CPU time for /proc/stat.
type cpuStats struct {
// user is time spent in userspace tasks with non-positive niceness.
@@ -137,22 +170,20 @@ func (c cpuStats) String() string {
//
// +stateify savable
type statData struct {
- kernfs.DynamicBytesFile
-
- // k is the owning Kernel.
- k *kernel.Kernel
+ dynamicBytesFileSetAttr
}
var _ dynamicInode = (*statData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+func (*statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// TODO(b/37226836): We currently export only zero CPU stats. We could
// at least provide some aggregate stats.
var cpu cpuStats
fmt.Fprintf(buf, "cpu %s\n", cpu)
- for c, max := uint(0), s.k.ApplicationCores(); c < max; c++ {
+ k := kernel.KernelFromContext(ctx)
+ for c, max := uint(0), k.ApplicationCores(); c < max; c++ {
fmt.Fprintf(buf, "cpu%d %s\n", c, cpu)
}
@@ -176,7 +207,7 @@ func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "ctxt 0\n")
// CLOCK_REALTIME timestamp from boot, in seconds.
- fmt.Fprintf(buf, "btime %d\n", s.k.Timekeeper().BootTime().Seconds())
+ fmt.Fprintf(buf, "btime %d\n", k.Timekeeper().BootTime().Seconds())
// Total number of clones.
// TODO(b/37226836): Count this.
@@ -203,13 +234,13 @@ func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
//
// +stateify savable
type loadavgData struct {
- kernfs.DynamicBytesFile
+ dynamicBytesFileSetAttr
}
var _ dynamicInode = (*loadavgData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+func (*loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// TODO(b/62345059): Include real data in fields.
// Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods.
// Column 4-5: currently running processes and the total number of processes.
@@ -222,17 +253,15 @@ func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
//
// +stateify savable
type meminfoData struct {
- kernfs.DynamicBytesFile
-
- // k is the owning Kernel.
- k *kernel.Kernel
+ dynamicBytesFileSetAttr
}
var _ dynamicInode = (*meminfoData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- mf := d.k.MemoryFile()
+func (*meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ mf := k.MemoryFile()
mf.UpdateUsage()
snapshot, totalUsage := usage.MemoryAccounting.Copy()
totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
@@ -275,7 +304,7 @@ func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
//
// +stateify savable
type uptimeData struct {
- kernfs.DynamicBytesFile
+ dynamicBytesFileSetAttr
}
var _ dynamicInode = (*uptimeData)(nil)
@@ -294,17 +323,15 @@ func (*uptimeData) Generate(ctx context.Context, buf *bytes.Buffer) error {
//
// +stateify savable
type versionData struct {
- kernfs.DynamicBytesFile
-
- // k is the owning Kernel.
- k *kernel.Kernel
+ dynamicBytesFileSetAttr
}
var _ dynamicInode = (*versionData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
-func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- init := v.k.GlobalInit()
+func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ init := k.GlobalInit()
if init == nil {
// Attempted to read before the init Task is created. This can
// only occur during startup, which should never need to read
@@ -335,3 +362,19 @@ func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version)
return nil
}
+
+// filesystemsData backs /proc/filesystems.
+//
+// +stateify savable
+type filesystemsData struct {
+ kernfs.DynamicBytesFile
+}
+
+var _ dynamicInode = (*filesystemsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *filesystemsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ k.VFS().GenerateProcFilesystems(buf)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index c5d531fe0..d0f97c137 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -47,10 +48,11 @@ var (
var (
tasksStaticFiles = map[string]testutil.DirentType{
"cpuinfo": linux.DT_REG,
+ "filesystems": linux.DT_REG,
"loadavg": linux.DT_REG,
"meminfo": linux.DT_REG,
"mounts": linux.DT_LNK,
- "net": linux.DT_DIR,
+ "net": linux.DT_LNK,
"self": linux.DT_LNK,
"stat": linux.DT_REG,
"sys": linux.DT_DIR,
@@ -63,21 +65,29 @@ var (
"thread-self": threadSelfLink.NextOff,
}
taskStaticFiles = map[string]testutil.DirentType{
- "auxv": linux.DT_REG,
- "cgroup": linux.DT_REG,
- "cmdline": linux.DT_REG,
- "comm": linux.DT_REG,
- "environ": linux.DT_REG,
- "gid_map": linux.DT_REG,
- "io": linux.DT_REG,
- "maps": linux.DT_REG,
- "ns": linux.DT_DIR,
- "smaps": linux.DT_REG,
- "stat": linux.DT_REG,
- "statm": linux.DT_REG,
- "status": linux.DT_REG,
- "task": linux.DT_DIR,
- "uid_map": linux.DT_REG,
+ "auxv": linux.DT_REG,
+ "cgroup": linux.DT_REG,
+ "cmdline": linux.DT_REG,
+ "comm": linux.DT_REG,
+ "environ": linux.DT_REG,
+ "exe": linux.DT_LNK,
+ "fd": linux.DT_DIR,
+ "fdinfo": linux.DT_DIR,
+ "gid_map": linux.DT_REG,
+ "io": linux.DT_REG,
+ "maps": linux.DT_REG,
+ "mountinfo": linux.DT_REG,
+ "mounts": linux.DT_REG,
+ "net": linux.DT_DIR,
+ "ns": linux.DT_DIR,
+ "oom_score": linux.DT_REG,
+ "oom_score_adj": linux.DT_REG,
+ "smaps": linux.DT_REG,
+ "stat": linux.DT_REG,
+ "statm": linux.DT_REG,
+ "status": linux.DT_REG,
+ "task": linux.DT_DIR,
+ "uid_map": linux.DT_REG,
}
)
@@ -93,17 +103,37 @@ func setup(t *testing.T) *testutil.System {
k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
- fsOpts := vfs.GetFilesystemOptions{
- InternalData: &InternalData{
- Cgroups: map[string]string{
- "cpuset": "/foo/cpuset",
- "memory": "/foo/memory",
+
+ mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("NewMountNamespace(): %v", err)
+ }
+ pop := &vfs.PathOperation{
+ Root: mntns.Root(),
+ Start: mntns.Root(),
+ Path: fspath.Parse("/proc"),
+ }
+ if err := k.VFS().MkdirAt(ctx, creds, pop, &vfs.MkdirOptions{Mode: 0777}); err != nil {
+ t.Fatalf("MkDir(/proc): %v", err)
+ }
+
+ pop = &vfs.PathOperation{
+ Root: mntns.Root(),
+ Start: mntns.Root(),
+ Path: fspath.Parse("/proc"),
+ }
+ mntOpts := &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ InternalData: &InternalData{
+ Cgroups: map[string]string{
+ "cpuset": "/foo/cpuset",
+ "memory": "/foo/memory",
+ },
},
},
}
- mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", Name, &fsOpts)
- if err != nil {
- t.Fatalf("NewMountNamespace(): %v", err)
+ if err := k.VFS().MountAt(ctx, creds, "", pop, Name, mntOpts); err != nil {
+ t.Fatalf("MountAt(/proc): %v", err)
}
return testutil.NewSystem(ctx, t, k.VFS(), mntns)
}
@@ -112,7 +142,7 @@ func TestTasksEmpty(t *testing.T) {
s := setup(t)
defer s.Destroy()
- collector := s.ListDirents(s.PathOpAtRoot("/"))
+ collector := s.ListDirents(s.PathOpAtRoot("/proc"))
s.AssertAllDirentTypes(collector, tasksStaticFiles)
s.AssertDirentOffsets(collector, tasksStaticFilesNextOffs)
}
@@ -138,7 +168,7 @@ func TestTasks(t *testing.T) {
expectedDirents[fmt.Sprintf("%d", i+1)] = linux.DT_DIR
}
- collector := s.ListDirents(s.PathOpAtRoot("/"))
+ collector := s.ListDirents(s.PathOpAtRoot("/proc"))
s.AssertAllDirentTypes(collector, expectedDirents)
s.AssertDirentOffsets(collector, tasksStaticFilesNextOffs)
@@ -178,7 +208,7 @@ func TestTasks(t *testing.T) {
}
// Test lookup.
- for _, path := range []string{"/1", "/2"} {
+ for _, path := range []string{"/proc/1", "/proc/2"} {
fd, err := s.VFS.OpenAt(
s.Ctx,
s.Creds,
@@ -188,6 +218,7 @@ func TestTasks(t *testing.T) {
if err != nil {
t.Fatalf("vfsfs.OpenAt(%q) failed: %v", path, err)
}
+ defer fd.DecRef()
buf := make([]byte, 1)
bufIOSeq := usermem.BytesIOSequence(buf)
if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR {
@@ -198,10 +229,10 @@ func TestTasks(t *testing.T) {
if _, err := s.VFS.OpenAt(
s.Ctx,
s.Creds,
- s.PathOpAtRoot("/9999"),
+ s.PathOpAtRoot("/proc/9999"),
&vfs.OpenOptions{},
); err != syserror.ENOENT {
- t.Fatalf("wrong error from vfsfs.OpenAt(/9999): %v", err)
+ t.Fatalf("wrong error from vfsfs.OpenAt(/proc/9999): %v", err)
}
}
@@ -299,12 +330,13 @@ func TestTasksOffset(t *testing.T) {
fd, err := s.VFS.OpenAt(
s.Ctx,
s.Creds,
- s.PathOpAtRoot("/"),
+ s.PathOpAtRoot("/proc"),
&vfs.OpenOptions{},
)
if err != nil {
t.Fatalf("vfsfs.OpenAt(/) failed: %v", err)
}
+ defer fd.DecRef()
if _, err := fd.Seek(s.Ctx, tc.offset, linux.SEEK_SET); err != nil {
t.Fatalf("Seek(%d, SEEK_SET): %v", tc.offset, err)
}
@@ -341,7 +373,7 @@ func TestTask(t *testing.T) {
t.Fatalf("CreateTask(): %v", err)
}
- collector := s.ListDirents(s.PathOpAtRoot("/1"))
+ collector := s.ListDirents(s.PathOpAtRoot("/proc/1"))
s.AssertAllDirentTypes(collector, taskStaticFiles)
}
@@ -359,14 +391,14 @@ func TestProcSelf(t *testing.T) {
collector := s.WithTemporaryContext(task).ListDirents(&vfs.PathOperation{
Root: s.Root,
Start: s.Root,
- Path: fspath.Parse("/self/"),
+ Path: fspath.Parse("/proc/self/"),
FollowFinalSymlink: true,
})
s.AssertAllDirentTypes(collector, taskStaticFiles)
}
func iterateDir(ctx context.Context, t *testing.T, s *testutil.System, fd *vfs.FileDescription) {
- t.Logf("Iterating: /proc%s", fd.MappedName(ctx))
+ t.Logf("Iterating: %s", fd.MappedName(ctx))
var collector testutil.DirentCollector
if err := fd.IterDirents(ctx, &collector); err != nil {
@@ -409,6 +441,7 @@ func iterateDir(ctx context.Context, t *testing.T, s *testutil.System, fd *vfs.F
t.Errorf("vfsfs.OpenAt(%v) failed: %v", childPath, err)
continue
}
+ defer child.DecRef()
stat, err := child.Stat(ctx, vfs.StatOptions{})
if err != nil {
t.Errorf("Stat(%v) failed: %v", childPath, err)
@@ -429,6 +462,22 @@ func TestTree(t *testing.T) {
defer s.Destroy()
k := kernel.KernelFromContext(s.Ctx)
+
+ pop := &vfs.PathOperation{
+ Root: s.Root,
+ Start: s.Root,
+ Path: fspath.Parse("test-file"),
+ }
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_CREAT,
+ Mode: 0777,
+ }
+ file, err := s.VFS.OpenAt(s.Ctx, s.Creds, pop, opts)
+ if err != nil {
+ t.Fatalf("failed to create test file: %v", err)
+ }
+ defer file.DecRef()
+
var tasks []*kernel.Task
for i := 0; i < 5; i++ {
tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
@@ -436,6 +485,8 @@ func TestTree(t *testing.T) {
if err != nil {
t.Fatalf("CreateTask(): %v", err)
}
+ // Add file to populate /proc/[pid]/fd and fdinfo directories.
+ task.FDTable().NewFDVFS2(task, 0, file, kernel.FDFlags{})
tasks = append(tasks, task)
}
@@ -443,11 +494,12 @@ func TestTree(t *testing.T) {
fd, err := s.VFS.OpenAt(
ctx,
auth.CredentialsFromContext(s.Ctx),
- &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse("/")},
+ &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse("/proc")},
&vfs.OpenOptions{},
)
if err != nil {
- t.Fatalf("vfsfs.OpenAt(/) failed: %v", err)
+ t.Fatalf("vfsfs.OpenAt(/proc) failed: %v", err)
}
iterateDir(ctx, t, s, fd)
+ fd.DecRef()
}
diff --git a/pkg/sentry/fsimpl/sockfs/BUILD b/pkg/sentry/fsimpl/sockfs/BUILD
new file mode 100644
index 000000000..52084ddb5
--- /dev/null
+++ b/pkg/sentry/fsimpl/sockfs/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "sockfs",
+ srcs = ["sockfs.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
new file mode 100644
index 000000000..3f7ad1d65
--- /dev/null
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -0,0 +1,102 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sockfs provides a filesystem implementation for anonymous sockets.
+package sockfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// NewFilesystem creates a new sockfs filesystem.
+//
+// Note that there should only ever be one instance of sockfs.Filesystem,
+// backing a global socket mount.
+func NewFilesystem(vfsObj *vfs.VirtualFilesystem) *vfs.Filesystem {
+ fs, _, err := filesystemType{}.GetFilesystem(nil, vfsObj, nil, "", vfs.GetFilesystemOptions{})
+ if err != nil {
+ panic("failed to create sockfs filesystem")
+ }
+ return fs
+}
+
+// filesystemType implements vfs.FilesystemType.
+type filesystemType struct{}
+
+// GetFilesystem implements FilesystemType.GetFilesystem.
+func (fsType filesystemType) GetFilesystem(_ context.Context, vfsObj *vfs.VirtualFilesystem, _ *auth.Credentials, _ string, _ vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ fs := &filesystem{}
+ fs.Init(vfsObj, fsType)
+ return fs.VFSFilesystem(), nil, nil
+}
+
+// Name implements FilesystemType.Name.
+//
+// Note that registering sockfs is unnecessary, except for the fact that it
+// will not show up under /proc/filesystems as a result. This is a very minor
+// discrepancy from Linux.
+func (filesystemType) Name() string {
+ return "sockfs"
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+}
+
+// inode implements kernfs.Inode.
+//
+// TODO(gvisor.dev/issue/1476): Add device numbers to this inode (which are
+// not included in InodeAttrs) to store the numbers of the appropriate
+// socket device. Override InodeAttrs.Stat() accordingly.
+type inode struct {
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+}
+
+// Open implements kernfs.Inode.Open.
+func (i *inode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ return nil, syserror.ENXIO
+}
+
+// InitSocket initializes a socket FileDescription, with a corresponding
+// Dentry in mnt.
+//
+// fd should be the FileDescription associated with socketImpl, i.e. its first
+// field. mnt should be the global socket mount, Kernel.socketMount.
+func InitSocket(socketImpl vfs.FileDescriptionImpl, fd *vfs.FileDescription, mnt *vfs.Mount, creds *auth.Credentials) error {
+ fsimpl := mnt.Filesystem().Impl()
+ fs := fsimpl.(*kernfs.Filesystem)
+
+ // File mode matches net/socket.c:sock_alloc.
+ filemode := linux.FileMode(linux.S_IFSOCK | 0600)
+ i := &inode{}
+ i.Init(creds, fs.NextIno(), filemode)
+
+ d := &kernfs.Dentry{}
+ d.Init(i)
+
+ opts := &vfs.FileDescriptionOptions{UseDentryMetadata: true}
+ if err := fd.Init(socketImpl, linux.O_RDWR, mnt, d.VFSDentry(), opts); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index c36c4fa11..5c617270e 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -39,10 +39,15 @@ type filesystem struct {
kernfs.Filesystem
}
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
-func (FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
fs := &filesystem{}
- fs.Filesystem.Init(vfsObj)
+ fs.Filesystem.Init(vfsObj, &fsType)
k := kernel.KernelFromContext(ctx)
maxCPUCores := k.ApplicationCores()
defaultSysDirMode := linux.FileMode(0755)
@@ -94,15 +99,17 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte
return &d.dentry
}
-// SetStat implements kernfs.Inode.SetStat.
-func (d *dir) SetStat(fs *vfs.Filesystem, opts vfs.SetStatOptions) error {
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
}
// Open implements kernfs.Inode.Open.
func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &kernfs.GenericDirectoryFD{}
- fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts)
+ if err := fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts); err != nil {
+ return nil, err
+ }
return fd.VFSFileDescription(), nil
}
diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD
index e4f36f4ae..0e4053a46 100644
--- a/pkg/sentry/fsimpl/testutil/BUILD
+++ b/pkg/sentry/fsimpl/testutil/BUILD
@@ -16,12 +16,14 @@ go_library(
"//pkg/cpuid",
"//pkg/fspath",
"//pkg/memutil",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/sched",
"//pkg/sentry/limits",
"//pkg/sentry/loader",
+ "//pkg/sentry/mm",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
"//pkg/sentry/platform/kvm",
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
index 488478e29..c16a36cdb 100644
--- a/pkg/sentry/fsimpl/testutil/kernel.go
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -23,13 +23,16 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/loader"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/time"
@@ -123,10 +126,17 @@ func Boot() (*kernel.Kernel, error) {
// CreateTask creates a new bare bones task for tests.
func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns *vfs.MountNamespace, root, cwd vfs.VirtualDentry) (*kernel.Task, error) {
k := kernel.KernelFromContext(ctx)
+ exe, err := newFakeExecutable(ctx, k.VFS(), auth.CredentialsFromContext(ctx), root)
+ if err != nil {
+ return nil, err
+ }
+ m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation)
+ m.SetExecutable(fsbridge.NewVFSFile(exe))
+
config := &kernel.TaskConfig{
Kernel: k,
ThreadGroup: tc,
- TaskContext: &kernel.TaskContext{Name: name},
+ TaskContext: &kernel.TaskContext{Name: name, MemoryManager: m},
Credentials: auth.CredentialsFromContext(ctx),
NetworkNamespace: k.RootNetworkNamespace(),
AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()),
@@ -135,10 +145,25 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns
AbstractSocketNamespace: kernel.NewAbstractSocketNamespace(),
MountNamespaceVFS2: mntns,
FSContext: kernel.NewFSContextVFS2(root, cwd, 0022),
+ FDTable: k.NewFDTable(),
}
return k.TaskSet().NewTask(config)
}
+func newFakeExecutable(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, root vfs.VirtualDentry) (*vfs.FileDescription, error) {
+ const name = "executable"
+ pop := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(name),
+ }
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_CREAT,
+ Mode: 0777,
+ }
+ return vfsObj.OpenAt(ctx, creds, pop, opts)
+}
+
func createMemoryFile() (*pgalloc.MemoryFile, error) {
const memfileName = "test-memory"
memfd, err := memutil.CreateMemFD(memfileName, 0)
diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go
index e16808c63..0556af877 100644
--- a/pkg/sentry/fsimpl/testutil/testutil.go
+++ b/pkg/sentry/fsimpl/testutil/testutil.go
@@ -162,6 +162,9 @@ func (s *System) ListDirents(pop *vfs.PathOperation) *DirentCollector {
// exactly the specified set of expected entries. AssertAllDirentTypes respects
// collector.skipDots, and implicitly checks for "." and ".." accordingly.
func (s *System) AssertAllDirentTypes(collector *DirentCollector, expected map[string]DirentType) {
+ if expected == nil {
+ expected = make(map[string]DirentType)
+ }
// Also implicitly check for "." and "..", if enabled.
if !collector.skipDots {
expected["."] = linux.DT_DIR
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index 57abd5583..4e6cd3491 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -24,6 +24,7 @@ go_library(
"filesystem.go",
"named_pipe.go",
"regular_file.go",
+ "socket_file.go",
"symlink.go",
"tmpfs.go",
],
@@ -46,9 +47,11 @@ go_library(
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
+ "//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
"//pkg/sentry/vfs/lock",
+ "//pkg/sentry/vfs/memxattr",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/fsimpl/tmpfs/device_file.go b/pkg/sentry/fsimpl/tmpfs/device_file.go
index 84b181b90..83bf885ee 100644
--- a/pkg/sentry/fsimpl/tmpfs/device_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/device_file.go
@@ -15,6 +15,8 @@
package tmpfs
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -33,6 +35,14 @@ func (fs *filesystem) newDeviceFile(creds *auth.Credentials, mode linux.FileMode
major: major,
minor: minor,
}
+ switch kind {
+ case vfs.BlockDevice:
+ mode |= linux.S_IFBLK
+ case vfs.CharDevice:
+ mode |= linux.S_IFCHR
+ default:
+ panic(fmt.Sprintf("invalid DeviceKind: %v", kind))
+ }
file.inode.init(file, fs, creds, mode)
file.inode.nlink = 1 // from parent directory
return &file.inode
diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index b4380af38..45712c9b9 100644
--- a/pkg/sentry/fsimpl/tmpfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -34,16 +34,11 @@ type directory struct {
func (fs *filesystem) newDirectory(creds *auth.Credentials, mode linux.FileMode) *inode {
dir := &directory{}
- dir.inode.init(dir, fs, creds, mode)
+ dir.inode.init(dir, fs, creds, linux.S_IFDIR|mode)
dir.inode.nlink = 2 // from "." and parent directory or ".." for root
return &dir.inode
}
-func (i *inode) isDir() bool {
- _, ok := i.impl.(*directory)
- return ok
-}
-
type directoryFD struct {
fileDescription
vfs.DirectoryFileDescriptionDefaultImpl
@@ -73,6 +68,8 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
fs.mu.Lock()
defer fs.mu.Unlock()
+ fd.inode().touchAtime(fd.vfsfd.Mount())
+
if fd.off == 0 {
if err := cb.Handle(vfs.Dirent{
Name: ".",
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index e1b551422..660f5a29b 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -20,6 +20,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -40,10 +42,13 @@ func stepLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) {
if !d.inode.isDir() {
return nil, syserror.ENOTDIR
}
- if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
afterSymlink:
+ if len(rp.Component()) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
nextVFSD, err := rp.ResolveComponent(&d.vfsd)
if err != nil {
return nil, err
@@ -55,7 +60,7 @@ afterSymlink:
}
next := nextVFSD.Impl().(*dentry)
if symlink, ok := next.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
- // TODO(gvisor.dev/issues/1197): Symlink traversals updates
+ // TODO(gvisor.dev/issue/1197): Symlink traversals updates
// access time.
if err := rp.HandleSymlink(symlink.target); err != nil {
return nil, err
@@ -124,13 +129,16 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa
if err != nil {
return err
}
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
name := rp.Component()
if name == "." || name == ".." {
return syserror.EEXIST
}
+ if len(name) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
+ }
// Call parent.vfsd.Child() instead of stepLocked() or rp.ResolveChild(),
// because if the child exists we want to return EEXIST immediately instead
// of attempting symlink/mount traversal.
@@ -151,7 +159,22 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa
return err
}
defer mnt.EndWrite()
- return create(parent, name)
+ if err := create(parent, name); err != nil {
+ return err
+ }
+ parent.inode.touchCMtime()
+ return nil
+}
+
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return err
+ }
+ return d.inode.checkPermissions(creds, ats)
}
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
@@ -166,7 +189,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
if !d.inode.isDir() {
return nil, syserror.ENOTDIR
}
- if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true /* isDir */); err != nil {
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
}
@@ -238,8 +261,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
case linux.S_IFCHR:
childInode = fs.newDeviceFile(rp.Credentials(), opts.Mode, vfs.CharDevice, opts.DevMajor, opts.DevMinor)
case linux.S_IFSOCK:
- // Not yet supported.
- return syserror.EPERM
+ childInode = fs.newSocketFile(rp.Credentials(), opts.Mode, opts.Endpoint)
default:
return syserror.EINVAL
}
@@ -289,7 +311,7 @@ afterTrailingSymlink:
return nil, err
}
// Check for search permission in the parent directory.
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
// Reject attempts to open directories with O_CREAT.
@@ -304,7 +326,7 @@ afterTrailingSymlink:
child, err := stepLocked(rp, parent)
if err == syserror.ENOENT {
// Already checked for searchability above; now check for writability.
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
return nil, err
}
if err := rp.Mount().CheckBeginWrite(); err != nil {
@@ -315,7 +337,12 @@ afterTrailingSymlink:
child := fs.newDentry(fs.newRegularFile(rp.Credentials(), opts.Mode))
parent.vfsd.InsertChild(&child.vfsd, name)
parent.inode.impl.(*directory).childList.PushBack(child)
- return child.open(ctx, rp, &opts, true)
+ fd, err := child.open(ctx, rp, &opts, true)
+ if err != nil {
+ return nil, err
+ }
+ parent.inode.touchCMtime()
+ return fd, nil
}
if err != nil {
return nil, err
@@ -335,7 +362,7 @@ afterTrailingSymlink:
func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, afterCreate bool) (*vfs.FileDescription, error) {
ats := vfs.AccessTypesForOpenFlags(opts)
if !afterCreate {
- if err := d.inode.checkPermissions(rp.Credentials(), ats, d.inode.isDir()); err != nil {
+ if err := d.inode.checkPermissions(rp.Credentials(), ats); err != nil {
return nil, err
}
}
@@ -365,9 +392,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
// Can't open symlinks without O_PATH (which is unimplemented).
return nil, syserror.ELOOP
case *namedPipe:
- return newNamedPipeFD(ctx, impl, rp, &d.vfsd, opts.Flags)
+ return impl.pipe.Open(ctx, rp.Mount(), &d.vfsd, opts.Flags)
case *deviceFile:
return rp.VirtualFilesystem().OpenDeviceSpecialFile(ctx, rp.Mount(), &d.vfsd, impl.kind, impl.major, impl.minor, opts)
+ case *socketFile:
+ return nil, syserror.ENXIO
default:
panic(fmt.Sprintf("unknown inode type: %T", d.inode.impl))
}
@@ -385,6 +414,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
if !ok {
return "", syserror.EINVAL
}
+ symlink.inode.touchAtime(rp.Mount())
return symlink.target, nil
}
@@ -416,7 +446,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
defer mnt.EndWrite()
oldParent := oldParentVD.Dentry().Impl().(*dentry)
- if err := oldParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := oldParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
// Call vfs.Dentry.Child() instead of stepLocked() or rp.ResolveChild(),
@@ -433,7 +463,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
if oldParent != newParent {
// Writability is needed to change renamed's "..".
- if err := renamed.inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true /* isDir */); err != nil {
+ if err := renamed.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
return err
}
}
@@ -443,7 +473,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
}
- if err := newParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := newParent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
replacedVFSD := newParent.vfsd.Child(newName)
@@ -502,7 +532,10 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
oldParent.inode.decLinksLocked()
newParent.inode.incLinksLocked()
}
- // TODO(gvisor.dev/issues/1197): Update timestamps and parent directory
+ oldParent.inode.touchCMtime()
+ newParent.inode.touchCMtime()
+ renamed.inode.touchCtime()
+ // TODO(gvisor.dev/issue/1197): Update timestamps and parent directory
// sizes.
vfsObj.CommitRenameReplaceDentry(renamedVFSD, &newParent.vfsd, newName, replacedVFSD)
return nil
@@ -516,7 +549,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
if err != nil {
return err
}
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
name := rp.Component()
@@ -552,6 +585,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
parent.inode.decLinksLocked() // from child's ".."
child.inode.decLinksLocked()
vfsObj.CommitDeleteDentry(childVFSD)
+ parent.inode.touchCMtime()
return nil
}
@@ -563,7 +597,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
if err != nil {
return err
}
- return d.inode.setStat(opts.Stat)
+ return d.inode.setStat(ctx, rp.Credentials(), &opts.Stat)
}
// StatAt implements vfs.FilesystemImpl.StatAt.
@@ -587,7 +621,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
if err != nil {
return linux.Statfs{}, err
}
- // TODO(gvisor.dev/issues/1197): Actually implement statfs.
+ // TODO(gvisor.dev/issue/1197): Actually implement statfs.
return linux.Statfs{}, syserror.ENOSYS
}
@@ -609,7 +643,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
if err != nil {
return err
}
- if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true /* isDir */); err != nil {
+ if err := parent.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
name := rp.Component()
@@ -641,55 +675,68 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
parent.inode.impl.(*directory).childList.Remove(child)
child.inode.decLinksLocked()
vfsObj.CommitDeleteDentry(childVFSD)
+ parent.inode.touchCMtime()
return nil
}
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return nil, err
+ }
+ switch impl := d.inode.impl.(type) {
+ case *socketFile:
+ return impl.ep, nil
+ default:
+ return nil, syserror.ECONNREFUSED
+ }
+}
+
// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
-func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
- _, err := resolveLocked(rp)
+ d, err := resolveLocked(rp)
if err != nil {
return nil, err
}
- // TODO(b/127675828): support extended attributes
- return nil, syserror.ENOTSUP
+ return d.inode.listxattr(size)
}
// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
-func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
- _, err := resolveLocked(rp)
+ d, err := resolveLocked(rp)
if err != nil {
return "", err
}
- // TODO(b/127675828): support extended attributes
- return "", syserror.ENOTSUP
+ return d.inode.getxattr(rp.Credentials(), &opts)
}
// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
fs.mu.RLock()
defer fs.mu.RUnlock()
- _, err := resolveLocked(rp)
+ d, err := resolveLocked(rp)
if err != nil {
return err
}
- // TODO(b/127675828): support extended attributes
- return syserror.ENOTSUP
+ return d.inode.setxattr(rp.Credentials(), &opts)
}
// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
fs.mu.RLock()
defer fs.mu.RUnlock()
- _, err := resolveLocked(rp)
+ d, err := resolveLocked(rp)
if err != nil {
return err
}
- // TODO(b/127675828): support extended attributes
- return syserror.ENOTSUP
+ return d.inode.removexattr(rp.Credentials(), name)
}
// PrependPath implements vfs.FilesystemImpl.PrependPath.
diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
index 0c57fdca3..8d77b3fa8 100644
--- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go
+++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
@@ -16,10 +16,8 @@ package tmpfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -33,27 +31,8 @@ type namedPipe struct {
// * fs.mu must be locked.
// * rp.Mount().CheckBeginWrite() has been called successfully.
func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode {
- file := &namedPipe{pipe: pipe.NewVFSPipe(pipe.DefaultPipeSize, usermem.PageSize)}
- file.inode.init(file, fs, creds, mode)
+ file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)}
+ file.inode.init(file, fs, creds, linux.S_IFIFO|mode)
file.inode.nlink = 1 // Only the parent has a link.
return &file.inode
}
-
-// namedPipeFD implements vfs.FileDescriptionImpl. Methods are implemented
-// entirely via struct embedding.
-type namedPipeFD struct {
- fileDescription
-
- *pipe.VFSPipeFD
-}
-
-func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
- var err error
- var fd namedPipeFD
- fd.VFSPipeFD, err = np.pipe.NewVFSPipeFD(ctx, vfsd, &fd.vfsfd, flags)
- if err != nil {
- return nil, err
- }
- fd.vfsfd.Init(&fd, flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{})
- return &fd.vfsfd, nil
-}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index 711442424..57e5e28ec 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -89,7 +89,7 @@ func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMod
file := &regularFile{
memFile: fs.memFile,
}
- file.inode.init(file, fs, creds, mode)
+ file.inode.init(file, fs, creds, linux.S_IFREG|mode)
file.inode.nlink = 1 // from parent directory
return &file.inode
}
@@ -286,7 +286,8 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
rw := getRegularFileReadWriter(f, offset)
n, err := dst.CopyOutFrom(ctx, rw)
putRegularFileReadWriter(rw)
- return int64(n), err
+ fd.inode().touchAtime(fd.vfsfd.Mount())
+ return n, err
}
// Read implements vfs.FileDescriptionImpl.Read.
@@ -308,14 +309,22 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
return 0, nil
}
f := fd.inode().impl.(*regularFile)
- end := offset + srclen
- if end < offset {
+ if end := offset + srclen; end < offset {
// Overflow.
return 0, syserror.EFBIG
}
+
+ var err error
+ srclen, err = vfs.CheckLimit(ctx, offset, srclen)
+ if err != nil {
+ return 0, err
+ }
+ src = src.TakeFirst64(srclen)
+
f.inode.mu.Lock()
rw := getRegularFileReadWriter(f, offset)
n, err := src.CopyInTo(ctx, rw)
+ fd.inode().touchCMtimeLocked()
f.inode.mu.Unlock()
putRegularFileReadWriter(rw)
return n, err
diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/sentry/fsimpl/tmpfs/socket_file.go
index ad3cc24fa..25c2321af 100644
--- a/pkg/tcpip/packet_buffer_state.go
+++ b/pkg/sentry/fsimpl/tmpfs/socket_file.go
@@ -12,16 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package tcpip
+package tmpfs
-import "gvisor.dev/gvisor/pkg/tcpip/buffer"
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+)
-// beforeSave is invoked by stateify.
-func (pk *PacketBuffer) beforeSave() {
- // Non-Data fields may be slices of the Data field. This causes
- // problems for SR, so during save we make each header independent.
- pk.Header = pk.Header.DeepCopy()
- pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...)
- pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...)
- pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...)
+// socketFile is a socket (=S_IFSOCK) tmpfs file.
+type socketFile struct {
+ inode inode
+ ep transport.BoundEndpoint
+}
+
+func (fs *filesystem) newSocketFile(creds *auth.Credentials, mode linux.FileMode, ep transport.BoundEndpoint) *inode {
+ file := &socketFile{ep: ep}
+ file.inode.init(file, fs, creds, mode)
+ file.inode.nlink = 1 // from parent directory
+ return &file.inode
}
diff --git a/pkg/sentry/fsimpl/tmpfs/stat_test.go b/pkg/sentry/fsimpl/tmpfs/stat_test.go
index ebe035dee..d4f59ee5b 100644
--- a/pkg/sentry/fsimpl/tmpfs/stat_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/stat_test.go
@@ -29,7 +29,7 @@ func TestStatAfterCreate(t *testing.T) {
mode := linux.FileMode(0644)
// Run with different file types.
- // TODO(gvisor.dev/issues/1197): Also test symlinks and sockets.
+ // TODO(gvisor.dev/issue/1197): Also test symlinks and sockets.
for _, typ := range []string{"file", "dir", "pipe"} {
t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) {
var (
@@ -140,7 +140,7 @@ func TestSetStatAtime(t *testing.T) {
Mask: 0,
Atime: linux.NsecToStatxTimestamp(100),
}}); err != nil {
- t.Errorf("SetStat atime without mask failed: %v")
+ t.Errorf("SetStat atime without mask failed: %v", err)
}
// Atime should be unchanged.
if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
@@ -155,7 +155,7 @@ func TestSetStatAtime(t *testing.T) {
Atime: linux.NsecToStatxTimestamp(100),
}
if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil {
- t.Errorf("SetStat atime with mask failed: %v")
+ t.Errorf("SetStat atime with mask failed: %v", err)
}
if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
t.Errorf("Stat got error: %v", err)
@@ -169,7 +169,7 @@ func TestSetStat(t *testing.T) {
mode := linux.FileMode(0644)
// Run with different file types.
- // TODO(gvisor.dev/issues/1197): Also test symlinks and sockets.
+ // TODO(gvisor.dev/issue/1197): Also test symlinks and sockets.
for _, typ := range []string{"file", "dir", "pipe"} {
t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) {
var (
@@ -205,7 +205,7 @@ func TestSetStat(t *testing.T) {
Mask: 0,
Atime: linux.NsecToStatxTimestamp(100),
}}); err != nil {
- t.Errorf("SetStat atime without mask failed: %v")
+ t.Errorf("SetStat atime without mask failed: %v", err)
}
// Atime should be unchanged.
if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
@@ -220,7 +220,7 @@ func TestSetStat(t *testing.T) {
Atime: linux.NsecToStatxTimestamp(100),
}
if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil {
- t.Errorf("SetStat atime with mask failed: %v")
+ t.Errorf("SetStat atime with mask failed: %v", err)
}
if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
t.Errorf("Stat got error: %v", err)
diff --git a/pkg/sentry/fsimpl/tmpfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go
index 5246aca84..47e075ed4 100644
--- a/pkg/sentry/fsimpl/tmpfs/symlink.go
+++ b/pkg/sentry/fsimpl/tmpfs/symlink.go
@@ -15,6 +15,7 @@
package tmpfs
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
@@ -27,7 +28,7 @@ func (fs *filesystem) newSymlink(creds *auth.Credentials, target string) *inode
link := &symlink{
target: target,
}
- link.inode.init(link, fs, creds, 0777)
+ link.inode.init(link, fs, creds, linux.S_IFLNK|0777)
link.inode.nlink = 1 // from parent directory
return &link.inode
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 521206305..a59b24d45 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -27,6 +27,7 @@ package tmpfs
import (
"fmt"
"math"
+ "strings"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -37,6 +38,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sentry/vfs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/memxattr"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -63,6 +65,27 @@ type filesystem struct {
nextInoMinusOne uint64 // accessed using atomic memory operations
}
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// FilesystemOpts is used to pass configuration data to tmpfs.
+type FilesystemOpts struct {
+ // RootFileType is the FileType of the filesystem root. Valid values
+ // are: S_IFDIR, S_IFREG, and S_IFLNK. Defaults to S_IFDIR.
+ RootFileType uint16
+
+ // RootSymlinkTarget is the target of the root symlink. Only valid if
+ // RootFileType == S_IFLNK.
+ RootSymlinkTarget string
+
+ // FilesystemType allows setting a different FilesystemType for this
+ // tmpfs filesystem. This allows tmpfs to "impersonate" other
+ // filesystems, like ramdiskfs and cgroupfs.
+ FilesystemType vfs.FilesystemType
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx)
@@ -74,9 +97,33 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
memFile: memFileProvider.MemoryFile(),
clock: clock,
}
- fs.vfsfs.Init(vfsObj, &fs)
- root := fs.newDentry(fs.newDirectory(creds, 01777))
- return &fs.vfsfs, &root.vfsd, nil
+
+ rootFileType := uint16(linux.S_IFDIR)
+ newFSType := vfs.FilesystemType(&fstype)
+ tmpfsOpts, ok := opts.InternalData.(FilesystemOpts)
+ if ok {
+ if tmpfsOpts.RootFileType != 0 {
+ rootFileType = tmpfsOpts.RootFileType
+ }
+ if tmpfsOpts.FilesystemType != nil {
+ newFSType = tmpfsOpts.FilesystemType
+ }
+ }
+
+ fs.vfsfs.Init(vfsObj, newFSType, &fs)
+
+ var root *inode
+ switch rootFileType {
+ case linux.S_IFREG:
+ root = fs.newRegularFile(creds, 0777)
+ case linux.S_IFLNK:
+ root = fs.newSymlink(creds, tmpfsOpts.RootSymlinkTarget)
+ case linux.S_IFDIR:
+ root = fs.newDirectory(creds, 01777)
+ default:
+ return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType)
+ }
+ return &fs.vfsfs, &fs.newDentry(root).vfsd, nil
}
// Release implements vfs.FilesystemImpl.Release.
@@ -141,10 +188,15 @@ type inode struct {
// filesystem.RmdirAt() drops the reference.
refs int64
+ // xattrs implements extended attributes.
+ //
+ // TODO(b/148380782): Support xattrs other than user.*
+ xattrs memxattr.SimpleExtendedAttributes
+
// Inode metadata. Writing multiple fields atomically requires holding
// mu, othewise atomic operations can be used.
mu sync.Mutex
- mode uint32 // excluding file type bits, which are based on impl
+ mode uint32 // file type and mode
nlink uint32 // protected by filesystem.mu instead of inode.mu
uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
gid uint32 // auth.KGID, but ...
@@ -168,6 +220,9 @@ type inode struct {
const maxLinks = math.MaxUint32
func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) {
+ if mode.FileType() == 0 {
+ panic("file type is required in FileMode")
+ }
i.clock = fs.clock
i.refs = 1
i.mode = uint32(mode)
@@ -242,8 +297,9 @@ func (i *inode) decRef() {
}
}
-func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error {
- return vfs.GenericCheckPermissions(creds, ats, isDir, uint16(atomic.LoadUint32(&i.mode)), auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid)))
+func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ return vfs.GenericCheckPermissions(creds, ats, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid)))
}
// Go won't inline this function, and returning linux.Statx (which is quite
@@ -266,51 +322,50 @@ func (i *inode) statTo(stat *linux.Statx) {
stat.Atime = linux.NsecToStatxTimestamp(i.atime)
stat.Ctime = linux.NsecToStatxTimestamp(i.ctime)
stat.Mtime = linux.NsecToStatxTimestamp(i.mtime)
- // TODO(gvisor.dev/issues/1197): Device number.
+ // TODO(gvisor.dev/issue/1197): Device number.
switch impl := i.impl.(type) {
case *regularFile:
- stat.Mode |= linux.S_IFREG
stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
stat.Size = uint64(atomic.LoadUint64(&impl.size))
// In tmpfs, this will be FileRangeSet.Span() / 512 (but also cached in
// a uint64 accessed using atomic memory operations to avoid taking
// locks).
stat.Blocks = allocatedBlocksForSize(stat.Size)
- case *directory:
- stat.Mode |= linux.S_IFDIR
case *symlink:
- stat.Mode |= linux.S_IFLNK
stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
stat.Size = uint64(len(impl.target))
stat.Blocks = allocatedBlocksForSize(stat.Size)
- case *namedPipe:
- stat.Mode |= linux.S_IFIFO
case *deviceFile:
- switch impl.kind {
- case vfs.BlockDevice:
- stat.Mode |= linux.S_IFBLK
- case vfs.CharDevice:
- stat.Mode |= linux.S_IFCHR
- }
stat.RdevMajor = impl.major
stat.RdevMinor = impl.minor
+ case *socketFile, *directory, *namedPipe:
+ // Nothing to do.
default:
panic(fmt.Sprintf("unknown inode type: %T", i.impl))
}
}
-func (i *inode) setStat(stat linux.Statx) error {
+func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx) error {
if stat.Mask == 0 {
return nil
}
+ if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME|linux.STATX_SIZE) != 0 {
+ return syserror.EPERM
+ }
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
+ return err
+ }
i.mu.Lock()
+ defer i.mu.Unlock()
var (
needsMtimeBump bool
needsCtimeBump bool
)
mask := stat.Mask
if mask&linux.STATX_MODE != 0 {
- atomic.StoreUint32(&i.mode, uint32(stat.Mode))
+ ft := atomic.LoadUint32(&i.mode) & linux.S_IFMT
+ atomic.StoreUint32(&i.mode, ft|uint32(stat.Mode&^linux.S_IFMT))
needsCtimeBump = true
}
if mask&linux.STATX_UID != 0 {
@@ -338,29 +393,41 @@ func (i *inode) setStat(stat linux.Statx) error {
return syserror.EINVAL
}
}
+ now := i.clock.Now().Nanoseconds()
if mask&linux.STATX_ATIME != 0 {
- atomic.StoreInt64(&i.atime, stat.Atime.ToNsecCapped())
+ if stat.Atime.Nsec == linux.UTIME_NOW {
+ atomic.StoreInt64(&i.atime, now)
+ } else {
+ atomic.StoreInt64(&i.atime, stat.Atime.ToNsecCapped())
+ }
needsCtimeBump = true
}
if mask&linux.STATX_MTIME != 0 {
- atomic.StoreInt64(&i.mtime, stat.Mtime.ToNsecCapped())
+ if stat.Mtime.Nsec == linux.UTIME_NOW {
+ atomic.StoreInt64(&i.mtime, now)
+ } else {
+ atomic.StoreInt64(&i.mtime, stat.Mtime.ToNsecCapped())
+ }
needsCtimeBump = true
// Ignore the mtime bump, since we just set it ourselves.
needsMtimeBump = false
}
if mask&linux.STATX_CTIME != 0 {
- atomic.StoreInt64(&i.ctime, stat.Ctime.ToNsecCapped())
+ if stat.Ctime.Nsec == linux.UTIME_NOW {
+ atomic.StoreInt64(&i.ctime, now)
+ } else {
+ atomic.StoreInt64(&i.ctime, stat.Ctime.ToNsecCapped())
+ }
// Ignore the ctime bump, since we just set it ourselves.
needsCtimeBump = false
}
- now := i.clock.Now().Nanoseconds()
if needsMtimeBump {
atomic.StoreInt64(&i.mtime, now)
}
if needsCtimeBump {
atomic.StoreInt64(&i.ctime, now)
}
- i.mu.Unlock()
+
return nil
}
@@ -419,6 +486,8 @@ func (i *inode) direntType() uint8 {
return linux.DT_DIR
case *symlink:
return linux.DT_LNK
+ case *socketFile:
+ return linux.DT_SOCK
case *deviceFile:
switch impl.kind {
case vfs.BlockDevice:
@@ -433,6 +502,96 @@ func (i *inode) direntType() uint8 {
}
}
+func (i *inode) isDir() bool {
+ return linux.FileMode(i.mode).FileType() == linux.S_IFDIR
+}
+
+func (i *inode) touchAtime(mnt *vfs.Mount) {
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return
+ }
+ now := i.clock.Now().Nanoseconds()
+ i.mu.Lock()
+ atomic.StoreInt64(&i.atime, now)
+ i.mu.Unlock()
+ mnt.EndWrite()
+}
+
+// Preconditions: The caller has called vfs.Mount.CheckBeginWrite().
+func (i *inode) touchCtime() {
+ now := i.clock.Now().Nanoseconds()
+ i.mu.Lock()
+ atomic.StoreInt64(&i.ctime, now)
+ i.mu.Unlock()
+}
+
+// Preconditions: The caller has called vfs.Mount.CheckBeginWrite().
+func (i *inode) touchCMtime() {
+ now := i.clock.Now().Nanoseconds()
+ i.mu.Lock()
+ atomic.StoreInt64(&i.mtime, now)
+ atomic.StoreInt64(&i.ctime, now)
+ i.mu.Unlock()
+}
+
+// Preconditions: The caller has called vfs.Mount.CheckBeginWrite() and holds
+// inode.mu.
+func (i *inode) touchCMtimeLocked() {
+ now := i.clock.Now().Nanoseconds()
+ atomic.StoreInt64(&i.mtime, now)
+ atomic.StoreInt64(&i.ctime, now)
+}
+
+func (i *inode) listxattr(size uint64) ([]string, error) {
+ return i.xattrs.Listxattr(size)
+}
+
+func (i *inode) getxattr(creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) {
+ if err := i.checkPermissions(creds, vfs.MayRead); err != nil {
+ return "", err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return "", syserror.EOPNOTSUPP
+ }
+ if !i.userXattrSupported() {
+ return "", syserror.ENODATA
+ }
+ return i.xattrs.Getxattr(opts)
+}
+
+func (i *inode) setxattr(creds *auth.Credentials, opts *vfs.SetxattrOptions) error {
+ if err := i.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ if !i.userXattrSupported() {
+ return syserror.EPERM
+ }
+ return i.xattrs.Setxattr(opts)
+}
+
+func (i *inode) removexattr(creds *auth.Credentials, name string) error {
+ if err := i.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ if !i.userXattrSupported() {
+ return syserror.EPERM
+ }
+ return i.xattrs.Removexattr(name)
+}
+
+// Extended attributes in the user.* namespace are only supported for regular
+// files and directories.
+func (i *inode) userXattrSupported() bool {
+ filetype := linux.S_IFMT & atomic.LoadUint32(&i.mode)
+ return filetype == linux.S_IFREG || filetype == linux.S_IFDIR
+}
+
// fileDescription is embedded by tmpfs implementations of
// vfs.FileDescriptionImpl.
type fileDescription struct {
@@ -457,5 +616,26 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
- return fd.inode().setStat(opts.Stat)
+ creds := auth.CredentialsFromContext(ctx)
+ return fd.inode().setStat(ctx, creds, &opts.Stat)
+}
+
+// Listxattr implements vfs.FileDescriptionImpl.Listxattr.
+func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
+ return fd.inode().listxattr(size)
+}
+
+// Getxattr implements vfs.FileDescriptionImpl.Getxattr.
+func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) {
+ return fd.inode().getxattr(auth.CredentialsFromContext(ctx), &opts)
+}
+
+// Setxattr implements vfs.FileDescriptionImpl.Setxattr.
+func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error {
+ return fd.inode().setxattr(auth.CredentialsFromContext(ctx), &opts)
+}
+
+// Removexattr implements vfs.FileDescriptionImpl.Removexattr.
+func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
+ return fd.inode().removexattr(auth.CredentialsFromContext(ctx), name)
}
diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go
index c16667e7f..029af3025 100644
--- a/pkg/sentry/inet/namespace.go
+++ b/pkg/sentry/inet/namespace.go
@@ -23,7 +23,10 @@ type Namespace struct {
// creator allows kernel to create new network stack for network namespaces.
// If nil, no networking will function if network is namespaced.
- creator NetworkStackCreator
+ //
+ // At afterLoad(), creator will be used to create network stack. Stateify
+ // needs to wait for this field to be loaded before calling afterLoad().
+ creator NetworkStackCreator `state:"wait"`
// isRoot indicates whether this is the root network namespace.
isRoot bool
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index beba29a09..e47af66d6 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -169,6 +169,9 @@ go_library(
"//pkg/sentry/fs/lock",
"//pkg/sentry/fs/timerfd",
"//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/fsimpl/pipefs",
+ "//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/hostcpu",
"//pkg/sentry/inet",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index 8bffb78fc..592650923 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -296,8 +296,10 @@ func (*readyCallback) Callback(w *waiter.Entry) {
e.waitingList.Remove(entry)
e.readyList.PushBack(entry)
entry.curList = &e.readyList
+ e.listsMu.Unlock()
e.Notify(waiter.EventIn)
+ return
}
e.listsMu.Unlock()
diff --git a/pkg/sentry/kernel/epoll/epoll_state.go b/pkg/sentry/kernel/epoll/epoll_state.go
index a0d35d350..8e9f200d0 100644
--- a/pkg/sentry/kernel/epoll/epoll_state.go
+++ b/pkg/sentry/kernel/epoll/epoll_state.go
@@ -38,11 +38,14 @@ func (e *EventPoll) afterLoad() {
}
}
- for it := e.waitingList.Front(); it != nil; it = it.Next() {
- if it.id.File.Readiness(it.mask) != 0 {
- e.waitingList.Remove(it)
- e.readyList.PushBack(it)
- it.curList = &e.readyList
+ for it := e.waitingList.Front(); it != nil; {
+ entry := it
+ it = it.Next()
+
+ if entry.id.File.Readiness(entry.mask) != 0 {
+ e.waitingList.Remove(entry)
+ e.readyList.PushBack(entry)
+ entry.curList = &e.readyList
e.Notify(waiter.EventIn)
}
}
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 58001d56c..ed40b5303 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -191,10 +191,12 @@ func (f *FDTable) Size() int {
return int(size)
}
-// forEach iterates over all non-nil files.
+// forEach iterates over all non-nil files in sorted order.
//
// It is the caller's responsibility to acquire an appropriate lock.
func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) {
+ // retries tracks the number of failed TryIncRef attempts for the same FD.
+ retries := 0
fd := int32(0)
for {
file, fileVFS2, flags, ok := f.getAll(fd)
@@ -204,17 +206,26 @@ func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDes
switch {
case file != nil:
if !file.TryIncRef() {
+ retries++
+ if retries > 1000 {
+ panic(fmt.Sprintf("File in FD table has been destroyed. FD: %d, File: %+v, FileOps: %+v", fd, file, file.FileOperations))
+ }
continue // Race caught.
}
fn(fd, file, nil, flags)
file.DecRef()
case fileVFS2 != nil:
if !fileVFS2.TryIncRef() {
+ retries++
+ if retries > 1000 {
+ panic(fmt.Sprintf("File in FD table has been destroyed. FD: %d, File: %+v, Impl: %+v", fd, fileVFS2, fileVFS2.Impl()))
+ }
continue // Race caught.
}
fn(fd, nil, fileVFS2, flags)
fileVFS2.DecRef()
}
+ retries = 0
fd++
}
}
@@ -296,6 +307,61 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
return fds, nil
}
+// NewFDsVFS2 allocates new FDs guaranteed to be the lowest number available
+// greater than or equal to the fd parameter. All files will share the set
+// flags. Success is guaranteed to be all or none.
+func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) {
+ if fd < 0 {
+ // Don't accept negative FDs.
+ return nil, syscall.EINVAL
+ }
+
+ // Default limit.
+ end := int32(math.MaxInt32)
+
+ // Ensure we don't get past the provided limit.
+ if limitSet := limits.FromContext(ctx); limitSet != nil {
+ lim := limitSet.Get(limits.NumberOfFiles)
+ if lim.Cur != limits.Infinity {
+ end = int32(lim.Cur)
+ }
+ if fd >= end {
+ return nil, syscall.EMFILE
+ }
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // From f.next to find available fd.
+ if fd < f.next {
+ fd = f.next
+ }
+
+ // Install all entries.
+ for i := fd; i < end && len(fds) < len(files); i++ {
+ if d, _, _ := f.getVFS2(i); d == nil {
+ f.setVFS2(i, files[len(fds)], flags) // Set the descriptor.
+ fds = append(fds, i) // Record the file descriptor.
+ }
+ }
+
+ // Failure? Unwind existing FDs.
+ if len(fds) < len(files) {
+ for _, i := range fds {
+ f.setVFS2(i, nil, FDFlags{}) // Zap entry.
+ }
+ return nil, syscall.EMFILE
+ }
+
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fds[len(fds)-1] + 1
+ }
+
+ return fds, nil
+}
+
// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for
// the given file description. If it succeeds, it takes a reference on file.
func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
@@ -327,7 +393,7 @@ func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDesc
fd = f.next
}
for fd < end {
- if d, _, _ := f.get(fd); d == nil {
+ if d, _, _ := f.getVFS2(fd); d == nil {
f.setVFS2(fd, file, flags)
if fd == f.next {
// Update next search start position.
@@ -447,7 +513,10 @@ func (f *FDTable) GetVFS2(fd int32) (*vfs.FileDescription, FDFlags) {
}
}
-// GetFDs returns a list of valid fds.
+// GetFDs returns a sorted list of valid fds.
+//
+// Precondition: The caller must be running on the task goroutine, or Task.mu
+// must be locked.
func (f *FDTable) GetFDs() []int32 {
fds := make([]int32, 0, int(atomic.LoadInt32(&f.used)))
f.forEach(func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) {
@@ -522,7 +591,9 @@ func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) {
case orig2 != nil:
orig2.IncRef()
}
- f.setAll(fd, nil, nil, FDFlags{}) // Zap entry.
+ if orig != nil || orig2 != nil {
+ f.setAll(fd, nil, nil, FDFlags{}) // Zap entry.
+ }
return orig, orig2
}
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 8b76750e9..fef60e636 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -50,6 +50,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/timerfd"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -225,6 +227,11 @@ type Kernel struct {
// by extMu.
nextSocketEntry uint64
+ // socketMount is a disconnected vfs.Mount, not included in k.vfs,
+ // representing a sockfs.filesystem. socketMount is used to back
+ // VirtualDentries representing anonymous sockets.
+ socketMount *vfs.Mount
+
// deviceRegistry is used to save/restore device.SimpleDevices.
deviceRegistry struct{} `state:".(*device.Registry)"`
@@ -248,6 +255,10 @@ type Kernel struct {
// VFS keeps the filesystem state used across the kernel.
vfs vfs.VirtualFilesystem
+ // pipeMount is the Mount used for pipes created by the pipe() and pipe2()
+ // syscalls (as opposed to named pipes created by mknod()).
+ pipeMount *vfs.Mount
+
// If set to true, report address space activation waits as if the task is in
// external wait so that the watchdog doesn't report the task stuck.
SleepForAddressSpaceActivation bool
@@ -348,6 +359,29 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic}
k.futexes = futex.NewManager()
k.netlinkPorts = port.New()
+
+ if VFS2Enabled {
+ if err := k.vfs.Init(); err != nil {
+ return fmt.Errorf("failed to initialize VFS: %v", err)
+ }
+
+ pipeFilesystem := pipefs.NewFilesystem(&k.vfs)
+ defer pipeFilesystem.DecRef()
+ pipeMount, err := k.vfs.NewDisconnectedMount(pipeFilesystem, nil, &vfs.MountOptions{})
+ if err != nil {
+ return fmt.Errorf("failed to create pipefs mount: %v", err)
+ }
+ k.pipeMount = pipeMount
+
+ socketFilesystem := sockfs.NewFilesystem(&k.vfs)
+ defer socketFilesystem.DecRef()
+ socketMount, err := k.vfs.NewDisconnectedMount(socketFilesystem, nil, &vfs.MountOptions{})
+ if err != nil {
+ return fmt.Errorf("failed to initialize socket mount: %v", err)
+ }
+ k.socketMount = socketMount
+ }
+
return nil
}
@@ -467,6 +501,11 @@ func (k *Kernel) flushMountSourceRefs() error {
//
// Precondition: Must be called with the kernel paused.
func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error) (err error) {
+ // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
+ if VFS2Enabled {
+ return nil
+ }
+
ts.mu.RLock()
defer ts.mu.RUnlock()
for t := range ts.Root.tids {
@@ -484,7 +523,7 @@ func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error)
}
func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error {
- // TODO(gvisor.dev/issues/1663): Add save support for VFS2.
+ // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
return ts.forEachFDPaused(func(file *fs.File, _ *vfs.FileDescription) error {
if flags := file.Flags(); !flags.Write {
return nil
@@ -533,17 +572,32 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
}
func (ts *TaskSet) unregisterEpollWaiters() {
+ // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
+ if VFS2Enabled {
+ return
+ }
+
ts.mu.RLock()
defer ts.mu.RUnlock()
+
+ // Tasks that belong to the same process could potentially point to the
+ // same FDTable. So we retain a map of processed ones to avoid
+ // processing the same FDTable multiple times.
+ processed := make(map[*FDTable]struct{})
for t := range ts.Root.tids {
// We can skip locking Task.mu here since the kernel is paused.
- if t.fdTable != nil {
- t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
- if e, ok := file.FileOperations.(*epoll.EventPoll); ok {
- e.UnregisterEpollWaiters()
- }
- })
+ if t.fdTable == nil {
+ continue
}
+ if _, ok := processed[t.fdTable]; ok {
+ continue
+ }
+ t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
+ if e, ok := file.FileOperations.(*epoll.EventPoll); ok {
+ e.UnregisterEpollWaiters()
+ }
+ })
+ processed[t.fdTable] = struct{}{}
}
}
@@ -755,6 +809,8 @@ func (ctx *createProcessContext) Value(key interface{}) interface{} {
return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
+ case inet.CtxStack:
+ return ctx.k.RootNetworkNamespace().Stack()
case ktime.CtxRealtimeClock:
return ctx.k.RealtimeClock()
case limits.CtxLimits:
@@ -1003,9 +1059,15 @@ func (k *Kernel) pauseTimeLocked() {
// This means we'll iterate FDTables shared by multiple tasks repeatedly,
// but ktime.Timer.Pause is idempotent so this is harmless.
if t.fdTable != nil {
- t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
- if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok {
- tfd.PauseTimer()
+ t.fdTable.forEach(func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) {
+ if VFS2Enabled {
+ if tfd, ok := fd.Impl().(*vfs.TimerFileDescription); ok {
+ tfd.PauseTimer()
+ }
+ } else {
+ if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok {
+ tfd.PauseTimer()
+ }
}
})
}
@@ -1033,9 +1095,15 @@ func (k *Kernel) resumeTimeLocked() {
}
}
if t.fdTable != nil {
- t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
- if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok {
- tfd.ResumeTimer()
+ t.fdTable.forEach(func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) {
+ if VFS2Enabled {
+ if tfd, ok := fd.Impl().(*vfs.TimerFileDescription); ok {
+ tfd.ResumeTimer()
+ }
+ } else {
+ if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok {
+ tfd.ResumeTimer()
+ }
}
})
}
@@ -1398,9 +1466,10 @@ func (k *Kernel) SupervisorContext() context.Context {
// +stateify savable
type SocketEntry struct {
socketEntry
- k *Kernel
- Sock *refs.WeakRef
- ID uint64 // Socket table entry number.
+ k *Kernel
+ Sock *refs.WeakRef
+ SockVFS2 *vfs.FileDescription
+ ID uint64 // Socket table entry number.
}
// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
@@ -1423,7 +1492,30 @@ func (k *Kernel) RecordSocket(sock *fs.File) {
k.extMu.Unlock()
}
+// RecordSocketVFS2 adds a VFS2 socket to the system-wide socket table for
+// tracking.
+//
+// Precondition: Caller must hold a reference to sock.
+//
+// Note that the socket table will not hold a reference on the
+// vfs.FileDescription, because we do not support weak refs on VFS2 files.
+func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) {
+ k.extMu.Lock()
+ id := k.nextSocketEntry
+ k.nextSocketEntry++
+ s := &SocketEntry{
+ k: k,
+ ID: id,
+ SockVFS2: sock,
+ }
+ k.sockets.PushBack(s)
+ k.extMu.Unlock()
+}
+
// ListSockets returns a snapshot of all sockets.
+//
+// Callers of ListSockets() in VFS2 should use SocketEntry.SockVFS2.TryIncRef()
+// to get a reference on a socket in the table.
func (k *Kernel) ListSockets() []*SocketEntry {
k.extMu.Lock()
var socks []*SocketEntry
@@ -1434,6 +1526,11 @@ func (k *Kernel) ListSockets() []*SocketEntry {
return socks
}
+// SocketMount returns the global socket mount.
+func (k *Kernel) SocketMount() *vfs.Mount {
+ return k.socketMount
+}
+
// supervisorContext is a privileged context.
type supervisorContext struct {
context.NoopSleeper
@@ -1481,6 +1578,8 @@ func (ctx supervisorContext) Value(key interface{}) interface{} {
return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
+ case inet.CtxStack:
+ return ctx.k.RootNetworkNamespace().Stack()
case ktime.CtxRealtimeClock:
return ctx.k.RealtimeClock()
case limits.CtxLimits:
@@ -1529,3 +1628,8 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) {
func (k *Kernel) VFS() *vfs.VirtualFilesystem {
return &k.vfs
}
+
+// PipeMount returns the pipefs mount.
+func (k *Kernel) PipeMount() *vfs.Mount {
+ return k.pipeMount
+}
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 4c049d5b4..f29dc0472 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -1,25 +1,10 @@
load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-go_template_instance(
- name = "buffer_list",
- out = "buffer_list.go",
- package = "pipe",
- prefix = "buffer",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*buffer",
- "Linker": "*buffer",
- },
-)
-
go_library(
name = "pipe",
srcs = [
- "buffer.go",
- "buffer_list.go",
"device.go",
"node.go",
"pipe.go",
@@ -33,8 +18,8 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/amutex",
+ "//pkg/buffer",
"//pkg/context",
- "//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
@@ -51,7 +36,6 @@ go_test(
name = "pipe_test",
size = "small",
srcs = [
- "buffer_test.go",
"node_test.go",
"pipe_test.go",
],
diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go
deleted file mode 100644
index fe3be5dbd..000000000
--- a/pkg/sentry/kernel/pipe/buffer.go
+++ /dev/null
@@ -1,115 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package pipe
-
-import (
- "io"
-
- "gvisor.dev/gvisor/pkg/safemem"
- "gvisor.dev/gvisor/pkg/sync"
-)
-
-// buffer encapsulates a queueable byte buffer.
-//
-// Note that the total size is slightly less than two pages. This
-// is done intentionally to ensure that the buffer object aligns
-// with runtime internals. We have no hard size or alignment
-// requirements. This two page size will effectively minimize
-// internal fragmentation, but still have a large enough chunk
-// to limit excessive segmentation.
-//
-// +stateify savable
-type buffer struct {
- data [8144]byte
- read int
- write int
- bufferEntry
-}
-
-// Reset resets internal data.
-//
-// This must be called before use.
-func (b *buffer) Reset() {
- b.read = 0
- b.write = 0
-}
-
-// Empty indicates the buffer is empty.
-//
-// This indicates there is no data left to read.
-func (b *buffer) Empty() bool {
- return b.read == b.write
-}
-
-// Full indicates the buffer is full.
-//
-// This indicates there is no capacity left to write.
-func (b *buffer) Full() bool {
- return b.write == len(b.data)
-}
-
-// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
-func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
- dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.write:]))
- n, err := safemem.CopySeq(dst, srcs)
- b.write += int(n)
- return n, err
-}
-
-// WriteFromReader writes to the buffer from an io.Reader.
-func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) {
- dst := b.data[b.write:]
- if count < int64(len(dst)) {
- dst = b.data[b.write:][:count]
- }
- n, err := r.Read(dst)
- b.write += n
- return int64(n), err
-}
-
-// ReadToBlocks implements safemem.Reader.ReadToBlocks.
-func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
- src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write]))
- n, err := safemem.CopySeq(dsts, src)
- b.read += int(n)
- return n, err
-}
-
-// ReadToWriter reads from the buffer into an io.Writer.
-func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) {
- src := b.data[b.read:b.write]
- if count < int64(len(src)) {
- src = b.data[b.read:][:count]
- }
- n, err := w.Write(src)
- if !dup {
- b.read += n
- }
- return int64(n), err
-}
-
-// bufferPool is a pool for buffers.
-var bufferPool = sync.Pool{
- New: func() interface{} {
- return new(buffer)
- },
-}
-
-// newBuffer grabs a new buffer from the pool.
-func newBuffer() *buffer {
- b := bufferPool.Get().(*buffer)
- b.Reset()
- return b
-}
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 08410283f..62c8691f1 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -20,6 +20,7 @@ import (
"sync/atomic"
"syscall"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sync"
@@ -70,10 +71,10 @@ type Pipe struct {
// mu protects all pipe internal state below.
mu sync.Mutex `state:"nosave"`
- // data is the buffer queue of pipe contents.
+ // view is the underlying set of buffers.
//
// This is protected by mu.
- data bufferList
+ view buffer.View
// max is the maximum size of the pipe in bytes. When this max has been
// reached, writers will get EWOULDBLOCK.
@@ -81,11 +82,6 @@ type Pipe struct {
// This is protected by mu.
max int64
- // size is the current size of the pipe in bytes.
- //
- // This is protected by mu.
- size int64
-
// hadWriter indicates if this pipe ever had a writer. Note that this
// does not necessarily indicate there is *currently* a writer, just
// that there has been a writer at some point since the pipe was
@@ -196,7 +192,7 @@ type readOps struct {
limit func(int64)
// read performs the actual read operation.
- read func(*buffer) (int64, error)
+ read func(*buffer.View) (int64, error)
}
// read reads data from the pipe into dst and returns the number of bytes
@@ -213,7 +209,7 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
defer p.mu.Unlock()
// Is the pipe empty?
- if p.size == 0 {
+ if p.view.Size() == 0 {
if !p.HasWriters() {
// There are no writers, return EOF.
return 0, nil
@@ -222,71 +218,13 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
}
// Limit how much we consume.
- if ops.left() > p.size {
- ops.limit(p.size)
- }
-
- done := int64(0)
- for ops.left() > 0 {
- // Pop the first buffer.
- first := p.data.Front()
- if first == nil {
- break
- }
-
- // Copy user data.
- n, err := ops.read(first)
- done += int64(n)
- p.size -= n
-
- // Empty buffer?
- if first.Empty() {
- // Push to the free list.
- p.data.Remove(first)
- bufferPool.Put(first)
- }
-
- // Handle errors.
- if err != nil {
- return done, err
- }
- }
-
- return done, nil
-}
-
-// dup duplicates all data from this pipe into the given writer.
-//
-// There is no blocking behavior implemented here. The writer may propagate
-// some blocking error. All the writes must be complete writes.
-func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- // Is the pipe empty?
- if p.size == 0 {
- if !p.HasWriters() {
- // See above.
- return 0, nil
- }
- return 0, syserror.ErrWouldBlock
- }
-
- // Limit how much we consume.
- if ops.left() > p.size {
- ops.limit(p.size)
+ if ops.left() > p.view.Size() {
+ ops.limit(p.view.Size())
}
- done := int64(0)
- for buf := p.data.Front(); buf != nil; buf = buf.Next() {
- n, err := ops.read(buf)
- done += n
- if err != nil {
- return done, err
- }
- }
-
- return done, nil
+ // Copy user data; the read op is responsible for trimming.
+ done, err := ops.read(&p.view)
+ return done, err
}
type writeOps struct {
@@ -297,7 +235,7 @@ type writeOps struct {
limit func(int64)
// write should write to the provided buffer.
- write func(*buffer) (int64, error)
+ write func(*buffer.View) (int64, error)
}
// write writes data from sv into the pipe and returns the number of bytes
@@ -317,35 +255,28 @@ func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
// POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be
// atomic, but requires no atomicity for writes larger than this.
wanted := ops.left()
- if avail := p.max - p.size; wanted > avail {
+ avail := p.max - p.view.Size()
+ if wanted > avail {
if wanted <= p.atomicIOBytes {
return 0, syserror.ErrWouldBlock
}
ops.limit(avail)
}
- done := int64(0)
- for ops.left() > 0 {
- // Need a new buffer?
- last := p.data.Back()
- if last == nil || last.Full() {
- // Add a new buffer to the data list.
- last = newBuffer()
- p.data.PushBack(last)
- }
-
- // Copy user data.
- n, err := ops.write(last)
- done += int64(n)
- p.size += n
+ // Copy user data.
+ done, err := ops.write(&p.view)
+ if err != nil {
+ return done, err
+ }
- // Handle errors.
- if err != nil {
- return done, err
- }
+ if done < avail {
+ // Non-failure, but short write.
+ return done, nil
}
- if wanted > done {
- // Partial write due to full pipe.
+ if done < wanted {
+ // Partial write due to full pipe. Note that this could also be
+ // the short write case above, we would expect a second call
+ // and the write to return zero bytes in this case.
return done, syserror.ErrWouldBlock
}
@@ -396,7 +327,7 @@ func (p *Pipe) HasWriters() bool {
// Precondition: mu must be held.
func (p *Pipe) rReadinessLocked() waiter.EventMask {
ready := waiter.EventMask(0)
- if p.HasReaders() && p.data.Front() != nil {
+ if p.HasReaders() && p.view.Size() != 0 {
ready |= waiter.EventIn
}
if !p.HasWriters() && p.hadWriter {
@@ -422,7 +353,7 @@ func (p *Pipe) rReadiness() waiter.EventMask {
// Precondition: mu must be held.
func (p *Pipe) wReadinessLocked() waiter.EventMask {
ready := waiter.EventMask(0)
- if p.HasWriters() && p.size < p.max {
+ if p.HasWriters() && p.view.Size() < p.max {
ready |= waiter.EventOut
}
if !p.HasReaders() {
@@ -451,7 +382,7 @@ func (p *Pipe) rwReadiness() waiter.EventMask {
func (p *Pipe) queued() int64 {
p.mu.Lock()
defer p.mu.Unlock()
- return p.size
+ return p.view.Size()
}
// FifoSize implements fs.FifoSizer.FifoSize.
@@ -474,7 +405,7 @@ func (p *Pipe) SetFifoSize(size int64) (int64, error) {
}
p.mu.Lock()
defer p.mu.Unlock()
- if size < p.size {
+ if size < p.view.Size() {
return 0, syserror.EBUSY
}
p.max = size
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
index 80158239e..5a1d4fd57 100644
--- a/pkg/sentry/kernel/pipe/pipe_util.go
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sync"
@@ -49,9 +50,10 @@ func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error)
limit: func(l int64) {
dst = dst.TakeFirst64(l)
},
- read: func(buf *buffer) (int64, error) {
- n, err := dst.CopyOutFrom(ctx, buf)
+ read: func(view *buffer.View) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, view)
dst = dst.DropFirst64(n)
+ view.TrimFront(n)
return n, err
},
})
@@ -70,16 +72,15 @@ func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool)
limit: func(l int64) {
count = l
},
- read: func(buf *buffer) (int64, error) {
- n, err := buf.ReadToWriter(w, count, dup)
+ read: func(view *buffer.View) (int64, error) {
+ n, err := view.ReadToWriter(w, count)
+ if !dup {
+ view.TrimFront(n)
+ }
count -= n
return n, err
},
}
- if dup {
- // There is no notification for dup operations.
- return p.dup(ctx, ops)
- }
n, err := p.read(ctx, ops)
if n > 0 {
p.Notify(waiter.EventOut)
@@ -96,8 +97,8 @@ func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error)
limit: func(l int64) {
src = src.TakeFirst64(l)
},
- write: func(buf *buffer) (int64, error) {
- n, err := src.CopyInTo(ctx, buf)
+ write: func(view *buffer.View) (int64, error) {
+ n, err := src.CopyInTo(ctx, view)
src = src.DropFirst64(n)
return n, err
},
@@ -117,8 +118,8 @@ func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, e
limit: func(l int64) {
count = l
},
- write: func(buf *buffer) (int64, error) {
- n, err := buf.WriteFromReader(r, count)
+ write: func(view *buffer.View) (int64, error) {
+ n, err := view.WriteFromReader(r, count)
count -= n
return n, err
},
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index a5675bd70..b54f08a30 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -49,38 +49,42 @@ type VFSPipe struct {
}
// NewVFSPipe returns an initialized VFSPipe.
-func NewVFSPipe(sizeBytes, atomicIOBytes int64) *VFSPipe {
+func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe {
var vp VFSPipe
- initPipe(&vp.pipe, true /* isNamed */, sizeBytes, atomicIOBytes)
+ initPipe(&vp.pipe, isNamed, sizeBytes, atomicIOBytes)
return &vp
}
-// NewVFSPipeFD opens a named pipe. Named pipes have special blocking semantics
-// during open:
+// ReaderWriterPair returns read-only and write-only FDs for vp.
//
-// "Normally, opening the FIFO blocks until the other end is opened also. A
-// process can open a FIFO in nonblocking mode. In this case, opening for
-// read-only will succeed even if no-one has opened on the write side yet,
-// opening for write-only will fail with ENXIO (no such device or address)
-// unless the other end has already been opened. Under Linux, opening a FIFO
-// for read and write will succeed both in blocking and nonblocking mode. POSIX
-// leaves this behavior undefined. This can be used to open a FIFO for writing
-// while there are no readers available." - fifo(7)
-func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) {
+// Preconditions: statusFlags should not contain an open access mode.
+func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) {
+ return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags)
+}
+
+// Open opens the pipe represented by vp.
+func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, error) {
vp.mu.Lock()
defer vp.mu.Unlock()
- readable := vfs.MayReadFileWithOpenFlags(flags)
- writable := vfs.MayWriteFileWithOpenFlags(flags)
+ readable := vfs.MayReadFileWithOpenFlags(statusFlags)
+ writable := vfs.MayWriteFileWithOpenFlags(statusFlags)
if !readable && !writable {
return nil, syserror.EINVAL
}
- vfd, err := vp.open(vfsd, vfsfd, flags)
- if err != nil {
- return nil, err
- }
+ fd := vp.newFD(mnt, vfsd, statusFlags)
+ // Named pipes have special blocking semantics during open:
+ //
+ // "Normally, opening the FIFO blocks until the other end is opened also. A
+ // process can open a FIFO in nonblocking mode. In this case, opening for
+ // read-only will succeed even if no-one has opened on the write side yet,
+ // opening for write-only will fail with ENXIO (no such device or address)
+ // unless the other end has already been opened. Under Linux, opening a
+ // FIFO for read and write will succeed both in blocking and nonblocking
+ // mode. POSIX leaves this behavior undefined. This can be used to open a
+ // FIFO for writing while there are no readers available." - fifo(7)
switch {
case readable && writable:
// Pipes opened for read-write always succeed without blocking.
@@ -89,23 +93,26 @@ func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, vfsd *vfs.Dentry, vfsfd *vf
case readable:
newHandleLocked(&vp.rWakeup)
- // If this pipe is being opened as nonblocking and there's no
+ // If this pipe is being opened as blocking and there's no
// writer, we have to wait for a writer to open the other end.
- if flags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) {
+ if vp.pipe.isNamed && statusFlags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) {
+ fd.DecRef()
return nil, syserror.EINTR
}
case writable:
newHandleLocked(&vp.wWakeup)
- if !vp.pipe.HasReaders() {
- // Nonblocking, write-only opens fail with ENXIO when
- // the read side isn't open yet.
- if flags&linux.O_NONBLOCK != 0 {
+ if vp.pipe.isNamed && !vp.pipe.HasReaders() {
+ // Non-blocking, write-only opens fail with ENXIO when the read
+ // side isn't open yet.
+ if statusFlags&linux.O_NONBLOCK != 0 {
+ fd.DecRef()
return nil, syserror.ENXIO
}
// Wait for a reader to open the other end.
if !waitFor(&vp.mu, &vp.rWakeup, ctx) {
+ fd.DecRef()
return nil, syserror.EINTR
}
}
@@ -114,96 +121,93 @@ func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, vfsd *vfs.Dentry, vfsfd *vf
panic("invalid pipe flags: must be readable, writable, or both")
}
- return vfd, nil
+ return fd, nil
}
// Preconditions: vp.mu must be held.
-func (vp *VFSPipe) open(vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) {
- var fd VFSPipeFD
- fd.flags = flags
- fd.readable = vfs.MayReadFileWithOpenFlags(flags)
- fd.writable = vfs.MayWriteFileWithOpenFlags(flags)
- fd.vfsfd = vfsfd
- fd.pipe = &vp.pipe
+func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) *vfs.FileDescription {
+ fd := &VFSPipeFD{
+ pipe: &vp.pipe,
+ }
+ fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ })
switch {
- case fd.readable && fd.writable:
+ case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable():
vp.pipe.rOpen()
vp.pipe.wOpen()
- case fd.readable:
+ case fd.vfsfd.IsReadable():
vp.pipe.rOpen()
- case fd.writable:
+ case fd.vfsfd.IsWritable():
vp.pipe.wOpen()
default:
panic("invalid pipe flags: must be readable, writable, or both")
}
- return &fd, nil
+ return &fd.vfsfd
}
-// VFSPipeFD implements a subset of vfs.FileDescriptionImpl for pipes. It is
-// expected that filesystesm will use this in a struct implementing
-// vfs.FileDescriptionImpl.
+// VFSPipeFD implements vfs.FileDescriptionImpl for pipes.
type VFSPipeFD struct {
- pipe *Pipe
- flags uint32
- readable bool
- writable bool
- vfsfd *vfs.FileDescription
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+
+ pipe *Pipe
}
// Release implements vfs.FileDescriptionImpl.Release.
func (fd *VFSPipeFD) Release() {
var event waiter.EventMask
- if fd.readable {
+ if fd.vfsfd.IsReadable() {
fd.pipe.rClose()
- event |= waiter.EventIn
+ event |= waiter.EventOut
}
- if fd.writable {
+ if fd.vfsfd.IsWritable() {
fd.pipe.wClose()
- event |= waiter.EventOut
+ event |= waiter.EventIn | waiter.EventHUp
}
if event == 0 {
panic("invalid pipe flags: must be readable, writable, or both")
}
- if fd.writable {
- fd.vfsfd.VirtualDentry().Mount().EndWrite()
- }
-
fd.pipe.Notify(event)
}
-// OnClose implements vfs.FileDescriptionImpl.OnClose.
-func (fd *VFSPipeFD) OnClose(_ context.Context) error {
- return nil
+// Readiness implements waiter.Waitable.Readiness.
+func (fd *VFSPipeFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ switch {
+ case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable():
+ return fd.pipe.rwReadiness()
+ case fd.vfsfd.IsReadable():
+ return fd.pipe.rReadiness()
+ case fd.vfsfd.IsWritable():
+ return fd.pipe.wReadiness()
+ default:
+ panic("pipe FD is neither readable nor writable")
+ }
}
-// PRead implements vfs.FileDescriptionImpl.PRead.
-func (fd *VFSPipeFD) PRead(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.ReadOptions) (int64, error) {
- return 0, syserror.ESPIPE
+// EventRegister implements waiter.Waitable.EventRegister.
+func (fd *VFSPipeFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fd.pipe.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (fd *VFSPipeFD) EventUnregister(e *waiter.Entry) {
+ fd.pipe.EventUnregister(e)
}
// Read implements vfs.FileDescriptionImpl.Read.
func (fd *VFSPipeFD) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
- if !fd.readable {
- return 0, syserror.EINVAL
- }
-
return fd.pipe.Read(ctx, dst)
}
-// PWrite implements vfs.FileDescriptionImpl.PWrite.
-func (fd *VFSPipeFD) PWrite(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.WriteOptions) (int64, error) {
- return 0, syserror.ESPIPE
-}
-
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
- if !fd.writable {
- return 0, syserror.EINVAL
- }
-
return fd.pipe.Write(ctx, src)
}
@@ -211,3 +215,17 @@ func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.Wr
func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
return fd.pipe.Ioctl(ctx, uio, args)
}
+
+// PipeSize implements fcntl(F_GETPIPE_SZ).
+func (fd *VFSPipeFD) PipeSize() int64 {
+ // Inline Pipe.FifoSize() rather than calling it with nil Context and
+ // fs.File and ignoring the returned error (which is always nil).
+ fd.pipe.mu.Lock()
+ defer fd.pipe.mu.Unlock()
+ return fd.pipe.max
+}
+
+// SetPipeSize implements fcntl(F_SETPIPE_SZ).
+func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) {
+ return fd.pipe.SetFifoSize(size)
+}
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 35ad97d5d..e23e796ef 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -184,7 +184,6 @@ func (t *Task) CanTrace(target *Task, attach bool) bool {
if targetCreds.PermittedCaps&^callerCreds.PermittedCaps != 0 {
return false
}
- // TODO: Yama LSM
return true
}
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
index ded95f532..18416643b 100644
--- a/pkg/sentry/kernel/rseq.go
+++ b/pkg/sentry/kernel/rseq.go
@@ -304,7 +304,7 @@ func (t *Task) rseqAddrInterrupt() {
}
var cs linux.RSeqCriticalSection
- if err := cs.CopyIn(t, critAddr); err != nil {
+ if _, err := cs.CopyIn(t, critAddr); err != nil {
t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err)
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index 1000f3287..c00fa1138 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -554,6 +554,7 @@ func (s *sem) wakeWaiters() {
for w := s.waiters.Front(); w != nil; {
if s.value < w.value {
// Still blocked, skip it.
+ w = w.Next()
continue
}
w.ch <- struct{}{}
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index 047b5214d..0e19286de 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -246,7 +246,7 @@ func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error {
var lastErr error
for tg := range tasks.Root.tgids {
- if tg.ProcessGroup() == pg {
+ if tg.processGroup == pg {
tg.signalHandlers.mu.Lock()
infoCopy := *info
if err := tg.leader.sendSignalLocked(&infoCopy, true /*group*/); err != nil {
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index 208569057..f66cfcc7f 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -461,7 +461,7 @@ func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.A
func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.AddrRange, _ uint64, _ bool) {
s.mu.Lock()
defer s.mu.Unlock()
- // TODO(b/38173783): RemoveMapping may be called during task exit, when ctx
+ // RemoveMapping may be called during task exit, when ctx
// is context.Background. Gracefully handle missing clocks. Failing to
// update the detach time in these cases is ok, since no one can observe the
// omission.
diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go
index 93c4fe969..84156d5a1 100644
--- a/pkg/sentry/kernel/syscalls.go
+++ b/pkg/sentry/kernel/syscalls.go
@@ -209,65 +209,61 @@ type Stracer interface {
// SyscallEnter is called on syscall entry.
//
// The returned private data is passed to SyscallExit.
- //
- // TODO(gvisor.dev/issue/155): remove kernel imports from the strace
- // package so that the type can be used directly.
SyscallEnter(t *Task, sysno uintptr, args arch.SyscallArguments, flags uint32) interface{}
// SyscallExit is called on syscall exit.
SyscallExit(context interface{}, t *Task, sysno, rval uintptr, err error)
}
-// SyscallTable is a lookup table of system calls. Critically, a SyscallTable
-// is *immutable*. In order to make supporting suspend and resume sane, they
-// must be uniquely registered and may not change during operation.
+// SyscallTable is a lookup table of system calls.
//
-// +stateify savable
+// Note that a SyscallTable is not savable directly. Instead, they are saved as
+// an OS/Arch pair and lookup happens again on restore.
type SyscallTable struct {
// OS is the operating system that this syscall table implements.
- OS abi.OS `state:"wait"`
+ OS abi.OS
// Arch is the architecture that this syscall table targets.
- Arch arch.Arch `state:"wait"`
+ Arch arch.Arch
// The OS version that this syscall table implements.
- Version Version `state:"manual"`
+ Version Version
// AuditNumber is a numeric constant that represents the syscall table. If
// non-zero, auditNumber must be one of the AUDIT_ARCH_* values defined by
// linux/audit.h.
- AuditNumber uint32 `state:"manual"`
+ AuditNumber uint32
// Table is the collection of functions.
- Table map[uintptr]Syscall `state:"manual"`
+ Table map[uintptr]Syscall
// lookup is a fixed-size array that holds the syscalls (indexed by
// their numbers). It is used for fast look ups.
- lookup []SyscallFn `state:"manual"`
+ lookup []SyscallFn
// Emulate is a collection of instruction addresses to emulate. The
// keys are addresses, and the values are system call numbers.
- Emulate map[usermem.Addr]uintptr `state:"manual"`
+ Emulate map[usermem.Addr]uintptr
// The function to call in case of a missing system call.
- Missing MissingFn `state:"manual"`
+ Missing MissingFn
// Stracer traces this syscall table.
- Stracer Stracer `state:"manual"`
+ Stracer Stracer
// External is used to handle an external callback.
- External func(*Kernel) `state:"manual"`
+ External func(*Kernel)
// ExternalFilterBefore is called before External is called before the syscall is executed.
// External is not called if it returns false.
- ExternalFilterBefore func(*Task, uintptr, arch.SyscallArguments) bool `state:"manual"`
+ ExternalFilterBefore func(*Task, uintptr, arch.SyscallArguments) bool
// ExternalFilterAfter is called before External is called after the syscall is executed.
// External is not called if it returns false.
- ExternalFilterAfter func(*Task, uintptr, arch.SyscallArguments) bool `state:"manual"`
+ ExternalFilterAfter func(*Task, uintptr, arch.SyscallArguments) bool
// FeatureEnable stores the strace and one-shot enable bits.
- FeatureEnable SyscallFlagsTable `state:"manual"`
+ FeatureEnable SyscallFlagsTable
}
// allSyscallTables contains all known tables.
@@ -330,6 +326,13 @@ func RegisterSyscallTable(s *SyscallTable) {
allSyscallTables = append(allSyscallTables, s)
}
+// FlushSyscallTablesTestOnly flushes the syscall tables for tests. Used for
+// parameterized VFSv2 tests.
+// TODO(gvisor.dv/issue/1624): Remove when VFS1 is no longer supported.
+func FlushSyscallTablesTestOnly() {
+ allSyscallTables = nil
+}
+
// Lookup returns the syscall implementation, if one exists.
func (s *SyscallTable) Lookup(sysno uintptr) SyscallFn {
if sysno < uintptr(len(s.lookup)) {
diff --git a/pkg/sentry/kernel/syscalls_state.go b/pkg/sentry/kernel/syscalls_state.go
index 00358326b..90f890495 100644
--- a/pkg/sentry/kernel/syscalls_state.go
+++ b/pkg/sentry/kernel/syscalls_state.go
@@ -14,16 +14,34 @@
package kernel
-import "fmt"
+import (
+ "fmt"
-// afterLoad is invoked by stateify.
-func (s *SyscallTable) afterLoad() {
- otherTable, ok := LookupSyscallTable(s.OS, s.Arch)
- if !ok {
- // Couldn't find a reference?
- panic(fmt.Sprintf("syscall table not found for OS %v Arch %v", s.OS, s.Arch))
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// syscallTableInfo is used to reload the SyscallTable.
+//
+// +stateify savable
+type syscallTableInfo struct {
+ OS abi.OS
+ Arch arch.Arch
+}
+
+// saveSt saves the SyscallTable.
+func (tc *TaskContext) saveSt() syscallTableInfo {
+ return syscallTableInfo{
+ OS: tc.st.OS,
+ Arch: tc.st.Arch,
}
+}
- // Copy the table.
- *s = *otherTable
+// loadSt loads the SyscallTable.
+func (tc *TaskContext) loadSt(sti syscallTableInfo) {
+ st, ok := LookupSyscallTable(sti.OS, sti.Arch)
+ if !ok {
+ panic(fmt.Sprintf("syscall table not found for OS %v, Arch %v", sti.OS, sti.Arch))
+ }
+ tc.st = st // Save the table reference.
}
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index 2cee2e6ed..e5d133d6c 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -37,6 +37,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -776,6 +777,15 @@ func (t *Task) NewFDs(fd int32, files []*fs.File, flags FDFlags) ([]int32, error
return t.fdTable.NewFDs(t, fd, files, flags)
}
+// NewFDsVFS2 is a convenience wrapper for t.FDTable().NewFDsVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.
+func (t *Task) NewFDsVFS2(fd int32, files []*vfs.FileDescription, flags FDFlags) ([]int32, error) {
+ return t.fdTable.NewFDsVFS2(t, fd, files, flags)
+}
+
// NewFDFrom is a convenience wrapper for t.FDTable().NewFDs with a single file.
//
// This automatically passes the task as the context.
@@ -847,3 +857,30 @@ func (t *Task) AbstractSockets() *AbstractSocketNamespace {
func (t *Task) ContainerID() string {
return t.containerID
}
+
+// OOMScoreAdj gets the task's thread group's OOM score adjustment.
+func (t *Task) OOMScoreAdj() int32 {
+ return atomic.LoadInt32(&t.tg.oomScoreAdj)
+}
+
+// SetOOMScoreAdj sets the task's thread group's OOM score adjustment. The
+// value should be between -1000 and 1000 inclusive.
+func (t *Task) SetOOMScoreAdj(adj int32) error {
+ if adj > 1000 || adj < -1000 {
+ return syserror.EINVAL
+ }
+ atomic.StoreInt32(&t.tg.oomScoreAdj, adj)
+ return nil
+}
+
+// UID returns t's uid.
+// TODO(gvisor.dev/issue/170): This method is not namespaced yet.
+func (t *Task) UID() uint32 {
+ return uint32(t.Credentials().EffectiveKUID)
+}
+
+// GID returns t's gid.
+// TODO(gvisor.dev/issue/170): This method is not namespaced yet.
+func (t *Task) GID() uint32 {
+ return uint32(t.Credentials().EffectiveKGID)
+}
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 78866f280..e1ecca99e 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -15,6 +15,8 @@
package kernel
import (
+ "sync/atomic"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
"gvisor.dev/gvisor/pkg/sentry/inet"
@@ -260,6 +262,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
sh = sh.Fork()
}
tg = t.k.NewThreadGroup(tg.mounts, pidns, sh, opts.TerminationSignal, tg.limits.GetCopy())
+ tg.oomScoreAdj = atomic.LoadInt32(&t.tg.oomScoreAdj)
rseqAddr = t.rseqAddr
rseqSignature = t.rseqSignature
}
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index 0158b1788..9fa528384 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -49,7 +49,7 @@ type TaskContext struct {
fu *futex.Manager
// st is the task's syscall table.
- st *SyscallTable
+ st *SyscallTable `state:".(syscallTableInfo)"`
}
// release releases all resources held by the TaskContext. release is called by
@@ -58,7 +58,6 @@ func (tc *TaskContext) release() {
// Nil out pointers so that if the task is saved after release, it doesn't
// follow the pointers to possibly now-invalid objects.
if tc.MemoryManager != nil {
- // TODO(b/38173783)
tc.MemoryManager.DecUsers(context.Background())
tc.MemoryManager = nil
}
diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go
index ce3e6ef28..0325967e4 100644
--- a/pkg/sentry/kernel/task_identity.go
+++ b/pkg/sentry/kernel/task_identity.go
@@ -455,7 +455,7 @@ func (t *Task) SetKeepCaps(k bool) {
t.creds.Store(creds)
}
-// updateCredsForExec updates t.creds to reflect an execve().
+// updateCredsForExecLocked updates t.creds to reflect an execve().
//
// NOTE(b/30815691): We currently do not implement privileged executables
// (set-user/group-ID bits and file capabilities). This allows us to make a lot
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index 5568c91bc..2ba8d7e63 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -126,13 +126,39 @@ func (t *Task) doStop() {
}
}
+func (*runApp) handleCPUIDInstruction(t *Task) error {
+ if len(arch.CPUIDInstruction) == 0 {
+ // CPUID emulation isn't supported, but this code can be
+ // executed, because the ptrace platform returns
+ // ErrContextSignalCPUID on page faults too. Look at
+ // pkg/sentry/platform/ptrace/ptrace.go:context.Switch for more
+ // details.
+ return platform.ErrContextSignal
+ }
+ // Is this a CPUID instruction?
+ region := trace.StartRegion(t.traceContext, cpuidRegion)
+ expected := arch.CPUIDInstruction[:]
+ found := make([]byte, len(expected))
+ _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found)
+ if err == nil && bytes.Equal(expected, found) {
+ // Skip the cpuid instruction.
+ t.Arch().CPUIDEmulate(t)
+ t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected)))
+ region.End()
+
+ return nil
+ }
+ region.End() // Not an actual CPUID, but required copy-in.
+ return platform.ErrContextSignal
+}
+
// The runApp state checks for interrupts before executing untrusted
// application code.
//
// +stateify savable
type runApp struct{}
-func (*runApp) execute(t *Task) taskRunState {
+func (app *runApp) execute(t *Task) taskRunState {
if t.interrupted() {
// Checkpointing instructs tasks to stop by sending an interrupt, so we
// must check for stops before entering runInterrupt (instead of
@@ -237,21 +263,10 @@ func (*runApp) execute(t *Task) taskRunState {
return (*runApp)(nil)
case platform.ErrContextSignalCPUID:
- // Is this a CPUID instruction?
- region := trace.StartRegion(t.traceContext, cpuidRegion)
- expected := arch.CPUIDInstruction[:]
- found := make([]byte, len(expected))
- _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found)
- if err == nil && bytes.Equal(expected, found) {
- // Skip the cpuid instruction.
- t.Arch().CPUIDEmulate(t)
- t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected)))
- region.End()
-
+ if err := app.handleCPUIDInstruction(t); err == nil {
// Resume execution.
return (*runApp)(nil)
}
- region.End() // Not an actual CPUID, but required copy-in.
// The instruction at the given RIP was not a CPUID, and we
// fallthrough to the default signal deliver behavior below.
@@ -338,7 +353,7 @@ func (*runApp) execute(t *Task) taskRunState {
default:
// What happened? Can't continue.
t.Warningf("Unexpected SwitchToApp error: %v", err)
- t.PrepareExit(ExitStatus{Code: t.ExtractErrno(err, -1)})
+ t.PrepareExit(ExitStatus{Code: ExtractErrno(err, -1)})
return (*runExit)(nil)
}
}
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index 8802db142..f07de2089 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -174,7 +174,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS
fallthrough
case (sre == ERESTARTSYS && !act.IsRestart()):
t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo)
- t.Arch().SetReturn(uintptr(-t.ExtractErrno(syserror.EINTR, -1)))
+ t.Arch().SetReturn(uintptr(-ExtractErrno(syserror.EINTR, -1)))
default:
t.Debugf("Restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo)
t.Arch().RestartSyscall()
@@ -513,8 +513,6 @@ func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool {
if t.stop != nil {
return false
}
- // - TODO(b/38173783): No special case for when t is also the sending task,
- // because the identity of the sender is unknown.
// - Do not choose tasks that have already been interrupted, as they may be
// busy handling another signal.
if len(t.interruptChan) != 0 {
diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go
index d555d69a8..c9db78e06 100644
--- a/pkg/sentry/kernel/task_syscall.go
+++ b/pkg/sentry/kernel/task_syscall.go
@@ -194,6 +194,19 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
//
// The syscall path is very hot; avoid defer.
func (t *Task) doSyscall() taskRunState {
+ // Save value of the register which is clobbered in the following
+ // t.Arch().SetReturn(-ENOSYS) operation. This is dedicated to arm64.
+ //
+ // On x86, register rax was shared by syscall number and return
+ // value, and at the entry of the syscall handler, the rax was
+ // saved to regs.orig_rax which was exposed to user space.
+ // But on arm64, syscall number was passed through X8, and the X0
+ // was shared by the first syscall argument and return value. The
+ // X0 was saved to regs.orig_x0 which was not exposed to user space.
+ // So we have to do the same operation here to save the X0 value
+ // into the task context.
+ t.Arch().SyscallSaveOrig()
+
sysno := t.Arch().SyscallNo()
args := t.Arch().SyscallArgs()
@@ -269,6 +282,7 @@ func (*runSyscallAfterSyscallEnterStop) execute(t *Task) taskRunState {
return (*runSyscallExit)(nil)
}
args := t.Arch().SyscallArgs()
+
return t.doSyscallInvoke(sysno, args)
}
@@ -298,7 +312,7 @@ func (t *Task) doSyscallInvoke(sysno uintptr, args arch.SyscallArguments) taskRu
return ctrl.next
}
} else if err != nil {
- t.Arch().SetReturn(uintptr(-t.ExtractErrno(err, int(sysno))))
+ t.Arch().SetReturn(uintptr(-ExtractErrno(err, int(sysno))))
t.haveSyscallReturn = true
} else {
t.Arch().SetReturn(rval)
@@ -417,7 +431,7 @@ func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, calle
// A return is not emulated in this case.
return (*runApp)(nil)
}
- t.Arch().SetReturn(uintptr(-t.ExtractErrno(err, int(sysno))))
+ t.Arch().SetReturn(uintptr(-ExtractErrno(err, int(sysno))))
}
t.Arch().SetIP(t.Arch().Value(caller))
t.Arch().SetStack(t.Arch().Stack() + uintptr(t.Arch().Width()))
@@ -427,7 +441,7 @@ func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, calle
// ExtractErrno extracts an integer error number from the error.
// The syscall number is purely for context in the error case. Use -1 if
// syscall number is unknown.
-func (t *Task) ExtractErrno(err error, sysno int) int {
+func ExtractErrno(err error, sysno int) int {
switch err := err.(type) {
case nil:
return 0
@@ -441,11 +455,11 @@ func (t *Task) ExtractErrno(err error, sysno int) int {
// handled (and the SIGBUS is delivered).
return int(syscall.EFAULT)
case *os.PathError:
- return t.ExtractErrno(err.Err, sysno)
+ return ExtractErrno(err.Err, sysno)
case *os.LinkError:
- return t.ExtractErrno(err.Err, sysno)
+ return ExtractErrno(err.Err, sysno)
case *os.SyscallError:
- return t.ExtractErrno(err.Err, sysno)
+ return ExtractErrno(err.Err, sysno)
default:
if errno, ok := syserror.TranslateError(err); ok {
return int(errno)
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 268f62e9d..52849f5b3 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -254,6 +254,13 @@ type ThreadGroup struct {
//
// tty is protected by the signal mutex.
tty *TTY
+
+ // oomScoreAdj is the thread group's OOM score adjustment. This is
+ // currently not used but is maintained for consistency.
+ // TODO(gvisor.dev/issue/1967)
+ //
+ // oomScoreAdj is accessed using atomic memory operations.
+ oomScoreAdj int32
}
// NewThreadGroup returns a new, empty thread group in PID namespace pidns. The
diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go
index 706de83ef..e959700f2 100644
--- a/pkg/sentry/kernel/time/time.go
+++ b/pkg/sentry/kernel/time/time.go
@@ -245,7 +245,7 @@ type Clock interface {
type WallRateClock struct{}
// WallTimeUntil implements Clock.WallTimeUntil.
-func (WallRateClock) WallTimeUntil(t, now Time) time.Duration {
+func (*WallRateClock) WallTimeUntil(t, now Time) time.Duration {
return t.Sub(now)
}
@@ -254,16 +254,16 @@ func (WallRateClock) WallTimeUntil(t, now Time) time.Duration {
type NoClockEvents struct{}
// Readiness implements waiter.Waitable.Readiness.
-func (NoClockEvents) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (*NoClockEvents) Readiness(mask waiter.EventMask) waiter.EventMask {
return 0
}
// EventRegister implements waiter.Waitable.EventRegister.
-func (NoClockEvents) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+func (*NoClockEvents) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
}
// EventUnregister implements waiter.Waitable.EventUnregister.
-func (NoClockEvents) EventUnregister(e *waiter.Entry) {
+func (*NoClockEvents) EventUnregister(e *waiter.Entry) {
}
// ClockEventsQueue implements waiter.Waitable by wrapping waiter.Queue and
@@ -273,7 +273,7 @@ type ClockEventsQueue struct {
}
// Readiness implements waiter.Waitable.Readiness.
-func (ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (*ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask {
return 0
}
diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go
index 0332fc71c..5c667117c 100644
--- a/pkg/sentry/mm/address_space.go
+++ b/pkg/sentry/mm/address_space.go
@@ -201,8 +201,10 @@ func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar usermem.AddrRange, pre
if pma.needCOW {
perms.Write = false
}
- if err := mm.as.MapFile(pmaMapAR.Start, pma.file, pseg.fileRangeOf(pmaMapAR), perms, precommit); err != nil {
- return err
+ if perms.Any() { // MapFile precondition
+ if err := mm.as.MapFile(pmaMapAR.Start, pma.file, pseg.fileRangeOf(pmaMapAR), perms, precommit); err != nil {
+ return err
+ }
}
pseg = pseg.NextSegment()
}
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index cb29d94b0..379148903 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -59,25 +59,27 @@ func (a *aioManager) newAIOContext(events uint32, id uint64) bool {
}
a.contexts[id] = &AIOContext{
- done: make(chan struct{}, 1),
+ requestReady: make(chan struct{}, 1),
maxOutstanding: events,
}
return true
}
-// destroyAIOContext destroys an asynchronous I/O context.
+// destroyAIOContext destroys an asynchronous I/O context. It doesn't wait for
+// for pending requests to complete. Returns the destroyed AIOContext so it can
+// be drained.
//
-// False is returned if the context does not exist.
-func (a *aioManager) destroyAIOContext(id uint64) bool {
+// Nil is returned if the context does not exist.
+func (a *aioManager) destroyAIOContext(id uint64) *AIOContext {
a.mu.Lock()
defer a.mu.Unlock()
ctx, ok := a.contexts[id]
if !ok {
- return false
+ return nil
}
delete(a.contexts, id)
ctx.destroy()
- return true
+ return ctx
}
// lookupAIOContext looks up the given context.
@@ -102,8 +104,8 @@ type ioResult struct {
//
// +stateify savable
type AIOContext struct {
- // done is the notification channel used for all requests.
- done chan struct{} `state:"nosave"`
+ // requestReady is the notification channel used for all requests.
+ requestReady chan struct{} `state:"nosave"`
// mu protects below.
mu sync.Mutex `state:"nosave"`
@@ -129,8 +131,14 @@ func (ctx *AIOContext) destroy() {
ctx.mu.Lock()
defer ctx.mu.Unlock()
ctx.dead = true
- if ctx.outstanding == 0 {
- close(ctx.done)
+ ctx.checkForDone()
+}
+
+// Preconditions: ctx.mu must be held by caller.
+func (ctx *AIOContext) checkForDone() {
+ if ctx.dead && ctx.outstanding == 0 {
+ close(ctx.requestReady)
+ ctx.requestReady = nil
}
}
@@ -154,11 +162,12 @@ func (ctx *AIOContext) PopRequest() (interface{}, bool) {
// Is there anything ready?
if e := ctx.results.Front(); e != nil {
- ctx.results.Remove(e)
- ctx.outstanding--
- if ctx.outstanding == 0 && ctx.dead {
- close(ctx.done)
+ if ctx.outstanding == 0 {
+ panic("AIOContext outstanding is going negative")
}
+ ctx.outstanding--
+ ctx.results.Remove(e)
+ ctx.checkForDone()
return e.data, true
}
return nil, false
@@ -172,26 +181,58 @@ func (ctx *AIOContext) FinishRequest(data interface{}) {
// Push to the list and notify opportunistically. The channel notify
// here is guaranteed to be safe because outstanding must be non-zero.
- // The done channel is only closed when outstanding reaches zero.
+ // The requestReady channel is only closed when outstanding reaches zero.
ctx.results.PushBack(&ioResult{data: data})
select {
- case ctx.done <- struct{}{}:
+ case ctx.requestReady <- struct{}{}:
default:
}
}
// WaitChannel returns a channel that is notified when an AIO request is
-// completed.
-//
-// The boolean return value indicates whether or not the context is active.
-func (ctx *AIOContext) WaitChannel() (chan struct{}, bool) {
+// completed. Returns nil if the context is destroyed and there are no more
+// outstanding requests.
+func (ctx *AIOContext) WaitChannel() chan struct{} {
ctx.mu.Lock()
defer ctx.mu.Unlock()
- if ctx.outstanding == 0 && ctx.dead {
- return nil, false
+ return ctx.requestReady
+}
+
+// Dead returns true if the context has been destroyed.
+func (ctx *AIOContext) Dead() bool {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+ return ctx.dead
+}
+
+// CancelPendingRequest forgets about a request that hasn't yet completed.
+func (ctx *AIOContext) CancelPendingRequest() {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+
+ if ctx.outstanding == 0 {
+ panic("AIOContext outstanding is going negative")
}
- return ctx.done, true
+ ctx.outstanding--
+ ctx.checkForDone()
+}
+
+// Drain drops all completed requests. Pending requests remain untouched.
+func (ctx *AIOContext) Drain() {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+
+ if ctx.outstanding == 0 {
+ return
+ }
+ size := uint32(ctx.results.Len())
+ if ctx.outstanding < size {
+ panic("AIOContext outstanding is going negative")
+ }
+ ctx.outstanding -= size
+ ctx.results.Reset()
+ ctx.checkForDone()
}
// aioMappable implements memmap.MappingIdentity and memmap.Mappable for AIO
@@ -332,9 +373,9 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint
Length: aioRingBufferSize,
MappingIdentity: m,
Mappable: m,
- // TODO(fvoznika): Linux does "do_mmap_pgoff(..., PROT_READ |
- // PROT_WRITE, ...)" in fs/aio.c:aio_setup_ring(); why do we make this
- // mapping read-only?
+ // Linux uses "do_mmap_pgoff(..., PROT_READ | PROT_WRITE, ...)" in
+ // fs/aio.c:aio_setup_ring(). Since we don't implement AIO_RING_MAGIC,
+ // user mode should not write to this page.
Perms: usermem.Read,
MaxPerms: usermem.Read,
})
@@ -349,11 +390,11 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint
return id, nil
}
-// DestroyAIOContext destroys an asynchronous I/O context. It returns false if
-// the context does not exist.
-func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) bool {
+// DestroyAIOContext destroys an asynchronous I/O context. It returns the
+// destroyed context. nil if the context does not exist.
+func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext {
if _, ok := mm.LookupAIOContext(ctx, id); !ok {
- return false
+ return nil
}
// Only unmaps after it assured that the address is a valid aio context to
diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go
index c37fc9f7b..3dabac1af 100644
--- a/pkg/sentry/mm/aio_context_state.go
+++ b/pkg/sentry/mm/aio_context_state.go
@@ -16,5 +16,5 @@ package mm
// afterLoad is invoked by stateify.
func (a *AIOContext) afterLoad() {
- a.done = make(chan struct{}, 1)
+ a.requestReady = make(chan struct{}, 1)
}
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index e6269107c..aac56679b 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -84,7 +84,7 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
dumpability: mm.dumpability,
aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
sleepForActivation: mm.sleepForActivation,
- vdsoSigReturnAddr : mm.vdsoSigReturnAddr,
+ vdsoSigReturnAddr: mm.vdsoSigReturnAddr,
}
// Copy vmas.
diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go
index 35cd55fef..4b23f7803 100644
--- a/pkg/sentry/platform/kvm/bluepill.go
+++ b/pkg/sentry/platform/kvm/bluepill.go
@@ -81,12 +81,6 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) {
// Save the death message, which will be thrown.
c.dieState.message = msg
- // Reload all registers to have an accurate stack trace when we return
- // to host mode. This means that the stack should be unwound correctly.
- if errno := c.getUserRegisters(&c.dieState.guestRegs); errno != 0 {
- throw(msg)
- }
-
// Setup the trampoline.
dieArchSetup(c, context, &c.dieState.guestRegs)
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index a63a6a071..99cac665d 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -31,6 +31,12 @@ import (
//
//go:nosplit
func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
+ // Reload all registers to have an accurate stack trace when we return
+ // to host mode. This means that the stack should be unwound correctly.
+ if errno := c.getUserRegisters(&c.dieState.guestRegs); errno != 0 {
+ throw(c.dieState.message)
+ }
+
// If the vCPU is in user mode, we set the stack to the stored stack
// value in the vCPU itself. We don't want to unwind the user stack.
if guestRegs.RFLAGS&ring0.UserFlagsSet == ring0.UserFlagsSet {
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s
index c61700892..04efa0147 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.s
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.s
@@ -82,6 +82,8 @@ fallback:
// dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation.
TEXT ·dieTrampoline(SB),NOSPLIT,$0
- // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64.
- MOVD R9, 8(RSP)
- BL ·dieHandler(SB)
+ // R0: Fake the old PC as caller
+ // R1: First argument (vCPU)
+ MOVD.P R1, 8(RSP) // R1: First argument (vCPU)
+ MOVD.P R0, 8(RSP) // R0: Fake the old PC as caller
+ B ·dieHandler(SB)
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
index 2f02c03cf..eb5ed574e 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -17,10 +17,33 @@
package kvm
import (
+ "unsafe"
+
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
)
+// dieArchSetup initialies the state for dieTrampoline.
+//
+// The arm64 dieTrampoline requires the vCPU to be set in R1, and the last PC
+// to be in R0. The trampoline then simulates a call to dieHandler from the
+// provided PC.
+//
//go:nosplit
func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
- // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64.
+ // If the vCPU is in user mode, we set the stack to the stored stack
+ // value in the vCPU itself. We don't want to unwind the user stack.
+ if guestRegs.Regs.Pstate&ring0.PSR_MODE_MASK == ring0.PSR_MODE_EL0t {
+ regs := c.CPU.Registers()
+ context.Regs[0] = regs.Regs[0]
+ context.Sp = regs.Sp
+ context.Regs[29] = regs.Regs[29] // stack base address
+ } else {
+ context.Regs[0] = guestRegs.Regs.Pc
+ context.Sp = guestRegs.Regs.Sp
+ context.Regs[29] = guestRegs.Regs.Regs[29]
+ context.Pstate = guestRegs.Regs.Pstate
+ }
+ context.Regs[1] = uint64(uintptr(unsafe.Pointer(c)))
+ context.Pc = uint64(dieTrampolineAddr)
}
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index 972ba85c3..a9b4af43e 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -27,6 +27,38 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// userMemoryRegion is a region of physical memory.
+//
+// This mirrors kvm_memory_region.
+type userMemoryRegion struct {
+ slot uint32
+ flags uint32
+ guestPhysAddr uint64
+ memorySize uint64
+ userspaceAddr uint64
+}
+
+// runData is the run structure. This may be mapped for synchronous register
+// access (although that doesn't appear to be supported by my kernel at least).
+//
+// This mirrors kvm_run.
+type runData struct {
+ requestInterruptWindow uint8
+ _ [7]uint8
+
+ exitReason uint32
+ readyForInterruptInjection uint8
+ ifFlag uint8
+ _ [2]uint8
+
+ cr8 uint64
+ apicBase uint64
+
+ // This is the union data for exits. Interpretation depends entirely on
+ // the exitReason above (see vCPU code for more information).
+ data [32]uint64
+}
+
// KVM represents a lightweight VM context.
type KVM struct {
platform.NoCPUPreemptionDetection
diff --git a/pkg/sentry/platform/kvm/kvm_amd64.go b/pkg/sentry/platform/kvm/kvm_amd64.go
index c5a6f9c7d..093497bc4 100644
--- a/pkg/sentry/platform/kvm/kvm_amd64.go
+++ b/pkg/sentry/platform/kvm/kvm_amd64.go
@@ -21,17 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
)
-// userMemoryRegion is a region of physical memory.
-//
-// This mirrors kvm_memory_region.
-type userMemoryRegion struct {
- slot uint32
- flags uint32
- guestPhysAddr uint64
- memorySize uint64
- userspaceAddr uint64
-}
-
// userRegs represents KVM user registers.
//
// This mirrors kvm_regs.
@@ -169,27 +158,6 @@ type modelControlRegisters struct {
entries [16]modelControlRegister
}
-// runData is the run structure. This may be mapped for synchronous register
-// access (although that doesn't appear to be supported by my kernel at least).
-//
-// This mirrors kvm_run.
-type runData struct {
- requestInterruptWindow uint8
- _ [7]uint8
-
- exitReason uint32
- readyForInterruptInjection uint8
- ifFlag uint8
- _ [2]uint8
-
- cr8 uint64
- apicBase uint64
-
- // This is the union data for exits. Interpretation depends entirely on
- // the exitReason above (see vCPU code for more information).
- data [32]uint64
-}
-
// cpuidEntry is a single CPUID entry.
//
// This mirrors kvm_cpuid_entry2.
diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go
index 2319c86d3..716198712 100644
--- a/pkg/sentry/platform/kvm/kvm_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_arm64.go
@@ -18,18 +18,9 @@ package kvm
import (
"syscall"
-)
-// userMemoryRegion is a region of physical memory.
-//
-// This mirrors kvm_memory_region.
-type userMemoryRegion struct {
- slot uint32
- flags uint32
- guestPhysAddr uint64
- memorySize uint64
- userspaceAddr uint64
-}
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+)
type kvmOneReg struct {
id uint64
@@ -53,31 +44,10 @@ type userRegs struct {
fpRegs userFpsimdState
}
-// runData is the run structure. This may be mapped for synchronous register
-// access (although that doesn't appear to be supported by my kernel at least).
-//
-// This mirrors kvm_run.
-type runData struct {
- requestInterruptWindow uint8
- _ [7]uint8
-
- exitReason uint32
- readyForInterruptInjection uint8
- ifFlag uint8
- _ [2]uint8
-
- cr8 uint64
- apicBase uint64
-
- // This is the union data for exits. Interpretation depends entirely on
- // the exitReason above (see vCPU code for more information).
- data [32]uint64
-}
-
// updateGlobalOnce does global initialization. It has to be called only once.
func updateGlobalOnce(fd int) error {
physicalInit()
err := updateSystemValues(int(fd))
- updateVectorTable()
+ ring0.Init()
return err
}
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 1c8384e6b..3b35858ae 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -29,30 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// setMemoryRegion initializes a region.
-//
-// This may be called from bluepillHandler, and therefore returns an errno
-// directly (instead of wrapping in an error) to avoid allocations.
-//
-//go:nosplit
-func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno {
- userRegion := userMemoryRegion{
- slot: uint32(slot),
- flags: 0,
- guestPhysAddr: uint64(physical),
- memorySize: uint64(length),
- userspaceAddr: uint64(virtual),
- }
-
- // Set the region.
- _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(m.fd),
- _KVM_SET_USER_MEMORY_REGION,
- uintptr(unsafe.Pointer(&userRegion)))
- return errno
-}
-
type kvmVcpuInit struct {
target uint32
features [7]uint32
@@ -72,69 +48,6 @@ func (m *machine) initArchState() error {
return nil
}
-func getPageWithReflect(p uintptr) []byte {
- return (*(*[0xFFFFFF]byte)(unsafe.Pointer(p & ^uintptr(syscall.Getpagesize()-1))))[:syscall.Getpagesize()]
-}
-
-// Work around: move ring0.Vectors() into a specific address with 11-bits alignment.
-//
-// According to the design documentation of Arm64,
-// the start address of exception vector table should be 11-bits aligned.
-// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S
-// But, we can't align a function's start address to a specific address by using golang.
-// We have raised this question in golang community:
-// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I
-// This function will be removed when golang supports this feature.
-//
-// There are 2 jobs were implemented in this function:
-// 1, move the start address of exception vector table into the specific address.
-// 2, modify the offset of each instruction.
-func updateVectorTable() {
- fromLocation := reflect.ValueOf(ring0.Vectors).Pointer()
- offset := fromLocation & (1<<11 - 1)
- if offset != 0 {
- offset = 1<<11 - offset
- }
-
- toLocation := fromLocation + offset
- page := getPageWithReflect(toLocation)
- if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC); err != nil {
- panic(err)
- }
-
- page = getPageWithReflect(toLocation + 4096)
- if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC); err != nil {
- panic(err)
- }
-
- // Move exception-vector-table into the specific address.
- var entry *uint32
- var entryFrom *uint32
- for i := 1; i <= 0x800; i++ {
- entry = (*uint32)(unsafe.Pointer(toLocation + 0x800 - uintptr(i)))
- entryFrom = (*uint32)(unsafe.Pointer(fromLocation + 0x800 - uintptr(i)))
- *entry = *entryFrom
- }
-
- // The offset from the address of each unconditionally branch is changed.
- // We should modify the offset of each instruction.
- nums := []uint32{0x0, 0x80, 0x100, 0x180, 0x200, 0x280, 0x300, 0x380, 0x400, 0x480, 0x500, 0x580, 0x600, 0x680, 0x700, 0x780}
- for _, num := range nums {
- entry = (*uint32)(unsafe.Pointer(toLocation + uintptr(num)))
- *entry = *entry - (uint32)(offset/4)
- }
-
- page = getPageWithReflect(toLocation)
- if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC); err != nil {
- panic(err)
- }
-
- page = getPageWithReflect(toLocation + 4096)
- if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC); err != nil {
- panic(err)
- }
-}
-
// initArchState initializes architecture-specific state.
func (c *vCPU) initArchState() error {
var (
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index 95abd321e..30402c2df 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -9,6 +9,7 @@ go_library(
"ptrace.go",
"ptrace_amd64.go",
"ptrace_arm64.go",
+ "ptrace_arm64_unsafe.go",
"ptrace_unsafe.go",
"stub_amd64.s",
"stub_arm64.s",
diff --git a/pkg/sentry/platform/ptrace/ptrace_amd64.go b/pkg/sentry/platform/ptrace/ptrace_amd64.go
index db0212538..24fc5dc62 100644
--- a/pkg/sentry/platform/ptrace/ptrace_amd64.go
+++ b/pkg/sentry/platform/ptrace/ptrace_amd64.go
@@ -31,3 +31,17 @@ func fpRegSet(useXsave bool) uintptr {
func stackPointer(r *syscall.PtraceRegs) uintptr {
return uintptr(r.Rsp)
}
+
+// x86 use the fs_base register to store the TLS pointer which can be
+// get/set in "func (t *thread) get/setRegs(regs *syscall.PtraceRegs)".
+// So both of the get/setTLS() operations are noop here.
+
+// getTLS gets the thread local storage register.
+func (t *thread) getTLS(tls *uint64) error {
+ return nil
+}
+
+// setTLS sets the thread local storage register.
+func (t *thread) setTLS(tls *uint64) error {
+ return nil
+}
diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go
new file mode 100644
index 000000000..32b8a6be9
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go
@@ -0,0 +1,62 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ptrace
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// getTLS gets the thread local storage register.
+func (t *thread) getTLS(tls *uint64) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(unsafe.Pointer(tls)),
+ Len: uint64(unsafe.Sizeof(*tls)),
+ }
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETREGSET,
+ uintptr(t.tid),
+ linux.NT_ARM_TLS,
+ uintptr(unsafe.Pointer(&iovec)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// setTLS sets the thread local storage register.
+func (t *thread) setTLS(tls *uint64) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(unsafe.Pointer(tls)),
+ Len: uint64(unsafe.Sizeof(*tls)),
+ }
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_SETREGSET,
+ uintptr(t.tid),
+ linux.NT_ARM_TLS,
+ uintptr(unsafe.Pointer(&iovec)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 31b7cec53..a644609ef 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -506,6 +506,9 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
regs := &ac.StateData().Regs
t.resetSysemuRegs(regs)
+ // Extract TLS register
+ tls := uint64(ac.TLS())
+
// Check for interrupts, and ensure that future interrupts will signal t.
if !c.interrupt.Enable(t) {
// Pending interrupt; simulate.
@@ -526,6 +529,9 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
if err := t.setFPRegs(fpState, uint64(fpLen), useXsave); err != nil {
panic(fmt.Sprintf("ptrace set fpregs (%+v) failed: %v", fpState, err))
}
+ if err := t.setTLS(&tls); err != nil {
+ panic(fmt.Sprintf("ptrace set tls (%+v) failed: %v", tls, err))
+ }
for {
// Start running until the next system call.
@@ -555,6 +561,12 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
if err := t.getFPRegs(fpState, uint64(fpLen), useXsave); err != nil {
panic(fmt.Sprintf("ptrace get fpregs failed: %v", err))
}
+ if err := t.getTLS(&tls); err != nil {
+ panic(fmt.Sprintf("ptrace get tls failed: %v", err))
+ }
+ if !ac.SetTLS(uintptr(tls)) {
+ panic(fmt.Sprintf("tls value %v is invalid", tls))
+ }
// Is it a system call?
if sig == (syscallEvent | syscall.SIGTRAP) {
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index e99798c56..cd74945e7 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -21,6 +21,7 @@ import (
"strings"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -183,13 +184,76 @@ func enableCpuidFault() {
// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program.
// Ref attachedThread() for more detail.
-func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet {
- return append(rules, seccomp.RuleSet{
- Rules: seccomp.SyscallRules{
- syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
- {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)},
+func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFAction) []seccomp.RuleSet {
+ rules = append(rules,
+ // Rules for trapping vsyscall access.
+ seccomp.RuleSet{
+ Rules: seccomp.SyscallRules{
+ syscall.SYS_GETTIMEOFDAY: {},
+ syscall.SYS_TIME: {},
+ unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64.
},
- },
- Action: linux.SECCOMP_RET_ALLOW,
- })
+ Action: linux.SECCOMP_RET_TRAP,
+ Vsyscall: true,
+ })
+ if defaultAction != linux.SECCOMP_RET_ALLOW {
+ rules = append(rules,
+ seccomp.RuleSet{
+ Rules: seccomp.SyscallRules{
+ syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
+ {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)},
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ })
+ }
+ return rules
+}
+
+// probeSeccomp returns true iff seccomp is run after ptrace notifications,
+// which is generally the case for kernel version >= 4.8. This check is dynamic
+// because kernels have be backported behavior.
+//
+// See createStub for more information.
+//
+// Precondition: the runtime OS thread must be locked.
+func probeSeccomp() bool {
+ // Create a completely new, destroyable process.
+ t, err := attachedThread(0, linux.SECCOMP_RET_ERRNO)
+ if err != nil {
+ panic(fmt.Sprintf("seccomp probe failed: %v", err))
+ }
+ defer t.destroy()
+
+ // Set registers to the yield system call. This call is not allowed
+ // by the filters specified in the attachThread function.
+ regs := createSyscallRegs(&t.initRegs, syscall.SYS_SCHED_YIELD)
+ if err := t.setRegs(&regs); err != nil {
+ panic(fmt.Sprintf("ptrace set regs failed: %v", err))
+ }
+
+ for {
+ // Attempt an emulation.
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
+ }
+
+ sig := t.wait(stopped)
+ if sig == (syscallEvent | syscall.SIGTRAP) {
+ // Did the seccomp errno hook already run? This would
+ // indicate that seccomp is first in line and we're
+ // less than 4.8.
+ if err := t.getRegs(&regs); err != nil {
+ panic(fmt.Sprintf("ptrace get-regs failed: %v", err))
+ }
+ if _, err := syscallReturnValue(&regs); err == nil {
+ // The seccomp errno mode ran first, and reset
+ // the error in the registers.
+ return false
+ }
+ // The seccomp hook did not run yet, and therefore it
+ // is safe to use RET_KILL mode for dispatched calls.
+ return true
+ }
+ }
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go
index 7b975137f..7f5c393f0 100644
--- a/pkg/sentry/platform/ptrace/subprocess_arm64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go
@@ -160,6 +160,15 @@ func enableCpuidFault() {
// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program.
// Ref attachedThread() for more detail.
-func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet {
+func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFAction) []seccomp.RuleSet {
return rules
}
+
+// probeSeccomp returns true if seccomp is run after ptrace notifications,
+// which is generally the case for kernel version >= 4.8.
+//
+// On arm64, the support of PTRACE_SYSEMU was added in the 5.3 kernel, so
+// probeSeccomp can always return true.
+func probeSeccomp() bool {
+ return true
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index 74968dfdf..2ce528601 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -20,7 +20,6 @@ import (
"fmt"
"syscall"
- "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
@@ -30,54 +29,6 @@ import (
const syscallEvent syscall.Signal = 0x80
-// probeSeccomp returns true iff seccomp is run after ptrace notifications,
-// which is generally the case for kernel version >= 4.8. This check is dynamic
-// because kernels have be backported behavior.
-//
-// See createStub for more information.
-//
-// Precondition: the runtime OS thread must be locked.
-func probeSeccomp() bool {
- // Create a completely new, destroyable process.
- t, err := attachedThread(0, linux.SECCOMP_RET_ERRNO)
- if err != nil {
- panic(fmt.Sprintf("seccomp probe failed: %v", err))
- }
- defer t.destroy()
-
- // Set registers to the yield system call. This call is not allowed
- // by the filters specified in the attachThread function.
- regs := createSyscallRegs(&t.initRegs, syscall.SYS_SCHED_YIELD)
- if err := t.setRegs(&regs); err != nil {
- panic(fmt.Sprintf("ptrace set regs failed: %v", err))
- }
-
- for {
- // Attempt an emulation.
- if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
- panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
- }
-
- sig := t.wait(stopped)
- if sig == (syscallEvent | syscall.SIGTRAP) {
- // Did the seccomp errno hook already run? This would
- // indicate that seccomp is first in line and we're
- // less than 4.8.
- if err := t.getRegs(&regs); err != nil {
- panic(fmt.Sprintf("ptrace get-regs failed: %v", err))
- }
- if _, err := syscallReturnValue(&regs); err == nil {
- // The seccomp errno mode ran first, and reset
- // the error in the registers.
- return false
- }
- // The seccomp hook did not run yet, and therefore it
- // is safe to use RET_KILL mode for dispatched calls.
- return true
- }
- }
-}
-
// createStub creates a fresh stub processes.
//
// Precondition: the runtime OS thread must be locked.
@@ -123,18 +74,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
// stub and all its children. This is used to create child stubs
// (below), so we must include the ability to fork, but otherwise lock
// down available calls only to what is needed.
- rules := []seccomp.RuleSet{
- // Rules for trapping vsyscall access.
- {
- Rules: seccomp.SyscallRules{
- syscall.SYS_GETTIMEOFDAY: {},
- syscall.SYS_TIME: {},
- unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64.
- },
- Action: linux.SECCOMP_RET_TRAP,
- Vsyscall: true,
- },
- }
+ rules := []seccomp.RuleSet{}
if defaultAction != linux.SECCOMP_RET_ALLOW {
rules = append(rules, seccomp.RuleSet{
Rules: seccomp.SyscallRules{
@@ -173,9 +113,8 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
},
Action: linux.SECCOMP_RET_ALLOW,
})
-
- rules = appendArchSeccompRules(rules)
}
+ rules = appendArchSeccompRules(rules, defaultAction)
instrs, err := seccomp.BuildProgram(rules, defaultAction)
if err != nil {
return nil, err
diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD
index 934b6fbcd..b69520030 100644
--- a/pkg/sentry/platform/ring0/BUILD
+++ b/pkg/sentry/platform/ring0/BUILD
@@ -72,11 +72,13 @@ go_library(
"lib_amd64.s",
"lib_arm64.go",
"lib_arm64.s",
+ "lib_arm64_unsafe.go",
"ring0.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/cpuid",
+ "//pkg/safecopy",
"//pkg/sentry/platform/ring0/pagetables",
"//pkg/usermem",
],
diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go
index f6da41c27..8122ac6e2 100644
--- a/pkg/sentry/platform/ring0/aarch64.go
+++ b/pkg/sentry/platform/ring0/aarch64.go
@@ -27,26 +27,27 @@ const (
_PTE_PGT_BASE = 0x7000
_PTE_PGT_SIZE = 0x1000
- _PSR_MODE_EL0t = 0x0
- _PSR_MODE_EL1t = 0x4
- _PSR_MODE_EL1h = 0x5
- _PSR_EL_MASK = 0xf
-
- _PSR_D_BIT = 0x200
- _PSR_A_BIT = 0x100
- _PSR_I_BIT = 0x80
- _PSR_F_BIT = 0x40
+ _PSR_D_BIT = 0x00000200
+ _PSR_A_BIT = 0x00000100
+ _PSR_I_BIT = 0x00000080
+ _PSR_F_BIT = 0x00000040
)
const (
+ // PSR bits
+ PSR_MODE_EL0t = 0x00000000
+ PSR_MODE_EL1t = 0x00000004
+ PSR_MODE_EL1h = 0x00000005
+ PSR_MODE_MASK = 0x0000000f
+
// KernelFlagsSet should always be set in the kernel.
- KernelFlagsSet = _PSR_MODE_EL1h
+ KernelFlagsSet = PSR_MODE_EL1h
// UserFlagsSet are always set in userspace.
- UserFlagsSet = _PSR_MODE_EL0t
+ UserFlagsSet = PSR_MODE_EL0t
- KernelFlagsClear = _PSR_EL_MASK
- UserFlagsClear = _PSR_EL_MASK
+ KernelFlagsClear = PSR_MODE_MASK
+ UserFlagsClear = PSR_MODE_MASK
PsrDefaultSet = _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT
)
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index d42eda37b..db6465663 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -394,6 +394,8 @@ TEXT ·Current(SB),NOSPLIT,$0-8
#define STACK_FRAME_SIZE 16
+// kernelExitToEl0 is the entrypoint for application in guest_el0.
+// Prepare the vcpu environment for container application.
TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
// Step1, save sentry context into memory.
REGISTERS_SAVE(RSV_REG, CPU_REGISTERS)
@@ -464,7 +466,23 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
ERET()
+// kernelExitToEl1 is the entrypoint for sentry in guest_el1.
+// Prepare the vcpu environment for sentry.
TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ MOVD CPU_REGISTERS+PTRACE_PSTATE(RSV_REG), R1
+ WORD $0xd5184001 //MSR R1, SPSR_EL1
+
+ MOVD CPU_REGISTERS+PTRACE_PC(RSV_REG), R1
+ MSR R1, ELR_EL1
+
+ MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1
+ MOVD R1, RSP
+
+ REGISTERS_LOAD(RSV_REG, CPU_REGISTERS)
+ MOVD CPU_REGISTERS+PTRACE_R9(RSV_REG), RSV_REG_APP
+
ERET()
// Start is the CPU entrypoint.
diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go
index af075aae4..242b9305c 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.go
+++ b/pkg/sentry/platform/ring0/lib_arm64.go
@@ -37,3 +37,10 @@ func SaveVRegs(*byte)
// LoadVRegs loads V0-V31 registers.
func LoadVRegs(*byte)
+
+// Init sets function pointers based on architectural features.
+//
+// This must be called prior to using ring0.
+func Init() {
+ rewriteVectors()
+}
diff --git a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
new file mode 100644
index 000000000..c05166fea
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
@@ -0,0 +1,108 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+import (
+ "reflect"
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/safecopy"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ nopInstruction = 0xd503201f
+ instSize = unsafe.Sizeof(uint32(0))
+ vectorsRawLen = 0x800
+)
+
+func unsafeSlice(addr uintptr, length int) (slice []uint32) {
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
+ hdr.Data = addr
+ hdr.Len = length / int(instSize)
+ hdr.Cap = length / int(instSize)
+ return slice
+}
+
+// Work around: move ring0.Vectors() into a specific address with 11-bits alignment.
+//
+// According to the design documentation of Arm64,
+// the start address of exception vector table should be 11-bits aligned.
+// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S
+// But, we can't align a function's start address to a specific address by using golang.
+// We have raised this question in golang community:
+// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I
+// This function will be removed when golang supports this feature.
+//
+// There are 2 jobs were implemented in this function:
+// 1, move the start address of exception vector table into the specific address.
+// 2, modify the offset of each instruction.
+func rewriteVectors() {
+ vectorsBegin := reflect.ValueOf(Vectors).Pointer()
+
+ // The exception-vector-table is required to be 11-bits aligned.
+ // And the size is 0x800.
+ // Please see the documentation as reference:
+ // https://developer.arm.com/docs/100933/0100/aarch64-exception-vector-table
+ //
+ // But, golang does not allow to set a function's address to a specific value.
+ // So, for gvisor, I defined the size of exception-vector-table as 4K,
+ // filled the 2nd 2K part with NOP-s.
+ // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
+ //
+ // So, the prerequisite for this function to work correctly is:
+ // vectorsSafeLen >= 0x1000
+ // vectorsRawLen = 0x800
+ vectorsSafeLen := int(safecopy.FindEndAddress(vectorsBegin) - vectorsBegin)
+ if vectorsSafeLen < 2*vectorsRawLen {
+ panic("Can't update vectors")
+ }
+
+ vectorsSafeTable := unsafeSlice(vectorsBegin, vectorsSafeLen) // Now a []uint32
+ vectorsRawLen32 := vectorsRawLen / int(instSize)
+
+ offset := vectorsBegin & (1<<11 - 1)
+ if offset != 0 {
+ offset = 1<<11 - offset
+ }
+
+ pageBegin := (vectorsBegin + offset) & ^uintptr(usermem.PageSize-1)
+
+ _, _, errno := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC))
+ if errno != 0 {
+ panic(errno.Error())
+ }
+
+ offset = offset / instSize // By index, not bytes.
+ // Move exception-vector-table into the specific address, should uses memmove here.
+ for i := 1; i <= vectorsRawLen32; i++ {
+ vectorsSafeTable[int(offset)+vectorsRawLen32-i] = vectorsSafeTable[vectorsRawLen32-i]
+ }
+
+ // Adjust branch since instruction was moved forward.
+ for i := 0; i < vectorsRawLen32; i++ {
+ if vectorsSafeTable[int(offset)+i] != nopInstruction {
+ vectorsSafeTable[int(offset)+i] -= uint32(offset)
+ }
+ }
+
+ _, _, errno = syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_EXEC))
+ if errno != 0 {
+ panic(errno.Error())
+ }
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index 4f2406ce3..581841555 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -80,7 +80,7 @@ go_library(
"pagetables_amd64.go",
"pagetables_arm64.go",
"pagetables_x86.go",
- "pcids_x86.go",
+ "pcids.go",
"walker_amd64.go",
"walker_arm64.go",
"walker_empty.go",
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go b/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go
index dcf061df9..157438d9b 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build 386 amd64
package pagetables
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids.go
index e199bae18..9206030bf 100644
--- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
+++ b/pkg/sentry/platform/ring0/pagetables/pcids.go
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
-
package pagetables
import (
diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go
index 5f80d64e8..9da0ea685 100644
--- a/pkg/sentry/platform/ring0/x86.go
+++ b/pkg/sentry/platform/ring0/x86.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build 386 amd64
package ring0
diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sentry/sighandling/sighandling.go
index ba1f9043d..83195d5a1 100644
--- a/pkg/sentry/sighandling/sighandling.go
+++ b/pkg/sentry/sighandling/sighandling.go
@@ -85,6 +85,11 @@ func StartSignalForwarding(handler func(linux.Signal)) func() {
for sig := 1; sig <= numSignals+1; sig++ {
sigchan := make(chan os.Signal, 1)
sigchans = append(sigchans, sigchan)
+
+ // SIGURG is used by Go's runtime scheduler.
+ if sig == int(linux.SIGURG) {
+ continue
+ }
signal.Notify(sigchan, syscall.Signal(sig))
}
// Start up our listener.
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 611fa22c3..c40c6d673 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
"//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/usermem",
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 7cd2ce55b..721094bbf 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"extensions.go",
"netfilter.go",
+ "owner_matcher.go",
"targets.go",
"tcp_matcher.go",
"udp_matcher.go",
@@ -22,7 +23,6 @@ go_library(
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/usermem",
],
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index b4b244abf..0336a32d8 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -19,7 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -37,12 +37,12 @@ type matchMaker interface {
// name is the matcher name as stored in the xt_entry_match struct.
name() string
- // marshal converts from an iptables.Matcher to an ABI struct.
- marshal(matcher iptables.Matcher) []byte
+ // marshal converts from an stack.Matcher to an ABI struct.
+ marshal(matcher stack.Matcher) []byte
// unmarshal converts from the ABI matcher struct to an
- // iptables.Matcher.
- unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error)
+ // stack.Matcher.
+ unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error)
}
// matchMakers maps the name of supported matchers to the matchMaker that
@@ -58,7 +58,7 @@ func registerMatchMaker(mm matchMaker) {
matchMakers[mm.name()] = mm
}
-func marshalMatcher(matcher iptables.Matcher) []byte {
+func marshalMatcher(matcher stack.Matcher) []byte {
matchMaker, ok := matchMakers[matcher.Name()]
if !ok {
panic(fmt.Sprintf("Unknown matcher of type %T.", matcher))
@@ -86,7 +86,7 @@ func marshalEntryMatch(name string, data []byte) []byte {
return append(buf, make([]byte, size-len(buf))...)
}
-func unmarshalMatcher(match linux.XTEntryMatch, filter iptables.IPHeaderFilter, buf []byte) (iptables.Matcher, error) {
+func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) {
matchMaker, ok := matchMakers[match.Name.String()]
if !ok {
return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String())
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 2ec11f6ac..878f81fd5 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -26,7 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -35,6 +35,11 @@ import (
// shouldn't be reached - an error has occurred if we fall through to one.
const errorTargetName = "ERROR"
+// redirectTargetName is used to mark targets as redirect targets. Redirect
+// targets should be reached for only NAT and Mangle tables. These targets will
+// change the destination port/destination IP for packets.
+const redirectTargetName = "REDIRECT"
+
// Metadata is used to verify that we are correctly serializing and
// deserializing iptables into structs consumable by the iptables tool. We save
// a metadata struct when the tables are written, and when they are read out we
@@ -123,19 +128,19 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen
return entries, nil
}
-func findTable(stack *stack.Stack, tablename linux.TableName) (iptables.Table, error) {
- ipt := stack.IPTables()
+func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) {
+ ipt := stk.IPTables()
table, ok := ipt.Tables[tablename.String()]
if !ok {
- return iptables.Table{}, fmt.Errorf("couldn't find table %q", tablename)
+ return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename)
}
return table, nil
}
// FillDefaultIPTables sets stack's IPTables to the default tables and
// populates them with metadata.
-func FillDefaultIPTables(stack *stack.Stack) {
- ipt := iptables.DefaultTables()
+func FillDefaultIPTables(stk *stack.Stack) {
+ ipt := stack.DefaultTables()
// In order to fill in the metadata, we have to translate ipt from its
// netstack format to Linux's giant-binary-blob format.
@@ -148,14 +153,14 @@ func FillDefaultIPTables(stack *stack.Stack) {
ipt.Tables[name] = table
}
- stack.SetIPTables(ipt)
+ stk.SetIPTables(ipt)
}
// convertNetstackToBinary converts the iptables as stored in netstack to the
// format expected by the iptables tool. Linux stores each table as a binary
// blob that can only be traversed by parsing a bit, reading some offsets,
// jumping to those offsets, parsing again, etc.
-func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, error) {
+func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelIPTGetEntries, metadata, error) {
// Return values.
var entries linux.KernelIPTGetEntries
var meta metadata
@@ -228,18 +233,20 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern
return entries, meta, nil
}
-func marshalTarget(target iptables.Target) []byte {
+func marshalTarget(target stack.Target) []byte {
switch tg := target.(type) {
- case iptables.AcceptTarget:
- return marshalStandardTarget(iptables.RuleAccept)
- case iptables.DropTarget:
- return marshalStandardTarget(iptables.RuleDrop)
- case iptables.ErrorTarget:
+ case stack.AcceptTarget:
+ return marshalStandardTarget(stack.RuleAccept)
+ case stack.DropTarget:
+ return marshalStandardTarget(stack.RuleDrop)
+ case stack.ErrorTarget:
return marshalErrorTarget(errorTargetName)
- case iptables.UserChainTarget:
+ case stack.UserChainTarget:
return marshalErrorTarget(tg.Name)
- case iptables.ReturnTarget:
- return marshalStandardTarget(iptables.RuleReturn)
+ case stack.ReturnTarget:
+ return marshalStandardTarget(stack.RuleReturn)
+ case stack.RedirectTarget:
+ return marshalRedirectTarget()
case JumpTarget:
return marshalJumpTarget(tg)
default:
@@ -247,7 +254,7 @@ func marshalTarget(target iptables.Target) []byte {
}
}
-func marshalStandardTarget(verdict iptables.RuleVerdict) []byte {
+func marshalStandardTarget(verdict stack.RuleVerdict) []byte {
nflog("convert to binary: marshalling standard target")
// The target's name will be the empty string.
@@ -276,6 +283,19 @@ func marshalErrorTarget(errorName string) []byte {
return binary.Marshal(ret, usermem.ByteOrder, target)
}
+func marshalRedirectTarget() []byte {
+ // This is a redirect target named redirect
+ target := linux.XTRedirectTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTRedirectTarget,
+ },
+ }
+ copy(target.Target.Name[:], redirectTargetName)
+
+ ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
func marshalJumpTarget(jt JumpTarget) []byte {
nflog("convert to binary: marshalling jump target")
@@ -295,13 +315,13 @@ func marshalJumpTarget(jt JumpTarget) []byte {
// translateFromStandardVerdict translates verdicts the same way as the iptables
// tool.
-func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 {
+func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 {
switch verdict {
- case iptables.RuleAccept:
+ case stack.RuleAccept:
return -linux.NF_ACCEPT - 1
- case iptables.RuleDrop:
+ case stack.RuleDrop:
return -linux.NF_DROP - 1
- case iptables.RuleReturn:
+ case stack.RuleReturn:
return linux.NF_RETURN
default:
// TODO(gvisor.dev/issue/170): Support Jump.
@@ -310,18 +330,18 @@ func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 {
}
// translateToStandardTarget translates from the value in a
-// linux.XTStandardTarget to an iptables.Verdict.
-func translateToStandardTarget(val int32) (iptables.Target, error) {
+// linux.XTStandardTarget to an stack.Verdict.
+func translateToStandardTarget(val int32) (stack.Target, error) {
// TODO(gvisor.dev/issue/170): Support other verdicts.
switch val {
case -linux.NF_ACCEPT - 1:
- return iptables.AcceptTarget{}, nil
+ return stack.AcceptTarget{}, nil
case -linux.NF_DROP - 1:
- return iptables.DropTarget{}, nil
+ return stack.DropTarget{}, nil
case -linux.NF_QUEUE - 1:
return nil, errors.New("unsupported iptables verdict QUEUE")
case linux.NF_RETURN:
- return iptables.ReturnTarget{}, nil
+ return stack.ReturnTarget{}, nil
default:
return nil, fmt.Errorf("unknown iptables verdict %d", val)
}
@@ -329,7 +349,7 @@ func translateToStandardTarget(val int32) (iptables.Target, error) {
// SetEntries sets iptables rules for a single table. See
// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
-func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
+func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// Get the basic rules data (struct ipt_replace).
if len(optVal) < linux.SizeOfIPTReplace {
nflog("optVal has insufficient size for replace %d", len(optVal))
@@ -341,10 +361,12 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace)
// TODO(gvisor.dev/issue/170): Support other tables.
- var table iptables.Table
+ var table stack.Table
switch replace.Name.String() {
- case iptables.TablenameFilter:
- table = iptables.EmptyFilterTable()
+ case stack.TablenameFilter:
+ table = stack.EmptyFilterTable()
+ case stack.TablenameNat:
+ table = stack.EmptyNatTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
return syserr.ErrInvalidArgument
@@ -404,14 +426,14 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
return syserr.ErrInvalidArgument
}
- target, err := parseTarget(optVal[:targetSize])
+ target, err := parseTarget(filter, optVal[:targetSize])
if err != nil {
nflog("failed to parse target: %v", err)
return syserr.ErrInvalidArgument
}
optVal = optVal[targetSize:]
- table.Rules = append(table.Rules, iptables.Rule{
+ table.Rules = append(table.Rules, stack.Rule{
Filter: filter,
Target: target,
Matchers: matchers,
@@ -442,11 +464,11 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
table.Underflows[hk] = ruleIdx
}
}
- if ruleIdx := table.BuiltinChains[hk]; ruleIdx == iptables.HookUnset {
+ if ruleIdx := table.BuiltinChains[hk]; ruleIdx == stack.HookUnset {
nflog("hook %v is unset.", hk)
return syserr.ErrInvalidArgument
}
- if ruleIdx := table.Underflows[hk]; ruleIdx == iptables.HookUnset {
+ if ruleIdx := table.Underflows[hk]; ruleIdx == stack.HookUnset {
nflog("underflow %v is unset.", hk)
return syserr.ErrInvalidArgument
}
@@ -455,7 +477,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// Add the user chains.
for ruleIdx, rule := range table.Rules {
- target, ok := rule.Target.(iptables.UserChainTarget)
+ target, ok := rule.Target.(stack.UserChainTarget)
if !ok {
continue
}
@@ -495,11 +517,11 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
}
// TODO(gvisor.dev/issue/170): Support other chains.
- // Since we only support modifying the INPUT chain right now, make sure
- // all other chains point to ACCEPT rules.
+ // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
+ // make sure all other chains point to ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
- if hook != iptables.Input {
- if _, ok := table.Rules[ruleIdx].Target.(iptables.AcceptTarget); !ok {
+ if hook == stack.Forward || hook == stack.Postrouting {
+ if _, ok := table.Rules[ruleIdx].Target.(stack.AcceptTarget); !ok {
nflog("hook %d is unsupported.", hook)
return syserr.ErrInvalidArgument
}
@@ -511,7 +533,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- ipt := stack.IPTables()
+ ipt := stk.IPTables()
table.SetMetadata(metadata{
HookEntry: replace.HookEntry,
Underflow: replace.Underflow,
@@ -519,16 +541,16 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
Size: replace.Size,
})
ipt.Tables[replace.Name.String()] = table
- stack.SetIPTables(ipt)
+ stk.SetIPTables(ipt)
return nil
}
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
// only the matchers.
-func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Matcher, error) {
+func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) {
nflog("set entries: parsing matchers of size %d", len(optVal))
- var matchers []iptables.Matcher
+ var matchers []stack.Matcher
for len(optVal) > 0 {
nflog("set entries: optVal has len %d", len(optVal))
@@ -570,7 +592,7 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma
// parseTarget parses a target from optVal. optVal should contain only the
// target.
-func parseTarget(optVal []byte) (iptables.Target, error) {
+func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) {
nflog("set entries: parsing target of size %d", len(optVal))
if len(optVal) < linux.SizeOfXTEntryTarget {
return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal))
@@ -614,67 +636,125 @@ func parseTarget(optVal []byte) (iptables.Target, error) {
switch name := errorTarget.Name.String(); name {
case errorTargetName:
nflog("set entries: error target")
- return iptables.ErrorTarget{}, nil
+ return stack.ErrorTarget{}, nil
default:
// User defined chain.
nflog("set entries: user-defined target %q", name)
- return iptables.UserChainTarget{Name: name}, nil
+ return stack.UserChainTarget{Name: name}, nil
+ }
+
+ case redirectTargetName:
+ // Redirect target.
+ if len(optVal) < linux.SizeOfXTRedirectTarget {
+ return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal))
+ }
+
+ if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
}
+
+ var redirectTarget linux.XTRedirectTarget
+ buf = optVal[:linux.SizeOfXTRedirectTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
+
+ // Copy linux.XTRedirectTarget to stack.RedirectTarget.
+ var target stack.RedirectTarget
+ nfRange := redirectTarget.NfRange
+
+ // RangeSize should be 1.
+ if nfRange.RangeSize != 1 {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ // TODO(gvisor.dev/issue/170): Check if the flags are valid.
+ // Also check if we need to map ports or IP.
+ // For now, redirect target only supports destination port change.
+ // Port range and IP range are not supported yet.
+ if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+ target.RangeProtoSpecified = true
+
+ target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:])
+
+ // TODO(gvisor.dev/issue/170): Port range is not supported yet.
+ if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ // Convert port from big endian to little endian.
+ port := make([]byte, 2)
+ binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort)
+ target.MinPort = binary.LittleEndian.Uint16(port)
+
+ binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort)
+ target.MaxPort = binary.LittleEndian.Uint16(port)
+ return target, nil
}
// Unknown target.
return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet.", target.Name.String())
}
-func filterFromIPTIP(iptip linux.IPTIP) (iptables.IPHeaderFilter, error) {
+func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
if containsUnsupportedFields(iptip) {
- return iptables.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
+ return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
+ }
+ if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask))
}
- return iptables.IPHeaderFilter{
- Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ return stack.IPHeaderFilter{
+ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ Dst: tcpip.Address(iptip.Dst[:]),
+ DstMask: tcpip.Address(iptip.DstMask[:]),
+ DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0,
}, nil
}
func containsUnsupportedFields(iptip linux.IPTIP) bool {
- // Currently we check that everything except protocol is zeroed.
+ // The following features are supported:
+ // - Protocol
+ // - Dst and DstMask
+ // - The inverse destination IP check flag
var emptyInetAddr = linux.InetAddr{}
var emptyInterface = [linux.IFNAMSIZ]byte{}
- return iptip.Dst != emptyInetAddr ||
- iptip.Src != emptyInetAddr ||
+ // Disable any supported inverse flags.
+ inverseMask := uint8(linux.IPT_INV_DSTIP)
+ return iptip.Src != emptyInetAddr ||
iptip.SrcMask != emptyInetAddr ||
- iptip.DstMask != emptyInetAddr ||
iptip.InputInterface != emptyInterface ||
iptip.OutputInterface != emptyInterface ||
iptip.InputInterfaceMask != emptyInterface ||
iptip.OutputInterfaceMask != emptyInterface ||
iptip.Flags != 0 ||
- iptip.InverseFlags != 0
+ iptip.InverseFlags&^inverseMask != 0
}
-func validUnderflow(rule iptables.Rule) bool {
+func validUnderflow(rule stack.Rule) bool {
if len(rule.Matchers) != 0 {
return false
}
switch rule.Target.(type) {
- case iptables.AcceptTarget, iptables.DropTarget:
+ case stack.AcceptTarget, stack.DropTarget:
return true
default:
return false
}
}
-func hookFromLinux(hook int) iptables.Hook {
+func hookFromLinux(hook int) stack.Hook {
switch hook {
case linux.NF_INET_PRE_ROUTING:
- return iptables.Prerouting
+ return stack.Prerouting
case linux.NF_INET_LOCAL_IN:
- return iptables.Input
+ return stack.Input
case linux.NF_INET_FORWARD:
- return iptables.Forward
+ return stack.Forward
case linux.NF_INET_LOCAL_OUT:
- return iptables.Output
+ return stack.Output
case linux.NF_INET_POST_ROUTING:
- return iptables.Postrouting
+ return stack.Postrouting
}
panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook))
}
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
new file mode 100644
index 000000000..5949a7c29
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -0,0 +1,128 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameOwner = "owner"
+
+func init() {
+ registerMatchMaker(ownerMarshaler{})
+}
+
+// ownerMarshaler implements matchMaker for owner matching.
+type ownerMarshaler struct{}
+
+// name implements matchMaker.name.
+func (ownerMarshaler) name() string {
+ return matcherNameOwner
+}
+
+// marshal implements matchMaker.marshal.
+func (ownerMarshaler) marshal(mr stack.Matcher) []byte {
+ matcher := mr.(*OwnerMatcher)
+ iptOwnerInfo := linux.IPTOwnerInfo{
+ UID: matcher.uid,
+ GID: matcher.gid,
+ }
+
+ // Support for UID match.
+ // TODO(gvisor.dev/issue/170): Need to support gid match.
+ if matcher.matchUID {
+ iptOwnerInfo.Match = linux.XT_OWNER_UID
+ } else if matcher.matchGID {
+ panic("GID match is not supported.")
+ } else {
+ panic("UID match is not set.")
+ }
+
+ buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo)
+ return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, usermem.ByteOrder, iptOwnerInfo))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+ if len(buf) < linux.SizeOfIPTOwnerInfo {
+ return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may
+ // exceed what's strictly necessary to hold matchData.
+ var matchData linux.IPTOwnerInfo
+ binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData)
+
+ if matchData.Invert != 0 {
+ return nil, fmt.Errorf("invert flag is not supported for owner match")
+ }
+
+ // Support for UID match.
+ // TODO(gvisor.dev/issue/170): Need to support gid match.
+ if matchData.Match&linux.XT_OWNER_UID != linux.XT_OWNER_UID {
+ return nil, fmt.Errorf("owner match is only supported for uid")
+ }
+
+ // Check Flags.
+ var owner OwnerMatcher
+ owner.uid = matchData.UID
+ owner.gid = matchData.GID
+ owner.matchUID = true
+
+ return &owner, nil
+}
+
+type OwnerMatcher struct {
+ uid uint32
+ gid uint32
+ matchUID bool
+ matchGID bool
+ invert uint8
+}
+
+// Name implements Matcher.Name.
+func (*OwnerMatcher) Name() string {
+ return matcherNameOwner
+}
+
+// Match implements Matcher.Match.
+func (om *OwnerMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
+ // Support only for OUTPUT chain.
+ // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also.
+ if hook != stack.Output {
+ return false, true
+ }
+
+ // If the packet owner is not set, drop the packet.
+ // Support for uid match.
+ // TODO(gvisor.dev/issue/170): Need to support gid match.
+ if pkt.Owner == nil || !om.matchUID {
+ return false, true
+ }
+
+ // TODO(gvisor.dev/issue/170): Need to add tests to verify
+ // drop rule when packet UID does not match owner matcher UID.
+ if pkt.Owner.UID() != om.uid {
+ return false, false
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index c421b87cf..c948de876 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -15,11 +15,10 @@
package netfilter
import (
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// JumpTarget implements iptables.Target.
+// JumpTarget implements stack.Target.
type JumpTarget struct {
// Offset is the byte offset of the rule to jump to. It is used for
// marshaling and unmarshaling.
@@ -29,7 +28,7 @@ type JumpTarget struct {
RuleNum int
}
-// Action implements iptables.Target.Action.
-func (jt JumpTarget) Action(tcpip.PacketBuffer) (iptables.RuleVerdict, int) {
- return iptables.RuleJump, jt.RuleNum
+// Action implements stack.Target.Action.
+func (jt JumpTarget) Action(stack.PacketBuffer) (stack.RuleVerdict, int) {
+ return stack.RuleJump, jt.RuleNum
}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index f9945e214..ff1cfd8f6 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -19,9 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -40,7 +39,7 @@ func (tcpMarshaler) name() string {
}
// marshal implements matchMaker.marshal.
-func (tcpMarshaler) marshal(mr iptables.Matcher) []byte {
+func (tcpMarshaler) marshal(mr stack.Matcher) []byte {
matcher := mr.(*TCPMatcher)
xttcp := linux.XTTCP{
SourcePortStart: matcher.sourcePortStart,
@@ -53,7 +52,7 @@ func (tcpMarshaler) marshal(mr iptables.Matcher) []byte {
}
// unmarshal implements matchMaker.unmarshal.
-func (tcpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) {
+func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
if len(buf) < linux.SizeOfXTTCP {
return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf))
}
@@ -97,7 +96,7 @@ func (*TCPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) {
+func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
netHeader := header.IPv4(pkt.NetworkHeader)
if netHeader.TransportProtocol() != header.TCPProtocolNumber {
@@ -115,7 +114,7 @@ func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfac
// Now we need the transport header. However, this may not have been set
// yet.
// TODO(gvisor.dev/issue/170): Parsing the transport header should
- // ultimately be moved into the iptables.Check codepath as matchers are
+ // ultimately be moved into the stack.Check codepath as matchers are
// added.
var tcpHeader header.TCP
if pkt.TransportHeader != nil {
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 86aa11696..3359418c1 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -19,9 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -40,7 +39,7 @@ func (udpMarshaler) name() string {
}
// marshal implements matchMaker.marshal.
-func (udpMarshaler) marshal(mr iptables.Matcher) []byte {
+func (udpMarshaler) marshal(mr stack.Matcher) []byte {
matcher := mr.(*UDPMatcher)
xtudp := linux.XTUDP{
SourcePortStart: matcher.sourcePortStart,
@@ -53,7 +52,7 @@ func (udpMarshaler) marshal(mr iptables.Matcher) []byte {
}
// unmarshal implements matchMaker.unmarshal.
-func (udpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) {
+func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
if len(buf) < linux.SizeOfXTUDP {
return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf))
}
@@ -94,11 +93,11 @@ func (*UDPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) {
+func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
netHeader := header.IPv4(pkt.NetworkHeader)
// TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
- // into the iptables.Check codepath as matchers are added.
+ // into the stack.Check codepath as matchers are added.
if netHeader.TransportProtocol() != header.UDPProtocolNumber {
return false, false
}
@@ -114,7 +113,7 @@ func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfac
// Now we need the transport header. However, this may not have been set
// yet.
// TODO(gvisor.dev/issue/170): Parsing the transport header should
- // ultimately be moved into the iptables.Check codepath as matchers are
+ // ultimately be moved into the stack.Check codepath as matchers are
// added.
var udpHeader header.UDP
if pkt.TransportHeader != nil {
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index ab01cb4fa..cbf46b1e9 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -38,7 +38,6 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index e187276c5..7ac38764d 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -29,6 +29,7 @@ import (
"io"
"math"
"reflect"
+ "sync/atomic"
"syscall"
"time"
@@ -264,6 +265,12 @@ type SocketOperations struct {
skType linux.SockType
protocol int
+ // readViewHasData is 1 iff readView has data to be read, 0 otherwise.
+ // Must be accessed using atomic operations. It must only be written
+ // with readMu held but can be read without holding readMu. The latter
+ // is required to avoid deadlocks in epoll Readiness checks.
+ readViewHasData uint32
+
// readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"`
// readView contains the remaining payload from the last packet.
@@ -293,7 +300,7 @@ type SocketOperations struct {
// New creates a new endpoint socket.
func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil {
+ if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
}
@@ -410,21 +417,24 @@ func (s *SocketOperations) isPacketBased() bool {
// fetchReadView updates the readView field of the socket if it's currently
// empty. It assumes that the socket is locked.
+//
+// Precondition: s.readMu must be held.
func (s *SocketOperations) fetchReadView() *syserr.Error {
if len(s.readView) > 0 {
return nil
}
-
s.readView = nil
s.sender = tcpip.FullAddress{}
v, cms, err := s.Endpoint.Read(&s.sender)
if err != nil {
+ atomic.StoreUint32(&s.readViewHasData, 0)
return syserr.TranslateNetstackError(err)
}
s.readView = v
s.readCM = cms
+ atomic.StoreUint32(&s.readViewHasData, 1)
return nil
}
@@ -525,7 +535,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
}
if resCh != nil {
- t := ctx.(*kernel.Task)
+ t := kernel.TaskFromContext(ctx)
if err := t.Block(resCh); err != nil {
return 0, syserr.FromError(err).ToError()
}
@@ -598,7 +608,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader
}
if resCh != nil {
- t := ctx.(*kernel.Task)
+ t := kernel.TaskFromContext(ctx)
if err := t.Block(resCh); err != nil {
return 0, syserr.FromError(err).ToError()
}
@@ -623,11 +633,9 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
// Check our cached value iff the caller asked for readability and the
// endpoint itself is currently not readable.
if (mask & ^r & waiter.EventIn) != 0 {
- s.readMu.Lock()
- if len(s.readView) > 0 {
+ if atomic.LoadUint32(&s.readViewHasData) == 1 {
r |= waiter.EventIn
}
- s.readMu.Unlock()
}
return r
@@ -655,7 +663,7 @@ func (s *SocketOperations) checkFamily(family uint16, exact bool) *syserr.Error
// This is a hack to work around the fact that both IPv4 and IPv6 ANY are
// represented by the empty string.
//
-// TODO(gvisor.dev/issues/1556): remove this function.
+// TODO(gvisor.dev/issue/1556): remove this function.
func (s *SocketOperations) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress {
if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET {
addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
@@ -712,14 +720,44 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, family, err := AddressAndFamily(sockaddr)
- if err != nil {
- return err
+ if len(sockaddr) < 2 {
+ return syserr.ErrInvalidArgument
}
- if err := s.checkFamily(family, true /* exact */); err != nil {
- return err
+
+ family := usermem.ByteOrder.Uint16(sockaddr)
+ var addr tcpip.FullAddress
+
+ // Bind for AF_PACKET requires only family, protocol and ifindex.
+ // In function AddressAndFamily, we check the address length which is
+ // not needed for AF_PACKET bind.
+ if family == linux.AF_PACKET {
+ var a linux.SockAddrLink
+ if len(sockaddr) < sockAddrLinkSize {
+ return syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+
+ if a.Protocol != uint16(s.protocol) {
+ return syserr.ErrInvalidArgument
+ }
+
+ addr = tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }
+ } else {
+ var err *syserr.Error
+ addr, family, err = AddressAndFamily(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ if err = s.checkFamily(family, true /* exact */); err != nil {
+ return err
+ }
+
+ addr = s.mapFamily(addr, family)
}
- addr = s.mapFamily(addr, family)
// Issue the bind request to the endpoint.
return syserr.TranslateNetstackError(s.Endpoint.Bind(addr))
@@ -902,7 +940,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -927,8 +965,15 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int,
return nil, syserr.ErrProtocolNotAvailable
}
+func boolToInt32(v bool) int32 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -960,12 +1005,11 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.PasscredOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.PasscredOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_SNDBUF:
if outLen < sizeOfInt32 {
@@ -1004,24 +1048,22 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.ReuseAddressOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.ReusePortOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.ReusePortOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
@@ -1051,24 +1093,22 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.BroadcastOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.BroadcastOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_KEEPALIVE:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.KeepaliveEnabledOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_LINGER:
if outLen < linux.SizeOfLinger {
@@ -1118,47 +1158,41 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptInt(tcpip.DelayOption)
+ v, err := ep.GetSockOptBool(tcpip.DelayOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- if v == 0 {
- return int32(1), nil
- }
- return int32(0), nil
+ return boolToInt32(!v), nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.CorkOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.CorkOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.QuickAckOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.QuickAckOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MaxSegOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.MaxSegOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1290,11 +1324,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- var o int32
- if v {
- o = 1
- }
- return o, nil
+ return boolToInt32(v), nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1304,8 +1334,8 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if outLen == 0 {
return make([]byte, 0), nil
}
- var v tcpip.IPv6TrafficClassOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1327,12 +1357,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- var o int32
- if v {
- o = 1
- }
- return o, nil
+ return boolToInt32(v), nil
default:
emitUnimplementedEventIPv6(t, name)
@@ -1348,8 +1373,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.TTLOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.TTLOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1365,8 +1390,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MulticastTTLOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.MulticastTTLOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1391,23 +1416,19 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MulticastLoopOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- if v {
- return int32(1), nil
- }
- return int32(0), nil
+ return boolToInt32(v), nil
case linux.IP_TOS:
// Length handling for parity with Linux.
if outLen == 0 {
return []byte(nil), nil
}
- var v tcpip.IPv4TOSOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
if outLen < sizeOfInt32 {
@@ -1424,11 +1445,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- var o int32
- if v {
- o = 1
- }
- return o, nil
+ return boolToInt32(v), nil
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
@@ -1439,11 +1456,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- var o int32
- if v {
- o = 1
- }
- return o, nil
+ return boolToInt32(v), nil
default:
emitUnimplementedEventIP(t, name)
@@ -1503,7 +1516,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
// SetSockOpt can be used to implement the linux syscall setsockopt(2) for
// sockets backed by a commonEndpoint.
-func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error {
+func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error {
switch level {
case linux.SOL_SOCKET:
return setSockOptSocket(t, s, ep, name, optVal)
@@ -1530,7 +1543,7 @@ func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, n
}
// setSockOptSocket implements SetSockOpt when level is SOL_SOCKET.
-func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
switch name {
case linux.SO_SNDBUF:
if len(optVal) < sizeOfInt32 {
@@ -1554,7 +1567,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0))
case linux.SO_REUSEPORT:
if len(optVal) < sizeOfInt32 {
@@ -1562,7 +1575,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0))
case linux.SO_BINDTODEVICE:
n := bytes.IndexByte(optVal, 0)
@@ -1590,7 +1603,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BroadcastOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0))
case linux.SO_PASSCRED:
if len(optVal) < sizeOfInt32 {
@@ -1598,7 +1611,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0))
case linux.SO_KEEPALIVE:
if len(optVal) < sizeOfInt32 {
@@ -1606,7 +1619,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveEnabledOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0))
case linux.SO_SNDTIMEO:
if len(optVal) < linux.SizeOfTimeval {
@@ -1678,11 +1691,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- var o int
- if v == 0 {
- o = 1
- }
- return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0))
case linux.TCP_CORK:
if len(optVal) < sizeOfInt32 {
@@ -1690,7 +1699,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.CorkOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0))
case linux.TCP_QUICKACK:
if len(optVal) < sizeOfInt32 {
@@ -1698,7 +1707,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0))
case linux.TCP_MAXSEG:
if len(optVal) < sizeOfInt32 {
@@ -1706,7 +1715,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v)))
case linux.TCP_KEEPIDLE:
if len(optVal) < sizeOfInt32 {
@@ -1817,7 +1826,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
if v == -1 {
v = 0
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, int(v)))
case linux.IPV6_RECVTCLASS:
v, err := parseIntOrChar(optVal)
@@ -1902,7 +1911,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
if v < 0 || v > 255 {
return syserr.ErrInvalidArgument
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastTTLOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MulticastTTLOption, int(v)))
case linux.IP_ADD_MEMBERSHIP:
req, err := copyInMulticastRequest(optVal, false /* allowAddr */)
@@ -1949,9 +1958,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(
- tcpip.MulticastLoopOption(v != 0),
- ))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0))
case linux.MCAST_JOIN_GROUP:
// FIXME(b/124219304): Implement MCAST_JOIN_GROUP.
@@ -1970,7 +1977,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
} else if v < 1 || v > 255 {
return syserr.ErrInvalidArgument
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TTLOption, int(v)))
case linux.IP_TOS:
if len(optVal) == 0 {
@@ -1980,7 +1987,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv4TOSOption, int(v)))
case linux.IP_RECVTOS:
v, err := parseIntOrChar(optVal)
@@ -2304,6 +2311,10 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
}
copied += n
s.readView.TrimFront(n)
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ }
+
dst = dst.DropFirst(n)
if e != nil {
err = syserr.FromError(e)
@@ -2350,9 +2361,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
// caller-supplied buffer.
s.readMu.Lock()
n, err := s.coalescingRead(ctx, dst, trunc)
- s.readMu.Unlock()
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
+ s.readMu.Unlock()
return n, 0, nil, 0, cmsg, err
}
@@ -2426,6 +2437,10 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
s.readView.TrimFront(int(n))
}
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ }
+
var flags int
if msgLen > int(n) {
flags |= linux.MSG_TRUNC
@@ -2637,7 +2652,9 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
}
// Add bytes removed from the endpoint but not yet sent to the caller.
+ s.readMu.Lock()
v += len(s.readView)
+ s.readMu.Unlock()
if v > math.MaxInt32 {
v = math.MaxInt32
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index 5f181f017..c3f04b613 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -62,10 +62,6 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
}
case linux.SOCK_RAW:
- // TODO(b/142504697): "In order to create a raw socket, a
- // process must have the CAP_NET_RAW capability in the user
- // namespace that governs its network namespace." - raw(7)
-
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
@@ -126,6 +122,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated)
} else {
ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+
+ // Assign task to PacketOwner interface to get the UID and GID for
+ // iptables owner matching.
+ if e == nil {
+ ep.SetOwner(t)
+ }
}
if e != nil {
return nil, syserr.TranslateNetstackError(e)
@@ -135,10 +137,6 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
}
func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
- // TODO(b/142504697): "In order to create a packet socket, a process
- // must have the CAP_NET_RAW capability in the user namespace that
- // governs its network namespace." - packet(7)
-
// Packet sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(t)
if !creds.HasCapability(linux.CAP_NET_RAW) {
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 0692482e9..f5fa18136 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -200,36 +199,66 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
// Statistics implements inet.Stack.Statistics.
func (s *Stack) Statistics(stat interface{}, arg string) error {
switch stats := stat.(type) {
+ case *inet.StatDev:
+ for _, ni := range s.Stack.NICInfo() {
+ if ni.Name != arg {
+ continue
+ }
+ // TODO(gvisor.dev/issue/2103) Support stubbed stats.
+ *stats = inet.StatDev{
+ // Receive section.
+ ni.Stats.Rx.Bytes.Value(), // bytes.
+ ni.Stats.Rx.Packets.Value(), // packets.
+ 0, // errs.
+ 0, // drop.
+ 0, // fifo.
+ 0, // frame.
+ 0, // compressed.
+ 0, // multicast.
+ // Transmit section.
+ ni.Stats.Tx.Bytes.Value(), // bytes.
+ ni.Stats.Tx.Packets.Value(), // packets.
+ 0, // errs.
+ 0, // drop.
+ 0, // fifo.
+ 0, // colls.
+ 0, // carrier.
+ 0, // compressed.
+ }
+ break
+ }
case *inet.StatSNMPIP:
ip := Metrics.IP
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPIP{
- 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL.
+ 0, // Ip/Forwarding.
+ 0, // Ip/DefaultTTL.
ip.PacketsReceived.Value(), // InReceives.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors.
+ 0, // Ip/InHdrErrors.
ip.InvalidDestinationAddressesReceived.Value(), // InAddrErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards.
+ 0, // Ip/ForwDatagrams.
+ 0, // Ip/InUnknownProtos.
+ 0, // Ip/InDiscards.
ip.PacketsDelivered.Value(), // InDelivers.
ip.PacketsSent.Value(), // OutRequests.
ip.OutgoingPacketErrors.Value(), // OutDiscards.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates.
+ 0, // Ip/OutNoRoutes.
+ 0, // Support Ip/ReasmTimeout.
+ 0, // Support Ip/ReasmReqds.
+ 0, // Support Ip/ReasmOKs.
+ 0, // Support Ip/ReasmFails.
+ 0, // Support Ip/FragOKs.
+ 0, // Support Ip/FragFails.
+ 0, // Support Ip/FragCreates.
}
case *inet.StatSNMPICMP:
in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats
out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPICMP{
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs.
+ 0, // Icmp/InMsgs.
Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors.
+ 0, // Icmp/InCsumErrors.
in.DstUnreachable.Value(), // InDestUnreachs.
in.TimeExceeded.Value(), // InTimeExcds.
in.ParamProblem.Value(), // InParmProbs.
@@ -241,7 +270,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
in.TimestampReply.Value(), // InTimestampReps.
in.InfoRequest.Value(), // InAddrMasks.
in.InfoReply.Value(), // InAddrMaskReps.
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs.
+ 0, // Icmp/OutMsgs.
Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors.
out.DstUnreachable.Value(), // OutDestUnreachs.
out.TimeExceeded.Value(), // OutTimeExcds.
@@ -277,15 +306,16 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
}
case *inet.StatSNMPUDP:
udp := Metrics.UDP
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPUDP{
udp.PacketsReceived.Value(), // InDatagrams.
udp.UnknownPortErrors.Value(), // NoPorts.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors.
+ 0, // Udp/InErrors.
udp.PacketsSent.Value(), // OutDatagrams.
udp.ReceiveBufferErrors.Value(), // RcvbufErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti.
+ 0, // Udp/SndbufErrors.
+ 0, // Udp/InCsumErrors.
+ 0, // Udp/IgnoredMulti.
}
default:
return syserr.ErrEndpointOperation.ToError()
@@ -332,7 +362,7 @@ func (s *Stack) RouteTable() []inet.Route {
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() (iptables.IPTables, error) {
+func (s *Stack) IPTables() (stack.IPTables, error) {
return s.Stack.IPTables(), nil
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 50d9744e6..6580bd6e9 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
@@ -48,11 +49,25 @@ func (c *ControlMessages) Release() {
c.Unix.Release()
}
-// Socket is the interface containing socket syscalls used by the syscall layer
-// to redirect them to the appropriate implementation.
+// Socket is an interface combining fs.FileOperations and SocketOps,
+// representing a VFS1 socket file.
type Socket interface {
fs.FileOperations
+ SocketOps
+}
+
+// SocketVFS2 is an interface combining vfs.FileDescription and SocketOps,
+// representing a VFS2 socket file.
+type SocketVFS2 interface {
+ vfs.FileDescriptionImpl
+ SocketOps
+}
+// SocketOps is the interface containing socket syscalls used by the syscall
+// layer to redirect them to the appropriate implementation.
+//
+// It is implemented by both Socket and SocketVFS2.
+type SocketOps interface {
// Connect implements the connect(2) linux syscall.
Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error
@@ -153,6 +168,8 @@ var families = make(map[int][]Provider)
// RegisterProvider registers the provider of a given address family so that
// sockets of that type can be created via socket() and/or socketpair()
// syscalls.
+//
+// This should only be called during the initialization of the address family.
func RegisterProvider(family int, provider Provider) {
families[family] = append(families[family], provider)
}
@@ -216,6 +233,74 @@ func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent {
return fs.NewDirent(ctx, inode, fmt.Sprintf("socket:[%d]", ino))
}
+// ProviderVFS2 is the vfs2 interface implemented by providers of sockets for
+// specific address families (e.g., AF_INET).
+type ProviderVFS2 interface {
+ // Socket creates a new socket.
+ //
+ // If a nil Socket _and_ a nil error is returned, it means that the
+ // protocol is not supported. A non-nil error should only be returned
+ // if the protocol is supported, but an error occurs during creation.
+ Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error)
+
+ // Pair creates a pair of connected sockets.
+ //
+ // See Socket for error information.
+ Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error)
+}
+
+// familiesVFS2 holds a map of all known address families and their providers.
+var familiesVFS2 = make(map[int][]ProviderVFS2)
+
+// RegisterProviderVFS2 registers the provider of a given address family so that
+// sockets of that type can be created via socket() and/or socketpair()
+// syscalls.
+//
+// This should only be called during the initialization of the address family.
+func RegisterProviderVFS2(family int, provider ProviderVFS2) {
+ familiesVFS2[family] = append(familiesVFS2[family], provider)
+}
+
+// NewVFS2 creates a new socket with the given family, type and protocol.
+func NewVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ for _, p := range familiesVFS2[family] {
+ s, err := p.Socket(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+ if s != nil {
+ t.Kernel().RecordSocketVFS2(s)
+ return s, nil
+ }
+ }
+
+ return nil, syserr.ErrAddressFamilyNotSupported
+}
+
+// PairVFS2 creates a new connected socket pair with the given family, type and
+// protocol.
+func PairVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ providers, ok := familiesVFS2[family]
+ if !ok {
+ return nil, nil, syserr.ErrAddressFamilyNotSupported
+ }
+
+ for _, p := range providers {
+ s1, s2, err := p.Pair(t, stype, protocol)
+ if err != nil {
+ return nil, nil, err
+ }
+ if s1 != nil && s2 != nil {
+ k := t.Kernel()
+ k.RecordSocketVFS2(s1)
+ k.RecordSocketVFS2(s2)
+ return s1, s2, nil
+ }
+ }
+
+ return nil, nil, syserr.ErrSocketNotSupported
+}
+
// SendReceiveTimeout stores timeouts for send and receive calls.
//
// It is meant to be embedded into Socket implementations to help satisfy the
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index 08743deba..de2cc4bdf 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -8,23 +8,27 @@ go_library(
"device.go",
"io.go",
"unix.go",
+ "unix_vfs2.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/fspath",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
"//pkg/sentry/socket",
"//pkg/sentry/socket/control",
"//pkg/sentry/socket/netstack",
"//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 74bcd6300..c708b6030 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/ilist",
+ "//pkg/log",
"//pkg/refs",
"//pkg/sync",
"//pkg/syserr",
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 2ef654235..2f1b127df 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -838,24 +839,43 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
// SetSockOpt sets a socket option. Currently not supported.
func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
- case tcpip.PasscredOption:
- e.setPasscred(v != 0)
- return nil
- }
return nil
}
func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.BroadcastOption:
+ case tcpip.PasscredOption:
+ e.setPasscred(v)
+ case tcpip.ReuseAddressOption:
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ }
return nil
}
func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ }
return nil
}
func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ case tcpip.PasscredOption:
+ return e.Passcred(), nil
+
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ return false, tcpip.ErrUnknownProtocolOption
+ }
}
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
@@ -914,29 +934,19 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return int(v), nil
default:
+ log.Warningf("Unsupported socket option: %d", opt)
return -1, tcpip.ErrUnknownProtocolOption
}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
+ switch opt.(type) {
case tcpip.ErrorOption:
return nil
- case *tcpip.PasscredOption:
- if e.Passcred() {
- *o = tcpip.PasscredOption(1)
- } else {
- *o = tcpip.PasscredOption(0)
- }
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
default:
+ log.Warningf("Unsupported socket option: %T", opt)
return tcpip.ErrUnknownProtocolOption
}
}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 4d30aa714..7c64f30fa 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -33,6 +34,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -52,11 +54,8 @@ type SocketOperations struct {
fsutil.FileNoSplice `state:"nosave"`
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- refs.AtomicRefCount
- socket.SendReceiveTimeout
- ep transport.Endpoint
- stype linux.SockType
+ socketOpsCommon
}
// New creates a new unix socket.
@@ -75,16 +74,29 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
}
s := SocketOperations{
- ep: ep,
- stype: stype,
+ socketOpsCommon: socketOpsCommon{
+ ep: ep,
+ stype: stype,
+ },
}
s.EnableLeakCheck("unix.SocketOperations")
return fs.NewFile(ctx, d, flags, &s)
}
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
+ refs.AtomicRefCount
+ socket.SendReceiveTimeout
+
+ ep transport.Endpoint
+ stype linux.SockType
+}
+
// DecRef implements RefCounter.DecRef.
-func (s *SocketOperations) DecRef() {
+func (s *socketOpsCommon) DecRef() {
s.DecRefWithDestructor(func() {
s.ep.Close()
})
@@ -97,7 +109,7 @@ func (s *SocketOperations) Release() {
s.DecRef()
}
-func (s *SocketOperations) isPacket() bool {
+func (s *socketOpsCommon) isPacket() bool {
switch s.stype {
case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
return true
@@ -110,7 +122,7 @@ func (s *SocketOperations) isPacket() bool {
}
// Endpoint extracts the transport.Endpoint.
-func (s *SocketOperations) Endpoint() transport.Endpoint {
+func (s *socketOpsCommon) Endpoint() transport.Endpoint {
return s.ep
}
@@ -143,7 +155,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) {
// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.ep.GetRemoteAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -155,7 +167,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32,
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.ep.GetLocalAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -178,7 +190,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// Listen implements the linux syscall listen(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
return s.ep.Listen(backlog)
}
@@ -310,6 +322,8 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
// Create the socket.
+ //
+ // TODO(gvisor.dev/issue/2324): Correctly set file permissions.
childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}})
if err != nil {
return syserr.ErrPortInUse
@@ -345,6 +359,31 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint,
return ep, nil
}
+ if kernel.VFS2Enabled {
+ p := fspath.Parse(path)
+ root := t.FSContext().RootDirectoryVFS2()
+ start := root
+ relPath := !p.Absolute
+ if relPath {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ }
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: p,
+ FollowFinalSymlink: true,
+ }
+ ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop)
+ root.DecRef()
+ if relPath {
+ start.DecRef()
+ }
+ if e != nil {
+ return nil, syserr.FromError(e)
+ }
+ return ep, nil
+ }
+
// Find the node in the filesystem.
root := t.FSContext().RootDirectory()
cwd := t.FSContext().WorkingDirectory()
@@ -363,12 +402,11 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint,
// No socket!
return nil, syserr.ErrConnectionRefused
}
-
return ep, nil
}
// Connect implements the linux syscall connect(2) for unix sockets.
-func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
ep, err := extractEndpoint(t, sockaddr)
if err != nil {
return err
@@ -379,7 +417,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
return s.ep.Connect(t, ep)
}
-// Writev implements fs.FileOperations.Write.
+// Write implements fs.FileOperations.Write.
func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
t := kernel.TaskFromContext(ctx)
ctrl := control.New(t, s.ep, nil)
@@ -399,7 +437,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
// SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
w := EndpointWriter{
Ctx: t,
Endpoint: s.ep,
@@ -453,27 +491,27 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
}
// Passcred implements transport.Credentialer.Passcred.
-func (s *SocketOperations) Passcred() bool {
+func (s *socketOpsCommon) Passcred() bool {
return s.ep.Passcred()
}
// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
-func (s *SocketOperations) ConnectedPasscred() bool {
+func (s *socketOpsCommon) ConnectedPasscred() bool {
return s.ep.ConnectedPasscred()
}
// Readiness implements waiter.Waitable.Readiness.
-func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.ep.Readiness(mask)
}
// EventRegister implements waiter.Waitable.EventRegister.
-func (s *SocketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
s.ep.EventRegister(e, mask)
}
// EventUnregister implements waiter.Waitable.EventUnregister.
-func (s *SocketOperations) EventUnregister(e *waiter.Entry) {
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
s.ep.EventUnregister(e)
}
@@ -485,7 +523,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
// Shutdown implements the linux syscall shutdown(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
f, err := netstack.ConvertShutdown(how)
if err != nil {
return err
@@ -511,7 +549,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -648,12 +686,12 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
// State implements socket.Socket.State.
-func (s *SocketOperations) State() uint32 {
+func (s *socketOpsCommon) State() uint32 {
return s.ep.State()
}
// Type implements socket.Socket.Type.
-func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) {
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
// Unix domain sockets always have a protocol of 0.
return linux.AF_UNIX, s.stype, 0
}
@@ -706,4 +744,5 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F
func init() {
socket.RegisterProvider(linux.AF_UNIX, &provider{})
+ socket.RegisterProviderVFS2(linux.AF_UNIX, &providerVFS2{})
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
new file mode 100644
index 000000000..3e54d49c4
--- /dev/null
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -0,0 +1,348 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SocketVFS2 implements socket.SocketVFS2 (and by extension,
+// vfs.FileDescriptionImpl) for Unix sockets.
+type SocketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+
+ socketOpsCommon
+}
+
+// NewVFS2File creates and returns a new vfs.FileDescription for a unix socket.
+func NewVFS2File(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
+ sock := NewFDImpl(ep, stype)
+ vfsfd := &sock.vfsfd
+ if err := sockfs.InitSocket(sock, vfsfd, t.Kernel().SocketMount(), t.Credentials()); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return vfsfd, nil
+}
+
+// NewFDImpl creates and returns a new SocketVFS2.
+func NewFDImpl(ep transport.Endpoint, stype linux.SockType) *SocketVFS2 {
+ // You can create AF_UNIX, SOCK_RAW sockets. They're the same as
+ // SOCK_DGRAM and don't require CAP_NET_RAW.
+ if stype == linux.SOCK_RAW {
+ stype = linux.SOCK_DGRAM
+ }
+
+ return &SocketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ ep: ep,
+ stype: stype,
+ },
+ }
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.socketOpsCommon.EventRegister(&e, waiter.EventIn)
+ defer s.socketOpsCommon.EventUnregister(&e)
+
+ // Try to accept the connection; if it fails, then wait until we get a
+ // notification.
+ for {
+ if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ return ep, err
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, err := s.ep.Accept()
+ if err != nil {
+ if err != syserr.ErrWouldBlock || !blocking {
+ return 0, nil, 0, err
+ }
+
+ var err *syserr.Error
+ ep, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ // We expect this to be a FileDescription here.
+ ns, err := NewVFS2File(t, ep, s.stype)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK)
+ }
+
+ var addr linux.SockAddr
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer.
+ var err *syserr.Error
+ addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ })
+ if e != nil {
+ return 0, nil, 0, syserr.FromError(e)
+ }
+
+ t.Kernel().RecordSocketVFS2(ns)
+ return fd, addr, addrLen, nil
+}
+
+// Bind implements the linux syscall bind(2) for unix sockets.
+func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ p, e := extractPath(sockaddr)
+ if e != nil {
+ return e
+ }
+
+ bep, ok := s.ep.(transport.BoundEndpoint)
+ if !ok {
+ // This socket can't be bound.
+ return syserr.ErrInvalidArgument
+ }
+
+ return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error {
+ // Is it abstract?
+ if p[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return syserr.ErrInvalidEndpointState
+ }
+ if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil {
+ // syserr.ErrPortInUse corresponds to EADDRINUSE.
+ return syserr.ErrPortInUse
+ }
+ } else {
+ path := fspath.Parse(p)
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ relPath := !path.Absolute
+ if relPath {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ }
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ }
+ err := t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{
+ // TODO(gvisor.dev/issue/2324): The file permissions should be taken
+ // from s and t.FSContext().Umask() (see net/unix/af_unix.c:unix_bind),
+ // but VFS1 just always uses 0400. Resolve this inconsistency.
+ Mode: linux.S_IFSOCK | 0400,
+ Endpoint: bep,
+ })
+ if err == syserror.EEXIST {
+ return syserr.ErrAddressInUse
+ }
+ return syserr.FromError(err)
+ }
+
+ return nil
+ })
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return netstack.Ioctl(ctx, s.ep, uio, args)
+}
+
+// PRead implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/1476): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &EndpointReader{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ NumRights: 0,
+ Peek: false,
+ From: nil,
+ })
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/1476): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ t := kernel.TaskFromContext(ctx)
+ ctrl := control.New(t, s.ep, nil)
+
+ if src.NumBytes() == 0 {
+ nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil)
+ return int64(nInt), err.ToError()
+ }
+
+ return src.CopyInTo(ctx, &EndpointWriter{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ Control: ctrl,
+ To: nil,
+ })
+}
+
+// Release implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Release() {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef()
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
+}
+
+// providerVFS2 is a unix domain socket provider for VFS2.
+type providerVFS2 struct{}
+
+func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // Create the endpoint and socket.
+ var ep transport.Endpoint
+ switch stype {
+ case linux.SOCK_DGRAM, linux.SOCK_RAW:
+ ep = transport.NewConnectionless(t)
+ case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
+ ep = transport.NewConnectioned(t, stype, t.Kernel())
+ default:
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ f, err := NewVFS2File(t, ep, stype)
+ if err != nil {
+ ep.Close()
+ return nil, err
+ }
+ return f, nil
+}
+
+// Pair creates a new pair of AF_UNIX connected sockets.
+func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, nil, syserr.ErrProtocolNotSupported
+ }
+
+ switch stype {
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW:
+ // Ok
+ default:
+ return nil, nil, syserr.ErrInvalidArgument
+ }
+
+ // Create the endpoints and sockets.
+ ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
+ s1, err := NewVFS2File(t, ep1, stype)
+ if err != nil {
+ ep1.Close()
+ ep2.Close()
+ return nil, nil, err
+ }
+ s2, err := NewVFS2File(t, ep2, stype)
+ if err != nil {
+ s1.DecRef()
+ ep2.Close()
+ return nil, nil, err
+ }
+
+ return s1, s2, nil
+}
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 77655558e..68ca537c8 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -719,7 +719,7 @@ func (s SyscallMap) SyscallEnter(t *kernel.Task, sysno uintptr, args arch.Syscal
// SyscallExit implements kernel.Stracer.SyscallExit. It logs the syscall
// exit trace.
func (s SyscallMap) SyscallExit(context interface{}, t *kernel.Task, sysno, rval uintptr, err error) {
- errno := t.ExtractErrno(err, int(sysno))
+ errno := kernel.ExtractErrno(err, int(sysno))
c := context.(*syscallContext)
elapsed := time.Since(c.start)
@@ -778,9 +778,6 @@ func (s SyscallMap) Name(sysno uintptr) string {
//
// N.B. This is not in an init function because we can't be sure all syscall
// tables are registered with the kernel when init runs.
-//
-// TODO(gvisor.dev/issue/155): remove kernel package dependencies from this
-// package and have the kernel package self-initialize all syscall tables.
func Initialize() {
for _, table := range kernel.SyscallTables() {
// Is this known?
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index c7883e68e..0d24fd3c4 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -42,6 +42,8 @@ go_library(
"sys_socket.go",
"sys_splice.go",
"sys_stat.go",
+ "sys_stat_amd64.go",
+ "sys_stat_arm64.go",
"sys_sync.go",
"sys_sysinfo.go",
"sys_syslog.go",
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
index b401978db..d781d6a04 100644
--- a/pkg/sentry/syscalls/linux/sys_aio.go
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -114,14 +114,28 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
func IoDestroy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
id := args[0].Uint64()
- // Destroy the given context.
- if !t.MemoryManager().DestroyAIOContext(t, id) {
+ ctx := t.MemoryManager().DestroyAIOContext(t, id)
+ if ctx == nil {
// Does not exist.
return 0, nil, syserror.EINVAL
}
- // FIXME(fvoznika): Linux blocks until all AIO to the destroyed context is
- // done.
- return 0, nil, nil
+
+ // Drain completed requests amd wait for pending requests until there are no
+ // more.
+ for {
+ ctx.Drain()
+
+ ch := ctx.WaitChannel()
+ if ch == nil {
+ // No more requests, we're done.
+ return 0, nil, nil
+ }
+ // The task cannot be interrupted during the wait. Equivalent to
+ // TASK_UNINTERRUPTIBLE in Linux.
+ t.UninterruptibleSleepStart(true /* deactivate */)
+ <-ch
+ t.UninterruptibleSleepFinish(true /* activate */)
+ }
}
// IoGetevents implements linux syscall io_getevents(2).
@@ -200,13 +214,13 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
func waitForRequest(ctx *mm.AIOContext, t *kernel.Task, haveDeadline bool, deadline ktime.Time) (interface{}, error) {
for {
if v, ok := ctx.PopRequest(); ok {
- // Request was readly available. Just return it.
+ // Request was readily available. Just return it.
return v, nil
}
// Need to wait for request completion.
- done, active := ctx.WaitChannel()
- if !active {
+ done := ctx.WaitChannel()
+ if done == nil {
// Context has been destroyed.
return nil, syserror.EINVAL
}
@@ -248,6 +262,10 @@ func memoryFor(t *kernel.Task, cb *ioCallback) (usermem.IOSequence, error) {
}
func performCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *ioCallback, ioseq usermem.IOSequence, ctx *mm.AIOContext, eventFile *fs.File) {
+ if ctx.Dead() {
+ ctx.CancelPendingRequest()
+ return
+ }
ev := &ioEvent{
Data: cb.Data,
Obj: uint64(cbAddr),
@@ -272,7 +290,7 @@ func performCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *ioC
// Update the result.
if err != nil {
err = handleIOError(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", file)
- ev.Result = -int64(t.ExtractErrno(err, 0))
+ ev.Result = -int64(kernel.ExtractErrno(err, 0))
}
file.DecRef()
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index d10a9bed8..35a98212a 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -514,7 +514,7 @@ func (ac accessContext) Value(key interface{}) interface{} {
}
}
-func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, resolve bool, mode uint) error {
+func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode uint) error {
const rOK = 4
const wOK = 2
const xOK = 1
@@ -529,7 +529,7 @@ func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, resolve bool, mode
return syserror.EINVAL
}
- return fileOpOn(t, dirFD, path, resolve, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ return fileOpOn(t, dirFD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
// access(2) and faccessat(2) check permissions using real
// UID/GID, not effective UID/GID.
//
@@ -564,17 +564,23 @@ func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
addr := args[0].Pointer()
mode := args[1].ModeT()
- return 0, nil, accessAt(t, linux.AT_FDCWD, addr, true, mode)
+ return 0, nil, accessAt(t, linux.AT_FDCWD, addr, mode)
}
// Faccessat implements linux syscall faccessat(2).
+//
+// Note that the faccessat() system call does not take a flags argument:
+// "The raw faccessat() system call takes only the first three arguments. The
+// AT_EACCESS and AT_SYMLINK_NOFOLLOW flags are actually implemented within
+// the glibc wrapper function for faccessat(). If either of these flags is
+// specified, then the wrapper function employs fstatat(2) to determine access
+// permissions." - faccessat(2)
func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
dirFD := args[0].Int()
addr := args[1].Pointer()
mode := args[2].ModeT()
- flags := args[3].Int()
- return 0, nil, accessAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, mode)
+ return 0, nil, accessAt(t, dirFD, addr, mode)
}
// LINT.ThenChange(vfs2/filesystem.go)
diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go
index 798344042..43c510930 100644
--- a/pkg/sentry/syscalls/linux/sys_pipe.go
+++ b/pkg/sentry/syscalls/linux/sys_pipe.go
@@ -24,6 +24,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// pipe2 implements the actual system call with flags.
func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
@@ -45,10 +47,12 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
}
if _, err := t.CopyOut(addr, fds); err != nil {
- // The files are not closed in this case, the exact semantics
- // of this error case are not well defined, but they could have
- // already been observed by user space.
- return 0, syserror.EFAULT
+ for _, fd := range fds {
+ if file, _ := t.FDTable().Remove(fd); file != nil {
+ file.DecRef()
+ }
+ }
+ return 0, err
}
return 0, nil
}
@@ -69,3 +73,5 @@ func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
n, err := pipe2(t, addr, flags)
return n, nil, err
}
+
+// LINT.ThenChange(vfs2/pipe.go)
diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go
index 9c6728530..f92bf8096 100644
--- a/pkg/sentry/syscalls/linux/sys_prctl.go
+++ b/pkg/sentry/syscalls/linux/sys_prctl.go
@@ -161,8 +161,8 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if args[1].Int() != 1 || args[2].Int() != 0 || args[3].Int() != 0 || args[4].Int() != 0 {
return 0, nil, syserror.EINVAL
}
- // no_new_privs is assumed to always be set. See
- // kernel.Task.updateCredsForExec.
+ // PR_SET_NO_NEW_PRIVS is assumed to always be set.
+ // See kernel.Task.updateCredsForExecLocked.
return 0, nil, nil
case linux.PR_GET_NO_NEW_PRIVS:
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index 78a2cb750..071b4bacc 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -96,8 +96,8 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, syserror.EINVAL
}
- // Check that the offset is legitimate.
- if offset < 0 {
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
return 0, nil, syserror.EINVAL
}
@@ -120,8 +120,8 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
defer file.DecRef()
- // Check that the offset is legitimate.
- if offset < 0 {
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
return 0, nil, syserror.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 2919228d0..0760af77b 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -31,6 +31,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// minListenBacklog is the minimum reasonable backlog for listening sockets.
const minListenBacklog = 8
@@ -244,7 +246,11 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// Copy the file descriptors out.
if _, err := t.CopyOut(socks, fds); err != nil {
- // Note that we don't close files here; see pipe(2) also.
+ for _, fd := range fds {
+ if file, _ := t.FDTable().Remove(fd); file != nil {
+ file.DecRef()
+ }
+ }
return 0, nil, err
}
@@ -1128,3 +1134,5 @@ func SendTo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
n, err := sendTo(t, fd, bufPtr, bufLen, flags, namePtr, nameLen)
return n, nil, err
}
+
+// LINT.ThenChange(./vfs2/socket.go)
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index fd642834b..df0d0f461 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -25,10 +26,15 @@ import (
// doSplice implements a blocking splice operation.
func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonBlocking bool) (int64, error) {
- if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 {
+ log.Infof("NLAC: doSplice opts: %+v", opts)
+ if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 || (opts.SrcStart+opts.Length < 0) {
return 0, syserror.EINVAL
}
+ if opts.Length > int64(kernel.MAX_RW_COUNT) {
+ opts.Length = int64(kernel.MAX_RW_COUNT)
+ }
+
var (
total int64
n int64
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
index 701b27b4a..46ebf27a2 100644
--- a/pkg/sentry/syscalls/linux/sys_stat.go
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -25,24 +25,6 @@ import (
// LINT.IfChange
-func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
- return linux.Stat{
- Dev: sattr.DeviceID,
- Ino: sattr.InodeID,
- Nlink: uattr.Links,
- Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()),
- UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()),
- GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()),
- Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)),
- Size: uattr.Size,
- Blksize: sattr.BlockSize,
- Blocks: uattr.Usage / 512,
- ATime: uattr.AccessTime.Timespec(),
- MTime: uattr.ModificationTime.Timespec(),
- CTime: uattr.StatusChangeTime.Timespec(),
- }
-}
-
// Stat implements linux syscall stat(2).
func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -133,7 +115,8 @@ func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) err
return err
}
s := statFromAttrs(t, d.Inode.StableAttr, uattr)
- return s.CopyOut(t, statAddr)
+ _, err = s.CopyOut(t, statAddr)
+ return err
}
// fstat implements fstat for the given *fs.File.
@@ -143,7 +126,8 @@ func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error {
return err
}
s := statFromAttrs(t, f.Dirent.Inode.StableAttr, uattr)
- return s.CopyOut(t, statAddr)
+ _, err = s.CopyOut(t, statAddr)
+ return err
}
// Statx implements linux syscall statx(2).
@@ -154,7 +138,10 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
mask := args[3].Uint()
statxAddr := args[4].Pointer()
- if mask&linux.STATX__RESERVED > 0 {
+ if mask&linux.STATX__RESERVED != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if flags&^(linux.AT_SYMLINK_NOFOLLOW|linux.AT_EMPTY_PATH|linux.AT_STATX_SYNC_TYPE) != 0 {
return 0, nil, syserror.EINVAL
}
if flags&linux.AT_STATX_SYNC_TYPE == linux.AT_STATX_SYNC_TYPE {
diff --git a/pkg/sentry/syscalls/linux/sys_stat_amd64.go b/pkg/sentry/syscalls/linux/sys_stat_amd64.go
new file mode 100644
index 000000000..0a04a6113
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_stat_amd64.go
@@ -0,0 +1,45 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// LINT.IfChange
+
+func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
+ return linux.Stat{
+ Dev: sattr.DeviceID,
+ Ino: sattr.InodeID,
+ Nlink: uattr.Links,
+ Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()),
+ UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()),
+ GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)),
+ Size: uattr.Size,
+ Blksize: sattr.BlockSize,
+ Blocks: uattr.Usage / 512,
+ ATime: uattr.AccessTime.Timespec(),
+ MTime: uattr.ModificationTime.Timespec(),
+ CTime: uattr.StatusChangeTime.Timespec(),
+ }
+}
+
+// LINT.ThenChange(vfs2/stat_amd64.go)
diff --git a/pkg/sentry/syscalls/linux/sys_stat_arm64.go b/pkg/sentry/syscalls/linux/sys_stat_arm64.go
new file mode 100644
index 000000000..5a3b1bfad
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_stat_arm64.go
@@ -0,0 +1,45 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// LINT.IfChange
+
+func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
+ return linux.Stat{
+ Dev: sattr.DeviceID,
+ Ino: sattr.InodeID,
+ Nlink: uint32(uattr.Links),
+ Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()),
+ UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()),
+ GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)),
+ Size: uattr.Size,
+ Blksize: int32(sattr.BlockSize),
+ Blocks: uattr.Usage / 512,
+ ATime: uattr.AccessTime.Timespec(),
+ MTime: uattr.ModificationTime.Timespec(),
+ CTime: uattr.StatusChangeTime.Timespec(),
+ }
+}
+
+// LINT.ThenChange(vfs2/stat_arm64.go)
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index 506ee54ce..6ec0de96e 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -87,8 +87,8 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
defer file.DecRef()
- // Check that the offset is legitimate.
- if offset < 0 {
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
return 0, nil, syserror.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index f51761e81..6ff2d84d2 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -18,31 +18,44 @@ go_library(
"linux64_override_arm64.go",
"mmap.go",
"path.go",
+ "pipe.go",
"poll.go",
"read_write.go",
"setstat.go",
+ "socket.go",
"stat.go",
+ "stat_amd64.go",
+ "stat_arm64.go",
"sync.go",
+ "sys_timerfd.go",
"xattr.go",
],
marshal = True,
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/bits",
"//pkg/fspath",
"//pkg/gohacks",
"//pkg/sentry/arch",
"//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/pipefs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
"//pkg/sentry/loader",
"//pkg/sentry/memmap",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/unix/transport",
"//pkg/sentry/syscalls",
"//pkg/sentry/syscalls/linux",
"//pkg/sentry/vfs",
"//pkg/sync",
+ "//pkg/syserr",
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
index d6cb0e79a..5a938cee2 100644
--- a/pkg/sentry/syscalls/linux/vfs2/epoll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -101,14 +101,14 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
var event linux.EpollEvent
switch op {
case linux.EPOLL_CTL_ADD:
- if err := event.CopyIn(t, eventAddr); err != nil {
+ if _, err := event.CopyIn(t, eventAddr); err != nil {
return 0, nil, err
}
return 0, nil, ep.AddInterest(file, fd, event)
case linux.EPOLL_CTL_DEL:
return 0, nil, ep.DeleteInterest(file, fd)
case linux.EPOLL_CTL_MOD:
- if err := event.CopyIn(t, eventAddr); err != nil {
+ if _, err := event.CopyIn(t, eventAddr); err != nil {
return 0, nil, err
}
return 0, nil, ep.ModifyInterest(file, fd, event)
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index 3afcea665..8181d80f4 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -140,6 +141,22 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return uintptr(file.StatusFlags()), nil, nil
case linux.F_SETFL:
return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint())
+ case linux.F_SETPIPE_SZ:
+ pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
+ if !ok {
+ return 0, nil, syserror.EBADF
+ }
+ n, err := pipefile.SetPipeSize(int64(args[2].Int()))
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+ case linux.F_GETPIPE_SZ:
+ pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
+ if !ok {
+ return 0, nil, syserror.EBADF
+ }
+ return uintptr(pipefile.PipeSize()), nil, nil
default:
// TODO(gvisor.dev/issue/1623): Everything else is not yet supported.
return 0, nil, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
index fc5ceea4c..46d3e189c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -172,7 +172,7 @@ func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mo
defer tpop.Release()
file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{
- Flags: flags,
+ Flags: flags | linux.O_LARGEFILE,
Mode: linux.FileMode(mode & (0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX) &^ t.FSContext().Umask()),
})
if err != nil {
@@ -250,7 +250,7 @@ func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
if err != nil {
return err
}
- tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, followFinalSymlink)
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
if err != nil {
return err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go
index ddc140b65..62e98817d 100644
--- a/pkg/sentry/syscalls/linux/vfs2/getdents.go
+++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go
@@ -97,7 +97,8 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
// char d_name[]; /* Filename (null-terminated) */
// };
size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
- if size < cb.remaining {
+ size = (size + 7) &^ 7 // round up to multiple of 8
+ if size > cb.remaining {
return syserror.EINVAL
}
buf = cb.t.CopyScratchBuffer(size)
@@ -106,7 +107,12 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
buf[18] = dirent.Type
copy(buf[19:], dirent.Name)
- buf[size-1] = 0 // NUL terminator
+ // Zero out all remaining bytes in buf, including the NUL terminator
+ // after dirent.Name.
+ bufTail := buf[19+len(dirent.Name):]
+ for i := range bufTail {
+ bufTail[i] = 0
+ }
} else {
// struct linux_dirent {
// unsigned long d_ino; /* Inode number */
@@ -125,7 +131,8 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width()))
}
size := 8 + 8 + 2 + 1 + 1 + 1 + len(dirent.Name)
- if size < cb.remaining {
+ size = (size + 7) &^ 7 // round up to multiple of sizeof(long)
+ if size > cb.remaining {
return syserror.EINVAL
}
buf = cb.t.CopyScratchBuffer(size)
@@ -133,9 +140,14 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
copy(buf[18:], dirent.Name)
- buf[size-3] = 0 // NUL terminator
- buf[size-2] = 0 // zero padding byte
- buf[size-1] = dirent.Type
+ // Zero out all remaining bytes in buf, including the NUL terminator
+ // after dirent.Name and the zero padding byte between the name and
+ // dirent type.
+ bufTail := buf[18+len(dirent.Name):]
+ for i := range bufTail {
+ bufTail[i] = 0
+ }
+ bufTail[2] = dirent.Type
}
n, err := cb.t.CopyOutBytes(cb.addr, buf)
if err != nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
index 7d220bc20..21eb98444 100644
--- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
+++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
@@ -39,26 +39,27 @@ func Override(table map[uintptr]kernel.Syscall) {
table[19] = syscalls.Supported("readv", Readv)
table[20] = syscalls.Supported("writev", Writev)
table[21] = syscalls.Supported("access", Access)
- delete(table, 22) // pipe
+ table[22] = syscalls.Supported("pipe", Pipe)
table[23] = syscalls.Supported("select", Select)
table[32] = syscalls.Supported("dup", Dup)
table[33] = syscalls.Supported("dup2", Dup2)
delete(table, 40) // sendfile
- delete(table, 41) // socket
- delete(table, 42) // connect
- delete(table, 43) // accept
- delete(table, 44) // sendto
- delete(table, 45) // recvfrom
- delete(table, 46) // sendmsg
- delete(table, 47) // recvmsg
- delete(table, 48) // shutdown
- delete(table, 49) // bind
- delete(table, 50) // listen
- delete(table, 51) // getsockname
- delete(table, 52) // getpeername
- delete(table, 53) // socketpair
- delete(table, 54) // setsockopt
- delete(table, 55) // getsockopt
+ // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2.
+ table[41] = syscalls.PartiallySupported("socket", Socket, "In process of porting socket syscalls to VFS2.", nil)
+ table[42] = syscalls.PartiallySupported("connect", Connect, "In process of porting socket syscalls to VFS2.", nil)
+ table[43] = syscalls.PartiallySupported("accept", Accept, "In process of porting socket syscalls to VFS2.", nil)
+ table[44] = syscalls.PartiallySupported("sendto", SendTo, "In process of porting socket syscalls to VFS2.", nil)
+ table[45] = syscalls.PartiallySupported("recvfrom", RecvFrom, "In process of porting socket syscalls to VFS2.", nil)
+ table[46] = syscalls.PartiallySupported("sendmsg", SendMsg, "In process of porting socket syscalls to VFS2.", nil)
+ table[47] = syscalls.PartiallySupported("recvmsg", RecvMsg, "In process of porting socket syscalls to VFS2.", nil)
+ table[48] = syscalls.PartiallySupported("shutdown", Shutdown, "In process of porting socket syscalls to VFS2.", nil)
+ table[49] = syscalls.PartiallySupported("bind", Bind, "In process of porting socket syscalls to VFS2.", nil)
+ table[50] = syscalls.PartiallySupported("listen", Listen, "In process of porting socket syscalls to VFS2.", nil)
+ table[51] = syscalls.PartiallySupported("getsockname", GetSockName, "In process of porting socket syscalls to VFS2.", nil)
+ table[52] = syscalls.PartiallySupported("getpeername", GetPeerName, "In process of porting socket syscalls to VFS2.", nil)
+ table[53] = syscalls.PartiallySupported("socketpair", SocketPair, "In process of porting socket syscalls to VFS2.", nil)
+ table[54] = syscalls.PartiallySupported("getsockopt", GetSockOpt, "In process of porting socket syscalls to VFS2.", nil)
+ table[55] = syscalls.PartiallySupported("setsockopt", SetSockOpt, "In process of porting socket syscalls to VFS2.", nil)
table[59] = syscalls.Supported("execve", Execve)
table[72] = syscalls.Supported("fcntl", Fcntl)
delete(table, 73) // flock
@@ -139,23 +140,26 @@ func Override(table map[uintptr]kernel.Syscall) {
table[280] = syscalls.Supported("utimensat", Utimensat)
table[281] = syscalls.Supported("epoll_pwait", EpollPwait)
delete(table, 282) // signalfd
- delete(table, 283) // timerfd_create
+ table[283] = syscalls.Supported("timerfd_create", TimerfdCreate)
delete(table, 284) // eventfd
delete(table, 285) // fallocate
- delete(table, 286) // timerfd_settime
- delete(table, 287) // timerfd_gettime
- delete(table, 288) // accept4
+ table[286] = syscalls.Supported("timerfd_settime", TimerfdSettime)
+ table[287] = syscalls.Supported("timerfd_gettime", TimerfdGettime)
+ // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2.
+ table[288] = syscalls.PartiallySupported("accept4", Accept4, "In process of porting socket syscalls to VFS2.", nil)
delete(table, 289) // signalfd4
delete(table, 290) // eventfd2
table[291] = syscalls.Supported("epoll_create1", EpollCreate1)
table[292] = syscalls.Supported("dup3", Dup3)
- delete(table, 293) // pipe2
+ table[293] = syscalls.Supported("pipe2", Pipe2)
delete(table, 294) // inotify_init1
table[295] = syscalls.Supported("preadv", Preadv)
table[296] = syscalls.Supported("pwritev", Pwritev)
- delete(table, 299) // recvmmsg
+ // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2.
+ table[299] = syscalls.PartiallySupported("recvmmsg", RecvMMsg, "In process of porting socket syscalls to VFS2.", nil)
table[306] = syscalls.Supported("syncfs", Syncfs)
- delete(table, 307) // sendmmsg
+ // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2.
+ table[307] = syscalls.PartiallySupported("sendmmsg", SendMMsg, "In process of porting socket syscalls to VFS2.", nil)
table[316] = syscalls.Supported("renameat2", Renameat2)
delete(table, 319) // memfd_create
table[322] = syscalls.Supported("execveat", Execveat)
diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go
new file mode 100644
index 000000000..4a01e4209
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go
@@ -0,0 +1,63 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Pipe implements Linux syscall pipe(2).
+func Pipe(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ return 0, nil, pipe2(t, addr, 0)
+}
+
+// Pipe2 implements Linux syscall pipe2(2).
+func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Int()
+ return 0, nil, pipe2(t, addr, flags)
+}
+
+func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error {
+ if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
+ return syserror.EINVAL
+ }
+ r, w := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK))
+ defer r.DecRef()
+ defer w.DecRef()
+
+ fds, err := t.NewFDsVFS2(0, []*vfs.FileDescription{r, w}, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ if err != nil {
+ return err
+ }
+ if _, err := t.CopyOut(addr, fds); err != nil {
+ for _, fd := range fds {
+ if _, file := t.FDTable().Remove(fd); file != nil {
+ file.DecRef()
+ }
+ }
+ return err
+ }
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
index dbf4882da..ff1b25d7b 100644
--- a/pkg/sentry/syscalls/linux/vfs2/poll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -374,7 +374,8 @@ func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.D
}
remaining := timeoutRemaining(t, startNs, timeout)
tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds())
- return tsRemaining.CopyOut(t, timespecAddr)
+ _, err := tsRemaining.CopyOut(t, timespecAddr)
+ return err
}
// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
@@ -386,7 +387,8 @@ func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Du
}
remaining := timeoutRemaining(t, startNs, timeout)
tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds())
- return tvRemaining.CopyOut(t, timevalAddr)
+ _, err := tvRemaining.CopyOut(t, timevalAddr)
+ return err
}
// pollRestartBlock encapsulates the state required to restart poll(2) via
@@ -477,7 +479,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
timeout := time.Duration(-1)
if timevalAddr != 0 {
var timeval linux.Timeval
- if err := timeval.CopyIn(t, timevalAddr); err != nil {
+ if _, err := timeval.CopyIn(t, timevalAddr); err != nil {
return 0, nil, err
}
if timeval.Sec < 0 || timeval.Usec < 0 {
@@ -519,7 +521,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
panic(fmt.Sprintf("unsupported sizeof(void*): %d", t.Arch().Width()))
}
var maskStruct sigSetWithSize
- if err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil {
+ if _, err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil {
return 0, nil, err
}
if err := setTempSignalSet(t, usermem.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil {
@@ -554,7 +556,7 @@ func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.D
timeout := time.Duration(-1)
if timespecAddr != 0 {
var timespec linux.Timespec
- if err := timespec.CopyIn(t, timespecAddr); err != nil {
+ if _, err := timespec.CopyIn(t, timespecAddr); err != nil {
return 0, err
}
if !timespec.Valid() {
@@ -573,7 +575,7 @@ func setTempSignalSet(t *kernel.Task, maskAddr usermem.Addr, maskSize uint) erro
return syserror.EINVAL
}
var mask linux.SignalSet
- if err := mask.CopyIn(t, maskAddr); err != nil {
+ if _, err := mask.CopyIn(t, maskAddr); err != nil {
return err
}
mask &^= kernel.UnblockableSignals
diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go
index 35f6308d6..6c6998f45 100644
--- a/pkg/sentry/syscalls/linux/vfs2/read_write.go
+++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go
@@ -103,7 +103,7 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt
// Issue the request and break out if it completes with anything other than
// "would block".
- n, err := file.Read(t, dst, opts)
+ n, err = file.Read(t, dst, opts)
total += n
if err != syserror.ErrWouldBlock {
break
@@ -130,8 +130,8 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
defer file.DecRef()
- // Check that the offset is legitimate.
- if offset < 0 {
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
return 0, nil, syserror.EINVAL
}
@@ -248,7 +248,7 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of
// Issue the request and break out if it completes with anything other than
// "would block".
- n, err := file.PRead(t, dst, offset+total, opts)
+ n, err = file.PRead(t, dst, offset+total, opts)
total += n
if err != syserror.ErrWouldBlock {
break
@@ -335,7 +335,7 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op
// Issue the request and break out if it completes with anything other than
// "would block".
- n, err := file.Write(t, src, opts)
+ n, err = file.Write(t, src, opts)
total += n
if err != syserror.ErrWouldBlock {
break
@@ -362,8 +362,8 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
defer file.DecRef()
- // Check that the offset is legitimate.
- if offset < 0 {
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
return 0, nil, syserror.EINVAL
}
@@ -480,7 +480,7 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o
// Issue the request and break out if it completes with anything other than
// "would block".
- n, err := file.PWrite(t, src, offset+total, opts)
+ n, err = file.PWrite(t, src, offset+total, opts)
total += n
if err != syserror.ErrWouldBlock {
break
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index 9250659ff..4e61f1452 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -173,12 +173,13 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, err
}
- return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ err = setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
Stat: linux.Statx{
Mask: linux.STATX_SIZE,
Size: uint64(length),
},
})
+ return 0, nil, handleSetSizeError(t, err)
}
// Ftruncate implements Linux syscall ftruncate(2).
@@ -196,12 +197,13 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
defer file.DecRef()
- return 0, nil, file.SetStat(t, vfs.SetStatOptions{
+ err := file.SetStat(t, vfs.SetStatOptions{
Stat: linux.Statx{
Mask: linux.STATX_SIZE,
Size: uint64(length),
},
})
+ return 0, nil, handleSetSizeError(t, err)
}
// Utime implements Linux syscall utime(2).
@@ -224,7 +226,7 @@ func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
opts.Stat.Mtime.Nsec = linux.UTIME_NOW
} else {
var times linux.Utime
- if err := times.CopyIn(t, timesAddr); err != nil {
+ if _, err := times.CopyIn(t, timesAddr); err != nil {
return 0, nil, err
}
opts.Stat.Atime.Sec = times.Actime
@@ -378,3 +380,12 @@ func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPa
FollowFinalSymlink: bool(shouldFollowFinalSymlink),
}, opts)
}
+
+func handleSetSizeError(t *kernel.Task, err error) error {
+ if err == syserror.ErrExceedsFileSizeLimit {
+ // Convert error to EFBIG and send a SIGXFSZ per setrlimit(2).
+ t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t))
+ return syserror.EFBIG
+ }
+ return err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
new file mode 100644
index 000000000..b1ede32f0
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -0,0 +1,1139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// minListenBacklog is the minimum reasonable backlog for listening sockets.
+const minListenBacklog = 8
+
+// maxListenBacklog is the maximum allowed backlog for listening sockets.
+const maxListenBacklog = 1024
+
+// maxAddrLen is the maximum socket address length we're willing to accept.
+const maxAddrLen = 200
+
+// maxOptLen is the maximum sockopt parameter length we're willing to accept.
+const maxOptLen = 1024 * 8
+
+// maxControlLen is the maximum length of the msghdr.msg_control buffer we're
+// willing to accept. Note that this limit is smaller than Linux, which allows
+// buffers upto INT_MAX.
+const maxControlLen = 10 * 1024 * 1024
+
+// nameLenOffset is the offset from the start of the MessageHeader64 struct to
+// the NameLen field.
+const nameLenOffset = 8
+
+// controlLenOffset is the offset form the start of the MessageHeader64 struct
+// to the ControlLen field.
+const controlLenOffset = 40
+
+// flagsOffset is the offset form the start of the MessageHeader64 struct
+// to the Flags field.
+const flagsOffset = 48
+
+const sizeOfInt32 = 4
+
+// messageHeader64Len is the length of a MessageHeader64 struct.
+var messageHeader64Len = uint64(binary.Size(MessageHeader64{}))
+
+// multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct.
+var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{}))
+
+// baseRecvFlags are the flags that are accepted across recvmsg(2),
+// recvmmsg(2), and recvfrom(2).
+const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | linux.MSG_NOSIGNAL | linux.MSG_WAITALL | linux.MSG_TRUNC | linux.MSG_CTRUNC
+
+// MessageHeader64 is the 64-bit representation of the msghdr struct used in
+// the recvmsg and sendmsg syscalls.
+type MessageHeader64 struct {
+ // Name is the optional pointer to a network address buffer.
+ Name uint64
+
+ // NameLen is the length of the buffer pointed to by Name.
+ NameLen uint32
+ _ uint32
+
+ // Iov is a pointer to an array of io vectors that describe the memory
+ // locations involved in the io operation.
+ Iov uint64
+
+ // IovLen is the length of the array pointed to by Iov.
+ IovLen uint64
+
+ // Control is the optional pointer to ancillary control data.
+ Control uint64
+
+ // ControlLen is the length of the data pointed to by Control.
+ ControlLen uint64
+
+ // Flags on the sent/received message.
+ Flags int32
+ _ int32
+}
+
+// multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in
+// the recvmmsg and sendmmsg syscalls.
+type multipleMessageHeader64 struct {
+ msgHdr MessageHeader64
+ msgLen uint32
+ _ int32
+}
+
+// CopyInMessageHeader64 copies a message header from user to kernel memory.
+func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error {
+ b := t.CopyScratchBuffer(52)
+ if _, err := t.CopyInBytes(addr, b); err != nil {
+ return err
+ }
+
+ msg.Name = usermem.ByteOrder.Uint64(b[0:])
+ msg.NameLen = usermem.ByteOrder.Uint32(b[8:])
+ msg.Iov = usermem.ByteOrder.Uint64(b[16:])
+ msg.IovLen = usermem.ByteOrder.Uint64(b[24:])
+ msg.Control = usermem.ByteOrder.Uint64(b[32:])
+ msg.ControlLen = usermem.ByteOrder.Uint64(b[40:])
+ msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:]))
+
+ return nil
+}
+
+// CaptureAddress allocates memory for and copies a socket address structure
+// from the untrusted address space range.
+func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) {
+ if addrlen > maxAddrLen {
+ return nil, syserror.EINVAL
+ }
+
+ addrBuf := make([]byte, addrlen)
+ if _, err := t.CopyInBytes(addr, addrBuf); err != nil {
+ return nil, err
+ }
+
+ return addrBuf, nil
+}
+
+// writeAddress writes a sockaddr structure and its length to an output buffer
+// in the unstrusted address space range. If the address is bigger than the
+// buffer, it is truncated.
+func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
+ // Get the buffer length.
+ var bufLen uint32
+ if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil {
+ return err
+ }
+
+ if int32(bufLen) < 0 {
+ return syserror.EINVAL
+ }
+
+ // Write the length unconditionally.
+ if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil {
+ return err
+ }
+
+ if addr == nil {
+ return nil
+ }
+
+ if bufLen > addrLen {
+ bufLen = addrLen
+ }
+
+ // Copy as much of the address as will fit in the buffer.
+ encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr)
+ if bufLen > uint32(len(encodedAddr)) {
+ bufLen = uint32(len(encodedAddr))
+ }
+ _, err := t.CopyOutBytes(addrPtr, encodedAddr[:int(bufLen)])
+ return err
+}
+
+// Socket implements the linux syscall socket(2).
+func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ domain := int(args[0].Int())
+ stype := args[1].Int()
+ protocol := int(args[2].Int())
+
+ // Check and initialize the flags.
+ if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Create the new socket.
+ s, e := socket.NewVFS2(t, domain, linux.SockType(stype&0xf), protocol)
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+ defer s.DecRef()
+
+ if err := s.SetStatusFlags(t, t.Credentials(), uint32(stype&linux.SOCK_NONBLOCK)); err != nil {
+ return 0, nil, err
+ }
+
+ fd, err := t.NewFDFromVFS2(0, s, kernel.FDFlags{
+ CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// SocketPair implements the linux syscall socketpair(2).
+func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ domain := int(args[0].Int())
+ stype := args[1].Int()
+ protocol := int(args[2].Int())
+ addr := args[3].Pointer()
+
+ // Check and initialize the flags.
+ if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Create the socket pair.
+ s1, s2, e := socket.PairVFS2(t, domain, linux.SockType(stype&0xf), protocol)
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+ // Adding to the FD table will cause an extra reference to be acquired.
+ defer s1.DecRef()
+ defer s2.DecRef()
+
+ nonblocking := uint32(stype & linux.SOCK_NONBLOCK)
+ if err := s1.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil {
+ return 0, nil, err
+ }
+ if err := s2.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil {
+ return 0, nil, err
+ }
+
+ // Create the FDs for the sockets.
+ flags := kernel.FDFlags{
+ CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
+ }
+ fds, err := t.NewFDsVFS2(0, []*vfs.FileDescription{s1, s2}, flags)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if _, err := t.CopyOut(addr, fds); err != nil {
+ for _, fd := range fds {
+ if _, file := t.FDTable().Remove(fd); file != nil {
+ file.DecRef()
+ }
+ }
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
+
+// Connect implements the linux syscall connect(2).
+func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Uint()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Capture address and call syscall implementation.
+ a, err := CaptureAddress(t, addr, addrlen)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0
+ return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), kernel.ERESTARTSYS)
+}
+
+// accept is the implementation of the accept syscall. It is called by accept
+// and accept4 syscall handlers.
+func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, flags int) (uintptr, error) {
+ // Check that no unsupported flags are passed in.
+ if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ // Call the syscall implementation for this socket, then copy the
+ // output address if one is specified.
+ blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0
+
+ peerRequested := addrLen != 0
+ nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking)
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+ if peerRequested {
+ // NOTE(magi): Linux does not give you an error if it can't
+ // write the data back out so neither do we.
+ if err := writeAddress(t, peer, peerLen, addr, addrLen); err == syserror.EINVAL {
+ return 0, err
+ }
+ }
+ return uintptr(nfd), nil
+}
+
+// Accept4 implements the linux syscall accept4(2).
+func Accept4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+ flags := int(args[3].Int())
+
+ n, err := accept(t, fd, addr, addrlen, flags)
+ return n, nil, err
+}
+
+// Accept implements the linux syscall accept(2).
+func Accept(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ n, err := accept(t, fd, addr, addrlen, 0)
+ return n, nil, err
+}
+
+// Bind implements the linux syscall bind(2).
+func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Uint()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Capture address and call syscall implementation.
+ a, err := CaptureAddress(t, addr, addrlen)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, s.Bind(t, a).ToError()
+}
+
+// Listen implements the linux syscall listen(2).
+func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ backlog := args[1].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Per Linux, the backlog is silently capped to reasonable values.
+ if backlog <= 0 {
+ backlog = minListenBacklog
+ }
+ if backlog > maxListenBacklog {
+ backlog = maxListenBacklog
+ }
+
+ return 0, nil, s.Listen(t, int(backlog)).ToError()
+}
+
+// Shutdown implements the linux syscall shutdown(2).
+func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ how := args[1].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Validate how, then call syscall implementation.
+ switch how {
+ case linux.SHUT_RD, linux.SHUT_WR, linux.SHUT_RDWR:
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, s.Shutdown(t, int(how)).ToError()
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2).
+func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ level := args[1].Int()
+ name := args[2].Int()
+ optValAddr := args[3].Pointer()
+ optLenAddr := args[4].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Read the length. Reject negative values.
+ optLen := int32(0)
+ if _, err := t.CopyIn(optLenAddr, &optLen); err != nil {
+ return 0, nil, err
+ }
+ if optLen < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Call syscall implementation then copy both value and value len out.
+ v, e := getSockOpt(t, s, int(level), int(name), optValAddr, int(optLen))
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+
+ vLen := int32(binary.Size(v))
+ if _, err := t.CopyOut(optLenAddr, vLen); err != nil {
+ return 0, nil, err
+ }
+
+ if v != nil {
+ if _, err := t.CopyOut(optValAddr, v); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ return 0, nil, nil
+}
+
+// getSockOpt tries to handle common socket options, or dispatches to a specific
+// socket implementation.
+func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+ if level == linux.SOL_SOCKET {
+ switch name {
+ case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
+ if len < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ }
+
+ switch name {
+ case linux.SO_TYPE:
+ _, skType, _ := s.Type()
+ return int32(skType), nil
+ case linux.SO_DOMAIN:
+ family, _, _ := s.Type()
+ return int32(family), nil
+ case linux.SO_PROTOCOL:
+ _, _, protocol := s.Type()
+ return int32(protocol), nil
+ }
+ }
+
+ return s.GetSockOpt(t, level, name, optValAddr, len)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2).
+//
+// Note that unlike Linux, enabling SO_PASSCRED does not autobind the socket.
+func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ level := args[1].Int()
+ name := args[2].Int()
+ optValAddr := args[3].Pointer()
+ optLen := args[4].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ if optLen < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if optLen > maxOptLen {
+ return 0, nil, syserror.EINVAL
+ }
+ buf := t.CopyScratchBuffer(int(optLen))
+ if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ return 0, nil, err
+ }
+
+ // Call syscall implementation.
+ if err := s.SetSockOpt(t, int(level), int(name), buf); err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, nil
+}
+
+// GetSockName implements the linux syscall getsockname(2).
+func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Get the socket name and copy it to the caller.
+ v, vl, err := s.GetSockName(t)
+ if err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, writeAddress(t, v, vl, addr, addrlen)
+}
+
+// GetPeerName implements the linux syscall getpeername(2).
+func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Get the socket peer name and copy it to the caller.
+ v, vl, err := s.GetPeerName(t)
+ if err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, writeAddress(t, v, vl, addr, addrlen)
+}
+
+// RecvMsg implements the linux syscall recvmsg(2).
+func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, err := recvSingleMsg(t, s, msgPtr, flags, haveDeadline, deadline)
+ return n, nil, err
+}
+
+// RecvMMsg implements the linux syscall recvmmsg(2).
+func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ vlen := args[2].Uint()
+ flags := args[3].Int()
+ toPtr := args[4].Pointer()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if toPtr != 0 {
+ var ts linux.Timespec
+ if _, err := ts.CopyIn(t, toPtr); err != nil {
+ return 0, nil, err
+ }
+ if !ts.Valid() {
+ return 0, nil, syserror.EINVAL
+ }
+ deadline = t.Kernel().MonotonicClock().Now().Add(ts.ToDuration())
+ haveDeadline = true
+ }
+
+ if !haveDeadline {
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+ }
+
+ var count uint32
+ var err error
+ for i := uint64(0); i < uint64(vlen); i++ {
+ mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ var n uintptr
+ if n, err = recvSingleMsg(t, s, mp, flags, haveDeadline, deadline); err != nil {
+ break
+ }
+
+ // Copy the received length to the caller.
+ lp, ok := mp.AddLength(messageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ break
+ }
+ count++
+ }
+
+ if count == 0 {
+ return 0, nil, err
+ }
+ return uintptr(count), nil, nil
+}
+
+func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
+ // Capture the message header and io vectors.
+ var msg MessageHeader64
+ if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ return 0, err
+ }
+
+ if msg.IovLen > linux.UIO_MAXIOV {
+ return 0, syserror.EMSGSIZE
+ }
+ dst, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ // FIXME(b/63594852): Pretend we have an empty error queue.
+ if flags&linux.MSG_ERRQUEUE != 0 {
+ return 0, syserror.EAGAIN
+ }
+
+ // Fast path when no control message nor name buffers are provided.
+ if msg.ControlLen == 0 && msg.NameLen == 0 {
+ n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
+ if err != nil {
+ return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS)
+ }
+ if !cms.Unix.Empty() {
+ mflags |= linux.MSG_CTRUNC
+ cms.Release()
+ }
+
+ if int(msg.Flags) != mflags {
+ // Copy out the flags to the caller.
+ if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ return 0, err
+ }
+ }
+
+ return uintptr(n), nil
+ }
+
+ if msg.ControlLen > maxControlLen {
+ return 0, syserror.ENOBUFS
+ }
+ n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen)
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+ defer cms.Release()
+
+ controlData := make([]byte, 0, msg.ControlLen)
+ controlData = control.PackControlMessages(t, cms, controlData)
+
+ if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
+ creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
+ controlData, mflags = control.PackCredentials(t, creds, controlData, mflags)
+ }
+
+ if cms.Unix.Rights != nil {
+ controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags)
+ }
+
+ // Copy the address to the caller.
+ if msg.NameLen != 0 {
+ if err := writeAddress(t, sender, senderLen, usermem.Addr(msg.Name), usermem.Addr(msgPtr+nameLenOffset)); err != nil {
+ return 0, err
+ }
+ }
+
+ // Copy the control data to the caller.
+ if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
+ return 0, err
+ }
+ if len(controlData) > 0 {
+ if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil {
+ return 0, err
+ }
+ }
+
+ // Copy out the flags to the caller.
+ if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ return 0, err
+ }
+
+ return uintptr(n), nil
+}
+
+// recvFrom is the implementation of the recvfrom syscall. It is called by
+// recvfrom and recv syscall handlers.
+func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLenPtr usermem.Addr) (uintptr, error) {
+ if int(bufLen) < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CONFIRM) != 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ dst, err := t.SingleIOSequence(bufPtr, int(bufLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
+ cm.Release()
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+
+ // Copy the address to the caller.
+ if nameLenPtr != 0 {
+ if err := writeAddress(t, sender, senderLen, namePtr, nameLenPtr); err != nil {
+ return 0, err
+ }
+ }
+
+ return uintptr(n), nil
+}
+
+// RecvFrom implements the linux syscall recvfrom(2).
+func RecvFrom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufPtr := args[1].Pointer()
+ bufLen := args[2].Uint64()
+ flags := args[3].Int()
+ namePtr := args[4].Pointer()
+ nameLenPtr := args[5].Pointer()
+
+ n, err := recvFrom(t, fd, bufPtr, bufLen, flags, namePtr, nameLenPtr)
+ return n, nil, err
+}
+
+// SendMsg implements the linux syscall sendmsg(2).
+func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, err := sendSingleMsg(t, s, file, msgPtr, flags)
+ return n, nil, err
+}
+
+// SendMMsg implements the linux syscall sendmmsg(2).
+func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ vlen := args[2].Uint()
+ flags := args[3].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var count uint32
+ var err error
+ for i := uint64(0); i < uint64(vlen); i++ {
+ mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ var n uintptr
+ if n, err = sendSingleMsg(t, s, file, mp, flags); err != nil {
+ break
+ }
+
+ // Copy the received length to the caller.
+ lp, ok := mp.AddLength(messageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ break
+ }
+ count++
+ }
+
+ if count == 0 {
+ return 0, nil, err
+ }
+ return uintptr(count), nil, nil
+}
+
+func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr usermem.Addr, flags int32) (uintptr, error) {
+ // Capture the message header.
+ var msg MessageHeader64
+ if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ return 0, err
+ }
+
+ var controlData []byte
+ if msg.ControlLen > 0 {
+ // Put an upper bound to prevent large allocations.
+ if msg.ControlLen > maxControlLen {
+ return 0, syserror.ENOBUFS
+ }
+ controlData = make([]byte, msg.ControlLen)
+ if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil {
+ return 0, err
+ }
+ }
+
+ // Read the destination address if one is specified.
+ var to []byte
+ if msg.NameLen != 0 {
+ var err error
+ to, err = CaptureAddress(t, usermem.Addr(msg.Name), msg.NameLen)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ // Read data then call the sendmsg implementation.
+ if msg.IovLen > linux.UIO_MAXIOV {
+ return 0, syserror.EMSGSIZE
+ }
+ src, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ controlMessages, err := control.Parse(t, s, controlData)
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.SendTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Call the syscall implementation.
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages)
+ err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file)
+ if err != nil {
+ controlMessages.Release()
+ }
+ return uintptr(n), err
+}
+
+// sendTo is the implementation of the sendto syscall. It is called by sendto
+// and send syscall handlers.
+func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLen uint32) (uintptr, error) {
+ bl := int(bufLen)
+ if bl < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Read the destination address if one is specified.
+ var to []byte
+ var err error
+ if namePtr != 0 {
+ to, err = CaptureAddress(t, namePtr, nameLen)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ src, err := t.SingleIOSequence(bufPtr, bl, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.SendTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Call the syscall implementation.
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)})
+ return uintptr(n), slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendto", file)
+}
+
+// SendTo implements the linux syscall sendto(2).
+func SendTo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufPtr := args[1].Pointer()
+ bufLen := args[2].Uint64()
+ flags := args[3].Int()
+ namePtr := args[4].Pointer()
+ nameLen := args[5].Uint()
+
+ n, err := sendTo(t, fd, bufPtr, bufLen, flags, namePtr, nameLen)
+ return n, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go
index dca8d7011..bb1d5cac4 100644
--- a/pkg/sentry/syscalls/linux/vfs2/stat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/stat.go
@@ -16,6 +16,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -90,7 +91,8 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags
}
var stat linux.Stat
convertStatxToUserStat(t, &statx, &stat)
- return stat.CopyOut(t, statAddr)
+ _, err = stat.CopyOut(t, statAddr)
+ return err
}
start = dirfile.VirtualDentry()
start.IncRef()
@@ -110,30 +112,8 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags
}
var stat linux.Stat
convertStatxToUserStat(t, &statx, &stat)
- return stat.CopyOut(t, statAddr)
-}
-
-// This takes both input and output as pointer arguments to avoid copying large
-// structs.
-func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
- // Linux just copies fields from struct kstat without regard to struct
- // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
- userns := t.UserNamespace()
- *stat = linux.Stat{
- Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
- Ino: statx.Ino,
- Nlink: uint64(statx.Nlink),
- Mode: uint32(statx.Mode),
- UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
- GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
- Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
- Size: int64(statx.Size),
- Blksize: int64(statx.Blksize),
- Blocks: int64(statx.Blocks),
- ATime: timespecFromStatxTimestamp(statx.Atime),
- MTime: timespecFromStatxTimestamp(statx.Mtime),
- CTime: timespecFromStatxTimestamp(statx.Ctime),
- }
+ _, err = stat.CopyOut(t, statAddr)
+ return err
}
func timespecFromStatxTimestamp(sxts linux.StatxTimestamp) linux.Timespec {
@@ -162,7 +142,8 @@ func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
var stat linux.Stat
convertStatxToUserStat(t, &statx, &stat)
- return 0, nil, stat.CopyOut(t, statAddr)
+ _, err = stat.CopyOut(t, statAddr)
+ return 0, nil, err
}
// Statx implements Linux syscall statx(2).
@@ -173,7 +154,15 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
mask := args[3].Uint()
statxAddr := args[4].Pointer()
- if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW|linux.AT_STATX_SYNC_TYPE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ // Make sure that only one sync type option is set.
+ syncType := uint32(flags & linux.AT_STATX_SYNC_TYPE)
+ if syncType != 0 && !bits.IsPowerOfTwo32(syncType) {
+ return 0, nil, syserror.EINVAL
+ }
+ if mask&linux.STATX__RESERVED != 0 {
return 0, nil, syserror.EINVAL
}
@@ -213,7 +202,8 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
}
userifyStatx(t, &statx)
- return 0, nil, statx.CopyOut(t, statxAddr)
+ _, err = statx.CopyOut(t, statxAddr)
+ return 0, nil, err
}
start = dirfile.VirtualDentry()
start.IncRef()
@@ -232,7 +222,8 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
}
userifyStatx(t, &statx)
- return 0, nil, statx.CopyOut(t, statxAddr)
+ _, err = statx.CopyOut(t, statxAddr)
+ return 0, nil, err
}
func userifyStatx(t *kernel.Task, statx *linux.Statx) {
@@ -251,14 +242,65 @@ func Readlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Access implements Linux syscall access(2).
func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- // FIXME(jamieliu): actually implement
- return 0, nil, nil
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+
+ return 0, nil, accessAt(t, linux.AT_FDCWD, addr, mode)
}
-// Faccessat implements Linux syscall access(2).
+// Faccessat implements Linux syscall faccessat(2).
+//
+// Note that the faccessat() system call does not take a flags argument:
+// "The raw faccessat() system call takes only the first three arguments. The
+// AT_EACCESS and AT_SYMLINK_NOFOLLOW flags are actually implemented within
+// the glibc wrapper function for faccessat(). If either of these flags is
+// specified, then the wrapper function employs fstatat(2) to determine access
+// permissions." - faccessat(2)
func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- // FIXME(jamieliu): actually implement
- return 0, nil, nil
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+
+ return 0, nil, accessAt(t, dirfd, addr, mode)
+}
+
+func accessAt(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error {
+ const rOK = 4
+ const wOK = 2
+ const xOK = 1
+
+ // Sanity check the mode.
+ if mode&^(rOK|wOK|xOK) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ // access(2) and faccessat(2) check permissions using real
+ // UID/GID, not effective UID/GID.
+ //
+ // "access() needs to use the real uid/gid, not the effective
+ // uid/gid. We do this by temporarily clearing all FS-related
+ // capabilities and switching the fsuid/fsgid around to the
+ // real ones." -fs/open.c:faccessat
+ creds := t.Credentials().Fork()
+ creds.EffectiveKUID = creds.RealKUID
+ creds.EffectiveKGID = creds.RealKGID
+ if creds.EffectiveKUID.In(creds.UserNamespace) == auth.RootUID {
+ creds.EffectiveCaps = creds.PermittedCaps
+ } else {
+ creds.EffectiveCaps = 0
+ }
+
+ return t.Kernel().VFS().AccessAt(t, creds, vfs.AccessTypes(mode), &tpop.pop)
}
// Readlinkat implements Linux syscall mknodat(2).
@@ -322,8 +364,8 @@ func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, err
}
-
- return 0, nil, statfs.CopyOut(t, bufAddr)
+ _, err = statfs.CopyOut(t, bufAddr)
+ return 0, nil, err
}
// Fstatfs implements Linux syscall fstatfs(2).
@@ -341,6 +383,6 @@ func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if err != nil {
return 0, nil, err
}
-
- return 0, nil, statfs.CopyOut(t, bufAddr)
+ _, err = statfs.CopyOut(t, bufAddr)
+ return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go b/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go
new file mode 100644
index 000000000..2da538fc6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go
@@ -0,0 +1,46 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// This takes both input and output as pointer arguments to avoid copying large
+// structs.
+func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
+ // Linux just copies fields from struct kstat without regard to struct
+ // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
+ userns := t.UserNamespace()
+ *stat = linux.Stat{
+ Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
+ Ino: statx.Ino,
+ Nlink: uint64(statx.Nlink),
+ Mode: uint32(statx.Mode),
+ UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
+ GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
+ Size: int64(statx.Size),
+ Blksize: int64(statx.Blksize),
+ Blocks: int64(statx.Blocks),
+ ATime: timespecFromStatxTimestamp(statx.Atime),
+ MTime: timespecFromStatxTimestamp(statx.Mtime),
+ CTime: timespecFromStatxTimestamp(statx.Ctime),
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go b/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go
new file mode 100644
index 000000000..88b9c7627
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go
@@ -0,0 +1,46 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// This takes both input and output as pointer arguments to avoid copying large
+// structs.
+func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
+ // Linux just copies fields from struct kstat without regard to struct
+ // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
+ userns := t.UserNamespace()
+ *stat = linux.Stat{
+ Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
+ Ino: statx.Ino,
+ Nlink: uint32(statx.Nlink),
+ Mode: uint32(statx.Mode),
+ UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
+ GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
+ Size: int64(statx.Size),
+ Blksize: int32(statx.Blksize),
+ Blocks: int64(statx.Blocks),
+ ATime: timespecFromStatxTimestamp(statx.Atime),
+ MTime: timespecFromStatxTimestamp(statx.Mtime),
+ CTime: timespecFromStatxTimestamp(statx.Ctime),
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sys_timerfd.go b/pkg/sentry/syscalls/linux/vfs2/sys_timerfd.go
new file mode 100644
index 000000000..7938a5249
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/sys_timerfd.go
@@ -0,0 +1,123 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// TimerfdCreate implements Linux syscall timerfd_create(2).
+func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ clockID := args[0].Int()
+ flags := args[1].Int()
+
+ if flags&^(linux.TFD_CLOEXEC|linux.TFD_NONBLOCK) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var fileFlags uint32
+ if flags&linux.TFD_NONBLOCK != 0 {
+ fileFlags = linux.O_NONBLOCK
+ }
+
+ var clock ktime.Clock
+ switch clockID {
+ case linux.CLOCK_REALTIME:
+ clock = t.Kernel().RealtimeClock()
+ case linux.CLOCK_MONOTONIC, linux.CLOCK_BOOTTIME:
+ clock = t.Kernel().MonotonicClock()
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+ file, err := t.Kernel().VFS().NewTimerFD(clock, fileFlags)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.TFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// TimerfdSettime implements Linux syscall timerfd_settime(2).
+func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ flags := args[1].Int()
+ newValAddr := args[2].Pointer()
+ oldValAddr := args[3].Pointer()
+
+ if flags&^(linux.TFD_TIMER_ABSTIME) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ tfd, ok := file.Impl().(*vfs.TimerFileDescription)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var newVal linux.Itimerspec
+ if _, err := t.CopyIn(newValAddr, &newVal); err != nil {
+ return 0, nil, err
+ }
+ newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tfd.Clock())
+ if err != nil {
+ return 0, nil, err
+ }
+ tm, oldS := tfd.SetTime(newS)
+ if oldValAddr != 0 {
+ oldVal := ktime.ItimerspecFromSetting(tm, oldS)
+ if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+}
+
+// TimerfdGettime implements Linux syscall timerfd_gettime(2).
+func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ curValAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ tfd, ok := file.Impl().(*vfs.TimerFileDescription)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ tm, s := tfd.GetTime()
+ curVal := ktime.ItimerspecFromSetting(tm, s)
+ _, err := t.CopyOut(curValAddr, &curVal)
+ return 0, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go
index 89e9ff4d7..af455d5c1 100644
--- a/pkg/sentry/syscalls/linux/vfs2/xattr.go
+++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go
@@ -51,7 +51,7 @@ func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSyml
}
defer tpop.Release()
- names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop)
+ names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop, uint64(size))
if err != nil {
return 0, nil, err
}
@@ -74,7 +74,7 @@ func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
defer file.DecRef()
- names, err := file.Listxattr(t)
+ names, err := file.Listxattr(t, uint64(size))
if err != nil {
return 0, nil, err
}
@@ -116,7 +116,10 @@ func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli
return 0, nil, err
}
- value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, name)
+ value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.GetxattrOptions{
+ Name: name,
+ Size: uint64(size),
+ })
if err != nil {
return 0, nil, err
}
@@ -145,7 +148,7 @@ func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, err
}
- value, err := file.Getxattr(t, name)
+ value, err := file.Getxattr(t, &vfs.GetxattrOptions{Name: name, Size: uint64(size)})
if err != nil {
return 0, nil, err
}
@@ -230,7 +233,7 @@ func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, err
}
- return 0, nil, file.Setxattr(t, vfs.SetxattrOptions{
+ return 0, nil, file.Setxattr(t, &vfs.SetxattrOptions{
Name: name,
Value: value,
Flags: uint32(flags),
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 07c8383e6..9aeb83fb0 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -36,23 +36,31 @@ go_library(
"pathname.go",
"permissions.go",
"resolving_path.go",
+ "timerfd.go",
"vfs.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/fd",
"//pkg/fspath",
"//pkg/gohacks",
"//pkg/log",
+ "//pkg/safemem",
"//pkg/sentry/arch",
+ "//pkg/sentry/fs",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
"//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index 2db25be49..a64d86122 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -41,7 +42,27 @@ func (vfs *VirtualFilesystem) NewAnonVirtualDentry(name string) VirtualDentry {
}
}
-const anonfsBlockSize = usermem.PageSize // via fs/libfs.c:pseudo_fs_fill_super()
+const (
+ anonfsBlockSize = usermem.PageSize // via fs/libfs.c:pseudo_fs_fill_super()
+
+ // Mode, UID, and GID for a generic anonfs file.
+ anonFileMode = 0600 // no type is correct
+ anonFileUID = auth.RootKUID
+ anonFileGID = auth.RootKGID
+)
+
+// anonFilesystemType implements FilesystemType.
+type anonFilesystemType struct{}
+
+// GetFilesystem implements FilesystemType.GetFilesystem.
+func (anonFilesystemType) GetFilesystem(context.Context, *VirtualFilesystem, *auth.Credentials, string, GetFilesystemOptions) (*Filesystem, *Dentry, error) {
+ panic("cannot instaniate an anon filesystem")
+}
+
+// Name implemenents FilesystemType.Name.
+func (anonFilesystemType) Name() string {
+ return "none"
+}
// anonFilesystem is the implementation of FilesystemImpl that backs
// VirtualDentries returned by VirtualFilesystem.NewAnonVirtualDentry().
@@ -69,6 +90,16 @@ func (fs *anonFilesystem) Sync(ctx context.Context) error {
return nil
}
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+//
+// TODO(gvisor.dev/issue/1965): Implement access permissions.
+func (fs *anonFilesystem) AccessAt(ctx context.Context, rp *ResolvingPath, creds *auth.Credentials, ats AccessTypes) error {
+ if !rp.Done() {
+ return syserror.ENOTDIR
+ }
+ return GenericCheckPermissions(creds, ats, anonFileMode, anonFileUID, anonFileGID)
+}
+
// GetDentryAt implements FilesystemImpl.GetDentryAt.
func (fs *anonFilesystem) GetDentryAt(ctx context.Context, rp *ResolvingPath, opts GetDentryOptions) (*Dentry, error) {
if !rp.Done() {
@@ -167,9 +198,9 @@ func (fs *anonFilesystem) StatAt(ctx context.Context, rp *ResolvingPath, opts St
Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS,
Blksize: anonfsBlockSize,
Nlink: 1,
- UID: uint32(auth.RootKUID),
- GID: uint32(auth.RootKGID),
- Mode: 0600, // no type is correct
+ UID: uint32(anonFileUID),
+ GID: uint32(anonFileGID),
+ Mode: anonFileMode,
Ino: 1,
Size: 0,
Blocks: 0,
@@ -205,8 +236,16 @@ func (fs *anonFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) error
return syserror.EPERM
}
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath) (transport.BoundEndpoint, error) {
+ if !rp.Final() {
+ return nil, syserror.ENOTDIR
+ }
+ return nil, syserror.ECONNREFUSED
+}
+
// ListxattrAt implements FilesystemImpl.ListxattrAt.
-func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) {
+func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) {
if !rp.Done() {
return nil, syserror.ENOTDIR
}
@@ -214,7 +253,7 @@ func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([
}
// GetxattrAt implements FilesystemImpl.GetxattrAt.
-func (fs *anonFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) {
+func (fs *anonFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) {
if !rp.Done() {
return "", syserror.ENOTDIR
}
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 3da45d744..8e0b40841 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -99,6 +99,8 @@ func (vfs *VirtualFilesystem) NewEpollInstanceFD() (*FileDescription, error) {
interest: make(map[epollInterestKey]*epollInterest),
}
if err := ep.vfsfd.Init(ep, linux.O_RDWR, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
UseDentryMetadata: true,
}); err != nil {
return nil, err
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 9a1ad630c..5976b5ccd 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -84,6 +84,17 @@ type FileDescriptionOptions struct {
// usually only the case if O_DIRECT would actually have an effect.
AllowDirectIO bool
+ // If DenyPRead is true, calls to FileDescription.PRead() return ESPIPE.
+ DenyPRead bool
+
+ // If DenyPWrite is true, calls to FileDescription.PWrite() return
+ // ESPIPE.
+ DenyPWrite bool
+
+ // if InvalidWrite is true, calls to FileDescription.Write() return
+ // EINVAL.
+ InvalidWrite bool
+
// If UseDentryMetadata is true, calls to FileDescription methods that
// interact with file and filesystem metadata (Stat, SetStat, StatFS,
// Listxattr, Getxattr, Setxattr, Removexattr) are implemented by calling
@@ -111,7 +122,7 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, statusFlags uint32, mn
}
fd.refs = 1
- fd.statusFlags = statusFlags | linux.O_LARGEFILE
+ fd.statusFlags = statusFlags
fd.vd = VirtualDentry{
mount: mnt,
dentry: d,
@@ -175,6 +186,12 @@ func (fd *FileDescription) DecRef() {
}
}
+// Refs returns the current number of references. The returned count
+// is inherently racy and is unsafe to use without external synchronization.
+func (fd *FileDescription) Refs() int64 {
+ return atomic.LoadInt64(&fd.refs)
+}
+
// Mount returns the mount on which fd was opened. It does not take a reference
// on the returned Mount.
func (fd *FileDescription) Mount() *Mount {
@@ -286,7 +303,8 @@ type FileDescriptionImpl interface {
Stat(ctx context.Context, opts StatOptions) (linux.Statx, error)
// SetStat updates metadata for the file represented by the
- // FileDescription.
+ // FileDescription. Implementations are responsible for checking if the
+ // operation can be performed (see vfs.CheckSetStat() for common checks).
SetStat(ctx context.Context, opts SetStatOptions) error
// StatFS returns metadata for the filesystem containing the file
@@ -305,6 +323,7 @@ type FileDescriptionImpl interface {
// - If opts.Flags specifies unsupported options, PRead returns EOPNOTSUPP.
//
// Preconditions: The FileDescription was opened for reading.
+ // FileDescriptionOptions.DenyPRead == false.
PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error)
// Read is similar to PRead, but does not specify an offset.
@@ -336,6 +355,7 @@ type FileDescriptionImpl interface {
// EOPNOTSUPP.
//
// Preconditions: The FileDescription was opened for writing.
+ // FileDescriptionOptions.DenyPWrite == false.
PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error)
// Write is similar to PWrite, but does not specify an offset, which is
@@ -381,11 +401,11 @@ type FileDescriptionImpl interface {
Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error)
// Listxattr returns all extended attribute names for the file.
- Listxattr(ctx context.Context) ([]string, error)
+ Listxattr(ctx context.Context, size uint64) ([]string, error)
// Getxattr returns the value associated with the given extended attribute
// for the file.
- Getxattr(ctx context.Context, name string) (string, error)
+ Getxattr(ctx context.Context, opts GetxattrOptions) (string, error)
// Setxattr changes the value associated with the given extended attribute
// for the file.
@@ -514,6 +534,9 @@ func (fd *FileDescription) EventUnregister(e *waiter.Entry) {
// offset, and returns the number of bytes read. PRead is permitted to return
// partial reads with a nil error.
func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
+ if fd.opts.DenyPRead {
+ return 0, syserror.ESPIPE
+ }
if !fd.readable {
return 0, syserror.EBADF
}
@@ -532,6 +555,9 @@ func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opt
// offset, and returns the number of bytes written. PWrite is permitted to
// return partial writes with a nil error.
func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ if fd.opts.DenyPWrite {
+ return 0, syserror.ESPIPE
+ }
if !fd.writable {
return 0, syserror.EBADF
}
@@ -540,6 +566,9 @@ func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, o
// Write is similar to PWrite, but does not specify an offset.
func (fd *FileDescription) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
+ if fd.opts.InvalidWrite {
+ return 0, syserror.EINVAL
+ }
if !fd.writable {
return 0, syserror.EBADF
}
@@ -576,18 +605,23 @@ func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
// Listxattr returns all extended attribute names for the file represented by
// fd.
-func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) {
+//
+// If the size of the list (including a NUL terminating byte after every entry)
+// would exceed size, ERANGE may be returned. Note that implementations
+// are free to ignore size entirely and return without error). In all cases,
+// if size is 0, the list should be returned without error, regardless of size.
+func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
if fd.opts.UseDentryMetadata {
vfsObj := fd.vd.mount.vfs
rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
Root: fd.vd,
Start: fd.vd,
})
- names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp)
+ names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp, size)
vfsObj.putResolvingPath(rp)
return names, err
}
- names, err := fd.impl.Listxattr(ctx)
+ names, err := fd.impl.Listxattr(ctx, size)
if err == syserror.ENOTSUP {
// Linux doesn't actually return ENOTSUP in this case; instead,
// fs/xattr.c:vfs_listxattr() falls back to allowing the security
@@ -600,34 +634,39 @@ func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) {
// Getxattr returns the value associated with the given extended attribute for
// the file represented by fd.
-func (fd *FileDescription) Getxattr(ctx context.Context, name string) (string, error) {
+//
+// If the size of the return value exceeds opts.Size, ERANGE may be returned
+// (note that implementations are free to ignore opts.Size entirely and return
+// without error). In all cases, if opts.Size is 0, the value should be
+// returned without error, regardless of size.
+func (fd *FileDescription) Getxattr(ctx context.Context, opts *GetxattrOptions) (string, error) {
if fd.opts.UseDentryMetadata {
vfsObj := fd.vd.mount.vfs
rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
Root: fd.vd,
Start: fd.vd,
})
- val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, name)
+ val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, *opts)
vfsObj.putResolvingPath(rp)
return val, err
}
- return fd.impl.Getxattr(ctx, name)
+ return fd.impl.Getxattr(ctx, *opts)
}
// Setxattr changes the value associated with the given extended attribute for
// the file represented by fd.
-func (fd *FileDescription) Setxattr(ctx context.Context, opts SetxattrOptions) error {
+func (fd *FileDescription) Setxattr(ctx context.Context, opts *SetxattrOptions) error {
if fd.opts.UseDentryMetadata {
vfsObj := fd.vd.mount.vfs
rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
Root: fd.vd,
Start: fd.vd,
})
- err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, opts)
+ err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, *opts)
vfsObj.putResolvingPath(rp)
return err
}
- return fd.impl.Setxattr(ctx, opts)
+ return fd.impl.Setxattr(ctx, *opts)
}
// Removexattr removes the given extended attribute from the file represented
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index c2a52ec1b..f4c111926 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -33,8 +33,8 @@ import (
// implementations to adapt:
// - Have a local fileDescription struct (containing FileDescription) which
// embeds FileDescriptionDefaultImpl and overrides the default methods
-// which are common to all fd implementations for that for that filesystem
-// like StatusFlags, SetStatusFlags, Stat, SetStat, StatFS, etc.
+// which are common to all fd implementations for that filesystem like
+// StatusFlags, SetStatusFlags, Stat, SetStat, StatFS, etc.
// - This should be embedded in all file description implementations as the
// first field by value.
// - Directory FDs would also embed DirectoryFileDescriptionDefaultImpl.
@@ -130,14 +130,14 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg
// Listxattr implements FileDescriptionImpl.Listxattr analogously to
// inode_operations::listxattr == NULL in Linux.
-func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context) ([]string, error) {
+func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context, size uint64) ([]string, error) {
// This isn't exactly accurate; see FileDescription.Listxattr.
return nil, syserror.ENOTSUP
}
// Getxattr implements FileDescriptionImpl.Getxattr analogously to
// inode::i_opflags & IOP_XATTR == 0 in Linux.
-func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, name string) (string, error) {
+func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) {
return "", syserror.ENOTSUP
}
@@ -339,6 +339,11 @@ func (fd *DynamicBytesFileDescriptionImpl) pwriteLocked(ctx context.Context, src
if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 {
return 0, syserror.EOPNOTSUPP
}
+ limit, err := CheckLimit(ctx, offset, src.NumBytes())
+ if err != nil {
+ return 0, err
+ }
+ src = src.TakeFirst64(limit)
writable, ok := fd.data.(WritableDynamicBytesSource)
if !ok {
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 556976d0b..a537a29d1 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -20,6 +20,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
)
// A Filesystem is a tree of nodes represented by Dentries, which forms part of
@@ -40,21 +42,30 @@ type Filesystem struct {
// immutable.
vfs *VirtualFilesystem
+ // fsType is the FilesystemType of this Filesystem.
+ fsType FilesystemType
+
// impl is the FilesystemImpl associated with this Filesystem. impl is
// immutable. This should be the last field in Dentry.
impl FilesystemImpl
}
// Init must be called before first use of fs.
-func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, impl FilesystemImpl) {
+func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, fsType FilesystemType, impl FilesystemImpl) {
fs.refs = 1
fs.vfs = vfsObj
+ fs.fsType = fsType
fs.impl = impl
vfsObj.filesystemsMu.Lock()
vfsObj.filesystems[fs] = struct{}{}
vfsObj.filesystemsMu.Unlock()
}
+// FilesystemType returns the FilesystemType for this Filesystem.
+func (fs *Filesystem) FilesystemType() FilesystemType {
+ return fs.fsType
+}
+
// VirtualFilesystem returns the containing VirtualFilesystem.
func (fs *Filesystem) VirtualFilesystem() *VirtualFilesystem {
return fs.vfs
@@ -144,6 +155,9 @@ type FilesystemImpl interface {
// file data to be written to the underlying [filesystem]", as by syncfs(2).
Sync(ctx context.Context) error
+ // AccessAt checks whether a user with creds can access the file at rp.
+ AccessAt(ctx context.Context, rp *ResolvingPath, creds *auth.Credentials, ats AccessTypes) error
+
// GetDentryAt returns a Dentry representing the file at rp. A reference is
// taken on the returned Dentry.
//
@@ -362,7 +376,9 @@ type FilesystemImpl interface {
// ResolvingPath.Resolve*(), then !rp.Done().
RmdirAt(ctx context.Context, rp *ResolvingPath) error
- // SetStatAt updates metadata for the file at the given path.
+ // SetStatAt updates metadata for the file at the given path. Implementations
+ // are responsible for checking if the operation can be performed
+ // (see vfs.CheckSetStat() for common checks).
//
// Errors:
//
@@ -426,7 +442,13 @@ type FilesystemImpl interface {
// - If extended attributes are not supported by the filesystem,
// ListxattrAt returns nil. (See FileDescription.Listxattr for an
// explanation.)
- ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error)
+ //
+ // - If the size of the list (including a NUL terminating byte after every
+ // entry) would exceed size, ERANGE may be returned. Note that
+ // implementations are free to ignore size entirely and return without
+ // error). In all cases, if size is 0, the list should be returned without
+ // error, regardless of size.
+ ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error)
// GetxattrAt returns the value associated with the given extended
// attribute for the file at rp.
@@ -435,7 +457,15 @@ type FilesystemImpl interface {
//
// - If extended attributes are not supported by the filesystem, GetxattrAt
// returns ENOTSUP.
- GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error)
+ //
+ // - If an extended attribute named opts.Name does not exist, ENODATA is
+ // returned.
+ //
+ // - If the size of the return value exceeds opts.Size, ERANGE may be
+ // returned (note that implementations are free to ignore opts.Size entirely
+ // and return without error). In all cases, if opts.Size is 0, the value
+ // should be returned without error, regardless of size.
+ GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error)
// SetxattrAt changes the value associated with the given extended
// attribute for the file at rp.
@@ -444,6 +474,10 @@ type FilesystemImpl interface {
//
// - If extended attributes are not supported by the filesystem, SetxattrAt
// returns ENOTSUP.
+ //
+ // - If XATTR_CREATE is set in opts.Flag and opts.Name already exists,
+ // EEXIST is returned. If XATTR_REPLACE is set and opts.Name does not exist,
+ // ENODATA is returned.
SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error
// RemovexattrAt removes the given extended attribute from the file at rp.
@@ -452,8 +486,15 @@ type FilesystemImpl interface {
//
// - If extended attributes are not supported by the filesystem,
// RemovexattrAt returns ENOTSUP.
+ //
+ // - If name does not exist, ENODATA is returned.
RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error
+ // BoundEndpointAt returns the Unix socket endpoint bound at the path rp.
+ //
+ // - If a non-socket file exists at rp, then BoundEndpointAt returns ECONNREFUSED.
+ BoundEndpointAt(ctx context.Context, rp *ResolvingPath) (transport.BoundEndpoint, error)
+
// PrependPath prepends a path from vd to vd.Mount().Root() to b.
//
// If vfsroot.Ok(), it is the contextual VFS root; if it is encountered
@@ -476,7 +517,7 @@ type FilesystemImpl interface {
// Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl.
PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error
- // TODO: inotify_add_watch(); bind()
+ // TODO(gvisor.dev/issue/1479): inotify_add_watch()
}
// PrependPathAtVFSRootError is returned by implementations of
diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go
index bb9cada81..f2298f7f6 100644
--- a/pkg/sentry/vfs/filesystem_type.go
+++ b/pkg/sentry/vfs/filesystem_type.go
@@ -30,6 +30,9 @@ type FilesystemType interface {
// along with its mount root. A reference is taken on the returned
// Filesystem and Dentry.
GetFilesystem(ctx context.Context, vfsObj *VirtualFilesystem, creds *auth.Credentials, source string, opts GetFilesystemOptions) (*Filesystem, *Dentry, error)
+
+ // Name returns the name of this FilesystemType.
+ Name() string
}
// GetFilesystemOptions contains options to FilesystemType.GetFilesystem.
diff --git a/pkg/sentry/vfs/memxattr/BUILD b/pkg/sentry/vfs/memxattr/BUILD
new file mode 100644
index 000000000..d8c4d27b9
--- /dev/null
+++ b/pkg/sentry/vfs/memxattr/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "memxattr",
+ srcs = ["xattr.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/vfs/memxattr/xattr.go b/pkg/sentry/vfs/memxattr/xattr.go
new file mode 100644
index 000000000..cc1e7d764
--- /dev/null
+++ b/pkg/sentry/vfs/memxattr/xattr.go
@@ -0,0 +1,102 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package memxattr provides a default, in-memory extended attribute
+// implementation.
+package memxattr
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// SimpleExtendedAttributes implements extended attributes using a map of
+// names to values.
+//
+// +stateify savable
+type SimpleExtendedAttributes struct {
+ // mu protects the below fields.
+ mu sync.RWMutex `state:"nosave"`
+ xattrs map[string]string
+}
+
+// Getxattr returns the value at 'name'.
+func (x *SimpleExtendedAttributes) Getxattr(opts *vfs.GetxattrOptions) (string, error) {
+ x.mu.RLock()
+ value, ok := x.xattrs[opts.Name]
+ x.mu.RUnlock()
+ if !ok {
+ return "", syserror.ENODATA
+ }
+ // Check that the size of the buffer provided in getxattr(2) is large enough
+ // to contain the value.
+ if opts.Size != 0 && uint64(len(value)) > opts.Size {
+ return "", syserror.ERANGE
+ }
+ return value, nil
+}
+
+// Setxattr sets 'value' at 'name'.
+func (x *SimpleExtendedAttributes) Setxattr(opts *vfs.SetxattrOptions) error {
+ x.mu.Lock()
+ defer x.mu.Unlock()
+ if x.xattrs == nil {
+ if opts.Flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
+ x.xattrs = make(map[string]string)
+ }
+
+ _, ok := x.xattrs[opts.Name]
+ if ok && opts.Flags&linux.XATTR_CREATE != 0 {
+ return syserror.EEXIST
+ }
+ if !ok && opts.Flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
+
+ x.xattrs[opts.Name] = opts.Value
+ return nil
+}
+
+// Listxattr returns all names in xattrs.
+func (x *SimpleExtendedAttributes) Listxattr(size uint64) ([]string, error) {
+ // Keep track of the size of the buffer needed in listxattr(2) for the list.
+ listSize := 0
+ x.mu.RLock()
+ names := make([]string, 0, len(x.xattrs))
+ for n := range x.xattrs {
+ names = append(names, n)
+ // Add one byte per null terminator.
+ listSize += len(n) + 1
+ }
+ x.mu.RUnlock()
+ if size != 0 && uint64(listSize) > size {
+ return nil, syserror.ERANGE
+ }
+ return names, nil
+}
+
+// Removexattr removes the xattr at 'name'.
+func (x *SimpleExtendedAttributes) Removexattr(name string) error {
+ x.mu.Lock()
+ defer x.mu.Unlock()
+ if _, ok := x.xattrs[name]; !ok {
+ return syserror.ENODATA
+ }
+ delete(x.xattrs, name)
+ return nil
+}
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 31a4e5480..f06946103 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -15,7 +15,11 @@
package vfs
import (
+ "bytes"
+ "fmt"
"math"
+ "sort"
+ "strings"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -24,6 +28,9 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// lastMountID is used to allocate mount ids. Must be accessed atomically.
+var lastMountID uint64
+
// A Mount is a replacement of a Dentry (Mount.key.point) from one Filesystem
// (Mount.key.parent.fs) with a Dentry (Mount.root) from another Filesystem
// (Mount.fs), which applies to path resolution in the context of a particular
@@ -41,13 +48,16 @@ import (
//
// +stateify savable
type Mount struct {
- // vfs, fs, and root are immutable. References are held on fs and root.
+ // vfs, fs, root are immutable. References are held on fs and root.
//
// Invariant: root belongs to fs.
vfs *VirtualFilesystem
fs *Filesystem
root *Dentry
+ // ID is the immutable mount ID.
+ ID uint64
+
// key is protected by VirtualFilesystem.mountMu and
// VirtualFilesystem.mounts.seq, and may be nil. References are held on
// key.parent and key.point if they are not nil.
@@ -74,6 +84,10 @@ type Mount struct {
// umounted is true. umounted is protected by VirtualFilesystem.mountMu.
umounted bool
+ // flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except
+ // for MS_RDONLY which is tracked in "writers".
+ flags MountFlags
+
// The lower 63 bits of writers is the number of calls to
// Mount.CheckBeginWrite() that have not yet been paired with a call to
// Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect.
@@ -81,6 +95,22 @@ type Mount struct {
writers int64
}
+func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *MountNamespace, opts *MountOptions) *Mount {
+ mnt := &Mount{
+ ID: atomic.AddUint64(&lastMountID, 1),
+ vfs: vfs,
+ fs: fs,
+ root: root,
+ flags: opts.Flags,
+ ns: mntns,
+ refs: 1,
+ }
+ if opts.ReadOnly {
+ mnt.setReadOnlyLocked(true)
+ }
+ return mnt
+}
+
// A MountNamespace is a collection of Mounts.
//
// MountNamespaces are reference-counted. Unless otherwise specified, all
@@ -129,13 +159,7 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth
refs: 1,
mountpoints: make(map[*Dentry]uint32),
}
- mntns.root = &Mount{
- vfs: vfs,
- fs: fs,
- root: root,
- ns: mntns,
- refs: 1,
- }
+ mntns.root = newMount(vfs, fs, root, mntns, &MountOptions{})
return mntns, nil
}
@@ -148,12 +172,7 @@ func (vfs *VirtualFilesystem) NewDisconnectedMount(fs *Filesystem, root *Dentry,
if root != nil {
root.IncRef()
}
- return &Mount{
- vfs: vfs,
- fs: fs,
- root: root,
- refs: 1,
- }, nil
+ return newMount(vfs, fs, root, nil /* mntns */, opts), nil
}
// MountAt creates and mounts a Filesystem configured by the given arguments.
@@ -214,17 +233,11 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia
}
vd.dentry.mu.Lock()
}
- // TODO: Linux requires that either both the mount point and the mount root
- // are directories, or neither are, and returns ENOTDIR if this is not the
- // case.
+ // TODO(gvisor.dev/issue/1035): Linux requires that either both the mount
+ // point and the mount root are directories, or neither are, and returns
+ // ENOTDIR if this is not the case.
mntns := vd.mount.ns
- mnt := &Mount{
- vfs: vfs,
- fs: fs,
- root: root,
- ns: mntns,
- refs: 1,
- }
+ mnt := newMount(vfs, fs, root, mntns, opts)
vfs.mounts.seq.BeginWrite()
vfs.connectLocked(mnt, vd, mntns)
vfs.mounts.seq.EndWrite()
@@ -261,9 +274,9 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti
}
}
- // TODO(jamieliu): Linux special-cases umount of the caller's root, which
- // we don't implement yet (we'll just fail it since the caller holds a
- // reference on it).
+ // TODO(gvisor.dev/issue/1035): Linux special-cases umount of the caller's
+ // root, which we don't implement yet (we'll just fail it since the caller
+ // holds a reference on it).
vfs.mounts.seq.BeginWrite()
if opts.Flags&linux.MNT_DETACH == 0 {
@@ -630,12 +643,28 @@ func (mnt *Mount) setReadOnlyLocked(ro bool) error {
return nil
}
+func (mnt *Mount) readOnly() bool {
+ return atomic.LoadInt64(&mnt.writers) < 0
+}
+
// Filesystem returns the mounted Filesystem. It does not take a reference on
// the returned Filesystem.
func (mnt *Mount) Filesystem() *Filesystem {
return mnt.fs
}
+// submountsLocked returns this Mount and all Mounts that are descendents of
+// it.
+//
+// Precondition: mnt.vfs.mountMu must be held.
+func (mnt *Mount) submountsLocked() []*Mount {
+ mounts := []*Mount{mnt}
+ for m := range mnt.children {
+ mounts = append(mounts, m.submountsLocked()...)
+ }
+ return mounts
+}
+
// Root returns mntns' root. A reference is taken on the returned
// VirtualDentry.
func (mntns *MountNamespace) Root() VirtualDentry {
@@ -646,3 +675,174 @@ func (mntns *MountNamespace) Root() VirtualDentry {
vd.IncRef()
return vd
}
+
+// GenerateProcMounts emits the contents of /proc/[pid]/mounts for vfs to buf.
+//
+// Preconditions: taskRootDir.Ok().
+func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDir VirtualDentry, buf *bytes.Buffer) {
+ vfs.mountMu.Lock()
+ defer vfs.mountMu.Unlock()
+ rootMnt := taskRootDir.mount
+ mounts := rootMnt.submountsLocked()
+ sort.Slice(mounts, func(i, j int) bool { return mounts[i].ID < mounts[j].ID })
+ for _, mnt := range mounts {
+ // Get the path to this mount relative to task root.
+ mntRootVD := VirtualDentry{
+ mount: mnt,
+ dentry: mnt.root,
+ }
+ path, err := vfs.PathnameReachable(ctx, taskRootDir, mntRootVD)
+ if err != nil {
+ // For some reason we didn't get a path. Log a warning
+ // and run with empty path.
+ ctx.Warningf("Error getting pathname for mount root %+v: %v", mnt.root, err)
+ path = ""
+ }
+ if path == "" {
+ // Either an error occurred, or path is not reachable
+ // from root.
+ break
+ }
+
+ opts := "rw"
+ if mnt.readOnly() {
+ opts = "ro"
+ }
+ if mnt.flags.NoExec {
+ opts += ",noexec"
+ }
+
+ // Format:
+ // <special device or remote filesystem> <mount point> <filesystem type> <mount options> <needs dump> <fsck order>
+ //
+ // The "needs dump" and "fsck order" flags are always 0, which
+ // is allowed.
+ fmt.Fprintf(buf, "%s %s %s %s %d %d\n", "none", path, mnt.fs.FilesystemType().Name(), opts, 0, 0)
+ }
+}
+
+// GenerateProcMountInfo emits the contents of /proc/[pid]/mountinfo for vfs to
+// buf.
+//
+// Preconditions: taskRootDir.Ok().
+func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRootDir VirtualDentry, buf *bytes.Buffer) {
+ vfs.mountMu.Lock()
+ defer vfs.mountMu.Unlock()
+ rootMnt := taskRootDir.mount
+ mounts := rootMnt.submountsLocked()
+ sort.Slice(mounts, func(i, j int) bool { return mounts[i].ID < mounts[j].ID })
+ for _, mnt := range mounts {
+ // Get the path to this mount relative to task root.
+ mntRootVD := VirtualDentry{
+ mount: mnt,
+ dentry: mnt.root,
+ }
+ path, err := vfs.PathnameReachable(ctx, taskRootDir, mntRootVD)
+ if err != nil {
+ // For some reason we didn't get a path. Log a warning
+ // and run with empty path.
+ ctx.Warningf("Error getting pathname for mount root %+v: %v", mnt.root, err)
+ path = ""
+ }
+ if path == "" {
+ // Either an error occurred, or path is not reachable
+ // from root.
+ break
+ }
+ // Stat the mount root to get the major/minor device numbers.
+ pop := &PathOperation{
+ Root: mntRootVD,
+ Start: mntRootVD,
+ }
+ statx, err := vfs.StatAt(ctx, auth.NewAnonymousCredentials(), pop, &StatOptions{})
+ if err != nil {
+ // Well that's not good. Ignore this mount.
+ break
+ }
+
+ // Format:
+ // 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
+ // (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
+
+ // (1) Mount ID.
+ fmt.Fprintf(buf, "%d ", mnt.ID)
+
+ // (2) Parent ID (or this ID if there is no parent).
+ pID := mnt.ID
+ if p := mnt.parent(); p != nil {
+ pID = p.ID
+ }
+ fmt.Fprintf(buf, "%d ", pID)
+
+ // (3) Major:Minor device ID. We don't have a superblock, so we
+ // just use the root inode device number.
+ fmt.Fprintf(buf, "%d:%d ", statx.DevMajor, statx.DevMinor)
+
+ // (4) Root: the pathname of the directory in the filesystem
+ // which forms the root of this mount.
+ //
+ // NOTE(b/78135857): This will always be "/" until we implement
+ // bind mounts.
+ fmt.Fprintf(buf, "/ ")
+
+ // (5) Mount point (relative to process root).
+ fmt.Fprintf(buf, "%s ", manglePath(path))
+
+ // (6) Mount options.
+ opts := "rw"
+ if mnt.readOnly() {
+ opts = "ro"
+ }
+ if mnt.flags.NoExec {
+ opts += ",noexec"
+ }
+ // TODO(gvisor.dev/issue/1193): Add "noatime" if MS_NOATIME is
+ // set.
+ fmt.Fprintf(buf, "%s ", opts)
+
+ // (7) Optional fields: zero or more fields of the form "tag[:value]".
+ // (8) Separator: the end of the optional fields is marked by a single hyphen.
+ fmt.Fprintf(buf, "- ")
+
+ // (9) Filesystem type.
+ fmt.Fprintf(buf, "%s ", mnt.fs.FilesystemType().Name())
+
+ // (10) Mount source: filesystem-specific information or "none".
+ fmt.Fprintf(buf, "none ")
+
+ // (11) Superblock options, and final newline.
+ fmt.Fprintf(buf, "%s\n", superBlockOpts(path, mnt))
+ }
+}
+
+// manglePath replaces ' ', '\t', '\n', and '\\' with their octal equivalents.
+// See Linux fs/seq_file.c:mangle_path.
+func manglePath(p string) string {
+ r := strings.NewReplacer(" ", "\\040", "\t", "\\011", "\n", "\\012", "\\", "\\134")
+ return r.Replace(p)
+}
+
+// superBlockOpts returns the super block options string for the the mount at
+// the given path.
+func superBlockOpts(mountPath string, mnt *Mount) string {
+ // gVisor doesn't (yet) have a concept of super block options, so we
+ // use the ro/rw bit from the mount flag.
+ opts := "rw"
+ if mnt.readOnly() {
+ opts = "ro"
+ }
+
+ // NOTE(b/147673608): If the mount is a cgroup, we also need to include
+ // the cgroup name in the options. For now we just read that from the
+ // path.
+ //
+ // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we
+ // should get this value from the cgroup itself, and not rely on the
+ // path.
+ if mnt.fs.FilesystemType().Name() == "cgroup" {
+ splitPath := strings.Split(mountPath, "/")
+ cgroupType := splitPath[len(splitPath)-1]
+ opts += "," + cgroupType
+ }
+ return opts
+}
diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go
index 3b933468d..3335e4057 100644
--- a/pkg/sentry/vfs/mount_test.go
+++ b/pkg/sentry/vfs/mount_test.go
@@ -55,7 +55,7 @@ func TestMountTableInsertLookup(t *testing.T) {
}
}
-// TODO: concurrent lookup/insertion/removal
+// TODO(gvisor.dev/issue/1035): concurrent lookup/insertion/removal.
// must be powers of 2
var benchNumMounts = []int{1 << 2, 1 << 5, 1 << 8}
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index 6af7fdac1..534528ce6 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -16,6 +16,7 @@ package vfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
)
// GetDentryOptions contains options to VirtualFilesystem.GetDentryAt() and
@@ -44,10 +45,27 @@ type MknodOptions struct {
// DevMinor are the major and minor device numbers for the created device.
DevMajor uint32
DevMinor uint32
+
+ // Endpoint is the endpoint to bind to the created file, if a socket file is
+ // being created for bind(2) on a Unix domain socket.
+ Endpoint transport.BoundEndpoint
+}
+
+// MountFlags contains flags as specified for mount(2), e.g. MS_NOEXEC.
+// MS_RDONLY is not part of MountFlags because it's tracked in Mount.writers.
+type MountFlags struct {
+ // NoExec is equivalent to MS_NOEXEC.
+ NoExec bool
}
// MountOptions contains options to VirtualFilesystem.MountAt().
type MountOptions struct {
+ // Flags contains flags as specified for mount(2), e.g. MS_NOEXEC.
+ Flags MountFlags
+
+ // ReadOnly is equivalent to MS_RDONLY.
+ ReadOnly bool
+
// GetFilesystemOptions contains options to FilesystemType.GetFilesystem().
GetFilesystemOptions GetFilesystemOptions
@@ -75,7 +93,8 @@ type OpenOptions struct {
// FileExec is set when the file is being opened to be executed.
// VirtualFilesystem.OpenAt() checks that the caller has execute permissions
- // on the file, and that the file is a regular file.
+ // on the file, that the file is a regular file, and that the mount doesn't
+ // have MS_NOEXEC set.
FileExec bool
}
@@ -113,6 +132,20 @@ type SetStatOptions struct {
Stat linux.Statx
}
+// GetxattrOptions contains options to VirtualFilesystem.GetxattrAt(),
+// FilesystemImpl.GetxattrAt(), FileDescription.Getxattr(), and
+// FileDescriptionImpl.Getxattr().
+type GetxattrOptions struct {
+ // Name is the name of the extended attribute to retrieve.
+ Name string
+
+ // Size is the maximum value size that the caller will tolerate. If the value
+ // is larger than size, getxattr methods may return ERANGE, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ Size uint64
+}
+
// SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(),
// FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and
// FileDescriptionImpl.Setxattr().
diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go
index b318c681a..f21a88034 100644
--- a/pkg/sentry/vfs/pathname.go
+++ b/pkg/sentry/vfs/pathname.go
@@ -90,6 +90,49 @@ loop:
return b.String(), nil
}
+// PathnameReachable returns an absolute pathname to vd, consistent with
+// Linux's __d_path() (as used by seq_path_root()). If vfsroot.Ok() and vd is
+// not reachable from vfsroot, such that seq_path_root() would return SEQ_SKIP
+// (causing the entire containing entry to be skipped), PathnameReachable
+// returns ("", nil).
+func (vfs *VirtualFilesystem) PathnameReachable(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) {
+ b := getFSPathBuilder()
+ defer putFSPathBuilder(b)
+ haveRef := false
+ defer func() {
+ if haveRef {
+ vd.DecRef()
+ }
+ }()
+loop:
+ for {
+ err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b)
+ switch err.(type) {
+ case nil:
+ if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry {
+ break loop
+ }
+ nextVD := vfs.getMountpointAt(vd.mount, vfsroot)
+ if !nextVD.Ok() {
+ return "", nil
+ }
+ if haveRef {
+ vd.DecRef()
+ }
+ vd = nextVD
+ haveRef = true
+ case PrependPathAtVFSRootError:
+ break loop
+ case PrependPathAtNonMountRootError, PrependPathSyntheticError:
+ return "", nil
+ default:
+ return "", err
+ }
+ }
+ b.PrependByte('/')
+ return b.String(), nil
+}
+
// PathnameForGetcwd returns an absolute pathname to vd, consistent with
// Linux's sys_getcwd().
func (vfs *VirtualFilesystem) PathnameForGetcwd(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) {
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index 8e250998a..f9647f90e 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -15,8 +15,12 @@
package vfs
import (
+ "math"
+
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -25,9 +29,9 @@ type AccessTypes uint16
// Bits in AccessTypes.
const (
+ MayExec AccessTypes = 1
+ MayWrite AccessTypes = 2
MayRead AccessTypes = 4
- MayWrite = 2
- MayExec = 1
)
// OnlyRead returns true if access _only_ allows read.
@@ -52,16 +56,17 @@ func (a AccessTypes) MayExec() bool {
// GenericCheckPermissions checks that creds has the given access rights on a
// file with the given permissions, UID, and GID, subject to the rules of
-// fs/namei.c:generic_permission(). isDir is true if the file is a directory.
-func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir bool, mode uint16, kuid auth.KUID, kgid auth.KGID) error {
+// fs/namei.c:generic_permission().
+func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
// Check permission bits.
- perms := mode
+ perms := uint16(mode.Permissions())
if creds.EffectiveKUID == kuid {
perms >>= 6
} else if creds.InGroup(kgid) {
perms >>= 3
}
if uint16(ats)&perms == uint16(ats) {
+ // All permission bits match, access granted.
return nil
}
@@ -73,7 +78,7 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
}
// CAP_DAC_READ_SEARCH allows the caller to read and search arbitrary
// directories, and read arbitrary non-directory files.
- if (isDir && !ats.MayWrite()) || ats.OnlyRead() {
+ if (mode.IsDir() && !ats.MayWrite()) || ats.OnlyRead() {
if creds.HasCapability(linux.CAP_DAC_READ_SEARCH) {
return nil
}
@@ -81,7 +86,7 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
// CAP_DAC_OVERRIDE allows arbitrary access to directories, read/write
// access to non-directory files, and execute access to non-directory files
// for which at least one execute bit is set.
- if isDir || !ats.MayExec() || (mode&0111 != 0) {
+ if mode.IsDir() || !ats.MayExec() || (mode.Permissions()&0111 != 0) {
if creds.HasCapability(linux.CAP_DAC_OVERRIDE) {
return nil
}
@@ -147,7 +152,16 @@ func MayWriteFileWithOpenFlags(flags uint32) bool {
// CheckSetStat checks that creds has permission to change the metadata of a
// file with the given permissions, UID, and GID as specified by stat, subject
// to the rules of Linux's fs/attr.c:setattr_prepare().
-func CheckSetStat(creds *auth.Credentials, stat *linux.Statx, mode uint16, kuid auth.KUID, kgid auth.KGID) error {
+func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ limit, err := CheckLimit(ctx, 0, int64(stat.Size))
+ if err != nil {
+ return err
+ }
+ if limit < int64(stat.Size) {
+ return syserror.ErrExceedsFileSizeLimit
+ }
+ }
if stat.Mask&linux.STATX_MODE != 0 {
if !CanActAsOwner(creds, kuid) {
return syserror.EPERM
@@ -177,11 +191,7 @@ func CheckSetStat(creds *auth.Credentials, stat *linux.Statx, mode uint16, kuid
(stat.Mask&linux.STATX_CTIME != 0 && stat.Ctime.Nsec != linux.UTIME_NOW) {
return syserror.EPERM
}
- // isDir is irrelevant in the following call to
- // GenericCheckPermissions since ats == MayWrite means that
- // CAP_DAC_READ_SEARCH does not apply, and CAP_DAC_OVERRIDE
- // applies, regardless of isDir.
- if err := GenericCheckPermissions(creds, MayWrite, false /* isDir */, mode, kuid, kgid); err != nil {
+ if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil {
return err
}
}
@@ -205,3 +215,21 @@ func CanActAsOwner(creds *auth.Credentials, kuid auth.KUID) bool {
func HasCapabilityOnFile(creds *auth.Credentials, cp linux.Capability, kuid auth.KUID, kgid auth.KGID) bool {
return creds.HasCapability(cp) && creds.UserNamespace.MapFromKUID(kuid).Ok() && creds.UserNamespace.MapFromKGID(kgid).Ok()
}
+
+// CheckLimit enforces file size rlimits. It returns error if the write
+// operation must not proceed. Otherwise it returns the max length allowed to
+// without violating the limit.
+func CheckLimit(ctx context.Context, offset, size int64) (int64, error) {
+ fileSizeLimit := limits.FromContext(ctx).Get(limits.FileSize).Cur
+ if fileSizeLimit > math.MaxInt64 {
+ return size, nil
+ }
+ if offset >= int64(fileSizeLimit) {
+ return 0, syserror.ErrExceedsFileSizeLimit
+ }
+ remaining := int64(fileSizeLimit) - offset
+ if remaining < size {
+ return remaining, nil
+ }
+ return size, nil
+}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index eb4ebb511..8f31495da 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -329,10 +329,22 @@ func (rp *ResolvingPath) ResolveComponent(d *Dentry) (*Dentry, error) {
// component in pcs represents a symbolic link, the symbolic link should be
// followed.
//
+// If path is terminated with '/', the '/' is considered the last element and
+// any symlink before that is followed:
+// - For most non-creating walks, the last path component is handled by
+// fs/namei.c:lookup_last(), which sets LOOKUP_FOLLOW if the first byte
+// after the path component is non-NULL (which is only possible if it's '/')
+// and the path component is of type LAST_NORM.
+//
+// - For open/openat/openat2 without O_CREAT, the last path component is
+// handled by fs/namei.c:do_last(), which does the same, though without the
+// LAST_NORM check.
+//
// Preconditions: !rp.Done().
func (rp *ResolvingPath) ShouldFollowSymlink() bool {
- // Non-final symlinks are always followed.
- return rp.flags&rpflagsFollowFinalSymlink != 0 || !rp.Final()
+ // Non-final symlinks are always followed. Paths terminated with '/' are also
+ // always followed.
+ return rp.flags&rpflagsFollowFinalSymlink != 0 || !rp.Final() || rp.MustBeDir()
}
// HandleSymlink is called when the current path component is a symbolic link
diff --git a/pkg/sentry/vfs/timerfd.go b/pkg/sentry/vfs/timerfd.go
new file mode 100644
index 000000000..42b880656
--- /dev/null
+++ b/pkg/sentry/vfs/timerfd.go
@@ -0,0 +1,142 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// TimerFileDescription implements FileDescriptionImpl for timer fds. It also
+// implements ktime.TimerListener.
+type TimerFileDescription struct {
+ vfsfd FileDescription
+ FileDescriptionDefaultImpl
+ DentryMetadataFileDescriptionImpl
+
+ events waiter.Queue
+ timer *ktime.Timer
+
+ // val is the number of timer expirations since the last successful
+ // call to PRead, or SetTime. val must be accessed using atomic memory
+ // operations.
+ val uint64
+}
+
+var _ FileDescriptionImpl = (*TimerFileDescription)(nil)
+var _ ktime.TimerListener = (*TimerFileDescription)(nil)
+
+// NewTimerFD returns a new timer fd.
+func (vfs *VirtualFilesystem) NewTimerFD(clock ktime.Clock, flags uint32) (*FileDescription, error) {
+ vd := vfs.NewAnonVirtualDentry("[timerfd]")
+ defer vd.DecRef()
+ tfd := &TimerFileDescription{}
+ tfd.timer = ktime.NewTimer(clock, tfd)
+ if err := tfd.vfsfd.Init(tfd, flags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{
+ UseDentryMetadata: true,
+ DenyPRead: true,
+ DenyPWrite: true,
+ InvalidWrite: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &tfd.vfsfd, nil
+}
+
+// Read implements FileDescriptionImpl.Read.
+func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
+ const sizeofUint64 = 8
+ if dst.NumBytes() < sizeofUint64 {
+ return 0, syserror.EINVAL
+ }
+ if val := atomic.SwapUint64(&tfd.val, 0); val != 0 {
+ var buf [sizeofUint64]byte
+ usermem.ByteOrder.PutUint64(buf[:], val)
+ if _, err := dst.CopyOut(ctx, buf[:]); err != nil {
+ // Linux does not undo consuming the number of
+ // expirations even if writing to userspace fails.
+ return 0, err
+ }
+ return sizeofUint64, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+// Clock returns the timer fd's Clock.
+func (tfd *TimerFileDescription) Clock() ktime.Clock {
+ return tfd.timer.Clock()
+}
+
+// GetTime returns the associated Timer's setting and the time at which it was
+// observed.
+func (tfd *TimerFileDescription) GetTime() (ktime.Time, ktime.Setting) {
+ return tfd.timer.Get()
+}
+
+// SetTime atomically changes the associated Timer's setting, resets the number
+// of expirations to 0, and returns the previous setting and the time at which
+// it was observed.
+func (tfd *TimerFileDescription) SetTime(s ktime.Setting) (ktime.Time, ktime.Setting) {
+ return tfd.timer.SwapAnd(s, func() { atomic.StoreUint64(&tfd.val, 0) })
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (tfd *TimerFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ var ready waiter.EventMask
+ if atomic.LoadUint64(&tfd.val) != 0 {
+ ready |= waiter.EventIn
+ }
+ return ready
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (tfd *TimerFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ tfd.events.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (tfd *TimerFileDescription) EventUnregister(e *waiter.Entry) {
+ tfd.events.EventUnregister(e)
+}
+
+// PauseTimer pauses the associated Timer.
+func (tfd *TimerFileDescription) PauseTimer() {
+ tfd.timer.Pause()
+}
+
+// ResumeTimer resumes the associated Timer.
+func (tfd *TimerFileDescription) ResumeTimer() {
+ tfd.timer.Resume()
+}
+
+// Release implements FileDescriptionImpl.Release()
+func (tfd *TimerFileDescription) Release() {
+ tfd.timer.Destroy()
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (tfd *TimerFileDescription) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
+ atomic.AddUint64(&tfd.val, exp)
+ tfd.events.Notify(waiter.EventIn)
+ return ktime.Setting{}, false
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (tfd *TimerFileDescription) Destroy() {}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index bde81e1ef..cb5bbd781 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -38,6 +38,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -133,7 +134,7 @@ func (vfs *VirtualFilesystem) Init() error {
anonfs := anonFilesystem{
devMinor: anonfsDevMinor,
}
- anonfs.vfsfs.Init(vfs, &anonfs)
+ anonfs.vfsfs.Init(vfs, &anonFilesystemType{}, &anonfs)
defer anonfs.vfsfs.DecRef()
anonMount, err := vfs.NewDisconnectedMount(&anonfs.vfsfs, nil, &MountOptions{})
if err != nil {
@@ -174,6 +175,23 @@ type PathOperation struct {
FollowFinalSymlink bool
}
+// AccessAt checks whether a user with creds has access to the file at
+// the given path.
+func (vfs *VirtualFilesystem) AccessAt(ctx context.Context, creds *auth.Credentials, ats AccessTypes, pop *PathOperation) error {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.AccessAt(ctx, rp, creds, ats)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
// GetDentryAt returns a VirtualDentry representing the given path, at which a
// file must exist. A reference is taken on the returned VirtualDentry.
func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) {
@@ -213,7 +231,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.GetParentDentryAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.GetParentDentryAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -254,7 +272,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.LinkAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.LinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -290,7 +308,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.MkdirAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.MkdirAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -317,13 +335,13 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
rp := vfs.getResolvingPath(creds, pop)
for {
err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts)
- if err != nil {
+ if err == nil {
vfs.putResolvingPath(rp)
return nil
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.MknodAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.MknodAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -333,19 +351,43 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
}
}
+// BoundEndpointAt gets the bound endpoint at the given path, if one exists.
+func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (transport.BoundEndpoint, error) {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return nil, syserror.ECONNREFUSED
+ }
+ return nil, syserror.ENOENT
+ }
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return bep, nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.BoundEndpointAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return nil, err
+ }
+ }
+}
+
// OpenAt returns a FileDescription providing access to the file at the given
// path. A reference is taken on the returned FileDescription.
func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
// Remove:
//
- // - O_LARGEFILE, which we always report in FileDescription status flags
- // since only 64-bit architectures are supported at this time.
- //
// - O_CLOEXEC, which affects file descriptors and therefore must be
// handled outside of VFS.
//
// - Unknown flags.
- opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE
+ opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_LARGEFILE | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE
// Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC.
if opts.Flags&linux.O_SYNC != 0 {
opts.Flags |= linux.O_DSYNC
@@ -385,9 +427,12 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
if err == nil {
vfs.putResolvingPath(rp)
- // TODO(gvisor.dev/issue/1193): Move inside fsimpl to avoid another call
- // to FileDescription.Stat().
if opts.FileExec {
+ if fd.Mount().flags.NoExec {
+ fd.DecRef()
+ return nil, syserror.EACCES
+ }
+
// Only a regular file can be executed.
stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE})
if err != nil {
@@ -474,7 +519,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.RenameAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.RenameAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -507,7 +552,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.RmdirAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.RmdirAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -588,7 +633,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.SymlinkAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.SymlinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -620,7 +665,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti
}
if checkInvariants {
if rp.canHandleError(err) && rp.Done() {
- panic(fmt.Sprintf("%T.UnlinkAt() consumed all path components and returned %T", rp.mount.fs.impl, err))
+ panic(fmt.Sprintf("%T.UnlinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
}
}
if !rp.handleError(err) {
@@ -632,10 +677,10 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti
// ListxattrAt returns all extended attribute names for the file at the given
// path.
-func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) ([]string, error) {
+func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) {
rp := vfs.getResolvingPath(creds, pop)
for {
- names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp)
+ names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp, size)
if err == nil {
vfs.putResolvingPath(rp)
return names, nil
@@ -657,10 +702,10 @@ func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Crede
// GetxattrAt returns the value associated with the given extended attribute
// for the file at the given path.
-func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) (string, error) {
+func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetxattrOptions) (string, error) {
rp := vfs.getResolvingPath(creds, pop)
for {
- val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, name)
+ val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, *opts)
if err == nil {
vfs.putResolvingPath(rp)
return val, nil
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index bfb2fac26..f7d6009a0 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -221,7 +221,7 @@ func (w *Watchdog) waitForStart() {
return
}
var buf bytes.Buffer
- buf.WriteString("Watchdog.Start() not called within %s:\n")
+ buf.WriteString(fmt.Sprintf("Watchdog.Start() not called within %s", w.StartupTimeout))
w.doAction(w.StartupTimeoutAction, false, &buf)
}
@@ -325,7 +325,7 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo
func (w *Watchdog) reportStuckWatchdog() {
var buf bytes.Buffer
- buf.WriteString("Watchdog goroutine is stuck:\n")
+ buf.WriteString("Watchdog goroutine is stuck:")
w.doAction(w.TaskTimeoutAction, false, &buf)
}
@@ -359,7 +359,7 @@ func (w *Watchdog) doAction(action Action, skipStack bool, msg *bytes.Buffer) {
case <-metricsEmitted:
case <-time.After(1 * time.Second):
}
- panic(fmt.Sprintf("Stack for running G's are skipped while panicking.\n%s", msg.String()))
+ panic(fmt.Sprintf("%s\nStack for running G's are skipped while panicking.", msg.String()))
default:
panic(fmt.Sprintf("Unknown watchdog action %v", action))
diff --git a/pkg/state/state.go b/pkg/state/state.go
index dbe507ab4..03ae2dbb0 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -241,10 +241,7 @@ func Register(name string, instance interface{}, fns Fns) {
//
// This function is used by the stateify tool.
func IsZeroValue(val interface{}) bool {
- if val == nil {
- return true
- }
- return reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface())
+ return val == nil || reflect.ValueOf(val).Elem().IsZero()
}
// step captures one encoding / decoding step. On each step, there is up to one
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index 5340cf0d6..0e35d7d17 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -31,13 +31,13 @@ go_library(
name = "sync",
srcs = [
"aliases.go",
- "downgradable_rwmutex_unsafe.go",
"memmove_unsafe.go",
+ "mutex_unsafe.go",
"norace_unsafe.go",
"race_unsafe.go",
+ "rwmutex_unsafe.go",
"seqcount.go",
- "syncutil.go",
- "tmutex_unsafe.go",
+ "sync.go",
],
)
@@ -45,9 +45,9 @@ go_test(
name = "sync_test",
size = "small",
srcs = [
- "downgradable_rwmutex_test.go",
+ "mutex_test.go",
+ "rwmutex_test.go",
"seqcount_test.go",
- "tmutex_test.go",
],
library = ":sync",
)
diff --git a/pkg/sync/aliases.go b/pkg/sync/aliases.go
index d2d7132fa..0d4316254 100644
--- a/pkg/sync/aliases.go
+++ b/pkg/sync/aliases.go
@@ -29,3 +29,8 @@ type (
// Map is an alias of sync.Map.
Map = sync.Map
)
+
+// NewCond is a wrapper around sync.NewCond.
+func NewCond(l Locker) *Cond {
+ return sync.NewCond(l)
+}
diff --git a/pkg/sync/tmutex_test.go b/pkg/sync/mutex_test.go
index 0838248b4..0838248b4 100644
--- a/pkg/sync/tmutex_test.go
+++ b/pkg/sync/mutex_test.go
diff --git a/pkg/sync/tmutex_unsafe.go b/pkg/sync/mutex_unsafe.go
index 3dd15578b..3dd15578b 100644
--- a/pkg/sync/tmutex_unsafe.go
+++ b/pkg/sync/mutex_unsafe.go
diff --git a/pkg/sync/downgradable_rwmutex_test.go b/pkg/sync/rwmutex_test.go
index ce667e825..ce667e825 100644
--- a/pkg/sync/downgradable_rwmutex_test.go
+++ b/pkg/sync/rwmutex_test.go
diff --git a/pkg/sync/downgradable_rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go
index ea6cdc447..ea6cdc447 100644
--- a/pkg/sync/downgradable_rwmutex_unsafe.go
+++ b/pkg/sync/rwmutex_unsafe.go
diff --git a/pkg/sync/syncutil.go b/pkg/sync/sync.go
index b16cf5333..b16cf5333 100644
--- a/pkg/sync/syncutil.go
+++ b/pkg/sync/sync.go
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index 4b5a0fca6..f86db0999 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -27,6 +27,7 @@ import (
var (
E2BIG = error(syscall.E2BIG)
EACCES = error(syscall.EACCES)
+ EADDRINUSE = error(syscall.EADDRINUSE)
EAGAIN = error(syscall.EAGAIN)
EBADF = error(syscall.EBADF)
EBADFD = error(syscall.EBADFD)
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 26f7ba86b..454e07662 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -5,8 +5,6 @@ package(licenses = ["notice"])
go_library(
name = "tcpip",
srcs = [
- "packet_buffer.go",
- "packet_buffer_state.go",
"tcpip.go",
"time_unsafe.go",
"timer.go",
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index a984f1712..e57d45f2a 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -22,6 +22,7 @@ go_test(
size = "small",
srcs = ["gonet_test.go"],
library = ":gonet",
+ tags = ["flaky"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 17e94c562..8ec5d5d5c 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -15,6 +15,11 @@
// Package buffer provides the implementation of a buffer view.
package buffer
+import (
+ "bytes"
+ "io"
+)
+
// View is a slice of a buffer, with convenience methods.
type View []byte
@@ -45,6 +50,13 @@ func (v *View) CapLength(length int) {
*v = (*v)[:length:length]
}
+// Reader returns a bytes.Reader for v.
+func (v *View) Reader() bytes.Reader {
+ var r bytes.Reader
+ r.Reset(*v)
+ return r
+}
+
// ToVectorisedView returns a VectorisedView containing the receiver.
func (v View) ToVectorisedView() VectorisedView {
return NewVectorisedView(len(v), []View{v})
@@ -78,6 +90,47 @@ func (vv *VectorisedView) TrimFront(count int) {
}
}
+// Read implements io.Reader.
+func (vv *VectorisedView) Read(v View) (copied int, err error) {
+ count := len(v)
+ for count > 0 && len(vv.views) > 0 {
+ if count < len(vv.views[0]) {
+ vv.size -= count
+ copy(v[copied:], vv.views[0][:count])
+ vv.views[0].TrimFront(count)
+ copied += count
+ return copied, nil
+ }
+ count -= len(vv.views[0])
+ copy(v[copied:], vv.views[0])
+ copied += len(vv.views[0])
+ vv.RemoveFirst()
+ }
+ if copied == 0 {
+ return 0, io.EOF
+ }
+ return copied, nil
+}
+
+// ReadToVV reads up to n bytes from vv to dstVV and removes them from vv. It
+// returns the number of bytes copied.
+func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int) {
+ for count > 0 && len(vv.views) > 0 {
+ if count < len(vv.views[0]) {
+ vv.size -= count
+ dstVV.AppendView(vv.views[0][:count])
+ vv.views[0].TrimFront(count)
+ copied += count
+ return
+ }
+ count -= len(vv.views[0])
+ dstVV.AppendView(vv.views[0])
+ copied += len(vv.views[0])
+ vv.RemoveFirst()
+ }
+ return copied
+}
+
// CapLength irreversibly reduces the length of the vectorised view.
func (vv *VectorisedView) CapLength(length int) {
if length < 0 {
@@ -105,12 +158,12 @@ func (vv *VectorisedView) CapLength(length int) {
// Clone returns a clone of this VectorisedView.
// If the buffer argument is large enough to contain all the Views of this VectorisedView,
// the method will avoid allocations and use the buffer to store the Views of the clone.
-func (vv VectorisedView) Clone(buffer []View) VectorisedView {
+func (vv *VectorisedView) Clone(buffer []View) VectorisedView {
return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
}
// First returns the first view of the vectorised view.
-func (vv VectorisedView) First() View {
+func (vv *VectorisedView) First() View {
if len(vv.views) == 0 {
return nil
}
@@ -123,11 +176,12 @@ func (vv *VectorisedView) RemoveFirst() {
return
}
vv.size -= len(vv.views[0])
+ vv.views[0] = nil
vv.views = vv.views[1:]
}
// Size returns the size in bytes of the entire content stored in the vectorised view.
-func (vv VectorisedView) Size() int {
+func (vv *VectorisedView) Size() int {
return vv.size
}
@@ -135,7 +189,7 @@ func (vv VectorisedView) Size() int {
//
// If the vectorised view contains a single view, that view will be returned
// directly.
-func (vv VectorisedView) ToView() View {
+func (vv *VectorisedView) ToView() View {
if len(vv.views) == 1 {
return vv.views[0]
}
@@ -147,7 +201,7 @@ func (vv VectorisedView) ToView() View {
}
// Views returns the slice containing the all views.
-func (vv VectorisedView) Views() []View {
+func (vv *VectorisedView) Views() []View {
return vv.views
}
@@ -162,3 +216,12 @@ func (vv *VectorisedView) AppendView(v View) {
vv.views = append(vv.views, v)
vv.size += len(v)
}
+
+// Readers returns a bytes.Reader for each of vv's views.
+func (vv *VectorisedView) Readers() []bytes.Reader {
+ readers := make([]bytes.Reader, 0, len(vv.views))
+ for _, v := range vv.views {
+ readers = append(readers, v.Reader())
+ }
+ return readers
+}
diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go
index ebc3a17b7..106e1994c 100644
--- a/pkg/tcpip/buffer/view_test.go
+++ b/pkg/tcpip/buffer/view_test.go
@@ -233,3 +233,140 @@ func TestToClone(t *testing.T) {
})
}
}
+
+func TestVVReadToVV(t *testing.T) {
+ testCases := []struct {
+ comment string
+ vv VectorisedView
+ bytesToRead int
+ wantBytes string
+ leftVV VectorisedView
+ }{
+ {
+ comment: "large VV, short read",
+ vv: vv(30, "012345678901234567890123456789"),
+ bytesToRead: 10,
+ wantBytes: "0123456789",
+ leftVV: vv(20, "01234567890123456789"),
+ },
+ {
+ comment: "largeVV, multiple views, short read",
+ vv: vv(13, "123", "345", "567", "8910"),
+ bytesToRead: 6,
+ wantBytes: "123345",
+ leftVV: vv(7, "567", "8910"),
+ },
+ {
+ comment: "smallVV (multiple views), large read",
+ vv: vv(3, "1", "2", "3"),
+ bytesToRead: 10,
+ wantBytes: "123",
+ leftVV: vv(0, ""),
+ },
+ {
+ comment: "smallVV (single view), large read",
+ vv: vv(1, "1"),
+ bytesToRead: 10,
+ wantBytes: "1",
+ leftVV: vv(0, ""),
+ },
+ {
+ comment: "emptyVV, large read",
+ vv: vv(0, ""),
+ bytesToRead: 10,
+ wantBytes: "",
+ leftVV: vv(0, ""),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.comment, func(t *testing.T) {
+ var readTo VectorisedView
+ inSize := tc.vv.Size()
+ copied := tc.vv.ReadToVV(&readTo, tc.bytesToRead)
+ if got, want := copied, len(tc.wantBytes); got != want {
+ t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc: %+v", got, want, tc)
+ }
+ if got, want := string(readTo.ToView()), tc.wantBytes; got != want {
+ t.Errorf("unexpected content in readTo got: %s, want: %s", got, want)
+ }
+ if got, want := tc.vv.Size(), inSize-copied; got != want {
+ t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
+ }
+ if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want {
+ t.Errorf("unexpected data left in vv after read got: %+v, want: %+v", got, want)
+ }
+ })
+ }
+}
+
+func TestVVRead(t *testing.T) {
+ testCases := []struct {
+ comment string
+ vv VectorisedView
+ bytesToRead int
+ readBytes string
+ leftBytes string
+ wantError bool
+ }{
+ {
+ comment: "large VV, short read",
+ vv: vv(30, "012345678901234567890123456789"),
+ bytesToRead: 10,
+ readBytes: "0123456789",
+ leftBytes: "01234567890123456789",
+ },
+ {
+ comment: "largeVV, multiple buffers, short read",
+ vv: vv(13, "123", "345", "567", "8910"),
+ bytesToRead: 6,
+ readBytes: "123345",
+ leftBytes: "5678910",
+ },
+ {
+ comment: "smallVV, large read",
+ vv: vv(3, "1", "2", "3"),
+ bytesToRead: 10,
+ readBytes: "123",
+ leftBytes: "",
+ },
+ {
+ comment: "smallVV, large read",
+ vv: vv(1, "1"),
+ bytesToRead: 10,
+ readBytes: "1",
+ leftBytes: "",
+ },
+ {
+ comment: "emptyVV, large read",
+ vv: vv(0, ""),
+ bytesToRead: 10,
+ readBytes: "",
+ wantError: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.comment, func(t *testing.T) {
+ readTo := NewView(tc.bytesToRead)
+ inSize := tc.vv.Size()
+ copied, err := tc.vv.Read(readTo)
+ if !tc.wantError && err != nil {
+ t.Fatalf("unexpected error in tc.vv.Read(..) = %s", err)
+ }
+ readTo = readTo[:copied]
+ if got, want := copied, len(tc.readBytes); got != want {
+ t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
+ }
+ if got, want := string(readTo), tc.readBytes; got != want {
+ t.Errorf("unexpected data in readTo got: %s, want: %s", got, want)
+ }
+ if got, want := tc.vv.Size(), inSize-copied; got != want {
+ t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
+ }
+ if got, want := string(tc.vv.ToView()), tc.leftBytes; got != want {
+ t.Errorf("vv has incorrect data after Read got: %s, want: %s", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index c6c160dfc..c1745ba6a 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -107,6 +107,8 @@ func DstAddr(addr tcpip.Address) NetworkChecker {
// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
func TTL(ttl uint8) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
var v uint8
switch ip := h[0].(type) {
case header.IPv4:
@@ -310,6 +312,8 @@ func SrcPort(port uint16) TransportChecker {
// DstPort creates a checker that checks the destination port.
func DstPort(port uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
if p := h.DestinationPort(); p != port {
t.Errorf("Bad destination port, got %v, want %v", p, port)
}
@@ -336,6 +340,7 @@ func SeqNum(seq uint32) TransportChecker {
func AckNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -350,6 +355,8 @@ func AckNum(seq uint32) TransportChecker {
// Window creates a checker that checks the tcp window.
func Window(window uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -381,6 +388,8 @@ func TCPFlags(flags uint8) TransportChecker {
// given mask, match the supplied flags.
func TCPFlagsMatch(flags, mask uint8) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -398,6 +407,8 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker {
// If wndscale is negative, the window scale option must not be present.
func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -494,6 +505,8 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
// skipped.
func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -612,6 +625,8 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
// Payload creates a checker that checks the payload.
func Payload(want []byte) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
if got := h.Payload(); !reflect.DeepEqual(got, want) {
t.Errorf("Wrong payload, got %v, want %v", got, want)
}
@@ -644,6 +659,7 @@ func ICMPv4(checkers ...TransportChecker) NetworkChecker {
func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
@@ -658,6 +674,7 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
func ICMPv4Code(want byte) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
@@ -700,6 +717,7 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker {
func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv6, ok := h.(header.ICMPv6)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
@@ -714,6 +732,7 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
func ICMPv6Code(want byte) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv6, ok := h.(header.ICMPv6)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
@@ -728,7 +747,7 @@ func ICMPv6Code(want byte) TransportChecker {
// message for type of ty, with potentially additional checks specified by
// checkers.
//
-// checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
// NDP message as far as the size of the message (minSize) is concerned. The
// values within the message are up to checkers to validate.
func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
@@ -760,9 +779,9 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N
// Neighbor Solicitation message (as per the raw wire format), with potentially
// additional checks specified by checkers.
//
-// checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDPNS message as far as the size of the messages concerned. The values within
-// the message are up to checkers to validate.
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPNS message as far as the size of the message is concerned. The values
+// within the message are up to checkers to validate.
func NDPNS(checkers ...TransportChecker) NetworkChecker {
return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
}
@@ -780,63 +799,162 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
ns := header.NDPNeighborSolicit(icmp.NDPPayload())
if got := ns.TargetAddress(); got != want {
- t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want)
+ t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
}
}
}
-// NDPNSOptions creates a checker that checks that the packet contains the
-// provided NDP options within an NDP Neighbor Solicitation message.
+// NDPNA creates a checker that checks that the packet contains a valid NDP
+// Neighbor Advertisement message (as per the raw wire format), with potentially
+// additional checks specified by checkers.
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPNA message as far as the size of the message is concerned. The values
+// within the message are up to checkers to validate.
+func NDPNA(checkers ...TransportChecker) NetworkChecker {
+ return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...)
+}
+
+// NDPNATargetAddress creates a checker that checks the Target Address field of
+// a header.NDPNeighborAdvert.
//
// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNS message as far as the size is concerned.
-func NDPNSOptions(opts []header.NDPOption) TransportChecker {
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNATargetAddress(want tcpip.Address) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
- it, err := ns.Options().Iter(true)
- if err != nil {
- t.Errorf("opts.Iter(true): %s", err)
- return
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+
+ if got := na.TargetAddress(); got != want {
+ t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
}
+ }
+}
- i := 0
- for {
- opt, done, _ := it.Next()
- if done {
- break
- }
+// NDPNASolicitedFlag creates a checker that checks the Solicited field of
+// a header.NDPNeighborAdvert.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNASolicitedFlag(want bool) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
- if i >= len(opts) {
- t.Errorf("got unexpected option: %s", opt)
- continue
- }
+ icmp := h.(header.ICMPv6)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
- switch wantOpt := opts[i].(type) {
- case header.NDPSourceLinkLayerAddressOption:
- gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
- if !ok {
- t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
- } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
- t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
- }
- default:
- panic("not implemented")
- }
+ if got := na.SolicitedFlag(); got != want {
+ t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
+ }
+ }
+}
- i++
+// ndpOptions checks that optsBuf only contains opts.
+func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) {
+ t.Helper()
+
+ it, err := optsBuf.Iter(true)
+ if err != nil {
+ t.Errorf("optsBuf.Iter(true): %s", err)
+ return
+ }
+
+ i := 0
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ t.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ }
+ if done {
+ break
+ }
+
+ if i >= len(opts) {
+ t.Errorf("got unexpected option: %s", opt)
+ continue
}
- if missing := opts[i:]; len(missing) > 0 {
- t.Errorf("missing options: %s", missing)
+ switch wantOpt := opts[i].(type) {
+ case header.NDPSourceLinkLayerAddressOption:
+ gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
+ if !ok {
+ t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
+ } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
+ t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
+ }
+ case header.NDPTargetLinkLayerAddressOption:
+ gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption)
+ if !ok {
+ t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
+ } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
+ t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
+ }
+ default:
+ t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt)
}
+
+ i++
+ }
+
+ if missing := opts[i:]; len(missing) > 0 {
+ t.Errorf("missing options: %s", missing)
+ }
+}
+
+// NDPNAOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Neighbor Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNAOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ ndpOptions(t, na.Options(), opts)
+ }
+}
+
+// NDPNSOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Neighbor Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNS message as far as the size is concerned.
+func NDPNSOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ndpOptions(t, ns.Options(), opts)
}
}
// NDPRS creates a checker that checks that the packet contains a valid NDP
// Router Solicitation message (as per the raw wire format).
-func NDPRS() NetworkChecker {
- return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize)
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPRS as far as the size of the message is concerned. The values within the
+// message are up to checkers to validate.
+func NDPRS(checkers ...TransportChecker) NetworkChecker {
+ return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...)
+}
+
+// NDPRSOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Router Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPRS message as far as the size is concerned.
+func NDPRSOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ rs := header.NDPRouterSolicit(icmp.NDPPayload())
+ ndpOptions(t, rs.Options(), opts)
+ }
}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 9da0d71f8..0cde694dc 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -14,12 +14,14 @@ go_library(
"interfaces.go",
"ipv4.go",
"ipv6.go",
+ "ipv6_extension_headers.go",
"ipv6_fragment.go",
"ndp_neighbor_advert.go",
"ndp_neighbor_solicit.go",
"ndp_options.go",
"ndp_router_advert.go",
"ndp_router_solicit.go",
+ "ndpoptionidentifier_string.go",
"tcp.go",
"udp.go",
],
@@ -55,11 +57,13 @@ go_test(
size = "small",
srcs = [
"eth_test.go",
+ "ipv6_extension_headers_test.go",
"ndp_test.go",
],
library = ":header",
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
"@com_github_google_go-cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index 7a0014ad9..14413f2ce 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -88,7 +88,7 @@ func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr {
- t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", got, test.expectedLinkAddr)
+ t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", test.addr, got, test.expectedLinkAddr)
}
})
}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index e5360e7c1..76839eb92 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -38,7 +38,8 @@ const (
// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
// fields of a packet that needs to be encoded.
type IPv4Fields struct {
- // IHL is the "internet header length" field of an IPv4 packet.
+ // IHL is the "internet header length" field of an IPv4 packet. The value
+ // is in bytes.
IHL uint8
// TOS is the "type of service" field of an IPv4 packet.
@@ -138,7 +139,7 @@ func IPVersion(b []byte) int {
}
// HeaderLength returns the value of the "header length" field of the ipv4
-// header.
+// header. The length returned is in bytes.
func (b IPv4) HeaderLength() uint8 {
return (b[versIHL] & 0xf) * 4
}
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
new file mode 100644
index 000000000..2c4591409
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -0,0 +1,544 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// IPv6ExtensionHeaderIdentifier is an IPv6 extension header identifier.
+type IPv6ExtensionHeaderIdentifier uint8
+
+const (
+ // IPv6HopByHopOptionsExtHdrIdentifier is the header identifier of a Hop by
+ // Hop Options extension header, as per RFC 8200 section 4.3.
+ IPv6HopByHopOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 0
+
+ // IPv6RoutingExtHdrIdentifier is the header identifier of a Routing extension
+ // header, as per RFC 8200 section 4.4.
+ IPv6RoutingExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 43
+
+ // IPv6FragmentExtHdrIdentifier is the header identifier of a Fragment
+ // extension header, as per RFC 8200 section 4.5.
+ IPv6FragmentExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 44
+
+ // IPv6DestinationOptionsExtHdrIdentifier is the header identifier of a
+ // Destination Options extension header, as per RFC 8200 section 4.6.
+ IPv6DestinationOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 60
+
+ // IPv6NoNextHeaderIdentifier is the header identifier used to signify the end
+ // of an IPv6 payload, as per RFC 8200 section 4.7.
+ IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59
+)
+
+const (
+ // ipv6UnknownExtHdrOptionActionMask is the mask of the action to take when
+ // a node encounters an unrecognized option.
+ ipv6UnknownExtHdrOptionActionMask = 192
+
+ // ipv6UnknownExtHdrOptionActionShift is the least significant bits to discard
+ // from the action value for an unrecognized option identifier.
+ ipv6UnknownExtHdrOptionActionShift = 6
+
+ // ipv6RoutingExtHdrSegmentsLeftIdx is the index to the Segments Left field
+ // within an IPv6RoutingExtHdr.
+ ipv6RoutingExtHdrSegmentsLeftIdx = 1
+
+ // IPv6FragmentExtHdrLength is the length of an IPv6 extension header, in
+ // bytes.
+ IPv6FragmentExtHdrLength = 8
+
+ // ipv6FragmentExtHdrFragmentOffsetOffset is the offset to the start of the
+ // Fragment Offset field within an IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrFragmentOffsetOffset = 0
+
+ // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to
+ // discard from the Fragment Offset.
+ ipv6FragmentExtHdrFragmentOffsetShift = 3
+
+ // ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an
+ // IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrFlagsIdx = 1
+
+ // ipv6FragmentExtHdrMFlagMask is the mask of the More (M) flag within the
+ // flags field of an IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrMFlagMask = 1
+
+ // ipv6FragmentExtHdrIdentificationOffset is the offset to the Identification
+ // field within an IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrIdentificationOffset = 2
+
+ // ipv6ExtHdrLenBytesPerUnit is the unit size of an extension header's length
+ // field. That is, given a Length field of 2, the extension header expects
+ // 16 bytes following the first 8 bytes (see ipv6ExtHdrLenBytesExcluded for
+ // details about the first 8 bytes' exclusion from the Length field).
+ ipv6ExtHdrLenBytesPerUnit = 8
+
+ // ipv6ExtHdrLenBytesExcluded is the number of bytes excluded from an
+ // extension header's Length field following the Length field.
+ //
+ // The Length field excludes the first 8 bytes, but the Next Header and Length
+ // field take up the first 2 of the 8 bytes so we expect (at minimum) 6 bytes
+ // after the Length field.
+ //
+ // This ensures that every extension header is at least 8 bytes.
+ ipv6ExtHdrLenBytesExcluded = 6
+
+ // IPv6FragmentExtHdrFragmentOffsetBytesPerUnit is the unit size of a Fragment
+ // extension header's Fragment Offset field. That is, given a Fragment Offset
+ // of 2, the extension header is indiciating that the fragment's payload
+ // starts at the 16th byte in the reassembled packet.
+ IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8
+)
+
+// IPv6PayloadHeader is implemented by the various headers that can be found
+// in an IPv6 payload.
+//
+// These headers include IPv6 extension headers or upper layer data.
+type IPv6PayloadHeader interface {
+ isIPv6PayloadHeader()
+}
+
+// IPv6RawPayloadHeader the remainder of an IPv6 payload after an iterator
+// encounters a Next Header field it does not recognize as an IPv6 extension
+// header.
+type IPv6RawPayloadHeader struct {
+ Identifier IPv6ExtensionHeaderIdentifier
+ Buf buffer.VectorisedView
+}
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6RawPayloadHeader) isIPv6PayloadHeader() {}
+
+// ipv6OptionsExtHdr is an IPv6 extension header that holds options.
+type ipv6OptionsExtHdr []byte
+
+// Iter returns an iterator over the IPv6 extension header options held in b.
+func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator {
+ it := IPv6OptionsExtHdrOptionsIterator{}
+ it.reader.Reset(b)
+ return it
+}
+
+// IPv6OptionsExtHdrOptionsIterator is an iterator over IPv6 extension header
+// options.
+//
+// Note, between when an IPv6OptionsExtHdrOptionsIterator is obtained and last
+// used, no changes to the underlying buffer may happen. Doing so may cause
+// undefined and unexpected behaviour. It is fine to obtain an
+// IPv6OptionsExtHdrOptionsIterator, iterate over the first few options then
+// modify the backing payload so long as the IPv6OptionsExtHdrOptionsIterator
+// obtained before modification is no longer used.
+type IPv6OptionsExtHdrOptionsIterator struct {
+ reader bytes.Reader
+}
+
+// IPv6OptionUnknownAction is the action that must be taken if the processing
+// IPv6 node does not recognize the option, as outlined in RFC 8200 section 4.2.
+type IPv6OptionUnknownAction int
+
+const (
+ // IPv6OptionUnknownActionSkip indicates that the unrecognized option must
+ // be skipped and the node should continue processing the header.
+ IPv6OptionUnknownActionSkip IPv6OptionUnknownAction = 0
+
+ // IPv6OptionUnknownActionDiscard indicates that the packet must be silently
+ // discarded.
+ IPv6OptionUnknownActionDiscard IPv6OptionUnknownAction = 1
+
+ // IPv6OptionUnknownActionDiscardSendICMP indicates that the packet must be
+ // discarded and the node must send an ICMP Parameter Problem, Code 2, message
+ // to the packet's source, regardless of whether or not the packet's
+ // Destination was a multicast address.
+ IPv6OptionUnknownActionDiscardSendICMP IPv6OptionUnknownAction = 2
+
+ // IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest indicates that the
+ // packet must be discarded and the node must send an ICMP Parameter Problem,
+ // Code 2, message to the packet's source only if the packet's Destination was
+ // not a multicast address.
+ IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest IPv6OptionUnknownAction = 3
+)
+
+// IPv6ExtHdrOption is implemented by the various IPv6 extension header options.
+type IPv6ExtHdrOption interface {
+ // UnknownAction returns the action to take in response to an unrecognized
+ // option.
+ UnknownAction() IPv6OptionUnknownAction
+
+ // isIPv6ExtHdrOption is used to "lock" this interface so it is not
+ // implemented by other packages.
+ isIPv6ExtHdrOption()
+}
+
+// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier.
+type IPv6ExtHdrOptionIndentifier uint8
+
+const (
+ // ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that
+ // provides 1 byte padding, as outlined in RFC 8200 section 4.2.
+ ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0
+
+ // ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that
+ // provides variable length byte padding, as outlined in RFC 8200 section 4.2.
+ ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1
+)
+
+// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension
+// header option that is unknown by the parsing utilities.
+type IPv6UnknownExtHdrOption struct {
+ Identifier IPv6ExtHdrOptionIndentifier
+ Data []byte
+}
+
+// UnknownAction implements IPv6OptionUnknownAction.UnknownAction.
+func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction {
+ return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
+}
+
+// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption.
+func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {}
+
+// Next returns the next option in the options data.
+//
+// If the next item is not a known extension header option,
+// IPv6UnknownExtHdrOption will be returned with the option identifier and data.
+//
+// The return is of the format (option, done, error). done will be true when
+// Next is unable to return anything because the iterator has reached the end of
+// the options data, or an error occured.
+func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) {
+ for {
+ temp, err := i.reader.ReadByte()
+ if err != nil {
+ // If we can't read the first byte of a new option, then we know the
+ // options buffer has been exhausted and we are done iterating.
+ return nil, true, nil
+ }
+ id := IPv6ExtHdrOptionIndentifier(temp)
+
+ // If the option identifier indicates the option is a Pad1 option, then we
+ // know the option does not have Length and Data fields. End processing of
+ // the Pad1 option and continue processing the buffer as a new option.
+ if id == ipv6Pad1ExtHdrOptionIdentifier {
+ continue
+ }
+
+ length, err := i.reader.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ // ReadByte should only ever return nil or io.EOF.
+ panic(fmt.Sprintf("unexpected error when reading the option's Length field for option with id = %d: %s", id, err))
+ }
+
+ // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once
+ // we start parsing an option; we expect the reader to contain enough
+ // bytes for the whole option.
+ return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF)
+ }
+
+ // Special-case the variable length padding option to avoid a copy.
+ if id == ipv6PadNExtHdrOptionIdentifier {
+ // Do we have enough bytes in the reader for the PadN option?
+ if n := i.reader.Len(); n < int(length) {
+ // Reset the reader to effectively consume the remaining buffer.
+ i.reader.Reset(nil)
+
+ // We return the same error as if we failed to read a non-padding option
+ // so consumers of this iterator don't need to differentiate between
+ // padding and non-padding options.
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
+ }
+
+ if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil {
+ panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err))
+ }
+
+ // End processing of the PadN option and continue processing the buffer as
+ // a new option.
+ continue
+ }
+
+ bytes := make([]byte, length)
+ if n, err := io.ReadFull(&i.reader, bytes); err != nil {
+ // io.ReadFull may return io.EOF if i.reader has been exhausted. We use
+ // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the
+ // Length field found in the option.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
+ }
+
+ return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
+ }
+}
+
+// IPv6HopByHopOptionsExtHdr is a buffer holding the Hop By Hop Options
+// extension header.
+type IPv6HopByHopOptionsExtHdr struct {
+ ipv6OptionsExtHdr
+}
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6HopByHopOptionsExtHdr) isIPv6PayloadHeader() {}
+
+// IPv6DestinationOptionsExtHdr is a buffer holding the Destination Options
+// extension header.
+type IPv6DestinationOptionsExtHdr struct {
+ ipv6OptionsExtHdr
+}
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6DestinationOptionsExtHdr) isIPv6PayloadHeader() {}
+
+// IPv6RoutingExtHdr is a buffer holding the Routing extension header specific
+// data as outlined in RFC 8200 section 4.4.
+type IPv6RoutingExtHdr []byte
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6RoutingExtHdr) isIPv6PayloadHeader() {}
+
+// SegmentsLeft returns the Segments Left field.
+func (b IPv6RoutingExtHdr) SegmentsLeft() uint8 {
+ return b[ipv6RoutingExtHdrSegmentsLeftIdx]
+}
+
+// IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific
+// data as outlined in RFC 8200 section 4.5.
+//
+// Note, the buffer does not include the Next Header and Reserved fields.
+type IPv6FragmentExtHdr [6]byte
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6FragmentExtHdr) isIPv6PayloadHeader() {}
+
+// FragmentOffset returns the Fragment Offset field.
+//
+// This value indicates where the buffer following the Fragment extension header
+// starts in the target (reassembled) packet.
+func (b IPv6FragmentExtHdr) FragmentOffset() uint16 {
+ return binary.BigEndian.Uint16(b[ipv6FragmentExtHdrFragmentOffsetOffset:]) >> ipv6FragmentExtHdrFragmentOffsetShift
+}
+
+// More returns the More (M) flag.
+//
+// This indicates whether any fragments are expected to succeed b.
+func (b IPv6FragmentExtHdr) More() bool {
+ return b[ipv6FragmentExtHdrFlagsIdx]&ipv6FragmentExtHdrMFlagMask != 0
+}
+
+// ID returns the Identification field.
+//
+// This value is used to uniquely identify the packet, between a
+// souce and destination.
+func (b IPv6FragmentExtHdr) ID() uint32 {
+ return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:])
+}
+
+// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload.
+//
+// The IPv6 payload may contain IPv6 extension headers before any upper layer
+// data.
+//
+// Note, between when an IPv6PayloadIterator is obtained and last used, no
+// changes to the payload may happen. Doing so may cause undefined and
+// unexpected behaviour. It is fine to obtain an IPv6PayloadIterator, iterate
+// over the first few headers then modify the backing payload so long as the
+// IPv6PayloadIterator obtained before modification is no longer used.
+type IPv6PayloadIterator struct {
+ // The identifier of the next header to parse.
+ nextHdrIdentifier IPv6ExtensionHeaderIdentifier
+
+ // reader is an io.Reader over payload.
+ reader bufio.Reader
+ payload buffer.VectorisedView
+
+ // Indicates to the iterator that it should return the remaining payload as a
+ // raw payload on the next call to Next.
+ forceRaw bool
+}
+
+// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing
+// extension headers, or a raw payload if the payload cannot be parsed.
+func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, payload buffer.VectorisedView) IPv6PayloadIterator {
+ readers := payload.Readers()
+ readerPs := make([]io.Reader, 0, len(readers))
+ for i := range readers {
+ readerPs = append(readerPs, &readers[i])
+ }
+
+ return IPv6PayloadIterator{
+ nextHdrIdentifier: nextHdrIdentifier,
+ payload: payload.Clone(nil),
+ // We need a buffer of size 1 for calls to bufio.Reader.ReadByte.
+ reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1),
+ }
+}
+
+// AsRawHeader returns the remaining payload of i as a raw header and
+// optionally consumes the iterator.
+//
+// If consume is true, calls to Next after calling AsRawHeader on i will
+// indicate that the iterator is done.
+func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
+ identifier := i.nextHdrIdentifier
+
+ var buf buffer.VectorisedView
+ if consume {
+ // Since we consume the iterator, we return the payload as is.
+ buf = i.payload
+
+ // Mark i as done.
+ *i = IPv6PayloadIterator{
+ nextHdrIdentifier: IPv6NoNextHeaderIdentifier,
+ }
+ } else {
+ buf = i.payload.Clone(nil)
+ }
+
+ return IPv6RawPayloadHeader{Identifier: identifier, Buf: buf}
+}
+
+// Next returns the next item in the payload.
+//
+// If the next item is not a known IPv6 extension header, IPv6RawPayloadHeader
+// will be returned with the remaining bytes and next header identifier.
+//
+// The return is of the format (header, done, error). done will be true when
+// Next is unable to return anything because the iterator has reached the end of
+// the payload, or an error occured.
+func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
+ // We could be forced to return i as a raw header when the previous header was
+ // a fragment extension header as the data following the fragment extension
+ // header may not be complete.
+ if i.forceRaw {
+ return i.AsRawHeader(true /* consume */), false, nil
+ }
+
+ // Is the header we are parsing a known extension header?
+ switch i.nextHdrIdentifier {
+ case IPv6HopByHopOptionsExtHdrIdentifier:
+ nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil)
+ if err != nil {
+ return nil, true, err
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: bytes}, false, nil
+ case IPv6RoutingExtHdrIdentifier:
+ nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil)
+ if err != nil {
+ return nil, true, err
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return IPv6RoutingExtHdr(bytes), false, nil
+ case IPv6FragmentExtHdrIdentifier:
+ var data [6]byte
+ // We ignore the returned bytes becauase we know the fragment extension
+ // header specific data will fit in data.
+ nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:])
+ if err != nil {
+ return nil, true, err
+ }
+
+ fragmentExtHdr := IPv6FragmentExtHdr(data)
+
+ // If the packet is not the first fragment, do not attempt to parse anything
+ // after the fragment extension header as the payload following the fragment
+ // extension header should not contain any headers; the first fragment must
+ // hold all the headers up to and including any upper layer headers, as per
+ // RFC 8200 section 4.5.
+ if fragmentExtHdr.FragmentOffset() != 0 {
+ i.forceRaw = true
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return fragmentExtHdr, false, nil
+ case IPv6DestinationOptionsExtHdrIdentifier:
+ nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil)
+ if err != nil {
+ return nil, true, err
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: bytes}, false, nil
+ case IPv6NoNextHeaderIdentifier:
+ // This indicates the end of the IPv6 payload.
+ return nil, true, nil
+
+ default:
+ // The header we are parsing is not a known extension header. Return the
+ // raw payload.
+ return i.AsRawHeader(true /* consume */), false, nil
+ }
+}
+
+// nextHeaderData returns the extension header's Next Header field and raw data.
+//
+// fragmentHdr indicates that the extension header being parsed is the Fragment
+// extension header so the Length field should be ignored as it is Reserved
+// for the Fragment extension header.
+//
+// If bytes is not nil, extension header specific data will be read into bytes
+// if it has enough capacity. If bytes is provided but does not have enough
+// capacity for the data, nextHeaderData will panic.
+func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IPv6ExtensionHeaderIdentifier, []byte, error) {
+ // We ignore the number of bytes read because we know we will only ever read
+ // at max 1 bytes since rune has a length of 1. If we read 0 bytes, the Read
+ // would return io.EOF to indicate that io.Reader has reached the end of the
+ // payload.
+ nextHdrIdentifier, err := i.reader.ReadByte()
+ i.payload.TrimFront(1)
+ if err != nil {
+ return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
+ }
+
+ var length uint8
+ length, err = i.reader.ReadByte()
+ i.payload.TrimFront(1)
+ if err != nil {
+ if fragmentHdr {
+ return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
+ }
+
+ return 0, nil, fmt.Errorf("error when reading the Reserved field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
+ }
+ if fragmentHdr {
+ length = 0
+ }
+
+ bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded
+ if bytes == nil {
+ bytes = make([]byte, bytesLen)
+ } else if n := len(bytes); n < bytesLen {
+ panic(fmt.Sprintf("bytes only has space for %d bytes but need space for %d bytes (length = %d) for extension header with id = %d", n, bytesLen, length, i.nextHdrIdentifier))
+ }
+
+ n, err := io.ReadFull(&i.reader, bytes)
+ i.payload.TrimFront(n)
+ if err != nil {
+ return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err)
+ }
+
+ return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil
+}
diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go
new file mode 100644
index 000000000..ab20c5f37
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_extension_headers_test.go
@@ -0,0 +1,992 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// Equal returns true of a and b are equivalent.
+//
+// Note, Equal will return true if a and b hold the same Identifier value and
+// contain the same bytes in Buf, even if the bytes are split across views
+// differently.
+//
+// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
+// fields.
+func (a IPv6RawPayloadHeader) Equal(b IPv6RawPayloadHeader) bool {
+ return a.Identifier == b.Identifier && bytes.Equal(a.Buf.ToView(), b.Buf.ToView())
+}
+
+// Equal returns true of a and b are equivalent.
+//
+// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs.
+//
+// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
+// fields.
+func (a IPv6HopByHopOptionsExtHdr) Equal(b IPv6HopByHopOptionsExtHdr) bool {
+ return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr)
+}
+
+// Equal returns true of a and b are equivalent.
+//
+// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs.
+//
+// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
+// fields.
+func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool {
+ return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr)
+}
+
+func TestIPv6UnknownExtHdrOption(t *testing.T) {
+ tests := []struct {
+ name string
+ identifier IPv6ExtHdrOptionIndentifier
+ expectedUnknownAction IPv6OptionUnknownAction
+ }{
+ {
+ name: "Skip with zero LSBs",
+ identifier: 0,
+ expectedUnknownAction: IPv6OptionUnknownActionSkip,
+ },
+ {
+ name: "Discard with zero LSBs",
+ identifier: 64,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscard,
+ },
+ {
+ name: "Discard and ICMP with zero LSBs",
+ identifier: 128,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP,
+ },
+ {
+ name: "Discard and ICMP for non multicast destination with zero LSBs",
+ identifier: 192,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ },
+ {
+ name: "Skip with non-zero LSBs",
+ identifier: 63,
+ expectedUnknownAction: IPv6OptionUnknownActionSkip,
+ },
+ {
+ name: "Discard with non-zero LSBs",
+ identifier: 127,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscard,
+ },
+ {
+ name: "Discard and ICMP with non-zero LSBs",
+ identifier: 191,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP,
+ },
+ {
+ name: "Discard and ICMP for non multicast destination with non-zero LSBs",
+ identifier: 255,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opt := &IPv6UnknownExtHdrOption{Identifier: test.identifier, Data: []byte{1, 2, 3, 4}}
+ if a := opt.UnknownAction(); a != test.expectedUnknownAction {
+ t.Fatalf("got UnknownAction() = %d, want = %d", a, test.expectedUnknownAction)
+ }
+ })
+ }
+
+}
+
+func TestIPv6OptionsExtHdrIterErr(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes []byte
+ err error
+ }{
+ {
+ name: "Single unknown with zero length",
+ bytes: []byte{255, 0},
+ },
+ {
+ name: "Single unknown with non-zero length",
+ bytes: []byte{255, 3, 1, 2, 3},
+ },
+ {
+ name: "Two options",
+ bytes: []byte{
+ 255, 0,
+ 254, 1, 1,
+ },
+ },
+ {
+ name: "Three options",
+ bytes: []byte{
+ 255, 0,
+ 254, 1, 1,
+ 253, 4, 2, 3, 4, 5,
+ },
+ },
+ {
+ name: "Single unknown only identifier",
+ bytes: []byte{255},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Single unknown too small with length = 1",
+ bytes: []byte{255, 1},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Single unknown too small with length = 2",
+ bytes: []byte{255, 2, 1},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid first with second unknown only identifier",
+ bytes: []byte{
+ 255, 0,
+ 254,
+ },
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid first with second unknown missing data",
+ bytes: []byte{
+ 255, 0,
+ 254, 1,
+ },
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid first with second unknown too small",
+ bytes: []byte{
+ 255, 0,
+ 254, 2, 1,
+ },
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "One Pad1",
+ bytes: []byte{0},
+ },
+ {
+ name: "Multiple Pad1",
+ bytes: []byte{0, 0, 0},
+ },
+ {
+ name: "Multiple PadN",
+ bytes: []byte{
+ // Pad3
+ 1, 1, 1,
+
+ // Pad5
+ 1, 3, 1, 2, 3,
+ },
+ },
+ {
+ name: "Pad5 too small middle of data buffer",
+ bytes: []byte{1, 3, 1, 2},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Pad5 no data",
+ bytes: []byte{1, 3},
+ err: io.ErrUnexpectedEOF,
+ },
+ }
+
+ check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) {
+ for i := 0; ; i++ {
+ _, done, err := it.Next()
+ if err != nil {
+ // If we encountered a non-nil error while iterating, make sure it is
+ // is the same error as expectedErr.
+ if !errors.Is(err, expectedErr) {
+ t.Fatalf("got %d-th Next() = %v, want = %v", i, err, expectedErr)
+ }
+
+ return
+ }
+ if done {
+ // If we are done (without an error), make sure that we did not expect
+ // an error.
+ if expectedErr != nil {
+ t.Fatalf("expected error when iterating; want = %s", expectedErr)
+ }
+
+ return
+ }
+ }
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Run("Hop By Hop", func(t *testing.T) {
+ extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ check(t, extHdr.Iter(), test.err)
+ })
+
+ t.Run("Destination", func(t *testing.T) {
+ extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ check(t, extHdr.Iter(), test.err)
+ })
+ })
+ }
+}
+
+func TestIPv6OptionsExtHdrIter(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes []byte
+ expected []IPv6ExtHdrOption
+ }{
+ {
+ name: "Single unknown with zero length",
+ bytes: []byte{255, 0},
+ expected: []IPv6ExtHdrOption{
+ &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}},
+ },
+ },
+ {
+ name: "Single unknown with non-zero length",
+ bytes: []byte{255, 3, 1, 2, 3},
+ expected: []IPv6ExtHdrOption{
+ &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{1, 2, 3}},
+ },
+ },
+ {
+ name: "Single Pad1",
+ bytes: []byte{0},
+ },
+ {
+ name: "Two Pad1",
+ bytes: []byte{0, 0},
+ },
+ {
+ name: "Single Pad3",
+ bytes: []byte{1, 1, 1},
+ },
+ {
+ name: "Single Pad5",
+ bytes: []byte{1, 3, 1, 2, 3},
+ },
+ {
+ name: "Multiple Pad",
+ bytes: []byte{
+ // Pad1
+ 0,
+
+ // Pad2
+ 1, 0,
+
+ // Pad3
+ 1, 1, 1,
+
+ // Pad4
+ 1, 2, 1, 2,
+
+ // Pad5
+ 1, 3, 1, 2, 3,
+ },
+ },
+ {
+ name: "Multiple options",
+ bytes: []byte{
+ // Pad1
+ 0,
+
+ // Unknown
+ 255, 0,
+
+ // Pad2
+ 1, 0,
+
+ // Unknown
+ 254, 1, 1,
+
+ // Pad3
+ 1, 1, 1,
+
+ // Unknown
+ 253, 4, 2, 3, 4, 5,
+
+ // Pad4
+ 1, 2, 1, 2,
+ },
+ expected: []IPv6ExtHdrOption{
+ &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}},
+ &IPv6UnknownExtHdrOption{Identifier: 254, Data: []byte{1}},
+ &IPv6UnknownExtHdrOption{Identifier: 253, Data: []byte{2, 3, 4, 5}},
+ },
+ },
+ }
+
+ checkIter := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expected []IPv6ExtHdrOption) {
+ for i, e := range expected {
+ opt, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(i=%d) Next(): %s", i, err)
+ }
+ if done {
+ t.Errorf("(i=%d) unexpectedly done iterating", i)
+ }
+ if diff := cmp.Diff(e, opt); diff != "" {
+ t.Errorf("(i=%d) got option mismatch (-want +got):\n%s", i, diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ opt, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(last) Next(): %s", err)
+ }
+ if !done {
+ t.Errorf("(last) iterator unexpectedly not done")
+ }
+ if opt != nil {
+ t.Errorf("(last) got Next() = %T, want = nil", opt)
+ }
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Run("Hop By Hop", func(t *testing.T) {
+ extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ checkIter(t, extHdr.Iter(), test.expected)
+ })
+
+ t.Run("Destination", func(t *testing.T) {
+ extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ checkIter(t, extHdr.Iter(), test.expected)
+ })
+ })
+ }
+}
+
+func TestIPv6RoutingExtHdr(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes []byte
+ segmentsLeft uint8
+ }{
+ {
+ name: "Zeroes",
+ bytes: []byte{0, 0, 0, 0, 0, 0},
+ segmentsLeft: 0,
+ },
+ {
+ name: "Ones",
+ bytes: []byte{1, 1, 1, 1, 1, 1},
+ segmentsLeft: 1,
+ },
+ {
+ name: "Mixed",
+ bytes: []byte{1, 2, 3, 4, 5, 6},
+ segmentsLeft: 2,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ extHdr := IPv6RoutingExtHdr(test.bytes)
+ if got := extHdr.SegmentsLeft(); got != test.segmentsLeft {
+ t.Errorf("got SegmentsLeft() = %d, want = %d", got, test.segmentsLeft)
+ }
+ })
+ }
+}
+
+func TestIPv6FragmentExtHdr(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes [6]byte
+ fragmentOffset uint16
+ more bool
+ id uint32
+ }{
+ {
+ name: "Zeroes",
+ bytes: [6]byte{0, 0, 0, 0, 0, 0},
+ fragmentOffset: 0,
+ more: false,
+ id: 0,
+ },
+ {
+ name: "Ones",
+ bytes: [6]byte{0, 9, 0, 0, 0, 1},
+ fragmentOffset: 1,
+ more: true,
+ id: 1,
+ },
+ {
+ name: "Mixed",
+ bytes: [6]byte{68, 9, 128, 4, 2, 1},
+ fragmentOffset: 2177,
+ more: true,
+ id: 2147746305,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ extHdr := IPv6FragmentExtHdr(test.bytes)
+ if got := extHdr.FragmentOffset(); got != test.fragmentOffset {
+ t.Errorf("got FragmentOffset() = %d, want = %d", got, test.fragmentOffset)
+ }
+ if got := extHdr.More(); got != test.more {
+ t.Errorf("got More() = %t, want = %t", got, test.more)
+ }
+ if got := extHdr.ID(); got != test.id {
+ t.Errorf("got ID() = %d, want = %d", got, test.id)
+ }
+ })
+ }
+}
+
+func makeVectorisedViewFromByteBuffers(bs ...[]byte) buffer.VectorisedView {
+ size := 0
+ var vs []buffer.View
+
+ for _, b := range bs {
+ vs = append(vs, buffer.View(b))
+ size += len(b)
+ }
+
+ return buffer.NewVectorisedView(size, vs)
+}
+
+func TestIPv6ExtHdrIterErr(t *testing.T) {
+ tests := []struct {
+ name string
+ firstNextHdr IPv6ExtensionHeaderIdentifier
+ payload buffer.VectorisedView
+ err error
+ }{
+ {
+ name: "Upper layer only without data",
+ firstNextHdr: 255,
+ },
+ {
+ name: "Upper layer only with data",
+ firstNextHdr: 255,
+ payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}),
+ },
+ {
+ name: "No next header",
+ firstNextHdr: IPv6NoNextHeaderIdentifier,
+ },
+ {
+ name: "No next header with data",
+ firstNextHdr: IPv6NoNextHeaderIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}),
+ },
+ {
+ name: "Valid single hop by hop",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}),
+ },
+ {
+ name: "Hop by hop too small",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid single fragment",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2, 1}),
+ },
+ {
+ name: "Fragment too small",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid single destination",
+ firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}),
+ },
+ {
+ name: "Destination too small",
+ firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid single routing",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5, 6}),
+ },
+ {
+ name: "Valid single routing across views",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2}, []byte{3, 4, 5, 6}),
+ },
+ {
+ name: "Routing too small with zero length field",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid routing with non-zero length field",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8}),
+ },
+ {
+ name: "Valid routing with non-zero length field across views",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7, 8}),
+ },
+ {
+ name: "Routing too small with non-zero length field",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Routing too small with non-zero length field across views",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Mixed",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // (Atomic) Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 255, 4, 1, 2, 3, 4,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ },
+ {
+ name: "Mixed without upper layer data",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // (Atomic) Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 255, 4, 1, 2, 3, 4,
+ }),
+ },
+ {
+ name: "Mixed without upper layer data but last ext hdr too small",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // (Atomic) Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 255, 4, 1, 2, 3,
+ }),
+ err: io.ErrUnexpectedEOF,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload)
+
+ for i := 0; ; i++ {
+ _, done, err := it.Next()
+ if err != nil {
+ // If we encountered a non-nil error while iterating, make sure it is
+ // is the same error as test.err.
+ if !errors.Is(err, test.err) {
+ t.Fatalf("got %d-th Next() = %v, want = %v", i, err, test.err)
+ }
+
+ return
+ }
+ if done {
+ // If we are done (without an error), make sure that we did not expect
+ // an error.
+ if test.err != nil {
+ t.Fatalf("expected error when iterating; want = %s", test.err)
+ }
+
+ return
+ }
+ }
+ })
+ }
+}
+
+func TestIPv6ExtHdrIter(t *testing.T) {
+ routingExtHdrWithUpperLayerData := buffer.View([]byte{255, 0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4})
+ upperLayerData := buffer.View([]byte{1, 2, 3, 4})
+ tests := []struct {
+ name string
+ firstNextHdr IPv6ExtensionHeaderIdentifier
+ payload buffer.VectorisedView
+ expected []IPv6PayloadHeader
+ }{
+ // With a non-atomic fragment that is not the first fragment, the payload
+ // after the fragment will not be parsed because the payload is expected to
+ // only hold upper layer data.
+ {
+ name: "hopbyhop - fragment (not first) - routing - upper",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Fragment extension header.
+ //
+ // More = 1, Fragment Offset = 2117, ID = 2147746305
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1,
+
+ // Routing extension header.
+ //
+ // Even though we have a routing ext header here, it should be
+ // be interpretted as raw bytes as only the first fragment is expected
+ // to hold headers.
+ 255, 0, 1, 2, 3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}),
+ IPv6RawPayloadHeader{
+ Identifier: IPv6RoutingExtHdrIdentifier,
+ Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+ {
+ name: "hopbyhop - fragment (first) - routing - upper",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Fragment extension header.
+ //
+ // More = 1, Fragment Offset = 0, ID = 2147746305
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 0, 1, 128, 4, 2, 1,
+
+ // Routing extension header.
+ 255, 0, 1, 2, 3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6FragmentExtHdr([6]byte{0, 1, 128, 4, 2, 1}),
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: upperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+ {
+ name: "fragment - routing - upper (across views)",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1,
+
+ // Routing extension header.
+ 255, 0, 1, 2}, []byte{3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}),
+ IPv6RawPayloadHeader{
+ Identifier: IPv6RoutingExtHdrIdentifier,
+ Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+
+ // If we have an atomic fragment, the payload following the fragment
+ // extension header should be parsed normally.
+ {
+ name: "atomic fragment - routing - destination - upper",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 1, 4, 1, 2, 3, 4,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: upperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+ {
+ name: "atomic fragment - routing - upper (across views)",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1,
+
+ // Routing extension header.
+ 255, 0, 1, 2}, []byte{3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2}, []byte{3, 4}),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
+ },
+ },
+ },
+ {
+ name: "atomic fragment - destination - no next header",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ //
+ // Res (Reserved) bits are 1 which should not affect anything.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 0, 6, 128, 4, 2, 1,
+
+ // Destination Options extension header.
+ uint8(IPv6NoNextHeaderIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ },
+ },
+ {
+ name: "routing - atomic fragment - no next header",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Routing extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6NoNextHeaderIdentifier), 0, 0, 6, 128, 4, 2, 1,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ },
+ },
+ {
+ name: "routing - atomic fragment - no next header (across views)",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Routing extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6NoNextHeaderIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ },
+ },
+ {
+ name: "hopbyhop - routing - fragment - no next header",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Fragment extension header.
+ //
+ // Fragment Offset = 32; Res = 6.
+ uint8(IPv6NoNextHeaderIdentifier), 0, 1, 6, 128, 4, 2, 1,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6FragmentExtHdr([6]byte{1, 6, 128, 4, 2, 1}),
+ IPv6RawPayloadHeader{
+ Identifier: IPv6NoNextHeaderIdentifier,
+ Buf: upperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+
+ // Test the raw payload for common transport layer protocol numbers.
+ {
+ name: "TCP raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "UDP raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "ICMPv4 raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "ICMPv6 raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "Unknwon next header raw payload",
+ firstNextHdr: 255,
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "Unknwon next header raw payload (across views)",
+ firstNextHdr: 255,
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
+ }},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload)
+
+ for i, e := range test.expected {
+ extHdr, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(i=%d) Next(): %s", i, err)
+ }
+ if done {
+ t.Errorf("(i=%d) unexpectedly done iterating", i)
+ }
+ if diff := cmp.Diff(e, extHdr); diff != "" {
+ t.Errorf("(i=%d) got ext hdr mismatch (-want +got):\n%s", i, diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ extHdr, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(last) Next(): %s", err)
+ }
+ if !done {
+ t.Errorf("(last) iterator unexpectedly not done")
+ }
+ if extHdr != nil {
+ t.Errorf("(last) got Next() = %T, want = nil", extHdr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index e6a6ad39b..444e90820 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -15,32 +15,43 @@
package header
import (
+ "bytes"
"encoding/binary"
"errors"
"fmt"
+ "io"
"math"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
)
+// NDPOptionIdentifier is an NDP option type identifier.
+type NDPOptionIdentifier uint8
+
const (
// NDPSourceLinkLayerAddressOptionType is the type of the Source Link Layer
// Address option, as per RFC 4861 section 4.6.1.
- NDPSourceLinkLayerAddressOptionType = 1
+ NDPSourceLinkLayerAddressOptionType NDPOptionIdentifier = 1
// NDPTargetLinkLayerAddressOptionType is the type of the Target Link Layer
// Address option, as per RFC 4861 section 4.6.1.
- NDPTargetLinkLayerAddressOptionType = 2
+ NDPTargetLinkLayerAddressOptionType NDPOptionIdentifier = 2
+
+ // NDPPrefixInformationType is the type of the Prefix Information
+ // option, as per RFC 4861 section 4.6.2.
+ NDPPrefixInformationType NDPOptionIdentifier = 3
+
+ // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS
+ // Server option, as per RFC 8106 section 5.1.
+ NDPRecursiveDNSServerOptionType NDPOptionIdentifier = 25
+)
+const (
// NDPLinkLayerAddressSize is the size of a Source or Target Link Layer
// Address option for an Ethernet address.
NDPLinkLayerAddressSize = 8
- // NDPPrefixInformationType is the type of the Prefix Information
- // option, as per RFC 4861 section 4.6.2.
- NDPPrefixInformationType = 3
-
// ndpPrefixInformationLength is the expected length, in bytes, of the
// body of an NDP Prefix Information option, as per RFC 4861 section
// 4.6.2 which specifies that the Length field is 4. Given this, the
@@ -91,10 +102,6 @@ const (
// within an NDPPrefixInformation.
ndpPrefixInformationPrefixOffset = 14
- // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS
- // Server option, as per RFC 8106 section 5.1.
- NDPRecursiveDNSServerOptionType = 25
-
// ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte
// Lifetime field within an NDPRecursiveDNSServer.
ndpRecursiveDNSServerLifetimeOffset = 2
@@ -103,10 +110,10 @@ const (
// for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer.
ndpRecursiveDNSServerAddressesOffset = 6
- // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS
- // Server option's length field value when it contains at least one
- // IPv6 address.
- minNDPRecursiveDNSServerLength = 3
+ // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS Server
+ // option's body size when it contains at least one IPv6 address, as per
+ // RFC 8106 section 5.3.1.
+ minNDPRecursiveDNSServerBodySize = 22
// lengthByteUnits is the multiplier factor for the Length field of an
// NDP option. That is, the length field for NDP options is in units of
@@ -132,16 +139,13 @@ var (
// few NDPOption then modify the backing NDPOptions so long as the
// NDPOptionIterator obtained before modification is no longer used.
type NDPOptionIterator struct {
- // The NDPOptions this NDPOptionIterator is iterating over.
- opts NDPOptions
+ opts *bytes.Buffer
}
// Potential errors when iterating over an NDPOptions.
var (
- ErrNDPOptBufExhausted = errors.New("Buffer unexpectedly exhausted")
- ErrNDPOptZeroLength = errors.New("NDP option has zero-valued Length field")
- ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body")
- ErrNDPInvalidLength = errors.New("NDP option's Length value is invalid as per relevant RFC")
+ ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body")
+ ErrNDPOptMalformedHeader = errors.New("NDP option has a malformed header")
)
// Next returns the next element in the backing NDPOptions, or true if we are
@@ -152,48 +156,50 @@ var (
func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
for {
// Do we still have elements to look at?
- if len(i.opts) == 0 {
+ if i.opts.Len() == 0 {
return nil, true, nil
}
- // Do we have enough bytes for an NDP option that has a Length
- // field of at least 1? Note, 0 in the Length field is invalid.
- if len(i.opts) < lengthByteUnits {
- return nil, true, ErrNDPOptBufExhausted
- }
-
// Get the Type field.
- t := i.opts[0]
-
- // Get the Length field.
- l := i.opts[1]
+ temp, err := i.opts.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ // ReadByte should only ever return nil or io.EOF.
+ panic(fmt.Sprintf("unexpected error when reading the option's Type field: %s", err))
+ }
- // This would indicate an erroneous NDP option as the Length
- // field should never be 0.
- if l == 0 {
- return nil, true, ErrNDPOptZeroLength
+ // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once
+ // we start parsing an option; we expect the buffer to contain enough
+ // bytes for the whole option.
+ return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Type field: %w", io.ErrUnexpectedEOF)
}
+ kind := NDPOptionIdentifier(temp)
- // How many bytes are in the option body?
- numBytes := int(l) * lengthByteUnits
- numBodyBytes := numBytes - 2
-
- potentialBody := i.opts[2:]
+ // Get the Length field.
+ length, err := i.opts.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ panic(fmt.Sprintf("unexpected error when reading the option's Length field for %s: %s", kind, err))
+ }
- // This would indicate an erroenous NDPOptions buffer as we ran
- // out of the buffer in the middle of an NDP option.
- if left := len(potentialBody); left < numBodyBytes {
- return nil, true, ErrNDPOptBufExhausted
+ return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Length field for %s: %w", kind, io.ErrUnexpectedEOF)
}
- // Get only the options body, leaving the rest of the options
- // buffer alone.
- body := potentialBody[:numBodyBytes]
+ // This would indicate an erroneous NDP option as the Length field should
+ // never be 0.
+ if length == 0 {
+ return nil, true, fmt.Errorf("zero valued Length field for %s: %w", kind, ErrNDPOptMalformedHeader)
+ }
- // Update opts with the remaining options body.
- i.opts = i.opts[numBytes:]
+ // Get the body.
+ numBytes := int(length) * lengthByteUnits
+ numBodyBytes := numBytes - 2
+ body := i.opts.Next(numBodyBytes)
+ if len(body) < numBodyBytes {
+ return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Body for %s: %w", kind, io.ErrUnexpectedEOF)
+ }
- switch t {
+ switch kind {
case NDPSourceLinkLayerAddressOptionType:
return NDPSourceLinkLayerAddressOption(body), false, nil
@@ -205,22 +211,15 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
// body is ndpPrefixInformationLength, as per RFC 4861
// section 4.6.2.
if numBodyBytes != ndpPrefixInformationLength {
- return nil, true, ErrNDPOptMalformedBody
+ return nil, true, fmt.Errorf("got %d bytes for NDP Prefix Information option's body, expected %d bytes: %w", numBodyBytes, ndpPrefixInformationLength, ErrNDPOptMalformedBody)
}
return NDPPrefixInformation(body), false, nil
case NDPRecursiveDNSServerOptionType:
- // RFC 8106 section 5.3.1 outlines that the RDNSS option
- // must have a minimum length of 3 so it contains at
- // least one IPv6 address.
- if l < minNDPRecursiveDNSServerLength {
- return nil, true, ErrNDPInvalidLength
- }
-
opt := NDPRecursiveDNSServer(body)
- if len(opt.Addresses()) == 0 {
- return nil, true, ErrNDPOptMalformedBody
+ if err := opt.checkAddresses(); err != nil {
+ return nil, true, err
}
return opt, false, nil
@@ -247,10 +246,16 @@ type NDPOptions []byte
//
// See NDPOptionIterator for more information.
func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) {
- it := NDPOptionIterator{opts: b}
+ it := NDPOptionIterator{
+ opts: bytes.NewBuffer(b),
+ }
if check {
- for it2 := it; true; {
+ it2 := NDPOptionIterator{
+ opts: bytes.NewBuffer(b),
+ }
+
+ for {
if _, done, err := it2.Next(); err != nil || done {
return it, err
}
@@ -278,7 +283,7 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int {
continue
}
- b[0] = o.Type()
+ b[0] = byte(o.Type())
// We know this safe because paddedLength would have returned
// 0 if o had an invalid length (> 255 * lengthByteUnits).
@@ -304,7 +309,7 @@ type NDPOption interface {
fmt.Stringer
// Type returns the type of the receiver.
- Type() uint8
+ Type() NDPOptionIdentifier
// Length returns the length of the body of the receiver, in bytes.
Length() int
@@ -386,7 +391,7 @@ func (b NDPOptionsSerializer) Length() int {
type NDPSourceLinkLayerAddressOption tcpip.LinkAddress
// Type implements NDPOption.Type.
-func (o NDPSourceLinkLayerAddressOption) Type() uint8 {
+func (o NDPSourceLinkLayerAddressOption) Type() NDPOptionIdentifier {
return NDPSourceLinkLayerAddressOptionType
}
@@ -426,7 +431,7 @@ func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
type NDPTargetLinkLayerAddressOption tcpip.LinkAddress
// Type implements NDPOption.Type.
-func (o NDPTargetLinkLayerAddressOption) Type() uint8 {
+func (o NDPTargetLinkLayerAddressOption) Type() NDPOptionIdentifier {
return NDPTargetLinkLayerAddressOptionType
}
@@ -466,7 +471,7 @@ func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
type NDPPrefixInformation []byte
// Type implements NDPOption.Type.
-func (o NDPPrefixInformation) Type() uint8 {
+func (o NDPPrefixInformation) Type() NDPOptionIdentifier {
return NDPPrefixInformationType
}
@@ -590,7 +595,7 @@ type NDPRecursiveDNSServer []byte
// Type returns the type of an NDP Recursive DNS Server option.
//
// Type implements NDPOption.Type.
-func (NDPRecursiveDNSServer) Type() uint8 {
+func (NDPRecursiveDNSServer) Type() NDPOptionIdentifier {
return NDPRecursiveDNSServerOptionType
}
@@ -613,7 +618,12 @@ func (o NDPRecursiveDNSServer) serializeInto(b []byte) int {
// String implements fmt.Stringer.String.
func (o NDPRecursiveDNSServer) String() string {
- return fmt.Sprintf("%T(%s valid for %s)", o, o.Addresses(), o.Lifetime())
+ lt := o.Lifetime()
+ addrs, err := o.Addresses()
+ if err != nil {
+ return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err)
+ }
+ return fmt.Sprintf("%T(%s valid for %s)", o, addrs, lt)
}
// Lifetime returns the length of time that the DNS server addresses
@@ -632,29 +642,45 @@ func (o NDPRecursiveDNSServer) Lifetime() time.Duration {
// Addresses returns the recursive DNS server IPv6 addresses that may be
// used for name resolution.
//
-// Note, some of the addresses returned MAY be link-local addresses.
+// Note, the addresses MAY be link-local addresses.
+func (o NDPRecursiveDNSServer) Addresses() ([]tcpip.Address, error) {
+ var addrs []tcpip.Address
+ return addrs, o.iterAddresses(func(addr tcpip.Address) { addrs = append(addrs, addr) })
+}
+
+// checkAddresses iterates over the addresses in an NDP Recursive DNS Server
+// option and returns any error it encounters.
+func (o NDPRecursiveDNSServer) checkAddresses() error {
+ return o.iterAddresses(nil)
+}
+
+// iterAddresses iterates over the addresses in an NDP Recursive DNS Server
+// option and calls a function with each valid unicast IPv6 address.
//
-// Addresses may panic if o does not hold valid IPv6 addresses.
-func (o NDPRecursiveDNSServer) Addresses() []tcpip.Address {
- l := len(o)
- if l < ndpRecursiveDNSServerAddressesOffset {
- return nil
+// Note, the addresses MAY be link-local addresses.
+func (o NDPRecursiveDNSServer) iterAddresses(fn func(tcpip.Address)) error {
+ if l := len(o); l < minNDPRecursiveDNSServerBodySize {
+ return fmt.Errorf("got %d bytes for NDP Recursive DNS Server option's body, expected at least %d bytes: %w", l, minNDPRecursiveDNSServerBodySize, io.ErrUnexpectedEOF)
}
- l -= ndpRecursiveDNSServerAddressesOffset
+ o = o[ndpRecursiveDNSServerAddressesOffset:]
+ l := len(o)
if l%IPv6AddressSize != 0 {
- return nil
+ return fmt.Errorf("NDP Recursive DNS Server option's body ends in the middle of an IPv6 address (addresses body size = %d bytes): %w", l, ErrNDPOptMalformedBody)
}
- buf := o[ndpRecursiveDNSServerAddressesOffset:]
- var addrs []tcpip.Address
- for len(buf) > 0 {
- addr := tcpip.Address(buf[:IPv6AddressSize])
+ for i := 0; len(o) != 0; i++ {
+ addr := tcpip.Address(o[:IPv6AddressSize])
if !IsV6UnicastAddress(addr) {
- return nil
+ return fmt.Errorf("%d-th address (%s) in NDP Recursive DNS Server option is not a valid unicast IPv6 address: %w", i, addr, ErrNDPOptMalformedBody)
+ }
+
+ if fn != nil {
+ fn(addr)
}
- addrs = append(addrs, addr)
- buf = buf[IPv6AddressSize:]
+
+ o = o[IPv6AddressSize:]
}
- return addrs
+
+ return nil
}
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index 1cb9f5dc8..2341329c4 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -16,6 +16,8 @@ package header
import (
"bytes"
+ "errors"
+ "io"
"testing"
"time"
@@ -115,7 +117,7 @@ func TestNDPNeighborAdvert(t *testing.T) {
// Make sure flags got updated in the backing buffer.
if got := b[ndpNAFlagsOffset]; got != 64 {
- t.Errorf("got flags byte = %d, want = 64")
+ t.Errorf("got flags byte = %d, want = 64", got)
}
}
@@ -543,8 +545,12 @@ func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) {
want := []tcpip.Address{
"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
}
- if got := opt.Addresses(); !cmp.Equal(got, want) {
- t.Errorf("got Addresses = %v, want = %v", got, want)
+ addrs, err := opt.Addresses()
+ if err != nil {
+ t.Errorf("opt.Addresses() = %s", err)
+ }
+ if diff := cmp.Diff(addrs, want); diff != "" {
+ t.Errorf("mismatched addresses (-want +got):\n%s", diff)
}
// Iterator should not return anything else.
@@ -638,8 +644,12 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) {
if got := opt.Lifetime(); got != test.lifetime {
t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime)
}
- if got := opt.Addresses(); !cmp.Equal(got, test.addrs) {
- t.Errorf("got Addresses = %v, want = %v", got, test.addrs)
+ addrs, err := opt.Addresses()
+ if err != nil {
+ t.Errorf("opt.Addresses() = %s", err)
+ }
+ if diff := cmp.Diff(addrs, test.addrs); diff != "" {
+ t.Errorf("mismatched addresses (-want +got):\n%s", diff)
}
// Iterator should not return anything else.
@@ -661,38 +671,38 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) {
// the iterator was returned for is malformed.
func TestNDPOptionsIterCheck(t *testing.T) {
tests := []struct {
- name string
- buf []byte
- expected error
+ name string
+ buf []byte
+ expectedErr error
}{
{
- "ZeroLengthField",
- []byte{0, 0, 0, 0, 0, 0, 0, 0},
- ErrNDPOptZeroLength,
+ name: "ZeroLengthField",
+ buf: []byte{0, 0, 0, 0, 0, 0, 0, 0},
+ expectedErr: ErrNDPOptMalformedHeader,
},
{
- "ValidSourceLinkLayerAddressOption",
- []byte{1, 1, 1, 2, 3, 4, 5, 6},
- nil,
+ name: "ValidSourceLinkLayerAddressOption",
+ buf: []byte{1, 1, 1, 2, 3, 4, 5, 6},
+ expectedErr: nil,
},
{
- "TooSmallSourceLinkLayerAddressOption",
- []byte{1, 1, 1, 2, 3, 4, 5},
- ErrNDPOptBufExhausted,
+ name: "TooSmallSourceLinkLayerAddressOption",
+ buf: []byte{1, 1, 1, 2, 3, 4, 5},
+ expectedErr: io.ErrUnexpectedEOF,
},
{
- "ValidTargetLinkLayerAddressOption",
- []byte{2, 1, 1, 2, 3, 4, 5, 6},
- nil,
+ name: "ValidTargetLinkLayerAddressOption",
+ buf: []byte{2, 1, 1, 2, 3, 4, 5, 6},
+ expectedErr: nil,
},
{
- "TooSmallTargetLinkLayerAddressOption",
- []byte{2, 1, 1, 2, 3, 4, 5},
- ErrNDPOptBufExhausted,
+ name: "TooSmallTargetLinkLayerAddressOption",
+ buf: []byte{2, 1, 1, 2, 3, 4, 5},
+ expectedErr: io.ErrUnexpectedEOF,
},
{
- "ValidPrefixInformation",
- []byte{
+ name: "ValidPrefixInformation",
+ buf: []byte{
3, 4, 43, 64,
1, 2, 3, 4,
5, 6, 7, 8,
@@ -702,11 +712,11 @@ func TestNDPOptionsIterCheck(t *testing.T) {
17, 18, 19, 20,
21, 22, 23, 24,
},
- nil,
+ expectedErr: nil,
},
{
- "TooSmallPrefixInformation",
- []byte{
+ name: "TooSmallPrefixInformation",
+ buf: []byte{
3, 4, 43, 64,
1, 2, 3, 4,
5, 6, 7, 8,
@@ -716,11 +726,11 @@ func TestNDPOptionsIterCheck(t *testing.T) {
17, 18, 19, 20,
21, 22, 23,
},
- ErrNDPOptBufExhausted,
+ expectedErr: io.ErrUnexpectedEOF,
},
{
- "InvalidPrefixInformationLength",
- []byte{
+ name: "InvalidPrefixInformationLength",
+ buf: []byte{
3, 3, 43, 64,
1, 2, 3, 4,
5, 6, 7, 8,
@@ -728,11 +738,11 @@ func TestNDPOptionsIterCheck(t *testing.T) {
9, 10, 11, 12,
13, 14, 15, 16,
},
- ErrNDPOptMalformedBody,
+ expectedErr: ErrNDPOptMalformedBody,
},
{
- "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation",
- []byte{
+ name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation",
+ buf: []byte{
// Source Link-Layer Address.
1, 1, 1, 2, 3, 4, 5, 6,
@@ -749,11 +759,11 @@ func TestNDPOptionsIterCheck(t *testing.T) {
17, 18, 19, 20,
21, 22, 23, 24,
},
- nil,
+ expectedErr: nil,
},
{
- "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized",
- []byte{
+ name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized",
+ buf: []byte{
// Source Link-Layer Address.
1, 1, 1, 2, 3, 4, 5, 6,
@@ -775,52 +785,52 @@ func TestNDPOptionsIterCheck(t *testing.T) {
17, 18, 19, 20,
21, 22, 23, 24,
},
- nil,
+ expectedErr: nil,
},
{
- "InvalidRecursiveDNSServerCutsOffAddress",
- []byte{
+ name: "InvalidRecursiveDNSServerCutsOffAddress",
+ buf: []byte{
25, 4, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 1, 2, 3, 4, 5, 6, 7,
},
- ErrNDPOptMalformedBody,
+ expectedErr: ErrNDPOptMalformedBody,
},
{
- "InvalidRecursiveDNSServerInvalidLengthField",
- []byte{
+ name: "InvalidRecursiveDNSServerInvalidLengthField",
+ buf: []byte{
25, 2, 0, 0,
0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8,
},
- ErrNDPInvalidLength,
+ expectedErr: io.ErrUnexpectedEOF,
},
{
- "RecursiveDNSServerTooSmall",
- []byte{
+ name: "RecursiveDNSServerTooSmall",
+ buf: []byte{
25, 1, 0, 0,
0, 0, 0,
},
- ErrNDPOptBufExhausted,
+ expectedErr: io.ErrUnexpectedEOF,
},
{
- "RecursiveDNSServerMulticast",
- []byte{
+ name: "RecursiveDNSServerMulticast",
+ buf: []byte{
25, 3, 0, 0,
0, 0, 0, 0,
255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
},
- ErrNDPOptMalformedBody,
+ expectedErr: ErrNDPOptMalformedBody,
},
{
- "RecursiveDNSServerUnspecified",
- []byte{
+ name: "RecursiveDNSServerUnspecified",
+ buf: []byte{
25, 3, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
},
- ErrNDPOptMalformedBody,
+ expectedErr: ErrNDPOptMalformedBody,
},
}
@@ -828,8 +838,8 @@ func TestNDPOptionsIterCheck(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
opts := NDPOptions(test.buf)
- if _, err := opts.Iter(true); err != test.expected {
- t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expected)
+ if _, err := opts.Iter(true); !errors.Is(err, test.expectedErr) {
+ t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expectedErr)
}
// test.buf may be malformed but we chose not to check
diff --git a/pkg/tcpip/header/ndpoptionidentifier_string.go b/pkg/tcpip/header/ndpoptionidentifier_string.go
new file mode 100644
index 000000000..6fe9a336b
--- /dev/null
+++ b/pkg/tcpip/header/ndpoptionidentifier_string.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type NDPOptionIdentifier ."; DO NOT EDIT.
+
+package header
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[NDPSourceLinkLayerAddressOptionType-1]
+ _ = x[NDPTargetLinkLayerAddressOptionType-2]
+ _ = x[NDPPrefixInformationType-3]
+ _ = x[NDPRecursiveDNSServerOptionType-25]
+}
+
+const (
+ _NDPOptionIdentifier_name_0 = "NDPSourceLinkLayerAddressOptionTypeNDPTargetLinkLayerAddressOptionTypeNDPPrefixInformationType"
+ _NDPOptionIdentifier_name_1 = "NDPRecursiveDNSServerOptionType"
+)
+
+var (
+ _NDPOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94}
+)
+
+func (i NDPOptionIdentifier) String() string {
+ switch {
+ case 1 <= i && i <= 3:
+ i -= 1
+ return _NDPOptionIdentifier_name_0[_NDPOptionIdentifier_index_0[i]:_NDPOptionIdentifier_index_0[i+1]]
+ case i == 25:
+ return _NDPOptionIdentifier_name_1
+ default:
+ return "NDPOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 82cfe785c..13480687d 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -81,7 +81,8 @@ type TCPFields struct {
// AckNum is the "acknowledgement number" field of a TCP packet.
AckNum uint32
- // DataOffset is the "data offset" field of a TCP packet.
+ // DataOffset is the "data offset" field of a TCP packet. It is the length of
+ // the TCP header in bytes.
DataOffset uint8
// Flags is the "flags" field of a TCP packet.
@@ -213,7 +214,8 @@ func (b TCP) AckNumber() uint32 {
return binary.BigEndian.Uint32(b[TCPAckNumOffset:])
}
-// DataOffset returns the "data offset" field of the tcp header.
+// DataOffset returns the "data offset" field of the tcp header. The return
+// value is the length of the TCP header in bytes.
func (b TCP) DataOffset() uint8 {
return (b[TCPDataOffset] >> 4) * 4
}
@@ -238,6 +240,11 @@ func (b TCP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[TCPChecksumOffset:])
}
+// UrgentPointer returns the "urgent pointer" field of the tcp header.
+func (b TCP) UrgentPointer() uint16 {
+ return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:])
+}
+
// SetSourcePort sets the "source port" field of the tcp header.
func (b TCP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port)
@@ -253,6 +260,37 @@ func (b TCP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[TCPChecksumOffset:], checksum)
}
+// SetDataOffset sets the data offset field of the tcp header. headerLen should
+// be the length of the TCP header in bytes.
+func (b TCP) SetDataOffset(headerLen uint8) {
+ b[TCPDataOffset] = (headerLen / 4) << 4
+}
+
+// SetSequenceNumber sets the sequence number field of the tcp header.
+func (b TCP) SetSequenceNumber(seqNum uint32) {
+ binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum)
+}
+
+// SetAckNumber sets the ack number field of the tcp header.
+func (b TCP) SetAckNumber(ackNum uint32) {
+ binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum)
+}
+
+// SetFlags sets the flags field of the tcp header.
+func (b TCP) SetFlags(flags uint8) {
+ b[TCPFlagsOffset] = flags
+}
+
+// SetWindowSize sets the window size field of the tcp header.
+func (b TCP) SetWindowSize(rcvwnd uint16) {
+ binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
+}
+
+// SetUrgentPoiner sets the window size field of the tcp header.
+func (b TCP) SetUrgentPoiner(urgentPointer uint16) {
+ binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer)
+}
+
// CalculateChecksum calculates the checksum of the tcp segment.
// partialChecksum is the checksum of the network-layer pseudo-header
// and the checksum of the segment data.
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index 74412c894..9339d637f 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -99,6 +99,11 @@ func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
}
+// SetLength sets the "length" field of the udp header.
+func (b UDP) SetLength(length uint16) {
+ binary.BigEndian.PutUint16(b[udpLength:], length)
+}
+
// CalculateChecksum calculates the checksum of the udp packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD
deleted file mode 100644
index d1b73cfdf..000000000
--- a/pkg/tcpip/iptables/BUILD
+++ /dev/null
@@ -1,18 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "iptables",
- srcs = [
- "iptables.go",
- "targets.go",
- "types.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- ],
-)
diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go
deleted file mode 100644
index 81a2e39a2..000000000
--- a/pkg/tcpip/iptables/targets.go
+++ /dev/null
@@ -1,65 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package iptables
-
-import (
- "gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-// AcceptTarget accepts packets.
-type AcceptTarget struct{}
-
-// Action implements Target.Action.
-func (AcceptTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) {
- return RuleAccept, 0
-}
-
-// DropTarget drops packets.
-type DropTarget struct{}
-
-// Action implements Target.Action.
-func (DropTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) {
- return RuleDrop, 0
-}
-
-// ErrorTarget logs an error and drops the packet. It represents a target that
-// should be unreachable.
-type ErrorTarget struct{}
-
-// Action implements Target.Action.
-func (ErrorTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) {
- log.Debugf("ErrorTarget triggered.")
- return RuleDrop, 0
-}
-
-// UserChainTarget marks a rule as the beginning of a user chain.
-type UserChainTarget struct {
- Name string
-}
-
-// Action implements Target.Action.
-func (UserChainTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) {
- panic("UserChainTarget should never be called.")
-}
-
-// ReturnTarget returns from the current chain. If the chain is a built-in, the
-// hook's underflow should be called.
-type ReturnTarget struct{}
-
-// Action implements Target.Action.
-func (ReturnTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) {
- return RuleReturn, 0
-}
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 5944ba190..9bf67686d 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -28,7 +28,7 @@ import (
// PacketInfo holds all the information about an outbound packet.
type PacketInfo struct {
- Pkt tcpip.PacketBuffer
+ Pkt *stack.PacketBuffer
Proto tcpip.NetworkProtocolNumber
GSO *stack.GSO
Route stack.Route
@@ -50,13 +50,11 @@ type NotificationHandle struct {
}
type queue struct {
+ // c is the outbound packet channel.
+ c chan PacketInfo
// mu protects fields below.
- mu sync.RWMutex
- // c is the outbound packet channel. Sending to c should hold mu.
- c chan PacketInfo
- numWrite int
- numRead int
- notify []*NotificationHandle
+ mu sync.RWMutex
+ notify []*NotificationHandle
}
func (q *queue) Close() {
@@ -64,11 +62,8 @@ func (q *queue) Close() {
}
func (q *queue) Read() (PacketInfo, bool) {
- q.mu.Lock()
- defer q.mu.Unlock()
select {
case p := <-q.c:
- q.numRead++
return p, true
default:
return PacketInfo{}, false
@@ -76,15 +71,8 @@ func (q *queue) Read() (PacketInfo, bool) {
}
func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) {
- // We have to receive from channel without holding the lock, since it can
- // block indefinitely. This will cause a window that numWrite - numRead
- // produces a larger number, but won't go to negative. numWrite >= numRead
- // still holds.
select {
case pkt := <-q.c:
- q.mu.Lock()
- defer q.mu.Unlock()
- q.numRead++
return pkt, true
case <-ctx.Done():
return PacketInfo{}, false
@@ -93,16 +81,12 @@ func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) {
func (q *queue) Write(p PacketInfo) bool {
wrote := false
-
- // It's important to make sure nobody can see numWrite until we increment it,
- // so numWrite >= numRead holds.
- q.mu.Lock()
select {
case q.c <- p:
wrote = true
- q.numWrite++
default:
}
+ q.mu.Lock()
notify := q.notify
q.mu.Unlock()
@@ -116,13 +100,7 @@ func (q *queue) Write(p PacketInfo) bool {
}
func (q *queue) Num() int {
- q.mu.RLock()
- defer q.mu.RUnlock()
- n := q.numWrite - q.numRead
- if n < 0 {
- panic("numWrite < numRead")
- }
- return n
+ return len(q.c)
}
func (q *queue) AddNotify(notify Notification) *NotificationHandle {
@@ -203,12 +181,12 @@ func (e *Endpoint) NumQueued() int {
}
// InjectInbound injects an inbound packet.
-func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
e.InjectLinkAddr(protocol, "", pkt)
}
// InjectLinkAddr injects an inbound packet with a remote link address.
-func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) {
+func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt stack.PacketBuffer) {
e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt)
}
@@ -251,13 +229,13 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
}
// WritePacket stores outbound packets into the channel.
-func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
// Clone r then release its resource so we only get the relevant fields from
// stack.Route without holding a reference to a NIC's endpoint.
route := r.Clone()
route.Release()
p := PacketInfo{
- Pkt: pkt,
+ Pkt: &pkt,
Proto: protocol,
GSO: gso,
Route: route,
@@ -269,21 +247,15 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
}
// WritePackets stores outbound packets into the channel.
-func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
// Clone r then release its resource so we only get the relevant fields from
// stack.Route without holding a reference to a NIC's endpoint.
route := r.Clone()
route.Release()
- payloadView := pkts[0].Data.ToView()
n := 0
- for _, pkt := range pkts {
- off := pkt.DataOffset
- size := pkt.DataSize
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
p := PacketInfo{
- Pkt: tcpip.PacketBuffer{
- Header: pkt.Header,
- Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(),
- },
+ Pkt: pkt,
Proto: protocol,
GSO: gso,
Route: route,
@@ -301,7 +273,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
p := PacketInfo{
- Pkt: tcpip.PacketBuffer{Data: vv},
+ Pkt: &stack.PacketBuffer{Data: vv},
Proto: 0,
GSO: nil,
}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index b7f60178e..b857ce9d0 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -91,7 +91,7 @@ func (p PacketDispatchMode) String() string {
case PacketMMap:
return "PacketMMap"
default:
- return fmt.Sprintf("unknown packet dispatch mode %v", p)
+ return fmt.Sprintf("unknown packet dispatch mode '%d'", p)
}
}
@@ -386,7 +386,7 @@ const (
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
if e.hdrSize > 0 {
// Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
@@ -405,9 +405,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
eth.Encode(ethHdr)
}
+ fd := e.fds[pkt.Hash%uint32(len(e.fds))]
if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
vnetHdr := virtioNetHdr{}
- vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr)
if gso != nil {
vnetHdr.hdrLen = uint16(pkt.Header.UsedLength())
if gso.NeedsCsum {
@@ -428,130 +428,120 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
}
}
- return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView())
+ vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr)
+ return rawfile.NonBlockingWrite3(fd, vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView())
}
if pkt.Data.Size() == 0 {
- return rawfile.NonBlockingWrite(e.fds[0], pkt.Header.View())
+ return rawfile.NonBlockingWrite(fd, pkt.Header.View())
}
- return rawfile.NonBlockingWrite3(e.fds[0], pkt.Header.View(), pkt.Data.ToView(), nil)
+ return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil)
}
// WritePackets writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- var ethHdrBuf []byte
- // hdr + data
- iovLen := 2
- if e.hdrSize > 0 {
- // Add ethernet header if needed.
- ethHdrBuf = make([]byte, header.EthernetMinimumSize)
- eth := header.Ethernet(ethHdrBuf)
- ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
- Type: protocol,
- }
-
- // Preserve the src address if it's set in the route.
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
- } else {
- ethHdr.SrcAddr = e.addr
- }
- eth.Encode(ethHdr)
- iovLen++
- }
+//
+// NOTE: This API uses sendmmsg to batch packets. As a result the underlying FD
+// picked to write the packet out has to be the same for all packets in the
+// list. In other words all packets in the batch should belong to the same
+// flow.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ n := pkts.Len()
- n := len(pkts)
-
- views := pkts[0].Data.Views()
- /*
- * Each bondary in views can add one more iovec.
- *
- * payload | | | |
- * -----------------------------
- * packets | | | | | | |
- * -----------------------------
- * iovecs | | | | | | | | |
- */
- iovec := make([]syscall.Iovec, n*iovLen+len(views)-1)
mmsgHdrs := make([]rawfile.MMsgHdr, n)
+ i := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ var ethHdrBuf []byte
+ iovLen := 0
+ if e.hdrSize > 0 {
+ // Add ethernet header if needed.
+ ethHdrBuf = make([]byte, header.EthernetMinimumSize)
+ eth := header.Ethernet(ethHdrBuf)
+ ethHdr := &header.EthernetFields{
+ DstAddr: r.RemoteLinkAddress,
+ Type: protocol,
+ }
- iovecIdx := 0
- viewIdx := 0
- viewOff := 0
- off := 0
- nextOff := 0
- for i := range pkts {
- // TODO(b/134618279): Different packets may have different data
- // in the future. We should handle this.
- if !viewsEqual(pkts[i].Data.Views(), views) {
- panic("All packets in pkts should have the same Data.")
+ // Preserve the src address if it's set in the route.
+ if r.LocalLinkAddress != "" {
+ ethHdr.SrcAddr = r.LocalLinkAddress
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+ iovLen++
}
- prevIovecIdx := iovecIdx
- mmsgHdr := &mmsgHdrs[i]
- mmsgHdr.Msg.Iov = &iovec[iovecIdx]
- packetSize := pkts[i].DataSize
- hdr := &pkts[i].Header
-
- off = pkts[i].DataOffset
- if off != nextOff {
- // We stop in a different point last time.
- size := packetSize
- viewIdx = 0
- viewOff = 0
- for size > 0 {
- if size >= len(views[viewIdx]) {
- viewIdx++
- viewOff = 0
- size -= len(views[viewIdx])
- } else {
- viewOff = size
- size = 0
+ var vnetHdrBuf []byte
+ vnetHdr := virtioNetHdr{}
+ if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if gso != nil {
+ vnetHdr.hdrLen = uint16(pkt.Header.UsedLength())
+ if gso.NeedsCsum {
+ vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
+ vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen
+ vnetHdr.csumOffset = gso.CsumOffset
+ }
+ if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS {
+ switch gso.Type {
+ case stack.GSOTCPv4:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
+ case stack.GSOTCPv6:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
+ default:
+ panic(fmt.Sprintf("Unknown gso type: %v", gso.Type))
+ }
+ vnetHdr.gsoSize = gso.MSS
}
}
+ vnetHdrBuf = vnetHdrToByteSlice(&vnetHdr)
+ iovLen++
}
- nextOff = off + packetSize
+ iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views()))
+ mmsgHdr := &mmsgHdrs[i]
+ mmsgHdr.Msg.Iov = &iovecs[0]
+ iovecIdx := 0
+ if vnetHdrBuf != nil {
+ v := &iovecs[iovecIdx]
+ v.Base = &vnetHdrBuf[0]
+ v.Len = uint64(len(vnetHdrBuf))
+ iovecIdx++
+ }
if ethHdrBuf != nil {
- v := &iovec[iovecIdx]
+ v := &iovecs[iovecIdx]
v.Base = &ethHdrBuf[0]
v.Len = uint64(len(ethHdrBuf))
iovecIdx++
}
-
- v := &iovec[iovecIdx]
+ pktSize := uint64(0)
+ // Encode L3 Header
+ v := &iovecs[iovecIdx]
+ hdr := &pkt.Header
hdrView := hdr.View()
v.Base = &hdrView[0]
v.Len = uint64(len(hdrView))
+ pktSize += v.Len
iovecIdx++
- for packetSize > 0 {
- vec := &iovec[iovecIdx]
+ // Now encode the Transport Payload.
+ pktViews := pkt.Data.Views()
+ for i := range pktViews {
+ vec := &iovecs[iovecIdx]
iovecIdx++
-
- v := views[viewIdx]
- vec.Base = &v[viewOff]
- s := len(v) - viewOff
- if s <= packetSize {
- viewIdx++
- viewOff = 0
- } else {
- s = packetSize
- viewOff += s
- }
- vec.Len = uint64(s)
- packetSize -= s
+ vec.Base = &pktViews[i][0]
+ vec.Len = uint64(len(pktViews[i]))
+ pktSize += vec.Len
}
-
- mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx)
+ mmsgHdr.Msg.Iovlen = uint64(iovecIdx)
+ i++
}
packets := 0
for packets < n {
- sent, err := rawfile.NonBlockingSendMMsg(e.fds[0], mmsgHdrs)
+ fd := e.fds[pkts.Front().Hash%uint32(len(e.fds))]
+ sent, err := rawfile.NonBlockingSendMMsg(fd, mmsgHdrs)
if err != nil {
return packets, err
}
@@ -610,7 +600,7 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
}
// InjectInbound injects an inbound packet.
-func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt)
}
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 2066987eb..3bfb15a8e 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -45,40 +45,46 @@ const (
type packetInfo struct {
raddr tcpip.LinkAddress
proto tcpip.NetworkProtocolNumber
- contents tcpip.PacketBuffer
+ contents stack.PacketBuffer
}
type context struct {
- t *testing.T
- fds [2]int
- ep stack.LinkEndpoint
- ch chan packetInfo
- done chan struct{}
+ t *testing.T
+ readFDs []int
+ writeFDs []int
+ ep stack.LinkEndpoint
+ ch chan packetInfo
+ done chan struct{}
}
func newContext(t *testing.T, opt *Options) *context {
- fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
+ firstFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
+ if err != nil {
+ t.Fatalf("Socketpair failed: %v", err)
+ }
+ secondFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
if err != nil {
t.Fatalf("Socketpair failed: %v", err)
}
- done := make(chan struct{}, 1)
+ done := make(chan struct{}, 2)
opt.ClosedFunc = func(*tcpip.Error) {
done <- struct{}{}
}
- opt.FDs = []int{fds[1]}
+ opt.FDs = []int{firstFDPair[1], secondFDPair[1]}
ep, err := New(opt)
if err != nil {
t.Fatalf("Failed to create FD endpoint: %v", err)
}
c := &context{
- t: t,
- fds: fds,
- ep: ep,
- ch: make(chan packetInfo, 100),
- done: done,
+ t: t,
+ readFDs: []int{firstFDPair[0], secondFDPair[0]},
+ writeFDs: opt.FDs,
+ ep: ep,
+ ch: make(chan packetInfo, 100),
+ done: done,
}
ep.Attach(c)
@@ -87,12 +93,17 @@ func newContext(t *testing.T, opt *Options) *context {
}
func (c *context) cleanup() {
- syscall.Close(c.fds[0])
+ for _, fd := range c.readFDs {
+ syscall.Close(fd)
+ }
+ <-c.done
<-c.done
- syscall.Close(c.fds[1])
+ for _, fd := range c.writeFDs {
+ syscall.Close(fd)
+ }
}
-func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
c.ch <- packetInfo{remote, protocol, pkt}
}
@@ -136,7 +147,7 @@ func TestAddress(t *testing.T) {
}
}
-func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) {
+func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash uint32) {
c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize})
defer c.cleanup()
@@ -168,16 +179,18 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) {
L3HdrLen: header.IPv4MaximumHeaderSize,
}
}
- if err := c.ep.WritePacket(r, gso, proto, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(r, gso, proto, stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
+ Hash: hash,
}); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
- // Read from fd, then compare with what we wrote.
+ // Read from the corresponding FD, then compare with what we wrote.
b = make([]byte, mtu)
- n, err := syscall.Read(c.fds[0], b)
+ fd := c.readFDs[hash%uint32(len(c.readFDs))]
+ n, err := syscall.Read(fd, b)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
@@ -238,7 +251,7 @@ func TestWritePacket(t *testing.T) {
t.Run(
fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso),
func(t *testing.T) {
- testWritePacket(t, plen, eth, gso)
+ testWritePacket(t, plen, eth, gso, 0)
},
)
}
@@ -246,6 +259,27 @@ func TestWritePacket(t *testing.T) {
}
}
+func TestHashedWritePacket(t *testing.T) {
+ lengths := []int{0, 100, 1000}
+ eths := []bool{true, false}
+ gsos := []uint32{0, 32768}
+ hashes := []uint32{0, 1}
+ for _, eth := range eths {
+ for _, plen := range lengths {
+ for _, gso := range gsos {
+ for _, hash := range hashes {
+ t.Run(
+ fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v,Hash=%d", eth, plen, gso, hash),
+ func(t *testing.T) {
+ testWritePacket(t, plen, eth, gso, hash)
+ },
+ )
+ }
+ }
+ }
+ }
+}
+
func TestPreserveSrcAddress(t *testing.T) {
baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99")
@@ -261,7 +295,7 @@ func TestPreserveSrcAddress(t *testing.T) {
// WritePacket panics given a prependable with anything less than
// the minimum size of the ethernet header.
hdr := buffer.NewPrependable(header.EthernetMinimumSize)
- if err := c.ep.WritePacket(r, nil /* gso */, proto, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(r, nil /* gso */, proto, stack.PacketBuffer{
Header: hdr,
Data: buffer.VectorisedView{},
}); err != nil {
@@ -270,7 +304,7 @@ func TestPreserveSrcAddress(t *testing.T) {
// Read from the FD, then compare with what we wrote.
b := make([]byte, mtu)
- n, err := syscall.Read(c.fds[0], b)
+ n, err := syscall.Read(c.readFDs[0], b)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
@@ -314,7 +348,7 @@ func TestDeliverPacket(t *testing.T) {
}
// Write packet via the file descriptor.
- if _, err := syscall.Write(c.fds[0], all); err != nil {
+ if _, err := syscall.Write(c.readFDs[0], all); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -324,7 +358,7 @@ func TestDeliverPacket(t *testing.T) {
want := packetInfo{
raddr: raddr,
proto: proto,
- contents: tcpip.PacketBuffer{
+ contents: stack.PacketBuffer{
Data: buffer.View(b).ToVectorisedView(),
LinkHeader: buffer.View(hdr),
},
diff --git a/pkg/tcpip/link/fdbased/endpoint_unsafe.go b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
index 97a477b61..d81858353 100644
--- a/pkg/tcpip/link/fdbased/endpoint_unsafe.go
+++ b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
@@ -24,9 +24,10 @@ import (
const virtioNetHdrSize = int(unsafe.Sizeof(virtioNetHdr{}))
func vnetHdrToByteSlice(hdr *virtioNetHdr) (slice []byte) {
- sh := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
- sh.Data = uintptr(unsafe.Pointer(hdr))
- sh.Len = virtioNetHdrSize
- sh.Cap = virtioNetHdrSize
+ *(*reflect.SliceHeader)(unsafe.Pointer(&slice)) = reflect.SliceHeader{
+ Data: uintptr((unsafe.Pointer(hdr))),
+ Len: virtioNetHdrSize,
+ Cap: virtioNetHdrSize,
+ }
return
}
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index 62ed1e569..fe2bf3b0b 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
@@ -190,7 +191,7 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
}
pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, tcpip.PacketBuffer{
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, stack.PacketBuffer{
Data: buffer.View(pkt).ToVectorisedView(),
LinkHeader: buffer.View(eth),
})
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index c67d684ce..cb4cbea69 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -139,7 +139,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
}
used := d.capViews(n, BufConfig)
- pkt := tcpip.PacketBuffer{
+ pkt := stack.PacketBuffer{
Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)),
LinkHeader: buffer.View(eth),
}
@@ -296,7 +296,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
}
used := d.capViews(k, int(n), BufConfig)
- pkt := tcpip.PacketBuffer{
+ pkt := stack.PacketBuffer{
Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)),
LinkHeader: buffer.View(eth),
}
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 499cc608f..1e2255bfa 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -76,7 +76,7 @@ func (*endpoint) Wait() {}
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
@@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw
// Because we're immediately turning around and writing the packet back
// to the rx path, we intentionally don't preserve the remote and local
// link addresses from the stack.Route we're passed.
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, tcpip.PacketBuffer{
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
@@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
@@ -106,7 +106,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
// There should be an ethernet header at the beginning of vv.
linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize])
vv.TrimFront(len(linkHeader))
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), tcpip.PacketBuffer{
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{
Data: vv,
LinkHeader: buffer.View(linkHeader),
})
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 445b22c17..a5478ce17 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -80,14 +80,14 @@ func (m *InjectableEndpoint) IsAttached() bool {
}
// InjectInbound implements stack.InjectableLinkEndpoint.
-func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt)
}
// WritePackets writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if
// r.RemoteAddress has a route registered in this endpoint.
-func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
endpoint, ok := m.routes[r.RemoteAddress]
if !ok {
return 0, tcpip.ErrNoRoute
@@ -98,7 +98,7 @@ func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts [
// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint
// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a
// route registered in this endpoint.
-func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
if endpoint, ok := m.routes[r.RemoteAddress]; ok {
return endpoint.WritePacket(r, gso, protocol, pkt)
}
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 63b249837..87c734c1f 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -50,7 +50,7 @@ func TestInjectableEndpointDispatch(t *testing.T) {
hdr.Prepend(1)[0] = 0xFA
packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(),
})
@@ -70,7 +70,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) {
hdr := buffer.NewPrependable(1)
hdr.Prepend(1)[0] = 0xFA
packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buffer.NewView(0).ToVectorisedView(),
})
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 655e537c4..0796d717e 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -185,7 +185,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
// Add the ethernet header here.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
@@ -214,7 +214,7 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.Netw
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
@@ -275,7 +275,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
// Send packet up the stack.
eth := header.Ethernet(b[:header.EthernetMinimumSize])
- d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), tcpip.PacketBuffer{
+ d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), stack.PacketBuffer{
Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(),
LinkHeader: buffer.View(eth),
})
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 5c729a439..27ea3f531 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -131,7 +131,7 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
return c
}
-func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
c.mu.Lock()
c.packets = append(c.packets, packetInfo{
addr: remoteLinkAddr,
@@ -273,7 +273,7 @@ func TestSimpleSend(t *testing.T) {
randomFill(buf)
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -345,7 +345,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
hdr := buffer.NewPrependable(header.EthernetMinimumSize)
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{
Header: hdr,
}); err != nil {
t.Fatalf("WritePacket failed: %v", err)
@@ -401,7 +401,7 @@ func TestFillTxQueue(t *testing.T) {
for i := queuePipeSize / 40; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -419,7 +419,7 @@ func TestFillTxQueue(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != want {
@@ -447,7 +447,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
// Send two packets so that the id slice has at least two slots.
for i := 2; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -470,7 +470,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
ids := make(map[uint64]struct{})
for i := queuePipeSize / 40; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -488,7 +488,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != want {
@@ -514,7 +514,7 @@ func TestFillTxMemory(t *testing.T) {
ids := make(map[uint64]struct{})
for i := queueDataSize / bufferSize; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -533,7 +533,7 @@ func TestFillTxMemory(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
})
@@ -561,7 +561,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
// until there is only one buffer left.
for i := queueDataSize/bufferSize - 1; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
@@ -577,7 +577,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
{
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
uu := buffer.NewView(bufferSize).ToVectorisedView()
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: uu,
}); err != want {
@@ -588,7 +588,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
// Attempt to write the one-buffer packet again. It must succeed.
{
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{
Header: hdr,
Data: buf.ToVectorisedView(),
}); err != nil {
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 3392b7edd..be2537a82 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -21,11 +21,9 @@
package sniffer
import (
- "bytes"
"encoding/binary"
"fmt"
"io"
- "os"
"sync/atomic"
"time"
@@ -42,12 +40,12 @@ import (
// LogPackets must be accessed atomically.
var LogPackets uint32 = 1
-// LogPacketsToFile is a flag used to enable or disable logging packets to a
-// pcap file. Valid values are 0 or 1. A file must have been specified when the
+// LogPacketsToPCAP is a flag used to enable or disable logging packets to a
+// pcap writer. Valid values are 0 or 1. A writer must have been specified when the
// sniffer was created for this flag to have effect.
//
-// LogPacketsToFile must be accessed atomically.
-var LogPacketsToFile uint32 = 1
+// LogPacketsToPCAP must be accessed atomically.
+var LogPacketsToPCAP uint32 = 1
var transportProtocolMinSizes map[tcpip.TransportProtocolNumber]int = map[tcpip.TransportProtocolNumber]int{
header.ICMPv4ProtocolNumber: header.IPv4MinimumSize,
@@ -59,7 +57,7 @@ var transportProtocolMinSizes map[tcpip.TransportProtocolNumber]int = map[tcpip.
type endpoint struct {
dispatcher stack.NetworkDispatcher
lower stack.LinkEndpoint
- file *os.File
+ writer io.Writer
maxPCAPLen uint32
}
@@ -99,23 +97,22 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error {
})
}
-// NewWithFile creates a new sniffer link-layer endpoint. It wraps around
-// another endpoint and logs packets and they traverse the endpoint.
+// NewWithWriter creates a new sniffer link-layer endpoint. It wraps around
+// another endpoint and logs packets as they traverse the endpoint.
//
-// Packets can be logged to file in the pcap format. A sniffer created
-// with this function will not emit packets using the standard log
-// package.
+// Packets are logged to writer in the pcap format. A sniffer created with this
+// function will not emit packets using the standard log package.
//
// snapLen is the maximum amount of a packet to be saved. Packets with a length
-// less than or equal too snapLen will be saved in their entirety. Longer
+// less than or equal to snapLen will be saved in their entirety. Longer
// packets will be truncated to snapLen.
-func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) {
- if err := writePCAPHeader(file, snapLen); err != nil {
+func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (stack.LinkEndpoint, error) {
+ if err := writePCAPHeader(writer, snapLen); err != nil {
return nil, err
}
return &endpoint{
lower: lower,
- file: file,
+ writer: writer,
maxPCAPLen: snapLen,
}, nil
}
@@ -123,37 +120,8 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
-func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
- if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("recv", protocol, pkt.Data.First(), nil)
- }
- if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
- vs := pkt.Data.Views()
- length := pkt.Data.Size()
- if length > int(e.maxPCAPLen) {
- length = int(e.maxPCAPLen)
- }
-
- buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
- if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(pkt.Data.Size()))); err != nil {
- panic(err)
- }
- for _, v := range vs {
- if length == 0 {
- break
- }
- if len(v) > length {
- v = v[:length]
- }
- if _, err := buf.Write([]byte(v)); err != nil {
- panic(err)
- }
- length -= len(v)
- }
- if _, err := e.file.Write(buf.Bytes()); err != nil {
- panic(err)
- }
- }
+func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
+ e.dumpPacket("recv", nil, protocol, &pkt)
e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt)
}
@@ -200,31 +168,43 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
-func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
- if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("send", protocol, pkt.Header.View(), gso)
+func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ writer := e.writer
+ if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
+ first := pkt.Header.View()
+ if len(first) == 0 {
+ first = pkt.Data.First()
+ }
+ logPacket(prefix, protocol, first, gso)
}
- if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
- hdrBuf := pkt.Header.View()
- length := len(hdrBuf) + pkt.Data.Size()
- if length > int(e.maxPCAPLen) {
- length = int(e.maxPCAPLen)
+ if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
+ totalLength := pkt.Header.UsedLength() + pkt.Data.Size()
+ length := totalLength
+ if max := int(e.maxPCAPLen); length > max {
+ length = max
}
-
- buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
- if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+pkt.Data.Size()))); err != nil {
+ if err := binary.Write(writer, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(totalLength))); err != nil {
panic(err)
}
- if len(hdrBuf) > length {
- hdrBuf = hdrBuf[:length]
- }
- if _, err := buf.Write(hdrBuf); err != nil {
- panic(err)
+ write := func(b []byte) {
+ if len(b) > length {
+ b = b[:length]
+ }
+ for len(b) != 0 {
+ n, err := writer.Write(b)
+ if err != nil {
+ panic(err)
+ }
+ b = b[n:]
+ length -= n
+ }
}
- length -= len(hdrBuf)
- logVectorisedView(pkt.Data, length, buf)
- if _, err := e.file.Write(buf.Bytes()); err != nil {
- panic(err)
+ write(pkt.Header.View())
+ for _, view := range pkt.Data.Views() {
+ if length == 0 {
+ break
+ }
+ write(view)
}
}
}
@@ -232,69 +212,31 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumb
// WritePacket implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
- e.dumpPacket(gso, protocol, pkt)
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+ e.dumpPacket("send", gso, protocol, &pkt)
return e.lower.WritePacket(r, gso, protocol, pkt)
}
// WritePackets implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- view := pkts[0].Data.ToView()
- for _, pkt := range pkts {
- e.dumpPacket(gso, protocol, tcpip.PacketBuffer{
- Header: pkt.Header,
- Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(),
- })
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.dumpPacket("send", gso, protocol, pkt)
}
return e.lower.WritePackets(r, gso, pkts, protocol)
}
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */)
- }
- if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
- length := vv.Size()
- if length > int(e.maxPCAPLen) {
- length = int(e.maxPCAPLen)
- }
-
- buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
- if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil {
- panic(err)
- }
- logVectorisedView(vv, length, buf)
- if _, err := e.file.Write(buf.Bytes()); err != nil {
- panic(err)
- }
- }
+ e.dumpPacket("send", nil, 0, &stack.PacketBuffer{
+ Data: vv,
+ })
return e.lower.WriteRawPacket(vv)
}
-func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) {
- if length <= 0 {
- return
- }
- for _, v := range vv.Views() {
- if len(v) > length {
- v = v[:length]
- }
- n, err := buf.Write(v)
- if err != nil {
- panic(err)
- }
- length -= n
- if length == 0 {
- return
- }
- }
-}
-
// Wait implements stack.LinkEndpoint.Wait.
-func (*endpoint) Wait() {}
+func (e *endpoint) Wait() { e.lower.Wait() }
func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
// Figure out the network layer info.
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 6ff47a742..617446ea2 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -98,7 +98,12 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
prefix = "tap"
}
- endpoint, err := attachOrCreateNIC(s, name, prefix)
+ linkCaps := stack.CapabilityNone
+ if isTap {
+ linkCaps |= stack.CapabilityResolutionRequired
+ }
+
+ endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
if err != nil {
return syserror.EINVAL
}
@@ -109,7 +114,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
return nil
}
-func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error) {
+func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) {
for {
// 1. Try to attach to an existing NIC.
if name != "" {
@@ -135,6 +140,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error
nicID: id,
name: name,
}
+ endpoint.Endpoint.LinkEPCapabilities = linkCaps
if endpoint.name == "" {
endpoint.name = fmt.Sprintf("%s%d", prefix, id)
}
@@ -207,7 +213,7 @@ func (d *Device) Write(data []byte) (int64, error) {
remote = tcpip.LinkAddress(zeroMAC[:])
}
- pkt := tcpip.PacketBuffer{
+ pkt := stack.PacketBuffer{
Data: buffer.View(data).ToVectorisedView(),
}
if ethHdr != nil {
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index a8de38979..2b3741276 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -50,7 +50,7 @@ func New(lower stack.LinkEndpoint) *Endpoint {
// It is called by the link-layer endpoint being wrapped when a packet arrives,
// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't
// been called.
-func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
if !e.dispatchGate.Enter() {
return
}
@@ -99,7 +99,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
if !e.writeGate.Enter() {
return nil
}
@@ -112,9 +112,9 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
if !e.writeGate.Enter() {
- return len(pkts), nil
+ return pkts.Len(), nil
}
n, err := e.lower.WritePackets(r, gso, pkts, protocol)
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 31b11a27a..54eb5322b 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -35,7 +35,7 @@ type countedEndpoint struct {
dispatcher stack.NetworkDispatcher
}
-func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
e.dispatchCount++
}
@@ -65,15 +65,15 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
e.writeCount++
return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- e.writeCount += len(pkts)
- return len(pkts), nil
+func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ e.writeCount += pkts.Len()
+ return pkts.Len(), nil
}
func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
@@ -89,21 +89,21 @@ func TestWaitWrite(t *testing.T) {
wep := New(ep)
// Write and check that it goes through.
- wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
if want := 1; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on dispatches, then try to write. It must go through.
wep.WaitDispatch()
- wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on writes, then try to write. It must not go through.
wep.WaitWrite()
- wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{})
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
@@ -120,21 +120,21 @@ func TestWaitDispatch(t *testing.T) {
}
// Dispatch and check that it goes through.
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{})
if want := 1; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on writes, then try to dispatch. It must go through.
wep.WaitWrite()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{})
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on dispatches, then try to dispatch. It must not go through.
wep.WaitDispatch()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{})
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index e9fcc89a8..7acbfa0a8 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -79,20 +79,20 @@ func (e *endpoint) MaxHeaderLength() uint16 {
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) {
return 0, tcpip.ErrNotSupported
}
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
-func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
v := pkt.Data.First()
h := header.ARP(v)
if !h.IsValid() {
@@ -113,7 +113,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
- e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{
+ e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
Header: hdr,
})
fallthrough // also fill the cache from requests
@@ -167,7 +167,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
copy(h.ProtocolAddressSender(), localAddr)
copy(h.ProtocolAddressTarget(), addr)
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
Header: hdr,
})
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 03cf03b6d..1646d9cde 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -103,7 +103,7 @@ func TestDirectRequest(t *testing.T) {
inject := func(addr tcpip.Address) {
copy(h.ProtocolAddressTarget(), addr)
- c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{
Data: v.ToVectorisedView(),
})
}
@@ -138,7 +138,8 @@ func TestDirectRequest(t *testing.T) {
// Sleep tests are gross, but this will only potentially flake
// if there's a bug. If there is no bug this will reliably
// succeed.
- ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
if pkt, ok := c.linkEP.ReadContext(ctx); ok {
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 92f2aa13a..f42abc4bb 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -115,10 +115,12 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf
// Evict reassemblers if we are consuming more memory than highLimit until
// we reach lowLimit.
if f.size > f.highLimit {
- tail := f.rList.Back()
- for f.size > f.lowLimit && tail != nil {
+ for f.size > f.lowLimit {
+ tail := f.rList.Back()
+ if tail == nil {
+ break
+ }
f.release(tail)
- tail = tail.Prev()
}
}
f.mu.Unlock()
diff --git a/pkg/tcpip/network/hash/hash.go b/pkg/tcpip/network/hash/hash.go
index 6a215938b..8f65713c5 100644
--- a/pkg/tcpip/network/hash/hash.go
+++ b/pkg/tcpip/network/hash/hash.go
@@ -80,12 +80,12 @@ func IPv4FragmentHash(h header.IPv4) uint32 {
// RFC 2640 (sec 4.5) is not very sharp on this aspect.
// As a reference, also Linux ignores the protocol to compute
// the hash (inet6_hash_frag).
-func IPv6FragmentHash(h header.IPv6, f header.IPv6Fragment) uint32 {
+func IPv6FragmentHash(h header.IPv6, id uint32) uint32 {
t := h.SourceAddress()
y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = h.DestinationAddress()
z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
- return Hash3Words(f.ID(), y, z, hashIV)
+ return Hash3Words(id, y, z, hashIV)
}
func rol32(v, shift uint32) uint32 {
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index f4d78f8c6..4c20301c6 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -96,7 +96,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) {
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) {
t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
t.dataCalls++
}
@@ -104,7 +104,7 @@ func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.Trans
// DeliverTransportControlPacket is called by network endpoints after parsing
// incoming control (ICMP) packets. This is used by the test object to verify
// that the results of the parsing are expected.
-func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
t.checkValues(trans, pkt.Data, remote, local)
if typ != t.typ {
t.t.Errorf("typ = %v, want %v", typ, t.typ)
@@ -150,7 +150,7 @@ func (*testObject) Wait() {}
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
@@ -172,7 +172,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
@@ -246,7 +246,7 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
}); err != nil {
@@ -289,7 +289,7 @@ func TestIPv4Receive(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, tcpip.PacketBuffer{
+ ep.HandlePacket(&r, stack.PacketBuffer{
Data: view.ToVectorisedView(),
})
if o.dataCalls != 1 {
@@ -379,7 +379,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
o.extra = c.expectedExtra
vv := view[:len(view)-c.trunc].ToVectorisedView()
- ep.HandlePacket(&r, tcpip.PacketBuffer{
+ ep.HandlePacket(&r, stack.PacketBuffer{
Data: vv,
})
if want := c.expectedCount; o.controlCalls != want {
@@ -444,7 +444,7 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Send first segment.
- ep.HandlePacket(&r, tcpip.PacketBuffer{
+ ep.HandlePacket(&r, stack.PacketBuffer{
Data: frag1.ToVectorisedView(),
})
if o.dataCalls != 0 {
@@ -452,7 +452,7 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Send second segment.
- ep.HandlePacket(&r, tcpip.PacketBuffer{
+ ep.HandlePacket(&r, stack.PacketBuffer{
Data: frag2.ToVectorisedView(),
})
if o.dataCalls != 1 {
@@ -487,7 +487,7 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
}); err != nil {
@@ -530,7 +530,7 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, tcpip.PacketBuffer{
+ ep.HandlePacket(&r, stack.PacketBuffer{
Data: view.ToVectorisedView(),
})
if o.dataCalls != 1 {
@@ -644,7 +644,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
- ep.HandlePacket(&r, tcpip.PacketBuffer{
+ ep.HandlePacket(&r, stack.PacketBuffer{
Data: view[:len(view)-c.trunc].ToVectorisedView(),
})
if want := c.expectedCount; o.controlCalls != want {
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 0fef2b1f1..880ea7de2 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -13,7 +13,6 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 32bf39e43..c4bf1ba5c 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,7 +15,6 @@
package ipv4
import (
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -25,7 +24,7 @@ import (
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
h := header.IPv4(pkt.Data.First())
// We don't use IsValid() here because ICMP only requires that the IP
@@ -53,7 +52,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) {
+func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
stats := r.Stats()
received := stats.ICMP.V4PacketsReceived
v := pkt.Data.First()
@@ -85,7 +84,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) {
// It's possible that a raw socket expects to receive this.
h.SetChecksum(wantChecksum)
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, tcpip.PacketBuffer{
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{
Data: pkt.Data.Clone(nil),
NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...),
})
@@ -99,7 +98,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) {
pkt.SetChecksum(0)
pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: vv,
TransportHeader: buffer.View(pkt),
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 4f1742938..104aafbed 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -125,7 +124,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
// packet's stated length matches the length of the header+payload. mtu
// includes the IP header and options. This does not support the DontFragment
// IP flag.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error {
// This packet is too big, it needs to be fragmented.
ip := header.IPv4(pkt.Header.View())
flags := ip.Flags()
@@ -165,7 +164,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
if i > 0 {
newPayload := pkt.Data.Clone(nil)
newPayload.CapLength(innerMTU)
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
Header: pkt.Header,
Data: newPayload,
NetworkHeader: buffer.View(h),
@@ -184,7 +183,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
newPayload := pkt.Data.Clone(nil)
newPayloadLength := outerMTU - pkt.Header.UsedLength()
newPayload.CapLength(newPayloadLength)
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
Header: pkt.Header,
Data: newPayload,
NetworkHeader: buffer.View(h),
@@ -198,7 +197,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
startOfHdr := pkt.Header
startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU)
emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
Header: startOfHdr,
Data: emptyVV,
NetworkHeader: buffer.View(h),
@@ -241,10 +240,18 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(stack.Output, pkt); !ok {
+ // iptables is telling us to drop the packet.
+ return nil
+ }
+
if r.Loop&stack.PacketLoop != 0 {
// The inbound path expects the network header to still be in
// the PacketBuffer's Data field.
@@ -253,7 +260,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, tcpip.PacketBuffer{
+ e.HandlePacket(&loopedR, stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
@@ -273,26 +280,52 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("multiple packets in local loop")
}
if r.Loop&stack.PacketOut == 0 {
- return len(pkts), nil
+ return pkts.Len(), nil
+ }
+
+ for pkt := pkts.Front(); pkt != nil; {
+ ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
+ pkt.NetworkHeader = buffer.View(ip)
+ pkt = pkt.Next()
}
- for i := range pkts {
- ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params)
- pkts[i].NetworkHeader = buffer.View(ip)
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ ipt := e.stack.IPTables()
+ dropped := ipt.CheckPackets(stack.Output, pkts)
+ if len(dropped) == 0 {
+ // Fast path: If no packets are to be dropped then we can just invoke the
+ // faster WritePackets API directly.
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+ }
+
+ // Slow Path as we are dropping some packets in the batch degrade to
+ // emitting one packet at a time.
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if _, ok := dropped[pkt]; ok {
+ continue
+ }
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil {
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+ }
+ n++
}
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, err
+ return n, nil
}
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
ip := header.IPv4(pkt.Data.First())
@@ -344,7 +377,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuf
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
headerView := pkt.Data.First()
h := header.IPv4(headerView)
if !h.IsValid(pkt.Data.Size()) {
@@ -361,7 +394,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
ipt := e.stack.IPTables()
- if ok := ipt.Check(iptables.Input, pkt); !ok {
+ if ok := ipt.Check(stack.Input, pkt); !ok {
// iptables is telling us to drop the packet.
return
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index e900f1b45..5a864d832 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -113,7 +113,7 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
-func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) {
+func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) {
t.Helper()
// Make a complete array of the sourcePacketInfo packet.
source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
@@ -173,7 +173,7 @@ func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketIn
type errorChannel struct {
*channel.Endpoint
- Ch chan tcpip.PacketBuffer
+ Ch chan stack.PacketBuffer
packetCollectorErrors []*tcpip.Error
}
@@ -183,7 +183,7 @@ type errorChannel struct {
func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
return &errorChannel{
Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan tcpip.PacketBuffer, size),
+ Ch: make(chan stack.PacketBuffer, size),
packetCollectorErrors: packetCollectorErrors,
}
}
@@ -202,7 +202,7 @@ func (e *errorChannel) Drain() int {
}
// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
select {
case e.Ch <- pkt:
default:
@@ -281,13 +281,13 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
- source := tcpip.PacketBuffer{
+ source := stack.PacketBuffer{
Header: hdr,
// Save the source payload because WritePacket will modify it.
Data: payload.Clone(nil),
}
c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload,
})
@@ -295,7 +295,7 @@ func TestFragmentation(t *testing.T) {
t.Errorf("err got %v, want %v", err, nil)
}
- var results []tcpip.PacketBuffer
+ var results []stack.PacketBuffer
L:
for {
select {
@@ -337,7 +337,7 @@ func TestFragmentationErrors(t *testing.T) {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload,
})
@@ -459,7 +459,7 @@ func TestInvalidFragments(t *testing.T) {
s.CreateNIC(nicID, sniffer.New(ep))
for _, pkt := range tc.packets {
- ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{
+ ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
})
}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index fb11874c6..3f71fc520 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -13,6 +13,8 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/network/fragmentation",
+ "//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
],
)
@@ -29,6 +31,7 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
@@ -36,5 +39,6 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
+ "@com_github_google_go-cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 45dc757c7..b68983d10 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -15,7 +15,7 @@
package ipv6
import (
- "log"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -27,7 +27,7 @@ import (
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
h := header.IPv6(pkt.Data.First())
// We don't use IsValid() here because ICMP only requires that up to
@@ -62,7 +62,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.PacketBuffer) {
+func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
@@ -79,32 +79,22 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
// Only the first view in vv is accounted for by h. To account for the
// rest of vv, a shallow copy is made and the first view is removed.
// This copy is used as extra payload during the checksum calculation.
- payload := pkt.Data
+ payload := pkt.Data.Clone(nil)
payload.RemoveFirst()
if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want {
received.Invalid.Increment()
return
}
- // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
- // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
- // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not
- // set to 0.
- switch h.Type() {
- case header.ICMPv6NeighborSolicit,
- header.ICMPv6NeighborAdvert,
- header.ICMPv6RouterSolicit,
- header.ICMPv6RouterAdvert,
- header.ICMPv6RedirectMsg:
- if iph.HopLimit() != header.NDPHopLimit {
- received.Invalid.Increment()
- return
- }
-
- if h.Code() != 0 {
- received.Invalid.Increment()
- return
- }
+ isNDPValid := func() bool {
+ // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
+ // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
+ // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not
+ // set to 0.
+ //
+ // As per RFC 6980 section 5, nodes MUST silently drop NDP messages if the
+ // packet includes a fragmentation header.
+ return !hasFragmentHeader && iph.HopLimit() == header.NDPHopLimit && h.Code() == 0
}
// TODO(b/112892170): Meaningfully handle all ICMP types.
@@ -133,7 +123,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
- if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
+ if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
received.Invalid.Increment()
return
}
@@ -148,58 +138,53 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
targetAddr := ns.TargetAddress()
s := r.Stack()
- rxNICID := r.NICID()
- if isTentative, err := s.IsAddrTentative(rxNICID, targetAddr); err != nil {
- // We will only get an error if rxNICID is unrecognized,
- // which should not happen. For now short-circuit this
- // packet.
+ if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
+ // We will only get an error if the NIC is unrecognized, which should not
+ // happen. For now, drop this packet.
//
// TODO(b/141002840): Handle this better?
return
} else if isTentative {
- // If the target address is tentative and the source
- // of the packet is a unicast (specified) address, then
- // the source of the packet is attempting to perform
- // address resolution on the target. In this case, the
- // solicitation is silently ignored, as per RFC 4862
- // section 5.4.3.
+ // If the target address is tentative and the source of the packet is a
+ // unicast (specified) address, then the source of the packet is
+ // attempting to perform address resolution on the target. In this case,
+ // the solicitation is silently ignored, as per RFC 4862 section 5.4.3.
//
- // If the target address is tentative and the source of
- // the packet is the unspecified address (::), then we
- // know another node is also performing DAD for the
- // same address (since targetAddr is tentative for us,
- // we know we are also performing DAD on it). In this
- // case we let the stack know so it can handle such a
- // scenario and do nothing further with the NDP NS.
- if iph.SourceAddress() == header.IPv6Any {
- s.DupTentativeAddrDetected(rxNICID, targetAddr)
+ // If the target address is tentative and the source of the packet is the
+ // unspecified address (::), then we know another node is also performing
+ // DAD for the same address (since the target address is tentative for us,
+ // we know we are also performing DAD on it). In this case we let the
+ // stack know so it can handle such a scenario and do nothing further with
+ // the NS.
+ if r.RemoteAddress == header.IPv6Any {
+ s.DupTentativeAddrDetected(e.nicID, targetAddr)
}
- // Do not handle neighbor solicitations targeted
- // to an address that is tentative on the received
- // NIC any further.
+ // Do not handle neighbor solicitations targeted to an address that is
+ // tentative on the NIC any further.
return
}
- // At this point we know that targetAddr is not tentative on
- // rxNICID so the packet is processed as defined in RFC 4861,
- // as per RFC 4862 section 5.4.3.
+ // At this point we know that the target address is not tentative on the NIC
+ // so the packet is processed as defined in RFC 4861, as per RFC 4862
+ // section 5.4.3.
+ // Is the NS targetting us?
if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
- // We don't have a useful answer; the best we can do is ignore the request.
return
}
- // If the NS message has the source link layer option, update the link
- // address cache with the link address for the sender of the message.
+ // If the NS message contains the Source Link-Layer Address option, update
+ // the link address cache with the value of the option.
//
// TODO(b/148429853): Properly process the NS message and do Neighbor
// Unreachability Detection.
+ var sourceLinkAddr tcpip.LinkAddress
for {
opt, done, err := it.Next()
if err != nil {
// This should never happen as Iter(true) above did not return an error.
- log.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
}
if done {
break
@@ -207,22 +192,36 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
switch opt := opt.(type) {
case header.NDPSourceLinkLayerAddressOption:
- e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, opt.EthernetAddress())
+ // No RFCs define what to do when an NS message has multiple Source
+ // Link-Layer Address options. Since no interface can have multiple
+ // link-layer addresses, we consider such messages invalid.
+ if len(sourceLinkAddr) != 0 {
+ received.Invalid.Increment()
+ return
+ }
+
+ sourceLinkAddr = opt.EthernetAddress()
}
}
- optsSerializer := header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]),
+ unspecifiedSource := r.RemoteAddress == header.IPv6Any
+
+ // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST
+ // NOT be included when the source IP address is the unspecified address.
+ // Otherwise, on link layers that have addresses this option MUST be
+ // included in multicast solicitations and SHOULD be included in unicast
+ // solicitations.
+ if len(sourceLinkAddr) == 0 {
+ if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource {
+ received.Invalid.Increment()
+ return
+ }
+ } else if unspecifiedSource {
+ received.Invalid.Increment()
+ return
+ } else {
+ e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr)
}
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()))
- packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
- packet.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(packet.NDPPayload())
- na.SetSolicitedFlag(true)
- na.SetOverrideFlag(true)
- na.SetTargetAddress(targetAddr)
- opts := na.Options()
- opts.Serialize(optsSerializer)
// ICMPv6 Neighbor Solicit messages are always sent to
// specially crafted IPv6 multicast addresses. As a result, the
@@ -235,6 +234,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
r := r.Clone()
defer r.Release()
r.LocalAddress = targetAddr
+
+ // As per RFC 4861 section 7.2.4, if the the source of the solicitation is
+ // the unspecified address, the node MUST set the Solicited flag to zero and
+ // multicast the advertisement to the all-nodes address.
+ solicited := true
+ if unspecifiedSource {
+ solicited = false
+ r.RemoteAddress = header.IPv6AllNodesMulticastAddress
+ }
+
+ // If the NS has a source link-layer option, use the link address it
+ // specifies as the remote link address for the response instead of the
+ // source link address of the packet.
+ //
+ // TODO(#2401): As per RFC 4861 section 7.2.4 we should consult our link
+ // address cache for the right destination link address instead of manually
+ // patching the route with the remote link address if one is specified in a
+ // Source Link-Layer Address option.
+ if len(sourceLinkAddr) != 0 {
+ r.RemoteLinkAddress = sourceLinkAddr
+ }
+
+ optsSerializer := header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress),
+ }
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()))
+ packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ packet.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(packet.NDPPayload())
+ na.SetSolicitedFlag(solicited)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(targetAddr)
+ opts := na.Options()
+ opts.Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
// RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
@@ -243,7 +276,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
//
// The IP Hop Limit field has a value of 255, i.e., the packet
// could not possibly have been forwarded by a router.
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
}); err != nil {
sent.Dropped.Increment()
@@ -253,7 +286,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
case header.ICMPv6NeighborAdvert:
received.NeighborAdvert.Increment()
- if len(v) < header.ICMPv6NeighborAdvertSize {
+ if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
received.Invalid.Increment()
return
}
@@ -268,45 +301,43 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
targetAddr := na.TargetAddress()
stack := r.Stack()
- rxNICID := r.NICID()
- if isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr); err != nil {
- // We will only get an error if rxNICID is unrecognized,
- // which should not happen. For now short-circuit this
- // packet.
+ if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil {
+ // We will only get an error if the NIC is unrecognized, which should not
+ // happen. For now short-circuit this packet.
//
// TODO(b/141002840): Handle this better?
return
} else if isTentative {
- // We just got an NA from a node that owns an address we
- // are performing DAD on, implying the address is not
- // unique. In this case we let the stack know so it can
- // handle such a scenario and do nothing furthur with
+ // We just got an NA from a node that owns an address we are performing
+ // DAD on, implying the address is not unique. In this case we let the
+ // stack know so it can handle such a scenario and do nothing furthur with
// the NDP NA.
- stack.DupTentativeAddrDetected(rxNICID, targetAddr)
+ stack.DupTentativeAddrDetected(e.nicID, targetAddr)
return
}
- // At this point we know that the targetAddress is not tentative
- // on rxNICID. However, targetAddr may still be assigned to
- // rxNICID but not tentative (it could be permanent). Such a
- // scenario is beyond the scope of RFC 4862. As such, we simply
- // ignore such a scenario for now and proceed as normal.
+ // At this point we know that the target address is not tentative on the
+ // NIC. However, the target address may still be assigned to the NIC but not
+ // tentative (it could be permanent). Such a scenario is beyond the scope of
+ // RFC 4862. As such, we simply ignore such a scenario for now and proceed
+ // as normal.
//
+ // TODO(b/143147598): Handle the scenario described above. Also inform the
+ // netstack integration that a duplicate address was detected outside of
+ // DAD.
+
// If the NA message has the target link layer option, update the link
// address cache with the link address for the target of the message.
//
- // TODO(b/143147598): Handle the scenario described above. Also
- // inform the netstack integration that a duplicate address was
- // detected outside of DAD.
- //
// TODO(b/148429853): Properly process the NA message and do Neighbor
// Unreachability Detection.
+ var targetLinkAddr tcpip.LinkAddress
for {
opt, done, err := it.Next()
if err != nil {
// This should never happen as Iter(true) above did not return an error.
- log.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
}
if done {
break
@@ -314,10 +345,22 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
switch opt := opt.(type) {
case header.NDPTargetLinkLayerAddressOption:
- e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, opt.EthernetAddress())
+ // No RFCs define what to do when an NA message has multiple Target
+ // Link-Layer Address options. Since no interface can have multiple
+ // link-layer addresses, we consider such messages invalid.
+ if len(targetLinkAddr) != 0 {
+ received.Invalid.Increment()
+ return
+ }
+
+ targetLinkAddr = opt.EthernetAddress()
}
}
+ if len(targetLinkAddr) != 0 {
+ e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr)
+ }
+
case header.ICMPv6EchoRequest:
received.EchoRequest.Increment()
if len(v) < header.ICMPv6EchoMinimumSize {
@@ -330,7 +373,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
copy(packet, h)
packet.SetType(header.ICMPv6EchoReply)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: pkt.Data,
}); err != nil {
@@ -355,8 +398,20 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
case header.ICMPv6RouterSolicit:
received.RouterSolicit.Increment()
+ if !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
case header.ICMPv6RouterAdvert:
+ received.RouterAdvert.Increment()
+
+ p := h.NDPPayload()
+ if len(p) < header.NDPRAMinimumSize || !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
routerAddr := iph.SourceAddress()
//
@@ -370,16 +425,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
return
}
- p := h.NDPPayload()
-
- // Is the NDP payload of sufficient size to hold a Router
- // Advertisement?
- if len(p) < header.NDPRAMinimumSize {
- // ...No, silently drop the packet.
- received.Invalid.Increment()
- return
- }
-
ra := header.NDPRouterAdvert(p)
opts := ra.Options()
@@ -395,8 +440,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
// as RFC 4861 section 6.1.2 is concerned.
//
- received.RouterAdvert.Increment()
-
// Tell the NIC to handle the RA.
stack := r.Stack()
rxNICID := r.NICID()
@@ -404,6 +447,10 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
case header.ICMPv6RedirectMsg:
received.RedirectMsg.Increment()
+ if !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
default:
received.Invalid.Increment()
@@ -463,7 +510,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
})
// TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
Header: hdr,
})
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 50c4b6474..bd099a7f8 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -32,7 +32,8 @@ import (
const (
linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
)
var (
@@ -56,7 +57,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
-func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error {
+func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error {
return nil
}
@@ -66,7 +67,7 @@ type stubDispatcher struct {
stack.TransportDispatcher
}
-func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) {
+func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) {
}
type stubLinkAddressCache struct {
@@ -187,7 +188,7 @@ func TestICMPCounts(t *testing.T) {
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, tcpip.PacketBuffer{
+ ep.HandlePacket(&r, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
}
@@ -326,7 +327,7 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size()
vv := buffer.NewVectorisedView(size, views)
- args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{
+ args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{
Data: vv,
})
}
@@ -561,7 +562,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
}
@@ -738,7 +739,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
}
@@ -916,7 +917,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
})
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 9aef5234b..331b0817b 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -21,11 +21,14 @@
package ipv6
import (
+ "fmt"
"sync/atomic"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -49,6 +52,7 @@ type endpoint struct {
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
dispatcher stack.TransportDispatcher
+ fragmentation *fragmentation.Fragmentation
protocol *protocol
}
@@ -112,7 +116,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
@@ -124,7 +128,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, tcpip.PacketBuffer{
+ e.HandlePacket(&loopedR, stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
@@ -139,19 +143,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
}
if r.Loop&stack.PacketOut == 0 {
- return len(pkts), nil
+ return pkts.Len(), nil
}
- for i := range pkts {
- hdr := &pkts[i].Header
- size := pkts[i].DataSize
- ip := e.addIPHeader(r, hdr, size, params)
- pkts[i].NetworkHeader = buffer.View(ip)
+ for pb := pkts.Front(); pb != nil; pb = pb.Next() {
+ ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params)
+ pb.NetworkHeader = buffer.View(ip)
}
n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
@@ -161,17 +163,18 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
// TODO(b/146666412): Support IPv6 header-included packets.
return tcpip.ErrNotSupported
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
headerView := pkt.Data.First()
h := header.IPv6(headerView)
if !h.IsValid(pkt.Data.Size()) {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
@@ -179,14 +182,235 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
pkt.Data.TrimFront(header.IPv6MinimumSize)
pkt.Data.CapLength(int(h.PayloadLength()))
- p := h.TransportProtocol()
- if p == header.ICMPv6ProtocolNumber {
- e.handleICMP(r, headerView, pkt)
- return
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data)
+ hasFragmentHeader := false
+
+ for firstHeader := true; ; firstHeader = false {
+ extHdr, done, err := it.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ switch extHdr := extHdr.(type) {
+ case header.IPv6HopByHopOptionsExtHdr:
+ // As per RFC 8200 section 4.1, the Hop By Hop extension header is
+ // restricted to appear immediately after an IPv6 fixed header.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1
+ // (unrecognized next header) error in response to an extension header's
+ // Next Header field with the Hop By Hop extension header identifier.
+ if !firstHeader {
+ return
+ }
+
+ optsIt := extHdr.Iter()
+
+ for {
+ opt, done, err := optsIt.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ // We currently do not support any IPv6 Hop By Hop extension header
+ // options.
+ switch opt.UnknownAction() {
+ case header.IPv6OptionUnknownActionSkip:
+ case header.IPv6OptionUnknownActionDiscard:
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ default:
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
+ }
+ }
+
+ case header.IPv6RoutingExtHdr:
+ // As per RFC 8200 section 4.4, if a node encounters a routing header with
+ // an unrecognized routing type value, with a non-zero Segments Left
+ // value, the node must discard the packet and send an ICMP Parameter
+ // Problem, Code 0. If the Segments Left is 0, the node must ignore the
+ // Routing extension header and process the next header in the packet.
+ //
+ // Note, the stack does not yet handle any type of routing extension
+ // header, so we just make sure Segments Left is zero before processing
+ // the next extension header.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for
+ // unrecognized routing types with a non-zero Segments Left value.
+ if extHdr.SegmentsLeft() != 0 {
+ return
+ }
+
+ case header.IPv6FragmentExtHdr:
+ hasFragmentHeader = true
+
+ fragmentOffset := extHdr.FragmentOffset()
+ more := extHdr.More()
+ if !more && fragmentOffset == 0 {
+ // This fragment extension header indicates that this packet is an
+ // atomic fragment. An atomic fragment is a fragment that contains
+ // all the data required to reassemble a full packet. As per RFC 6946,
+ // atomic fragments must not interfere with "normal" fragmented traffic
+ // so we skip processing the fragment instead of feeding it through the
+ // reassembly process below.
+ continue
+ }
+
+ // Don't consume the iterator if we have the first fragment because we
+ // will use it to validate that the first fragment holds the upper layer
+ // header.
+ rawPayload := it.AsRawHeader(fragmentOffset != 0 /* consume */)
+
+ if fragmentOffset == 0 {
+ // Check that the iterator ends with a raw payload as the first fragment
+ // should include all headers up to and including any upper layer
+ // headers, as per RFC 8200 section 4.5; only upper layer data
+ // (non-headers) should follow the fragment extension header.
+ var lastHdr header.IPv6PayloadHeader
+
+ for {
+ it, done, err := it.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ lastHdr = it
+ }
+
+ // If the last header is a raw header, then the last portion of the IPv6
+ // payload is not a known IPv6 extension header. Note, this does not
+ // mean that the last portion is an upper layer header or not an
+ // extension header because:
+ // 1) we do not yet support all extension headers
+ // 2) we do not validate the upper layer header before reassembling.
+ //
+ // This check makes sure that a known IPv6 extension header is not
+ // present after the Fragment extension header in a non-initial
+ // fragment.
+ //
+ // TODO(#2196): Support IPv6 Authentication and Encapsulated
+ // Security Payload extension headers.
+ // TODO(#2333): Validate that the upper layer header is valid.
+ switch lastHdr.(type) {
+ case header.IPv6RawPayloadHeader:
+ default:
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+ }
+
+ fragmentPayloadLen := rawPayload.Buf.Size()
+ if fragmentPayloadLen == 0 {
+ // Drop the packet as it's marked as a fragment but has no payload.
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ // The packet is a fragment, let's try to reassemble it.
+ start := fragmentOffset * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
+ last := start + uint16(fragmentPayloadLen) - 1
+
+ // Drop the packet if the fragmentOffset is incorrect. i.e the
+ // combination of fragmentOffset and pkt.Data.size() causes a
+ // wrap around resulting in last being less than the offset.
+ if last < start {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ var ready bool
+ pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, more, rawPayload.Buf)
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ if ready {
+ // We create a new iterator with the reassembled packet because we could
+ // have more extension headers in the reassembled payload, as per RFC
+ // 8200 section 4.5.
+ it = header.MakeIPv6PayloadIterator(rawPayload.Identifier, pkt.Data)
+ }
+
+ case header.IPv6DestinationOptionsExtHdr:
+ optsIt := extHdr.Iter()
+
+ for {
+ opt, done, err := optsIt.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ // We currently do not support any IPv6 Destination extension header
+ // options.
+ switch opt.UnknownAction() {
+ case header.IPv6OptionUnknownActionSkip:
+ case header.IPv6OptionUnknownActionDiscard:
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ default:
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
+ }
+ }
+
+ case header.IPv6RawPayloadHeader:
+ // If the last header in the payload isn't a known IPv6 extension header,
+ // handle it as if it is transport layer data.
+ pkt.Data = extHdr.Buf
+
+ if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
+ e.handleICMP(r, headerView, pkt, hasFragmentHeader)
+ } else {
+ r.Stats().IP.PacketsDelivered.Increment()
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
+ // in response to unrecognized next header values.
+ e.dispatcher.DeliverTransportPacket(r, p, pkt)
+ }
+
+ default:
+ // If we receive a packet for an extension header we do not yet handle,
+ // drop the packet for now.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
+ // in response to unrecognized next header values.
+ r.Stats().UnknownProtocolRcvdPackets.Increment()
+ return
+ }
}
-
- r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, pkt)
}
// Close cleans up resources associated with the endpoint.
@@ -229,6 +453,7 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi
linkEP: linkEP,
linkAddrCache: linkAddrCache,
dispatcher: dispatcher,
+ fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
protocol: p,
}, nil
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 1cbfa7278..841a0cb7a 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -17,6 +17,7 @@ package ipv6
import (
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -33,6 +34,15 @@ const (
// The least significant 3 bytes are the same as addr2 so both addr2 and
// addr3 will have the same solicited-node address.
addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
+ addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03"
+
+ // Tests use the extension header identifier values as uint8 instead of
+ // header.IPv6ExtensionHeaderIdentifier.
+ hopByHopExtHdrID = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier)
+ routingExtHdrID = uint8(header.IPv6RoutingExtHdrIdentifier)
+ fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier)
+ destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
+ noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier)
)
// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
@@ -55,7 +65,7 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -113,7 +123,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -158,6 +168,8 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
// packets destined to the IPv6 solicited-node address of an assigned IPv6
// address.
func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
+ const nicID = 1
+
tests := []struct {
name string
protocolFactory stack.TransportProtocol
@@ -175,50 +187,61 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
})
- e := channel.New(10, 1280, linkAddr1)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ e := channel.New(1, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- // Should not receive a packet destined to the solicited
- // node address of addr2/addr3 yet as we haven't added
- // those addresses.
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // Should not receive a packet destined to the solicited node address of
+ // addr2/addr3 yet as we haven't added those addresses.
test.rxf(t, s, e, addr1, snmc, 0)
- if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err)
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
}
- // Should receive a packet destined to the solicited
- // node address of addr2/addr3 now that we have added
- // added addr2.
+ // Should receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have added added addr2.
test.rxf(t, s, e, addr1, snmc, 1)
- if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err)
+ if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err)
}
- // Should still receive a packet destined to the
- // solicited node address of addr2/addr3 now that we
- // have added addr3.
+ // Should still receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have added addr3.
test.rxf(t, s, e, addr1, snmc, 2)
- if err := s.RemoveAddress(1, addr2); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err)
+ if err := s.RemoveAddress(nicID, addr2); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr2, err)
}
- // Should still receive a packet destined to the
- // solicited node address of addr2/addr3 now that we
- // have removed addr2.
+ // Should still receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have removed addr2.
test.rxf(t, s, e, addr1, snmc, 3)
- if err := s.RemoveAddress(1, addr3); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err)
+ // Make sure addr3's endpoint does not get removed from the NIC by
+ // incrementing its reference count with a route.
+ r, err := s.FindRoute(nicID, addr3, addr4, ProtocolNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr3, addr4, ProtocolNumber, err)
+ }
+ defer r.Release()
+
+ if err := s.RemoveAddress(nicID, addr3); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr3, err)
}
- // Should not receive a packet destined to the solicited
- // node address of addr2/addr3 yet as both of them got
- // removed.
+ // Should not receive a packet destined to the solicited node address of
+ // addr2/addr3 yet as both of them got removed, even though a route using
+ // addr3 exists.
test.rxf(t, s, e, addr1, snmc, 3)
})
}
@@ -268,3 +291,975 @@ func TestAddIpv6Address(t *testing.T) {
})
}
}
+
+func TestReceiveIPv6ExtHdrs(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ extHdr func(nextHdr uint8) ([]byte, uint8)
+ shouldAccept bool
+ }{
+ {
+ name: "None",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop with unknown option skippable action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 62, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop with unknown option discard action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard unknown.
+ 127, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "routing with zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "routing with non-zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "atomic fragment with zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "atomic fragment with non-zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option skippable action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 62, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "destination with unknown option discard action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard unknown.
+ 127, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action unless multicast dest",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "routing - atomic fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ nextHdr, 0, 0, 0, 1, 2, 3, 4,
+ }, routingExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "atomic fragment - routing",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Fragment extension header.
+ routingExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Routing extension header.
+ nextHdr, 0, 1, 0, 2, 3, 4, 5,
+ }, fragmentExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hop by hop (with skippable unknown) - routing",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ nextHdr, 0, 1, 0, 2, 3, 4, 5,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "routing - hop by hop (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Hop By Hop extension header with skippable unknown option.
+ nextHdr, 0, 62, 4, 1, 2, 3, 4,
+ }, routingExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with skippable unknown option.
+ nextHdr, 0, 63, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop (with discard unknown) - routing - atomic fragment - destination (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with discard action for unknown option.
+ routingExtHdrID, 0, 65, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with skippable unknown option.
+ nextHdr, 0, 63, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with discard action for unknown
+ // option.
+ nextHdr, 0, 65, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ udpPayload := []byte{1, 2, 3, 4, 5, 6, 7, 8}
+ udpLength := header.UDPMinimumSize + len(udpPayload)
+ extHdrBytes, ipv6NextHdr := test.extHdr(uint8(header.UDPProtocolNumber))
+ extHdrLen := len(extHdrBytes)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + extHdrLen + udpLength)
+
+ // Serialize UDP message.
+ u := header.UDP(hdr.Prepend(udpLength))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: uint16(udpLength),
+ })
+ copy(u.Payload(), udpPayload)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum = header.Checksum(udpPayload, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ // Copy extension header bytes between the UDP message and the IPv6
+ // fixed header.
+ copy(hdr.Prepend(extHdrLen), extHdrBytes)
+
+ // Serialize IPv6 fixed header.
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: ipv6NextHdr,
+ HopLimit: 255,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ })
+
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ stats := s.Stats().UDP.PacketsReceived
+
+ if !test.shouldAccept {
+ if got := stats.Value(); got != 0 {
+ t.Errorf("got UDP Rx Packets = %d, want = 0", got)
+ }
+
+ return
+ }
+
+ // Expect a UDP packet.
+ if got := stats.Value(); got != 1 {
+ t.Errorf("got UDP Rx Packets = %d, want = 1", got)
+ }
+ gotPayload, _, err := ep.Read(nil)
+ if err != nil {
+ t.Fatalf("Read(nil): %s", err)
+ }
+ if diff := cmp.Diff(buffer.View(udpPayload), gotPayload); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have any more UDP packets.
+ if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+}
+
+// fragmentData holds the IPv6 payload for a fragmented IPv6 packet.
+type fragmentData struct {
+ nextHdr uint8
+ data buffer.VectorisedView
+}
+
+func TestReceiveIPv6Fragments(t *testing.T) {
+ const nicID = 1
+ const udpPayload1Length = 256
+ const udpPayload2Length = 128
+ const fragmentExtHdrLen = 8
+ // Note, not all routing extension headers will be 8 bytes but this test
+ // uses 8 byte routing extension headers for most sub tests.
+ const routingExtHdrLen = 8
+
+ udpGen := func(payload []byte, multiplier uint8) buffer.View {
+ payloadLen := len(payload)
+ for i := 0; i < payloadLen; i++ {
+ payload[i] = uint8(i) * multiplier
+ }
+
+ udpLength := header.UDPMinimumSize + payloadLen
+
+ hdr := buffer.NewPrependable(udpLength)
+ u := header.UDP(hdr.Prepend(udpLength))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: uint16(udpLength),
+ })
+ copy(u.Payload(), payload)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum = header.Checksum(payload, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ return hdr.View()
+ }
+
+ var udpPayload1Buf [udpPayload1Length]byte
+ udpPayload1 := udpPayload1Buf[:]
+ ipv6Payload1 := udpGen(udpPayload1, 1)
+
+ var udpPayload2Buf [udpPayload2Length]byte
+ udpPayload2 := udpPayload2Buf[:]
+ ipv6Payload2 := udpGen(udpPayload2, 2)
+
+ tests := []struct {
+ name string
+ expectedPayload []byte
+ fragments []fragmentData
+ expectedPayloads [][]byte
+ }{
+ {
+ name: "No fragmentation",
+ fragments: []fragmentData{
+ {
+ nextHdr: uint8(header.UDPProtocolNumber),
+ data: ipv6Payload1.ToVectorisedView(),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Atomic fragment",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1),
+ []buffer.View{
+ // Fragment extension header.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
+
+ ipv6Payload1,
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments with different IDs",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with per-fragment routing header with zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments with per-fragment routing header with non-zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with routing header with zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1},
+ },
+ {
+ name: "Two fragments with routing header with non-zero segments left",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with routing header with zero segments left across fragments",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is fragmentExtHdrLen+8 because the
+ // first 8 bytes of the 16 byte routing extension header is in
+ // this fragment.
+ fragmentExtHdrLen+8,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header (part 1)
+ //
+ // Segments left = 0.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}),
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is
+ // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of
+ // the 16 byte routing extension header is in this fagment.
+ fragmentExtHdrLen+8+len(ipv6Payload1),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 1, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
+
+ // Routing extension header (part 2)
+ buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
+
+ ipv6Payload1,
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with routing header with non-zero segments left across fragments",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is fragmentExtHdrLen+8 because the
+ // first 8 bytes of the 16 byte routing extension header is in
+ // this fragment.
+ fragmentExtHdrLen+8,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header (part 1)
+ //
+ // Segments left = 1.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}),
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is
+ // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of
+ // the 16 byte routing extension header is in this fagment.
+ fragmentExtHdrLen+8+len(ipv6Payload1),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 1, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
+
+ // Routing extension header (part 2)
+ buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
+
+ ipv6Payload1,
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ // As per RFC 6946, IPv6 atomic fragments MUST NOT interfere with "normal"
+ // fragmented traffic.
+ {
+ name: "Two fragments with atomic",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ // This fragment has the same ID as the other fragments but is an atomic
+ // fragment. It should not interfere with the other fragments.
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload2),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}),
+
+ ipv6Payload2,
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload2, udpPayload1},
+ },
+ {
+ name: "Two interleaved fragmented packets",
+ fragments: []fragmentData{
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1[:64],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}),
+
+ ipv6Payload2[:32],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1[64:],
+ },
+ ),
+ },
+ {
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload2)-32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 4, More = false, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}),
+
+ ipv6Payload2[32:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1, udpPayload2},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ for _, f := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize)
+
+ // Serialize IPv6 fixed header.
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(f.data.Size()),
+ NextHeader: f.nextHdr,
+ HopLimit: 255,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ })
+
+ vv := hdr.View().ToVectorisedView()
+ vv.Append(f.data)
+
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ Data: vv,
+ })
+ }
+
+ if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
+ t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
+ }
+
+ for i, p := range test.expectedPayloads {
+ gotPayload, _, err := ep.Read(nil)
+ if err != nil {
+ t.Fatalf("(i=%d) Read(nil): %s", i, err)
+ }
+ if diff := cmp.Diff(buffer.View(p), gotPayload); diff != "" {
+ t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
+ }
+ }
+
+ if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index c9395de52..12b70f7e9 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -135,7 +136,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -173,6 +174,257 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
}
}
+func TestNeighorSolicitationResponse(t *testing.T) {
+ const nicID = 1
+ nicAddr := lladdr0
+ remoteAddr := lladdr1
+ nicAddrSNMC := header.SolicitedNodeAddr(nicAddr)
+ nicLinkAddr := linkAddr0
+ remoteLinkAddr0 := linkAddr1
+ remoteLinkAddr1 := linkAddr2
+
+ tests := []struct {
+ name string
+ nsOpts header.NDPOptionsSerializer
+ nsSrcLinkAddr tcpip.LinkAddress
+ nsSrc tcpip.Address
+ nsDst tcpip.Address
+ nsInvalid bool
+ naDstLinkAddr tcpip.LinkAddress
+ naSolicited bool
+ naSrc tcpip.Address
+ naDst tcpip.Address
+ }{
+ {
+ name: "Unspecified source to multicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: false,
+ naSrc: nicAddr,
+ naDst: header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ name: "Unspecified source with source ll option to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+ {
+ name: "Unspecified source to unicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: false,
+ naSrc: nicAddr,
+ naDst: header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ name: "Unspecified source with source ll option to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddr,
+ nsInvalid: true,
+ },
+
+ {
+ name: "Specified source with 1 source ll to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll different from route to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr1,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source to multicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+ {
+ name: "Specified source with 2 source ll to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+
+ {
+ name: "Specified source to unicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll different from route to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr1,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 2 source ll to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(1, 1280, nicLinkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(nicAddr)
+ opts := ns.Options()
+ opts.Serialize(test.nsOpts)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ if test.nsInvalid {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+
+ if p, got := e.Read(); got {
+ t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
+ }
+
+ // If we expected the NS to be invalid, we have nothing else to check.
+ return
+ }
+
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ p, got := e.Read()
+ if !got {
+ t.Fatal("expected an NDP NA response")
+ }
+
+ if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ }
+
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(test.naSrc),
+ checker.DstAddr(test.naDst),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNA(
+ checker.NDPNASolicitedFlag(test.naSolicited),
+ checker.NDPNATargetAddress(nicAddr),
+ checker.NDPNAOptions([]header.NDPOption{
+ header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
+ }),
+ ))
+ })
+ }
+}
+
// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a
// valid NDP NA message with the Target Link Layer Address option results in a
// new entry in the link address cache for the target of the message.
@@ -197,6 +449,13 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
name: "Invalid Length",
optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
},
+ {
+ name: "Multiple",
+ optsBuf: []byte{
+ 2, 1, 2, 3, 4, 5, 6, 7,
+ 2, 1, 2, 3, 4, 5, 6, 8,
+ },
+ },
}
for _, test := range tests {
@@ -238,7 +497,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -276,9 +535,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
}
}
-// TestHopLimitValidation is a test that makes sure that NDP packets are only
-// received if their IP header's hop limit is set to 255.
-func TestHopLimitValidation(t *testing.T) {
+func TestNDPValidation(t *testing.T) {
setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
t.Helper()
@@ -294,17 +551,24 @@ func TestHopLimitValidation(t *testing.T) {
return s, ep, r
}
- handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, ep stack.NetworkEndpoint, r *stack.Route) {
+ handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ nextHdr := uint8(header.ICMPv6ProtocolNumber)
+ if atomicFragment {
+ bytes := hdr.Prepend(header.IPv6FragmentExtHdrLength)
+ bytes[0] = nextHdr
+ nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
+ }
+
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ NextHeader: nextHdr,
HopLimit: hopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(r, tcpip.PacketBuffer{
+ ep.HandlePacket(r, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
}
@@ -364,61 +628,93 @@ func TestHopLimitValidation(t *testing.T) {
},
}
+ subTests := []struct {
+ name string
+ atomicFragment bool
+ hopLimit uint8
+ code uint8
+ valid bool
+ }{
+ {
+ name: "Valid",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: true,
+ },
+ {
+ name: "Fragmented",
+ atomicFragment: true,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid hop limit",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit - 1,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid ICMPv6 code",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 1,
+ valid: false,
+ },
+ }
+
for _, typ := range types {
t.Run(typ.name, func(t *testing.T) {
- s, ep, r := setup(t)
- defer r.Release()
-
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- extraDataLen := len(typ.extraData)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
- extraData := buffer.View(hdr.Prepend(extraDataLen))
- copy(extraData, typ.extraData)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- // Should not have received any ICMPv6 packets with
- // type = typ.typ.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Receive the NDP packet with an invalid hop limit
- // value.
- handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r)
-
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
-
- // Rx count of NDP packet of type typ.typ should not
- // have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Receive the NDP packet with a valid hop limit value.
- handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r)
-
- // Rx count of NDP packet of type typ.typ should have
- // increased.
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
- }
-
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
+ for _, test := range subTests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep, r := setup(t)
+ defer r.Release()
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ extraDataLen := len(typ.extraData)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen + header.IPv6FragmentExtHdrLength)
+ extraData := buffer.View(hdr.Prepend(extraDataLen))
+ copy(extraData, typ.extraData)
+ pkt := header.ICMPv6(hdr.Prepend(typ.size))
+ pkt.SetType(typ.typ)
+ pkt.SetCode(test.code)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
+
+ // Rx count of the NDP message should initially be 0.
+ if got := typStat.Value(); got != 0 {
+ t.Errorf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ handleIPv6Payload(hdr, test.hopLimit, test.atomicFragment, ep, &r)
+
+ // Rx count of the NDP packet should have increased.
+ if got := typStat.Value(); got != 1 {
+ t.Errorf("got %s = %d, want = 1", typ.name, got)
+ }
+
+ want := uint64(0)
+ if !test.valid {
+ // Invalid count should have increased.
+ want = 1
+ }
+ if got := invalid.Value(); got != want {
+ t.Errorf("got invalid = %d, want = %d", got, want)
+ }
+ })
}
})
}
@@ -588,25 +884,22 @@ func TestRouterAdvertValidation(t *testing.T) {
t.Fatalf("got rxRA = %d, want = 0", got)
}
- e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
+ if got := rxRA.Value(); got != 1 {
+ t.Fatalf("got rxRA = %d, want = 1", got)
+ }
+
if test.expectedSuccess {
if got := invalid.Value(); got != 0 {
t.Fatalf("got invalid = %d, want = 0", got)
}
- if got := rxRA.Value(); got != 1 {
- t.Fatalf("got rxRA = %d, want = 1", got)
- }
-
} else {
if got := invalid.Value(); got != 1 {
t.Fatalf("got invalid = %d, want = 1", got)
}
- if got := rxRA.Value(); got != 0 {
- t.Fatalf("got rxRA = %d, want = 0", got)
- }
}
})
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 705cf01ee..5e963a4af 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -15,14 +15,34 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "packet_buffer_list",
+ out = "packet_buffer_list.go",
+ package = "stack",
+ prefix = "PacketBuffer",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*PacketBuffer",
+ "Linker": "*PacketBuffer",
+ },
+)
+
go_library(
name = "stack",
srcs = [
+ "dhcpv6configurationfromndpra_string.go",
+ "forwarder.go",
"icmp_rate_limit.go",
+ "iptables.go",
+ "iptables_targets.go",
+ "iptables_types.go",
"linkaddrcache.go",
"linkaddrentry_list.go",
"ndp.go",
"nic.go",
+ "packet_buffer.go",
+ "packet_buffer_list.go",
+ "rand.go",
"registration.go",
"route.go",
"stack.go",
@@ -32,6 +52,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/ilist",
+ "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
"//pkg/sync",
@@ -39,7 +60,6 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/waiter",
@@ -63,7 +83,6 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
@@ -79,6 +98,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
+ "forwarder_test.go",
"linkaddrcache_test.go",
"nic_test.go",
],
diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go
new file mode 100644
index 000000000..8b4213eec
--- /dev/null
+++ b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type=DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[DHCPv6NoConfiguration-0]
+ _ = x[DHCPv6ManagedAddress-1]
+ _ = x[DHCPv6OtherConfigurations-2]
+}
+
+const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations"
+
+var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66}
+
+func (i DHCPv6ConfigurationFromNDPRA) String() string {
+ if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) {
+ return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go
new file mode 100644
index 000000000..6b64cd37f
--- /dev/null
+++ b/pkg/tcpip/stack/forwarder.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // maxPendingResolutions is the maximum number of pending link-address
+ // resolutions.
+ maxPendingResolutions = 64
+ maxPendingPacketsPerResolution = 256
+)
+
+type pendingPacket struct {
+ nic *NIC
+ route *Route
+ proto tcpip.NetworkProtocolNumber
+ pkt PacketBuffer
+}
+
+type forwardQueue struct {
+ sync.Mutex
+
+ // The packets to send once the resolver completes.
+ packets map[<-chan struct{}][]*pendingPacket
+
+ // FIFO of channels used to cancel the oldest goroutine waiting for
+ // link-address resolution.
+ cancelChans []chan struct{}
+}
+
+func newForwardQueue() *forwardQueue {
+ return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
+}
+
+func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+ shouldWait := false
+
+ f.Lock()
+ packets, ok := f.packets[ch]
+ if !ok {
+ shouldWait = true
+ }
+ for len(packets) == maxPendingPacketsPerResolution {
+ p := packets[0]
+ packets = packets[1:]
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Release()
+ }
+ if l := len(packets); l >= maxPendingPacketsPerResolution {
+ panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
+ }
+ f.packets[ch] = append(packets, &pendingPacket{
+ nic: n,
+ route: r,
+ proto: protocol,
+ pkt: pkt,
+ })
+ f.Unlock()
+
+ if !shouldWait {
+ return
+ }
+
+ // Wait for the link-address resolution to complete.
+ // Start a goroutine with a forwarding-cancel channel so that we can
+ // limit the maximum number of goroutines running concurrently.
+ cancel := f.newCancelChannel()
+ go func() {
+ cancelled := false
+ select {
+ case <-ch:
+ case <-cancel:
+ cancelled = true
+ }
+
+ f.Lock()
+ packets := f.packets[ch]
+ delete(f.packets, ch)
+ f.Unlock()
+
+ for _, p := range packets {
+ if cancelled {
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ } else if _, err := p.route.Resolve(nil); err != nil {
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ } else {
+ p.nic.forwardPacket(p.route, p.proto, p.pkt)
+ }
+ p.route.Release()
+ }
+ }()
+}
+
+// newCancelChannel creates a channel that can cancel a pending forwarding
+// activity. The oldest channel is closed if the number of open channels would
+// exceed maxPendingResolutions.
+func (f *forwardQueue) newCancelChannel() chan struct{} {
+ f.Lock()
+ defer f.Unlock()
+
+ if len(f.cancelChans) == maxPendingResolutions {
+ ch := f.cancelChans[0]
+ f.cancelChans = f.cancelChans[1:]
+ close(ch)
+ }
+ if l := len(f.cancelChans); l >= maxPendingResolutions {
+ panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions))
+ }
+
+ ch := make(chan struct{})
+ f.cancelChans = append(f.cancelChans, ch)
+ return ch
+}
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
new file mode 100644
index 000000000..e9c652042
--- /dev/null
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -0,0 +1,635 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "encoding/binary"
+ "math"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+const (
+ fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+ fwdTestNetHeaderLen = 12
+ fwdTestNetDefaultPrefixLen = 8
+
+ // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
+ // except where another value is explicitly used. It is chosen to match
+ // the MTU of loopback interfaces on linux systems.
+ fwdTestNetDefaultMTU = 65536
+)
+
+// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
+// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only
+// use the first three: destination address, source address, and transport
+// protocol. They're all one byte fields to simplify parsing.
+type fwdTestNetworkEndpoint struct {
+ nicID tcpip.NICID
+ id NetworkEndpointID
+ prefixLen int
+ proto *fwdTestNetworkProtocol
+ dispatcher TransportDispatcher
+ ep LinkEndpoint
+}
+
+func (f *fwdTestNetworkEndpoint) MTU() uint32 {
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
+}
+
+func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID {
+ return f.nicID
+}
+
+func (f *fwdTestNetworkEndpoint) PrefixLen() int {
+ return f.prefixLen
+}
+
+func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
+ return 123
+}
+
+func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
+ return &f.id
+}
+
+func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) {
+ // Consume the network header.
+ b := pkt.Data.First()
+ pkt.Data.TrimFront(fwdTestNetHeaderLen)
+
+ // Dispatch the packet to the transport protocol.
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
+}
+
+func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
+ return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen
+}
+
+func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
+ return 0
+}
+
+func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities {
+ return f.ep.Capabilities()
+}
+
+func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error {
+ // Add the protocol's header to the packet and send it to the link
+ // endpoint.
+ b := pkt.Header.Prepend(fwdTestNetHeaderLen)
+ b[0] = r.RemoteAddress[0]
+ b[1] = f.id.LocalAddress[0]
+ b[2] = byte(params.Protocol)
+
+ return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func (*fwdTestNetworkEndpoint) Close() {}
+
+// fwdTestNetworkProtocol is a network-layer protocol that implements Address
+// resolution.
+type fwdTestNetworkProtocol struct {
+ addrCache *linkAddrCache
+ addrResolveDelay time.Duration
+ onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address)
+ onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
+}
+
+func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
+ return fwdTestNetNumber
+}
+
+func (f *fwdTestNetworkProtocol) MinimumPacketSize() int {
+ return fwdTestNetHeaderLen
+}
+
+func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int {
+ return fwdTestNetDefaultPrefixLen
+}
+
+func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
+}
+
+func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
+ return &fwdTestNetworkEndpoint{
+ nicID: nicID,
+ id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
+ proto: f,
+ dispatcher: dispatcher,
+ ep: ep,
+ }, nil
+}
+
+func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func (f *fwdTestNetworkProtocol) Close() {}
+
+func (f *fwdTestNetworkProtocol) Wait() {}
+
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error {
+ if f.addrCache != nil && f.onLinkAddressResolved != nil {
+ time.AfterFunc(f.addrResolveDelay, func() {
+ f.onLinkAddressResolved(f.addrCache, addr)
+ })
+ }
+ return nil
+}
+
+func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if f.onResolveStaticAddress != nil {
+ return f.onResolveStaticAddress(addr)
+ }
+ return "", false
+}
+
+func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return fwdTestNetNumber
+}
+
+// fwdTestPacketInfo holds all the information about an outbound packet.
+type fwdTestPacketInfo struct {
+ RemoteLinkAddress tcpip.LinkAddress
+ LocalLinkAddress tcpip.LinkAddress
+ Pkt PacketBuffer
+}
+
+type fwdTestLinkEndpoint struct {
+ dispatcher NetworkDispatcher
+ mtu uint32
+ linkAddr tcpip.LinkAddress
+
+ // C is where outbound packets are queued.
+ C chan fwdTestPacketInfo
+}
+
+// InjectInbound injects an inbound packet.
+func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+ e.InjectLinkAddr(protocol, "", pkt)
+}
+
+// InjectLinkAddr injects an inbound packet with a remote link address.
+func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt)
+}
+
+// Attach saves the stack network-layer dispatcher for use later when packets
+// are injected.
+func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *fwdTestLinkEndpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *fwdTestLinkEndpoint) MTU() uint32 {
+ return e.mtu
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities {
+ caps := LinkEndpointCapabilities(0)
+ return caps | CapabilityResolutionRequired
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 {
+ return 1 << 15
+}
+
+// MaxHeaderLength returns the maximum size of the link layer header. Given it
+// doesn't have a header, it just returns 0.
+func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error {
+ p := fwdTestPacketInfo{
+ RemoteLinkAddress: r.RemoteLinkAddress,
+ LocalLinkAddress: r.LocalLinkAddress,
+ Pkt: pkt,
+ }
+
+ select {
+ case e.C <- p:
+ default:
+ }
+
+ return nil
+}
+
+// WritePackets stores outbound packets into the channel.
+func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.WritePacket(r, gso, protocol, *pkt)
+ n++
+ }
+
+ return n, nil
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ p := fwdTestPacketInfo{
+ Pkt: PacketBuffer{Data: vv},
+ }
+
+ select {
+ case e.C <- p:
+ default:
+ }
+
+ return nil
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*fwdTestLinkEndpoint) Wait() {}
+
+func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
+ // Create a stack with the network protocol and two NICs.
+ s := New(Options{
+ NetworkProtocols: []NetworkProtocol{proto},
+ })
+
+ proto.addrCache = s.linkAddrCache
+
+ // Enable forwarding.
+ s.SetForwarding(true)
+
+ // NIC 1 has the link address "a", and added the network address 1.
+ ep1 = &fwdTestLinkEndpoint{
+ C: make(chan fwdTestPacketInfo, 300),
+ mtu: fwdTestNetDefaultMTU,
+ linkAddr: "a",
+ }
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC #1 failed:", err)
+ }
+ if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil {
+ t.Fatal("AddAddress #1 failed:", err)
+ }
+
+ // NIC 2 has the link address "b", and added the network address 2.
+ ep2 = &fwdTestLinkEndpoint{
+ C: make(chan fwdTestPacketInfo, 300),
+ mtu: fwdTestNetDefaultMTU,
+ linkAddr: "b",
+ }
+ if err := s.CreateNIC(2, ep2); err != nil {
+ t.Fatal("CreateNIC #2 failed:", err)
+ }
+ if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil {
+ t.Fatal("AddAddress #2 failed:", err)
+ }
+
+ // Route all packets to NIC 2.
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}})
+ }
+
+ return ep1, ep2
+}
+
+func TestForwardingWithStaticResolver(t *testing.T) {
+ // Create a network protocol with a static resolver.
+ proto := &fwdTestNetworkProtocol{
+ onResolveStaticAddress:
+ // The network address 3 is resolved to the link address "c".
+ func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "\x03" {
+ return "c", true
+ }
+ return "", false
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[0] = 3
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ default:
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the static address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+}
+
+func TestForwardingWithFakeResolver(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Any address will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[0] = 3
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+}
+
+func TestForwardingWithNoResolver(t *testing.T) {
+ // Create a network protocol without a resolver.
+ proto := &fwdTestNetworkProtocol{}
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[0] = 3
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ select {
+ case <-ep2.C:
+ t.Fatal("Packet should not be forwarded")
+ case <-time.After(time.Second):
+ }
+}
+
+func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Only packets to address 3 will be resolved to the
+ // link address "c".
+ if addr == "\x03" {
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ }
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject an inbound packet to address 4 on NIC 1. This packet should
+ // not be forwarded.
+ buf := buffer.NewView(30)
+ buf[0] = 4
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf = buffer.NewView(30)
+ buf[0] = 3
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ b := p.Pkt.Header.View()
+ if b[0] != 3 {
+ t.Fatalf("got b[0] = %d, want = 3", b[0])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+}
+
+func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject two inbound packets to address 3 on NIC 1.
+ for i := 0; i < 2; i++ {
+ buf := buffer.NewView(30)
+ buf[0] = 3
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ }
+
+ for i := 0; i < 2; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ b := p.Pkt.Header.View()
+ if b[0] != 3 {
+ t.Fatalf("got b[0] = %d, want = 3", b[0])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+}
+
+func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
+ // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
+ buf := buffer.NewView(30)
+ buf[0] = 3
+ // Set the packet sequence number.
+ binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ }
+
+ for i := 0; i < maxPendingPacketsPerResolution; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ b := p.Pkt.Header.View()
+ if b[0] != 3 {
+ t.Fatalf("got b[0] = %d, want = 3", b[0])
+ }
+ // The first 5 packets should not be forwarded so the the
+ // sequemnce number should start with 5.
+ want := uint16(i + 5)
+ if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want {
+ t.Fatalf("got the packet #%d, want = #%d", n, want)
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+}
+
+func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ for i := 0; i < maxPendingResolutions+5; i++ {
+ // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
+ // Each packet has a different destination address (3 to
+ // maxPendingResolutions + 7).
+ buf := buffer.NewView(30)
+ buf[0] = byte(3 + i)
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+ }
+
+ for i := 0; i < maxPendingResolutions; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ // The first 5 packets (address 3 to 7) should not be forwarded
+ // because their address resolutions are interrupted.
+ b := p.Pkt.Header.View()
+ if b[0] < 8 {
+ t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+}
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/stack/iptables.go
index dbaccbb36..6c0a4b24d 100644
--- a/pkg/tcpip/iptables/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -12,14 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package iptables supports packet filtering and manipulation via the iptables
-// tool.
-package iptables
+package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -135,6 +132,27 @@ func EmptyFilterTable() Table {
}
}
+// EmptyNatTable returns a Table with no rules and the filter table chains
+// mapped to HookUnset.
+func EmptyNatTable() Table {
+ return Table{
+ Rules: []Rule{},
+ BuiltinChains: map[Hook]int{
+ Prerouting: HookUnset,
+ Input: HookUnset,
+ Output: HookUnset,
+ Postrouting: HookUnset,
+ },
+ Underflows: map[Hook]int{
+ Prerouting: HookUnset,
+ Input: HookUnset,
+ Output: HookUnset,
+ Postrouting: HookUnset,
+ },
+ UserChains: map[string]int{},
+ }
+}
+
// A chainVerdict is what a table decides should be done with a packet.
type chainVerdict int
@@ -155,7 +173,7 @@ const (
// dropped.
//
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool {
+func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool {
// Go through each table containing the hook.
for _, tablename := range it.Priorities[hook] {
table := it.Tables[tablename]
@@ -191,8 +209,25 @@ func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool {
return true
}
+// CheckPackets runs pkts through the rules for hook and returns a map of packets that
+// should not go forward.
+//
+// NOTE: unlike the Check API the returned map contains packets that should be
+// dropped.
+func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if ok := it.Check(hook, *pkt); !ok {
+ if drop == nil {
+ drop = make(map[*PacketBuffer]struct{})
+ }
+ drop[pkt] = struct{}{}
+ }
+ }
+ return drop
+}
+
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) checkChain(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
@@ -237,12 +272,17 @@ func (it *IPTables) checkChain(hook Hook, pkt tcpip.PacketBuffer, table Table, r
}
// Precondition: pk.NetworkHeader is set.
-func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
- // First check whether the packet matches the IP header filter.
- // TODO(gvisor.dev/issue/170): Support other fields of the filter.
- if rule.Filter.Protocol != 0 && rule.Filter.Protocol != header.IPv4(pkt.NetworkHeader).TransportProtocol() {
+ // If pkt.NetworkHeader hasn't been set yet, it will be contained in
+ // pkt.Data.First().
+ if pkt.NetworkHeader == nil {
+ pkt.NetworkHeader = pkt.Data.First()
+ }
+
+ // Check whether the packet matches the IP header filter.
+ if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader)) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
@@ -263,3 +303,26 @@ func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ru
// All the matchers matched, so run the target.
return rule.Target.Action(pkt)
}
+
+func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool {
+ // TODO(gvisor.dev/issue/170): Support other fields of the filter.
+ // Check the transport protocol.
+ if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() {
+ return false
+ }
+
+ // Check the destination IP.
+ dest := hdr.DestinationAddress()
+ matches := true
+ for i := range filter.Dst {
+ if dest[i]&filter.DstMask[i] != filter.Dst[i] {
+ matches = false
+ break
+ }
+ }
+ if matches == filter.DstInvert {
+ return false
+ }
+
+ return true
+}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
new file mode 100644
index 000000000..7b4543caf
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -0,0 +1,144 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// AcceptTarget accepts packets.
+type AcceptTarget struct{}
+
+// Action implements Target.Action.
+func (AcceptTarget) Action(packet PacketBuffer) (RuleVerdict, int) {
+ return RuleAccept, 0
+}
+
+// DropTarget drops packets.
+type DropTarget struct{}
+
+// Action implements Target.Action.
+func (DropTarget) Action(packet PacketBuffer) (RuleVerdict, int) {
+ return RuleDrop, 0
+}
+
+// ErrorTarget logs an error and drops the packet. It represents a target that
+// should be unreachable.
+type ErrorTarget struct{}
+
+// Action implements Target.Action.
+func (ErrorTarget) Action(packet PacketBuffer) (RuleVerdict, int) {
+ log.Debugf("ErrorTarget triggered.")
+ return RuleDrop, 0
+}
+
+// UserChainTarget marks a rule as the beginning of a user chain.
+type UserChainTarget struct {
+ Name string
+}
+
+// Action implements Target.Action.
+func (UserChainTarget) Action(PacketBuffer) (RuleVerdict, int) {
+ panic("UserChainTarget should never be called.")
+}
+
+// ReturnTarget returns from the current chain. If the chain is a built-in, the
+// hook's underflow should be called.
+type ReturnTarget struct{}
+
+// Action implements Target.Action.
+func (ReturnTarget) Action(PacketBuffer) (RuleVerdict, int) {
+ return RuleReturn, 0
+}
+
+// RedirectTarget redirects the packet by modifying the destination port/IP.
+// Min and Max values for IP and Ports in the struct indicate the range of
+// values which can be used to redirect.
+type RedirectTarget struct {
+ // TODO(gvisor.dev/issue/170): Other flags need to be added after
+ // we support them.
+ // RangeProtoSpecified flag indicates single port is specified to
+ // redirect.
+ RangeProtoSpecified bool
+
+ // Min address used to redirect.
+ MinIP tcpip.Address
+
+ // Max address used to redirect.
+ MaxIP tcpip.Address
+
+ // Min port used to redirect.
+ MinPort uint16
+
+ // Max port used to redirect.
+ MaxPort uint16
+}
+
+// Action implements Target.Action.
+// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
+// implementation only works for PREROUTING and calls pkt.Clone(), neither
+// of which should be the case.
+func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) {
+ newPkt := pkt.Clone()
+
+ // Set network header.
+ headerView := newPkt.Data.First()
+ netHeader := header.IPv4(headerView)
+ newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize]
+
+ hlen := int(netHeader.HeaderLength())
+ tlen := int(netHeader.TotalLength())
+ newPkt.Data.TrimFront(hlen)
+ newPkt.Data.CapLength(tlen - hlen)
+
+ // TODO(gvisor.dev/issue/170): Change destination address to
+ // loopback or interface address on which the packet was
+ // received.
+
+ // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
+ // we need to change dest address (for OUTPUT chain) or ports.
+ switch protocol := netHeader.TransportProtocol(); protocol {
+ case header.UDPProtocolNumber:
+ var udpHeader header.UDP
+ if newPkt.TransportHeader != nil {
+ udpHeader = header.UDP(newPkt.TransportHeader)
+ } else {
+ if len(pkt.Data.First()) < header.UDPMinimumSize {
+ return RuleDrop, 0
+ }
+ udpHeader = header.UDP(newPkt.Data.First())
+ }
+ udpHeader.SetDestinationPort(rt.MinPort)
+ case header.TCPProtocolNumber:
+ var tcpHeader header.TCP
+ if newPkt.TransportHeader != nil {
+ tcpHeader = header.TCP(newPkt.TransportHeader)
+ } else {
+ if len(pkt.Data.First()) < header.TCPMinimumSize {
+ return RuleDrop, 0
+ }
+ tcpHeader = header.TCP(newPkt.TransportHeader)
+ }
+ // TODO(gvisor.dev/issue/170): Need to recompute checksum
+ // and implement nat connection tracking to support TCP.
+ tcpHeader.SetDestinationPort(rt.MinPort)
+ default:
+ return RuleDrop, 0
+ }
+
+ return RuleAccept, 0
+}
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/stack/iptables_types.go
index 7d032fd23..2ffb55f2a 100644
--- a/pkg/tcpip/iptables/types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package iptables
+package stack
import (
"gvisor.dev/gvisor/pkg/tcpip"
@@ -144,6 +144,18 @@ type Rule struct {
type IPHeaderFilter struct {
// Protocol matches the transport protocol.
Protocol tcpip.TransportProtocolNumber
+
+ // Dst matches the destination IP address.
+ Dst tcpip.Address
+
+ // DstMask masks bits of the destination IP address when comparing with
+ // Dst.
+ DstMask tcpip.Address
+
+ // DstInvert inverts the meaning of the destination IP check, i.e. when
+ // true the filter will match packets that fail the destination
+ // comparison.
+ DstInvert bool
}
// A Matcher is the interface for matching packets.
@@ -156,7 +168,7 @@ type Matcher interface {
// used for suspicious packets.
//
// Precondition: packet.NetworkHeader is set.
- Match(hook Hook, packet tcpip.PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
+ Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
// A Target is the interface for taking an action for a packet.
@@ -164,5 +176,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(packet tcpip.PacketBuffer) (RuleVerdict, int)
+ Action(packet PacketBuffer) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index f651871ce..8140c6dd4 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -15,6 +15,7 @@
package stack
import (
+ "fmt"
"log"
"math/rand"
"time"
@@ -304,6 +305,15 @@ type NDPConfigurations struct {
// lifetime(s) of the generated address changes; this option only
// affects the generation of new addresses as part of SLAAC.
AutoGenGlobalAddresses bool
+
+ // AutoGenAddressConflictRetries determines how many times to attempt to retry
+ // generation of a permanent auto-generated address in response to DAD
+ // conflicts.
+ //
+ // If the method used to generate the address does not support creating
+ // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's
+ // MAC address), then no attempt will be made to resolve the conflict.
+ AutoGenAddressConflictRetries uint8
}
// DefaultNDPConfigurations returns an NDPConfigurations populated with
@@ -361,16 +371,16 @@ type ndpState struct {
// The default routers discovered through Router Advertisements.
defaultRouters map[tcpip.Address]defaultRouterState
+ // The timer used to send the next router solicitation message.
+ rtrSolicitTimer *time.Timer
+
// The on-link prefixes discovered through Router Advertisements' Prefix
// Information option.
onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState
- // The timer used to send the next router solicitation message.
- // If routers are being solicited, rtrSolicitTimer MUST NOT be nil.
- rtrSolicitTimer *time.Timer
-
- // The addresses generated by SLAAC.
- autoGenAddresses map[tcpip.Address]autoGenAddressState
+ // The SLAAC prefixes discovered through Router Advertisements' Prefix
+ // Information option.
+ slaacPrefixes map[tcpip.Subnet]slaacPrefixState
// The last learned DHCPv6 configuration from an NDP RA.
dhcpv6Configuration DHCPv6ConfigurationFromNDPRA
@@ -402,18 +412,31 @@ type onLinkPrefixState struct {
invalidationTimer tcpip.CancellableTimer
}
-// autoGenAddressState holds data associated with an address generated via
-// SLAAC.
-type autoGenAddressState struct {
- // A reference to the referencedNetworkEndpoint that this autoGenAddressState
- // is holding state for.
- ref *referencedNetworkEndpoint
-
+// slaacPrefixState holds state associated with a SLAAC prefix.
+type slaacPrefixState struct {
deprecationTimer tcpip.CancellableTimer
invalidationTimer tcpip.CancellableTimer
// Nonzero only when the address is not valid forever.
validUntil time.Time
+
+ // Nonzero only when the address is not preferred forever.
+ preferredUntil time.Time
+
+ // The prefix's permanent address endpoint.
+ //
+ // May only be nil when a SLAAC address is being (re-)generated. Otherwise,
+ // must not be nil as all SLAAC prefixes must have a SLAAC address.
+ ref *referencedNetworkEndpoint
+
+ // The number of times a permanent address has been generated for the prefix.
+ //
+ // Addresses may be regenerated in reseponse to a DAD conflicts.
+ generationAttempts uint8
+
+ // The maximum number of times to attempt regeneration of a permanent SLAAC
+ // address in response to DAD conflicts.
+ maxGenerationAttempts uint8
}
// startDuplicateAddressDetection performs Duplicate Address Detection.
@@ -430,7 +453,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
if ref.getKind() != permanentTentative {
// The endpoint should be marked as tentative since we are starting DAD.
- log.Fatalf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
}
// Should not attempt to perform DAD on an address that is currently in the
@@ -442,7 +465,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// address, or its reference count would have been increased without doing
// the work that would have been done for an address that was brand new.
// See NIC.addAddressLocked.
- log.Fatalf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()))
}
remaining := ndp.configs.DupAddrDetectTransmits
@@ -478,7 +501,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
if ref.getKind() != permanentTentative {
// The endpoint should still be marked as tentative since we are still
// performing DAD on it.
- log.Fatalf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID())
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID()))
}
dadDone := remaining == 0
@@ -548,9 +571,9 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
// Route should resolve immediately since snmc is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
- log.Fatalf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err))
} else if c != nil {
- log.Fatalf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()))
}
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
@@ -566,7 +589,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
TOS: DefaultTOS,
- }, tcpip.PacketBuffer{Header: hdr},
+ }, PacketBuffer{Header: hdr},
); err != nil {
sent.Dropped.Increment()
return err
@@ -688,7 +711,8 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
continue
}
- ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), opt.Addresses(), opt.Lifetime())
+ addrs, _ := opt.Addresses()
+ ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime())
case header.NDPPrefixInformation:
prefix := opt.Subnet()
@@ -899,23 +923,15 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
prefix := pi.Subnet()
- // Check if we already have an auto-generated address for prefix.
- for addr, addrState := range ndp.autoGenAddresses {
- refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: addrState.ref.ep.PrefixLen()}
- if refAddrWithPrefix.Subnet() != prefix {
- continue
- }
-
- // At this point, we know we are refreshing a SLAAC generated IPv6 address
- // with the prefix prefix. Do the work as outlined by RFC 4862 section
- // 5.5.3.e.
- ndp.refreshAutoGenAddressLifetimes(addr, pl, vl)
+ // Check if we already maintain SLAAC state for prefix.
+ if _, ok := ndp.slaacPrefixes[prefix]; ok {
+ // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes.
+ ndp.refreshSLAACPrefixLifetimes(prefix, pl, vl)
return
}
- // We do not already have an address with the prefix prefix. Do the
- // work as outlined by RFC 4862 section 5.5.3.d if n is configured
- // to auto-generate global addresses by SLAAC.
+ // prefix is a new SLAAC prefix. Do the work as outlined by RFC 4862 section
+ // 5.5.3.d if ndp is configured to auto-generate new addresses via SLAAC.
if !ndp.configs.AutoGenGlobalAddresses {
return
}
@@ -927,6 +943,8 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
// for prefix.
//
// pl is the new preferred lifetime. vl is the new valid lifetime.
+//
+// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// If we do not already have an address for this prefix and the valid
// lifetime is 0, no need to do anything further, as per RFC 4862
@@ -942,10 +960,83 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
return
}
+ state := slaacPrefixState{
+ deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
+ }
+
+ ndp.deprecateSLAACAddress(state.ref)
+ }),
+ invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
+ }
+
+ ndp.invalidateSLAACPrefix(prefix, state)
+ }),
+ maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1,
+ }
+
+ now := time.Now()
+
+ // The time an address is preferred until is needed to properly generate the
+ // address.
+ if pl < header.NDPInfiniteLifetime {
+ state.preferredUntil = now.Add(pl)
+ }
+
+ if !ndp.generateSLAACAddr(prefix, &state) {
+ // We were unable to generate an address for the prefix, we do not nothing
+ // further as there is no reason to maintain state or timers for a prefix we
+ // do not have an address for.
+ return
+ }
+
+ // Setup the initial timers to deprecate and invalidate prefix.
+
+ if pl < header.NDPInfiniteLifetime && pl != 0 {
+ state.deprecationTimer.Reset(pl)
+ }
+
+ if vl < header.NDPInfiniteLifetime {
+ state.invalidationTimer.Reset(vl)
+ state.validUntil = now.Add(vl)
+ }
+
+ ndp.slaacPrefixes[prefix] = state
+}
+
+// generateSLAACAddr generates a SLAAC address for prefix.
+//
+// Returns true if an address was successfully generated.
+//
+// Panics if the prefix is not a SLAAC prefix or it already has an address.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool {
+ if r := state.ref; r != nil {
+ panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix()))
+ }
+
+ // If we have already reached the maximum address generation attempts for the
+ // prefix, do not generate another address.
+ if state.generationAttempts == state.maxGenerationAttempts {
+ return false
+ }
+
addrBytes := []byte(prefix.ID())
if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
- addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), 0 /* dadCounter */, oIID.SecretKey)
- } else {
+ addrBytes = header.AppendOpaqueInterfaceIdentifier(
+ addrBytes[:header.IIDOffsetInIPv6Address],
+ prefix,
+ oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name),
+ state.generationAttempts,
+ oIID.SecretKey,
+ )
+ } else if state.generationAttempts == 0 {
// Only attempt to generate an interface-specific IID if we have a valid
// link address.
//
@@ -953,137 +1044,140 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// LinkEndpoint.LinkAddress) before reaching this point.
linkAddr := ndp.nic.linkEP.LinkAddress()
if !header.IsValidUnicastEthernetAddress(linkAddr) {
- return
+ return false
}
// Generate an address within prefix from the modified EUI-64 of ndp's NIC's
// Ethernet MAC address.
header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ } else {
+ // We have no way to regenerate an address when addresses are not generated
+ // with opaque IIDs.
+ return false
}
- addr := tcpip.Address(addrBytes)
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: validPrefixLenForAutoGen,
+
+ generatedAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: validPrefixLenForAutoGen,
+ },
}
// If the nic already has this address, do nothing further.
- if ndp.nic.hasPermanentAddrLocked(addr) {
- return
+ if ndp.nic.hasPermanentAddrLocked(generatedAddr.AddressWithPrefix.Address) {
+ return false
}
// Inform the integrator that we have a new SLAAC address.
ndpDisp := ndp.nic.stack.ndpDisp
if ndpDisp == nil {
- return
+ return false
}
- if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) {
+
+ if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), generatedAddr.AddressWithPrefix) {
// Informed by the integrator not to add the address.
- return
+ return false
}
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addrWithPrefix,
- }
- // If the preferred lifetime is zero, then the address should be considered
- // deprecated.
- deprecated := pl == 0
- ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated)
+ deprecated := time.Since(state.preferredUntil) >= 0
+ ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated)
if err != nil {
- log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err)
- }
-
- state := autoGenAddressState{
- ref: ref,
- deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
- addrState, ok := ndp.autoGenAddresses[addr]
- if !ok {
- log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr)
- }
- addrState.ref.deprecated = true
- ndp.notifyAutoGenAddressDeprecated(addr)
- }),
- invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
- ndp.invalidateAutoGenAddress(addr)
- }),
+ panic(fmt.Sprintf("ndp: error when adding address %+v: %s", generatedAddr, err))
}
- // Setup the initial timers to deprecate and invalidate this newly generated
- // address.
+ state.generationAttempts++
+ state.ref = ref
+ return true
+}
- if !deprecated && pl < header.NDPInfiniteLifetime {
- state.deprecationTimer.Reset(pl)
+// regenerateSLAACAddr regenerates an address for a SLAAC prefix.
+//
+// If generating a new address for the prefix fails, the prefix will be
+// invalidated.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate address for %s", prefix))
}
- if vl < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(vl)
- state.validUntil = time.Now().Add(vl)
+ if ndp.generateSLAACAddr(prefix, &state) {
+ ndp.slaacPrefixes[prefix] = state
+ return
}
- ndp.autoGenAddresses[addr] = state
+ // We were unable to generate a permanent address for the SLAAC prefix so
+ // invalidate the prefix as there is no reason to maintain state for a
+ // SLAAC prefix we do not have an address for.
+ ndp.invalidateSLAACPrefix(prefix, state)
}
-// refreshAutoGenAddressLifetimes refreshes the lifetime of a SLAAC generated
-// address addr.
+// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix.
//
// pl is the new preferred lifetime. vl is the new valid lifetime.
-func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl time.Duration) {
- addrState, ok := ndp.autoGenAddresses[addr]
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl time.Duration) {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
- log.Fatalf("ndp: SLAAC state not found to refresh lifetimes for %s", addr)
+ panic(fmt.Sprintf("ndp: SLAAC prefix state not found to refresh lifetimes for %s", prefix))
}
- defer func() { ndp.autoGenAddresses[addr] = addrState }()
+ defer func() { ndp.slaacPrefixes[prefix] = prefixState }()
- // If the preferred lifetime is zero, then the address should be considered
- // deprecated.
+ // If the preferred lifetime is zero, then the prefix should be deprecated.
deprecated := pl == 0
- wasDeprecated := addrState.ref.deprecated
- addrState.ref.deprecated = deprecated
-
- // Only send the deprecation event if the deprecated status for addr just
- // changed from non-deprecated to deprecated.
- if !wasDeprecated && deprecated {
- ndp.notifyAutoGenAddressDeprecated(addr)
+ if deprecated {
+ ndp.deprecateSLAACAddress(prefixState.ref)
+ } else {
+ prefixState.ref.deprecated = false
}
- // If addr was preferred for some finite lifetime before, stop the deprecation
- // timer so it can be reset.
- addrState.deprecationTimer.StopLocked()
+ // If prefix was preferred for some finite lifetime before, stop the
+ // deprecation timer so it can be reset.
+ prefixState.deprecationTimer.StopLocked()
+
+ now := time.Now()
- // Reset the deprecation timer if addr has a finite preferred lifetime.
- if !deprecated && pl < header.NDPInfiniteLifetime {
- addrState.deprecationTimer.Reset(pl)
+ // Reset the deprecation timer if prefix has a finite preferred lifetime.
+ if pl < header.NDPInfiniteLifetime {
+ if !deprecated {
+ prefixState.deprecationTimer.Reset(pl)
+ }
+ prefixState.preferredUntil = now.Add(pl)
+ } else {
+ prefixState.preferredUntil = time.Time{}
}
- // As per RFC 4862 section 5.5.3.e, the valid lifetime of the address
- //
+ // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix:
//
// 1) If the received Valid Lifetime is greater than 2 hours or greater than
- // RemainingLifetime, set the valid lifetime of the address to the
+ // RemainingLifetime, set the valid lifetime of the prefix to the
// advertised Valid Lifetime.
//
// 2) If RemainingLifetime is less than or equal to 2 hours, ignore the
// advertised Valid Lifetime.
//
- // 3) Otherwise, reset the valid lifetime of the address to 2 hours.
+ // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
// Handle the infinite valid lifetime separately as we do not keep a timer in
// this case.
if vl >= header.NDPInfiniteLifetime {
- addrState.invalidationTimer.StopLocked()
- addrState.validUntil = time.Time{}
+ prefixState.invalidationTimer.StopLocked()
+ prefixState.validUntil = time.Time{}
return
}
var effectiveVl time.Duration
var rl time.Duration
- // If the address was originally set to be valid forever, assume the remaining
+ // If the prefix was originally set to be valid forever, assume the remaining
// time to be the maximum possible value.
- if addrState.validUntil == (time.Time{}) {
+ if prefixState.validUntil == (time.Time{}) {
rl = header.NDPInfiniteLifetime
} else {
- rl = time.Until(addrState.validUntil)
+ rl = time.Until(prefixState.validUntil)
}
if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
@@ -1094,58 +1188,78 @@ func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl t
effectiveVl = MinPrefixInformationValidLifetimeForUpdate
}
- addrState.invalidationTimer.StopLocked()
- addrState.invalidationTimer.Reset(effectiveVl)
- addrState.validUntil = time.Now().Add(effectiveVl)
+ prefixState.invalidationTimer.StopLocked()
+ prefixState.invalidationTimer.Reset(effectiveVl)
+ prefixState.validUntil = now.Add(effectiveVl)
}
-// notifyAutoGenAddressDeprecated notifies the stack's NDP dispatcher that addr
-// has been deprecated.
-func (ndp *ndpState) notifyAutoGenAddressDeprecated(addr tcpip.Address) {
+// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP
+// dispatcher that ref has been deprecated.
+//
+// deprecateSLAACAddress does nothing if ref is already deprecated.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) {
+ if ref.deprecated {
+ return
+ }
+
+ ref.deprecated = true
if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: validPrefixLenForAutoGen,
- })
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix())
}
}
-// invalidateAutoGenAddress invalidates an auto-generated address.
+// invalidateSLAACPrefix invalidates a SLAAC prefix.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) {
- if !ndp.cleanupAutoGenAddrResourcesAndNotify(addr) {
- return
+func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) {
+ if r := state.ref; r != nil {
+ // Since we are already invalidating the prefix, do not invalidate the
+ // prefix when removing the address.
+ if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACPrefixInvalidation */); err != nil {
+ panic(fmt.Sprintf("ndp: removePermanentIPv6EndpointLocked(%s, false): %s", r.addrWithPrefix(), err))
+ }
}
- ndp.nic.removePermanentAddressLocked(addr)
+ ndp.cleanupSLAACPrefixResources(prefix, state)
}
-// cleanupAutoGenAddrResourcesAndNotify cleans up an invalidated auto-generated
-// address's resources from ndp. If the stack has an NDP dispatcher, it will
-// be notified that addr has been invalidated.
-//
-// Returns true if ndp had resources for addr to cleanup.
+// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's
+// resources.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool {
- state, ok := ndp.autoGenAddresses[addr]
- if !ok {
- return false
+func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr)
}
- state.deprecationTimer.StopLocked()
- state.invalidationTimer.StopLocked()
- delete(ndp.autoGenAddresses, addr)
+ prefix := addr.Subnet()
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok || state.ref == nil || addr.Address != state.ref.ep.ID().LocalAddress {
+ return
+ }
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: validPrefixLenForAutoGen,
- })
+ if !invalidatePrefix {
+ // If the prefix is not being invalidated, disassociate the address from the
+ // prefix and do nothing further.
+ state.ref = nil
+ ndp.slaacPrefixes[prefix] = state
+ return
}
- return true
+ ndp.cleanupSLAACPrefixResources(prefix, state)
+}
+
+// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry.
+//
+// Panics if the SLAAC prefix is not known.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) {
+ state.deprecationTimer.StopLocked()
+ state.invalidationTimer.StopLocked()
+ delete(ndp.slaacPrefixes, prefix)
}
// cleanupState cleans up ndp's state.
@@ -1163,21 +1277,21 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupState(hostOnly bool) {
linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
- linkLocalAddrs := 0
- for addr := range ndp.autoGenAddresses {
+ linkLocalPrefixes := 0
+ for prefix, state := range ndp.slaacPrefixes {
// RFC 4862 section 5 states that routers are also expected to generate a
// link-local address so we do not invalidate them if we are cleaning up
// host-only state.
- if hostOnly && linkLocalSubnet.Contains(addr) {
- linkLocalAddrs++
+ if hostOnly && prefix == linkLocalSubnet {
+ linkLocalPrefixes++
continue
}
- ndp.invalidateAutoGenAddress(addr)
+ ndp.invalidateSLAACPrefix(prefix, state)
}
- if got := len(ndp.autoGenAddresses); got != linkLocalAddrs {
- log.Fatalf("ndp: still have non-linklocal auto-generated addresses after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalAddrs)
+ if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
+ panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes))
}
for prefix := range ndp.onLinkPrefixes {
@@ -1185,7 +1299,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
}
if got := len(ndp.onLinkPrefixes); got != 0 {
- log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got)
+ panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got))
}
for router := range ndp.defaultRouters {
@@ -1193,7 +1307,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
}
if got := len(ndp.defaultRouters); got != 0 {
- log.Fatalf("ndp: still have discovered default routers after cleaning up; found = %d", got)
+ panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got))
}
}
@@ -1220,24 +1334,45 @@ func (ndp *ndpState) startSolicitingRouters() {
}
ndp.rtrSolicitTimer = time.AfterFunc(delay, func() {
- // Send an RS message with the unspecified source address.
- ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
- r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ // As per RFC 4861 section 4.1, the source of the RS is an address assigned
+ // to the sending interface, or the unspecified address if no address is
+ // assigned to the sending interface.
+ ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress)
+ if ref == nil {
+ ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
+ }
+ localAddr := ref.ep.ID().LocalAddress
+ r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
// Route should resolve immediately since
// header.IPv6AllRoutersMulticastAddress is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
- log.Fatalf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err))
} else if c != nil {
- log.Fatalf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()))
}
- payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize
+ // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
+ // link-layer address option if the source address of the NDP RS is
+ // specified. This option MUST NOT be included if the source address is
+ // unspecified.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ var optsSerializer header.NDPOptionsSerializer
+ if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) {
+ optsSerializer = header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress),
+ }
+ }
+ payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize)
pkt := header.ICMPv6(hdr.Prepend(payloadSize))
pkt.SetType(header.ICMPv6RouterSolicit)
+ rs := header.NDPRouterSolicit(pkt.NDPPayload())
+ rs.Options().Serialize(optsSerializer)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
sent := r.Stats().ICMP.V6PacketsSent
@@ -1246,7 +1381,7 @@ func (ndp *ndpState) startSolicitingRouters() {
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
TOS: DefaultTOS,
- }, tcpip.PacketBuffer{Header: hdr},
+ }, PacketBuffer{Header: hdr},
); err != nil {
sent.Dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 6e9306d09..6562a2d22 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -406,8 +406,7 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
}
- // Address should not be considered bound to the NIC yet
- // (DAD ongoing).
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
@@ -416,10 +415,9 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
- // Wait for the remaining time - some delta (500ms), to
- // make sure the address is still not resolved.
- const delta = 500 * time.Millisecond
- time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta)
+ // Make sure the address does not resolve before the resolution time has
+ // passed.
+ time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout)
addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
@@ -430,13 +428,7 @@ func TestDADResolve(t *testing.T) {
// Wait for DAD to resolve.
select {
- case <-time.After(2 * delta):
- // We should get a resolution event after 500ms
- // (delta) since we wait for 500ms less than the
- // expected resolution time above to make sure
- // that the address did not yet resolve. Waiting
- // for 1s (2x delta) without a resolution event
- // means something is wrong.
+ case <-time.After(2 * defaultAsyncEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
@@ -476,7 +468,7 @@ func TestDADResolve(t *testing.T) {
// As per RFC 4861 section 4.3, a possible option is the Source Link
// Layer option, but this option MUST NOT be included when the source
// address of the packet is the unspecified address.
- checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(),
+ checker.IPv6(t, p.Pkt.Header.View(),
checker.SrcAddr(header.IPv6Any),
checker.DstAddr(snmc),
checker.TTL(header.NDPHopLimit),
@@ -602,7 +594,7 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate multiple nodes owning or
// attempting to own the same address.
hdr := test.makeBuf(addr1)
- e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -631,6 +623,12 @@ func TestDADFail(t *testing.T) {
if want := (tcpip.AddressWithPrefix{}); addr != want {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
+
+ // Attempting to add the address again should not fail if the address's
+ // state was cleaned up when DAD failed.
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
})
}
}
@@ -639,8 +637,9 @@ func TestDADStop(t *testing.T) {
const nicID = 1
tests := []struct {
- name string
- stopFn func(t *testing.T, s *stack.Stack)
+ name string
+ stopFn func(t *testing.T, s *stack.Stack)
+ skipFinalAddrCheck bool
}{
// Tests to make sure that DAD stops when an address is removed.
{
@@ -661,6 +660,19 @@ func TestDADStop(t *testing.T) {
}
},
},
+
+ // Tests to make sure that DAD stops when the NIC is removed.
+ {
+ name: "Remove NIC",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("RemoveNIC(%d): %s", nicID, err)
+ }
+ },
+ // The NIC is removed so we can't check its addresses after calling
+ // stopFn.
+ skipFinalAddrCheck: true,
+ },
}
for _, test := range tests {
@@ -710,12 +722,15 @@ func TestDADStop(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+
+ if !test.skipFinalAddrCheck {
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
}
// Should not have sent more than 1 NS message.
@@ -901,7 +916,7 @@ func TestSetNDPConfigurations(t *testing.T) {
// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
// and DHCPv6 configurations specified.
-func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
+func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer {
icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
@@ -936,14 +951,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
DstAddr: header.IPv6AllNodesMulticastAddress,
})
- return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()}
+ return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()}
}
// raBufWithOpts returns a valid NDP Router Advertisement with options.
//
// Note, raBufWithOpts does not populate any of the RA fields other than the
// Router Lifetime.
-func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
+func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer {
return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
}
@@ -952,7 +967,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ
//
// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
// DHCPv6 related ones.
-func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer {
+func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer {
return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
}
@@ -960,7 +975,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo
//
// Note, raBuf does not populate any of the RA fields other than the
// Router Lifetime.
-func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer {
+func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer {
return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{})
}
@@ -969,7 +984,7 @@ func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer {
//
// Note, raBufWithPI does not populate any of the RA fields other than the
// Router Lifetime.
-func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer {
+func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer {
flags := uint8(0)
if onLink {
// The OnLink flag is the 7th bit in the flags byte.
@@ -1017,8 +1032,6 @@ func TestNoRouterDiscovery(t *testing.T) {
forwarding := i&4 == 0
t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 1),
}
@@ -1057,8 +1070,6 @@ func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) str
// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
// remember a discovered router when the dispatcher asks it not to.
func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 1),
}
@@ -1099,8 +1110,6 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
}
func TestRouterDiscovery(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 1),
rememberRouter: true,
@@ -1202,8 +1211,6 @@ func TestRouterDiscovery(t *testing.T) {
// TestRouterDiscoveryMaxRouters tests that only
// stack.MaxDiscoveredDefaultRouters discovered routers are remembered.
func TestRouterDiscoveryMaxRouters(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 1),
rememberRouter: true,
@@ -1270,8 +1277,6 @@ func TestNoPrefixDiscovery(t *testing.T) {
forwarding := i&4 == 0
t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
prefixC: make(chan ndpPrefixEvent, 1),
}
@@ -1311,8 +1316,6 @@ func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) st
// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not
// remember a discovered on-link prefix when the dispatcher asks it not to.
func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
- t.Parallel()
-
prefix, subnet, _ := prefixSubnetAddr(0, "")
ndpDisp := ndpDispatcher{
@@ -1356,8 +1359,6 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
}
func TestPrefixDiscovery(t *testing.T) {
- t.Parallel()
-
prefix1, subnet1, _ := prefixSubnetAddr(0, "")
prefix2, subnet2, _ := prefixSubnetAddr(1, "")
prefix3, subnet3, _ := prefixSubnetAddr(2, "")
@@ -1546,8 +1547,6 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// TestPrefixDiscoveryMaxRouters tests that only
// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3),
rememberPrefix: true,
@@ -1642,8 +1641,6 @@ func TestNoAutoGenAddr(t *testing.T) {
forwarding := i&4 == 0
t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
@@ -1968,7 +1965,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// addr2 is deprecated but if explicitly requested, it should be used.
fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address)
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
}
// Another PI w/ 0 preferred lifetime should not result in a deprecation
@@ -1981,7 +1978,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
}
expectPrimaryAddr(addr1)
if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address)
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
}
// Refresh lifetimes of addr generated from prefix2.
@@ -2093,7 +2090,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// addr1 is deprecated but if explicitly requested, it should be used.
fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
}
// Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
@@ -2106,7 +2103,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
}
expectPrimaryAddr(addr2)
if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
}
// Refresh lifetimes for addr of prefix1.
@@ -2130,7 +2127,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// addr2 should be the primary endpoint now since it is not deprecated.
expectPrimaryAddr(addr2)
if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
}
// Wait for addr of prefix1 to be invalidated.
@@ -2393,8 +2390,6 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
},
}
- const delta = 500 * time.Millisecond
-
// This Run will not return until the parallel tests finish.
//
// We need this because we need to do some teardown work after the
@@ -2447,24 +2442,21 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
// to test.evl.
//
- // Make sure we do not get any invalidation
- // events until atleast 500ms (delta) before
- // test.evl.
+ // The address should not be invalidated until the effective valid
+ // lifetime has passed.
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(time.Duration(test.evl)*time.Second - delta):
+ case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncEventTimeout):
}
- // Wait for another second (2x delta), but now
- // we expect the invalidation event.
+ // Wait for the invalidation event.
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
-
- case <-time.After(2 * delta):
+ case <-time.After(2 * defaultAsyncEventTimeout):
t.Fatal("timeout waiting for addr auto gen event")
}
})
@@ -2476,8 +2468,6 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
// by the user, its resources will be cleaned up and an invalidation event will
// be sent to the integrator.
func TestAutoGenAddrRemoval(t *testing.T) {
- t.Parallel()
-
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
ndpDisp := ndpDispatcher{
@@ -2534,8 +2524,6 @@ func TestAutoGenAddrRemoval(t *testing.T) {
// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously
// assigned to the NIC but is in the permanentExpired state.
func TestAutoGenAddrAfterRemoval(t *testing.T) {
- t.Parallel()
-
const nicID = 1
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
@@ -2582,7 +2570,7 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
AddressWithPrefix: addr2,
}
if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d, %s) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
}
// addr2 should be more preferred now since it is at the front of the primary
// list.
@@ -2647,8 +2635,6 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
// is already assigned to the NIC, the static address remains.
func TestAutoGenAddrStaticConflict(t *testing.T) {
- t.Parallel()
-
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
ndpDisp := ndpDispatcher{
@@ -2704,8 +2690,6 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use
// opaque interface identifiers when configured to do so.
func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
- t.Parallel()
-
const nicID = 1
const nicName = "nic1"
var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
@@ -2805,12 +2789,465 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
}
}
+// TestAutoGenAddrWithOpaqueIIDDADRetries tests the regeneration of an
+// auto-generated IPv6 address in response to a DAD conflict.
+func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) {
+ const nicID = 1
+ const nicName = "nic"
+ const dadTransmits = 1
+ const retransmitTimer = time.Second
+ const maxMaxRetries = 3
+ const lifetimeSeconds = 10
+
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
+ secretKey := secretKeyBuf[:]
+ n, err := rand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ }
+
+ prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
+
+ for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ {
+ for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ {
+ addrTypes := []struct {
+ name string
+ ndpConfigs stack.NDPConfigurations
+ autoGenLinkLocal bool
+ subnet tcpip.Subnet
+ triggerSLAACFn func(e *channel.Endpoint)
+ }{
+ {
+ name: "Global address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ subnet: subnet,
+ triggerSLAACFn: func(e *channel.Endpoint) {
+ // Receive an RA with prefix1 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+
+ },
+ },
+ {
+ name: "LinkLocal address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ autoGenLinkLocal: true,
+ subnet: header.IPv6LinkLocalPrefix.Subnet(),
+ triggerSLAACFn: func(e *channel.Endpoint) {},
+ },
+ }
+
+ for _, addrType := range addrTypes {
+ maxRetries := maxRetries
+ numFailures := numFailures
+ addrType := addrType
+
+ t.Run(fmt.Sprintf("%s with %d max retries and %d failures", addrType.name, maxRetries, numFailures), func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })
+ opts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ addrType.triggerSLAACFn(e)
+
+ // Simulate DAD conflicts so the address is regenerated.
+ for i := uint8(0); i < numFailures; i++ {
+ addrBytes := []byte(addrType.subnet.ID())
+ addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, i, secretKey)),
+ PrefixLen: 64,
+ }
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ // Should not have any addresses assigned to the NIC.
+ mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); mainAddr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want)
+ }
+
+ // Simulate a DAD conflict.
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+
+ // Attempting to add the address manually should not fail if the
+ // address's state was cleaned up when DAD failed.
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err)
+ }
+ if err := s.RemoveAddress(nicID, addr.Address); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
+ }
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+ }
+
+ // Should not have any addresses assigned to the NIC.
+ mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); mainAddr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want)
+ }
+
+ // If we had less failures than generation attempts, we should have an
+ // address after DAD resolves.
+ if maxRetries+1 > numFailures {
+ addrBytes := []byte(addrType.subnet.ID())
+ addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, numFailures, secretKey)),
+ PrefixLen: 64,
+ }
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+
+ mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ }
+ if mainAddr != addr {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, addr)
+ }
+ }
+
+ // Should not attempt address regeneration again.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
+ case <-time.After(defaultAsyncEventTimeout):
+ }
+ })
+ }
+ }
+ }
+}
+
+// TestAutoGenAddrWithEUI64IIDNoDADRetries tests that a regeneration attempt is
+// not made for SLAAC addresses generated with an IID based on the NIC's link
+// address.
+func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
+ const nicID = 1
+ const dadTransmits = 1
+ const retransmitTimer = time.Second
+ const maxRetries = 3
+ const lifetimeSeconds = 10
+
+ prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
+
+ addrTypes := []struct {
+ name string
+ ndpConfigs stack.NDPConfigurations
+ autoGenLinkLocal bool
+ subnet tcpip.Subnet
+ triggerSLAACFn func(e *channel.Endpoint)
+ }{
+ {
+ name: "Global address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ subnet: subnet,
+ triggerSLAACFn: func(e *channel.Endpoint) {
+ // Receive an RA with prefix1 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+
+ },
+ },
+ {
+ name: "LinkLocal address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ autoGenLinkLocal: true,
+ subnet: header.IPv6LinkLocalPrefix.Subnet(),
+ triggerSLAACFn: func(e *channel.Endpoint) {},
+ },
+ }
+
+ for _, addrType := range addrTypes {
+ addrType := addrType
+
+ t.Run(addrType.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ addrType.triggerSLAACFn(e)
+
+ addrBytes := []byte(addrType.subnet.ID())
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr1, addrBytes[header.IIDOffsetInIPv6Address:])
+ addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: 64,
+ }
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ // Simulate a DAD conflict.
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+
+ // Should not attempt address regeneration.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
+ case <-time.After(defaultAsyncEventTimeout):
+ }
+ })
+ }
+}
+
+// TestAutoGenAddrContinuesLifetimesAfterRetry tests that retrying address
+// generation in response to DAD conflicts does not refresh the lifetimes.
+func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
+ const nicID = 1
+ const nicName = "nic"
+ const dadTransmits = 1
+ const retransmitTimer = 2 * time.Second
+ const failureTimer = time.Second
+ const maxRetries = 1
+ const lifetimeSeconds = 5
+
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
+ secretKey := secretKeyBuf[:]
+ n, err := rand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ }
+
+ prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })
+ opts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+
+ addrBytes := []byte(subnet.ID())
+ addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 0, secretKey)),
+ PrefixLen: 64,
+ }
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ // Simulate a DAD conflict after some time has passed.
+ time.Sleep(failureTimer)
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+
+ // Let the next address resolve.
+ addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey))
+ expectAutoGenAddrEvent(addr, newAddr)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+
+ // Address should be deprecated/invalidated after the lifetime expires.
+ //
+ // Note, the remaining lifetime is calculated from when the PI was first
+ // processed. Since we wait for some time before simulating a DAD conflict
+ // and more time for the new address to resolve, the new address is only
+ // expected to be valid for the remaining time. The DAD conflict should
+ // not have reset the lifetimes.
+ //
+ // We expect either just the invalidation event or the deprecation event
+ // followed by the invalidation event.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if e.eventType == deprecatedAddr {
+ if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation")
+ }
+ } else {
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for auto gen addr event")
+ }
+}
+
// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event
// to the integrator when an RA is received with the NDP Recursive DNS Server
// option with at least one valid address.
func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
- t.Parallel()
-
tests := []struct {
name string
opt header.NDPRecursiveDNSServer
@@ -2902,11 +3339,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
}
for _, test := range tests {
- test := test
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
// We do not expect more than a single RDNSS
// event at any time for this test.
@@ -2956,8 +3389,6 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
// TestCleanupNDPState tests that all discovered routers and prefixes, and
// auto-generated addresses are invalidated when a NIC becomes a router.
func TestCleanupNDPState(t *testing.T) {
- t.Parallel()
-
const (
lifetimeSeconds = 5
maxRouterAndPrefixEvents = 4
@@ -2983,11 +3414,12 @@ func TestCleanupNDPState(t *testing.T) {
cleanupFn func(t *testing.T, s *stack.Stack)
keepAutoGenLinkLocal bool
maxAutoGenAddrEvents int
+ skipFinalAddrCheck bool
}{
// A NIC should still keep its auto-generated link-local address when
// becoming a router.
{
- name: "Forwarding Enable",
+ name: "Enable forwarding",
cleanupFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
s.SetForwarding(true)
@@ -2998,7 +3430,7 @@ func TestCleanupNDPState(t *testing.T) {
// A NIC should cleanup all NDP state when it is disabled.
{
- name: "NIC Disable",
+ name: "Disable NIC",
cleanupFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
@@ -3012,6 +3444,26 @@ func TestCleanupNDPState(t *testing.T) {
keepAutoGenLinkLocal: false,
maxAutoGenAddrEvents: 6,
},
+
+ // A NIC should cleanup all NDP state when it is removed.
+ {
+ name: "Remove NIC",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ if err := s.RemoveNIC(nicID1); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err)
+ }
+ if err := s.RemoveNIC(nicID2); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err)
+ }
+ },
+ keepAutoGenLinkLocal: false,
+ maxAutoGenAddrEvents: 6,
+ // The NICs are removed so we can't check their addresses after calling
+ // stopFn.
+ skipFinalAddrCheck: true,
+ },
}
for _, test := range tests {
@@ -3230,35 +3682,37 @@ func TestCleanupNDPState(t *testing.T) {
t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
}
- // Make sure the auto-generated addresses got removed.
- nicinfo = s.NICInfo()
- nic1Addrs = nicinfo[nicID1].ProtocolAddresses
- nic2Addrs = nicinfo[nicID2].ProtocolAddresses
- if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
- if test.keepAutoGenLinkLocal {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- } else {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ if !test.skipFinalAddrCheck {
+ // Make sure the auto-generated addresses got removed.
+ nicinfo = s.NICInfo()
+ nic1Addrs = nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs = nicinfo[nicID2].ProtocolAddresses
+ if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
}
- }
- if containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
- if test.keepAutoGenLinkLocal {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- } else {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ if containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ }
+ if containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
}
- }
- if containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
}
// Should not get any more events (invalidation timers should have been
@@ -3377,13 +3831,15 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
// TestRouterSolicitation tests the initial Router Solicitations that are sent
// when a NIC newly becomes enabled.
func TestRouterSolicitation(t *testing.T) {
- t.Parallel()
-
const nicID = 1
tests := []struct {
name string
linkHeaderLen uint16
+ linkAddr tcpip.LinkAddress
+ nicAddr tcpip.Address
+ expectedSrcAddr tcpip.Address
+ expectedNDPOpts []header.NDPOption
maxRtrSolicit uint8
rtrSolicitInt time.Duration
effectiveRtrSolicitInt time.Duration
@@ -3391,34 +3847,54 @@ func TestRouterSolicitation(t *testing.T) {
effectiveMaxRtrSolicitDelay time.Duration
}{
{
- name: "Single RS with delay",
+ name: "Single RS with 2s delay and interval",
+ expectedSrcAddr: header.IPv6Any,
maxRtrSolicit: 1,
- rtrSolicitInt: time.Second,
- effectiveRtrSolicitInt: time.Second,
- maxRtrSolicitDelay: time.Second,
- effectiveMaxRtrSolicitDelay: time.Second,
+ rtrSolicitInt: 2 * time.Second,
+ effectiveRtrSolicitInt: 2 * time.Second,
+ maxRtrSolicitDelay: 2 * time.Second,
+ effectiveMaxRtrSolicitDelay: 2 * time.Second,
+ },
+ {
+ name: "Single RS with 4s delay and interval",
+ expectedSrcAddr: header.IPv6Any,
+ maxRtrSolicit: 1,
+ rtrSolicitInt: 4 * time.Second,
+ effectiveRtrSolicitInt: 4 * time.Second,
+ maxRtrSolicitDelay: 4 * time.Second,
+ effectiveMaxRtrSolicitDelay: 4 * time.Second,
},
{
name: "Two RS with delay",
linkHeaderLen: 1,
+ nicAddr: llAddr1,
+ expectedSrcAddr: llAddr1,
maxRtrSolicit: 2,
- rtrSolicitInt: time.Second,
- effectiveRtrSolicitInt: time.Second,
+ rtrSolicitInt: 2 * time.Second,
+ effectiveRtrSolicitInt: 2 * time.Second,
maxRtrSolicitDelay: 500 * time.Millisecond,
effectiveMaxRtrSolicitDelay: 500 * time.Millisecond,
},
{
- name: "Single RS without delay",
- linkHeaderLen: 2,
+ name: "Single RS without delay",
+ linkHeaderLen: 2,
+ linkAddr: linkAddr1,
+ nicAddr: llAddr1,
+ expectedSrcAddr: llAddr1,
+ expectedNDPOpts: []header.NDPOption{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ },
maxRtrSolicit: 1,
- rtrSolicitInt: time.Second,
- effectiveRtrSolicitInt: time.Second,
+ rtrSolicitInt: 2 * time.Second,
+ effectiveRtrSolicitInt: 2 * time.Second,
maxRtrSolicitDelay: 0,
effectiveMaxRtrSolicitDelay: 0,
},
{
name: "Two RS without delay and invalid zero interval",
linkHeaderLen: 3,
+ linkAddr: linkAddr1,
+ expectedSrcAddr: header.IPv6Any,
maxRtrSolicit: 2,
rtrSolicitInt: 0,
effectiveRtrSolicitInt: 4 * time.Second,
@@ -3427,6 +3903,8 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Three RS without delay",
+ linkAddr: linkAddr1,
+ expectedSrcAddr: header.IPv6Any,
maxRtrSolicit: 3,
rtrSolicitInt: 500 * time.Millisecond,
effectiveRtrSolicitInt: 500 * time.Millisecond,
@@ -3435,6 +3913,8 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Two RS with invalid negative delay",
+ linkAddr: linkAddr1,
+ expectedSrcAddr: header.IPv6Any,
maxRtrSolicit: 2,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3456,14 +3936,16 @@ func TestRouterSolicitation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
+
e := channelLinkWithHeaderLength{
- Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1),
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
headerLength: test.linkHeaderLen,
}
e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
waitForPkt := func(timeout time.Duration) {
t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), timeout)
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
p, ok := e.ReadContext(ctx)
if !ok {
t.Fatal("timed out waiting for packet")
@@ -3481,10 +3963,10 @@ func TestRouterSolicitation(t *testing.T) {
checker.IPv6(t,
p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
+ checker.SrcAddr(test.expectedSrcAddr),
checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
checker.TTL(header.NDPHopLimit),
- checker.NDPRS(),
+ checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
)
if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
@@ -3493,7 +3975,8 @@ func TestRouterSolicitation(t *testing.T) {
}
waitForNothing := func(timeout time.Duration) {
t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), timeout)
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet")
}
@@ -3510,23 +3993,33 @@ func TestRouterSolicitation(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- // Make sure each RS got sent at the right
- // times.
+ if addr := test.nicAddr; addr != "" {
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
+ }
+ }
+
+ // Make sure each RS is sent at the right time.
remaining := test.maxRtrSolicit
if remaining > 0 {
waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
remaining--
}
+
for ; remaining > 0; remaining-- {
- waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout)
- waitForPkt(defaultAsyncEventTimeout)
+ if test.effectiveRtrSolicitInt > defaultAsyncEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncEventTimeout)
+ waitForPkt(2 * defaultAsyncEventTimeout)
+ } else {
+ waitForPkt(test.effectiveRtrSolicitInt * defaultAsyncEventTimeout)
+ }
}
// Make sure no more RS.
if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt + defaultTimeout)
+ waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncEventTimeout)
} else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultTimeout)
+ waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
}
// Make sure the counter got properly
@@ -3540,27 +4033,27 @@ func TestRouterSolicitation(t *testing.T) {
}
func TestStopStartSolicitingRouters(t *testing.T) {
- t.Parallel()
-
const nicID = 1
+ const delay = 0
const interval = 500 * time.Millisecond
- const delay = time.Second
const maxRtrSolicitations = 3
tests := []struct {
name string
startFn func(t *testing.T, s *stack.Stack)
- stopFn func(t *testing.T, s *stack.Stack)
+ // first is used to tell stopFn that it is being called for the first time
+ // after router solicitations were last enabled.
+ stopFn func(t *testing.T, s *stack.Stack, first bool)
}{
// Tests that when forwarding is enabled or disabled, router solicitations
// are stopped or started, respectively.
{
- name: "Forwarding enabled and disabled",
+ name: "Enable and disable forwarding",
startFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
s.SetForwarding(false)
},
- stopFn: func(t *testing.T, s *stack.Stack) {
+ stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
t.Helper()
s.SetForwarding(true)
},
@@ -3569,7 +4062,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Tests that when a NIC is enabled or disabled, router solicitations
// are started or stopped, respectively.
{
- name: "NIC disabled and enabled",
+ name: "Enable and disable NIC",
startFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
@@ -3577,7 +4070,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
}
},
- stopFn: func(t *testing.T, s *stack.Stack) {
+ stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
t.Helper()
if err := s.DisableNIC(nicID); err != nil {
@@ -3585,6 +4078,25 @@ func TestStopStartSolicitingRouters(t *testing.T) {
}
},
},
+
+ // Tests that when a NIC is removed, router solicitations are stopped. We
+ // cannot start router solications on a removed NIC.
+ {
+ name: "Remove NIC",
+ stopFn: func(t *testing.T, s *stack.Stack, first bool) {
+ t.Helper()
+
+ // Only try to remove the NIC the first time stopFn is called since it's
+ // impossible to remove an already removed NIC.
+ if !first {
+ return
+ }
+
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
+ }
+ },
+ },
}
for _, test := range tests {
@@ -3598,7 +4110,6 @@ func TestStopStartSolicitingRouters(t *testing.T) {
p, ok := e.ReadContext(ctx)
if !ok {
t.Fatal("timed out waiting for packet")
- return
}
if p.Proto != header.IPv6ProtocolNumber {
@@ -3623,12 +4134,12 @@ func TestStopStartSolicitingRouters(t *testing.T) {
}
// Stop soliciting routers.
- test.stopFn(t, s)
- ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout)
+ test.stopFn(t, s, true /* first */)
+ ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
- // A single RS may have been sent before forwarding was enabled.
- ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout)
+ // A single RS may have been sent before solicitations were stopped.
+ ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout)
defer cancel()
if _, ok = e.ReadContext(ctx); ok {
t.Fatal("should not have sent more than one RS message")
@@ -3637,19 +4148,24 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stopping router solicitations after it has already been stopped should
// do nothing.
- test.stopFn(t, s)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ test.stopFn(t, s, false /* first */)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
}
+ // If test.startFn is nil, there is no way to restart router solications.
+ if test.startFn == nil {
+ return
+ }
+
// Start soliciting routers.
test.startFn(t, s)
waitForPkt(delay + defaultAsyncEventTimeout)
waitForPkt(interval + defaultAsyncEventTimeout)
waitForPkt(interval + defaultAsyncEventTimeout)
- ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
@@ -3658,7 +4174,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Starting router solicitations after it has already completed should do
// nothing.
test.startFn(t, s)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after finishing router solicitations")
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 46d3a6646..016dbe15e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -15,7 +15,7 @@
package stack
import (
- "log"
+ "fmt"
"reflect"
"sort"
"strings"
@@ -54,7 +54,7 @@ type NIC struct {
primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
addressRanges []tcpip.Subnet
- mcastJoins map[NetworkEndpointID]int32
+ mcastJoins map[NetworkEndpointID]uint32
// packetEPs is protected by mu, but the contained PacketEndpoint
// values are not.
packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
@@ -121,15 +121,15 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
}
nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
- nic.mu.mcastJoins = make(map[NetworkEndpointID]int32)
+ nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32)
nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
nic.mu.ndp = ndpState{
- nic: nic,
- configs: stack.ndpConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
- autoGenAddresses: make(map[tcpip.Address]autoGenAddressState),
+ nic: nic,
+ configs: stack.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
}
// Register supported packet endpoint protocols.
@@ -165,8 +165,17 @@ func (n *NIC) disable() *tcpip.Error {
}
n.mu.Lock()
- defer n.mu.Unlock()
+ err := n.disableLocked()
+ n.mu.Unlock()
+ return err
+}
+// disableLocked disables n.
+//
+// It undoes the work done by enable.
+//
+// n MUST be locked.
+func (n *NIC) disableLocked() *tcpip.Error {
if !n.mu.enabled {
return nil
}
@@ -189,7 +198,7 @@ func (n *NIC) disable() *tcpip.Error {
}
// The NIC may have already left the multicast group.
- if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
}
@@ -305,24 +314,33 @@ func (n *NIC) remove() *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- // Detach from link endpoint, so no packet comes in.
- n.linkEP.Attach(nil)
+ n.disableLocked()
+
+ // TODO(b/151378115): come up with a better way to pick an error than the
+ // first one.
+ var err *tcpip.Error
+
+ // Forcefully leave multicast groups.
+ for nid := range n.mu.mcastJoins {
+ if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil {
+ err = tempErr
+ }
+ }
// Remove permanent and permanentTentative addresses, so no packet goes out.
- var errs []*tcpip.Error
for nid, ref := range n.mu.endpoints {
switch ref.getKind() {
case permanentTentative, permanent:
- if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil {
- errs = append(errs, err)
+ if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil {
+ err = tempErr
}
}
}
- if len(errs) > 0 {
- return errs[0]
- }
- return nil
+ // Detach from link endpoint, so no packet comes in.
+ n.linkEP.Attach(nil)
+
+ return err
}
// becomeIPv6Router transitions n into an IPv6 router.
@@ -451,7 +469,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn
cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs))
for _, r := range primaryAddrs {
// If r is not valid for outgoing connections, it is not a valid endpoint.
- if !r.isValidForOutgoing() {
+ if !r.isValidForOutgoingRLocked() {
continue
}
@@ -461,7 +479,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn
// Should never happen as we got r from the primary IPv6 endpoint list and
// ScopeForIPv6Address only returns an error if addr is not an IPv6
// address.
- log.Fatalf("header.ScopeForIPv6Address(%s): %s", addr, err)
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err))
}
cs = append(cs, ipv6AddrCandidate{
@@ -473,7 +491,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn
remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
if err != nil {
// primaryIPv6Endpoint should never be called with an invalid IPv6 address.
- log.Fatalf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
}
// Sort the addresses as per RFC 6724 section 5 rules 1-3.
@@ -969,6 +987,7 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
for i, ref := range refs {
if ref == r {
n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ refs[len(refs)-1] = nil
break
}
}
@@ -993,35 +1012,42 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadLocalAddress
}
- isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr)
+ switch r.protocol {
+ case header.IPv6ProtocolNumber:
+ return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAAPrefixInvalidation */)
+ default:
+ r.expireLocked()
+ return nil
+ }
+}
+
+func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACPrefixInvalidation bool) *tcpip.Error {
+ addr := r.addrWithPrefix()
+
+ isIPv6Unicast := header.IsV6UnicastAddress(addr.Address)
if isIPv6Unicast {
- // If we are removing a tentative IPv6 unicast address, stop
- // DAD.
- if kind == permanentTentative {
- n.mu.ndp.stopDuplicateAddressDetection(addr)
- }
+ n.mu.ndp.stopDuplicateAddressDetection(addr.Address)
// If we are removing an address generated via SLAAC, cleanup
// its SLAAC resources and notify the integrator.
if r.configType == slaac {
- n.mu.ndp.cleanupAutoGenAddrResourcesAndNotify(addr)
+ n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACPrefixInvalidation)
}
}
- r.setKind(permanentExpired)
- if !r.decRefLocked() {
- // The endpoint still has references to it.
- return nil
- }
+ r.expireLocked()
// At this point the endpoint is deleted.
// If we are removing an IPv6 unicast address, leave the solicited-node
// multicast address.
+ //
+ // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node
+ // multicast group may be left by user action.
if isIPv6Unicast {
- snmc := header.SolicitedNodeAddr(addr)
- if err := n.leaveGroupLocked(snmc); err != nil {
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
}
@@ -1081,26 +1107,31 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- return n.leaveGroupLocked(addr)
+ return n.leaveGroupLocked(addr, false /* force */)
}
// leaveGroupLocked decrements the count for the given multicast address, and
// when it reaches zero removes the endpoint for this address. n MUST be locked
// before leaveGroupLocked is called.
-func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+//
+// If force is true, then the count for the multicast addres is ignored and the
+// endpoint will be removed immediately.
+func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error {
id := NetworkEndpointID{addr}
- joins := n.mu.mcastJoins[id]
- switch joins {
- case 0:
+ joins, ok := n.mu.mcastJoins[id]
+ if !ok {
// There are no joins with this address on this NIC.
return tcpip.ErrBadLocalAddress
- case 1:
- // This is the last one, clean up.
- if err := n.removePermanentAddressLocked(addr); err != nil {
- return err
- }
}
- n.mu.mcastJoins[id] = joins - 1
+
+ joins--
+ if force || joins == 0 {
+ // There are no outstanding joins or we are forced to leave, clean up.
+ delete(n.mu.mcastJoins, id)
+ return n.removePermanentAddressLocked(addr)
+ }
+
+ n.mu.mcastJoins[id] = joins
return nil
}
@@ -1113,9 +1144,10 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
return joins != 0
}
-func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) {
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) {
r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
+
ref.ep.HandlePacket(&r, pkt)
ref.decRef()
}
@@ -1126,7 +1158,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
-func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
n.mu.RLock()
enabled := n.mu.enabled
// If the NIC is not yet enabled, don't receive any packets.
@@ -1186,6 +1218,16 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
n.stack.stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
+
+ // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet.
+ if protocol == header.IPv4ProtocolNumber {
+ ipt := n.stack.IPTables()
+ if ok := ipt.Check(Prerouting, pkt); !ok {
+ // iptables is telling us to drop the packet.
+ return
+ }
+ }
+
if ref := n.getRef(protocol, dst); ref != nil {
handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt)
return
@@ -1201,10 +1243,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
return
}
- defer r.Release()
-
- r.LocalLinkAddress = n.linkEP.LinkAddress()
- r.RemoteLinkAddress = remote
// Found a NIC.
n := r.ref.nic
@@ -1213,24 +1251,33 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef()
n.mu.RUnlock()
if ok {
+ r.LocalLinkAddress = n.linkEP.LinkAddress()
+ r.RemoteLinkAddress = remote
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
ref.ep.HandlePacket(&r, pkt)
ref.decRef()
- } else {
- // n doesn't have a destination endpoint.
- // Send the packet out of n.
- pkt.Header = buffer.NewPrependableFromView(pkt.Data.First())
- pkt.Data.RemoveFirst()
-
- // TODO(b/128629022): use route.WritePacket.
- if err := n.linkEP.WritePacket(&r, nil /* gso */, protocol, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- } else {
- n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size()))
+ r.Release()
+ return
+ }
+
+ // n doesn't have a destination endpoint.
+ // Send the packet out of n.
+ // TODO(b/128629022): move this logic to route.WritePacket.
+ if ch, err := r.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
+ // forwarder will release route.
+ return
}
+ n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
+ r.Release()
+ return
}
+
+ // The link-address resolution finished immediately.
+ n.forwardPacket(&r, protocol, pkt)
+ r.Release()
return
}
@@ -1240,9 +1287,38 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
}
+func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+
+ firstData := pkt.Data.First()
+ pkt.Data.RemoveFirst()
+
+ if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 {
+ pkt.Header = buffer.NewPrependableFromView(firstData)
+ } else {
+ firstDataLen := len(firstData)
+
+ // pkt.Header should have enough capacity to hold n.linkEP's headers.
+ pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen)
+
+ // TODO(b/151227689): avoid copying the packet when forwarding
+ if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen))
+ }
+ }
+
+ if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return
+ }
+
+ n.stats.Tx.Packets.Increment()
+ n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size()))
+}
+
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -1288,7 +1364,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// DeliverTransportControlPacket delivers control packets to the appropriate
// transport protocol endpoint.
-func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) {
state, ok := n.stack.transportProtocols[trans]
if !ok {
return
@@ -1351,10 +1427,12 @@ func (n *NIC) isAddrTentative(addr tcpip.Address) bool {
return ref.getKind() == permanentTentative
}
-// dupTentativeAddrDetected attempts to inform n that a tentative addr
-// is a duplicate on a link.
+// dupTentativeAddrDetected attempts to inform n that a tentative addr is a
+// duplicate on a link.
//
-// dupTentativeAddrDetected will delete the tentative address if it exists.
+// dupTentativeAddrDetected will remove the tentative address if it exists. If
+// the address was generated via SLAAC, an attempt will be made to generate a
+// new address.
func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
@@ -1368,7 +1446,17 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- return n.removePermanentAddressLocked(addr)
+ // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a
+ // new address will be generated for it.
+ if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACPrefixInvalidation */); err != nil {
+ return err
+ }
+
+ if ref.configType == slaac {
+ n.mu.ndp.regenerateSLAACAddr(ref.addrWithPrefix().Subnet())
+ }
+
+ return nil
}
// setNDPConfigs sets the NDP configurations for n.
@@ -1496,6 +1584,13 @@ type referencedNetworkEndpoint struct {
deprecated bool
}
+func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix {
+ return tcpip.AddressWithPrefix{
+ Address: r.ep.ID().LocalAddress,
+ PrefixLen: r.ep.PrefixLen(),
+ }
+}
+
func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind)))
}
@@ -1523,6 +1618,13 @@ func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing)
}
+// expireLocked decrements the reference count and marks the permanent endpoint
+// as expired.
+func (r *referencedNetworkEndpoint) expireLocked() {
+ r.setKind(permanentExpired)
+ r.decRefLocked()
+}
+
// decRef decrements the ref count and cleans up the endpoint once it reaches
// zero.
func (r *referencedNetworkEndpoint) decRef() {
@@ -1532,14 +1634,11 @@ func (r *referencedNetworkEndpoint) decRef() {
}
// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
-// locked. Returns true if the endpoint was removed.
-func (r *referencedNetworkEndpoint) decRefLocked() bool {
+// locked.
+func (r *referencedNetworkEndpoint) decRefLocked() {
if atomic.AddInt32(&r.refs, -1) == 0 {
r.nic.removeEndpointLocked(r)
- return true
}
-
- return false
}
// incRef increments the ref count. It must only be called when the caller is
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index edaee3b86..d672fc157 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -17,7 +17,6 @@ package stack
import (
"testing"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -45,7 +44,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
t.FailNow()
}
- nic.DeliverNetworkPacket(nil, "", "", 0, tcpip.PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
+ nic.DeliverNetworkPacket(nil, "", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index ab24372e7..dc125f25e 100644
--- a/pkg/tcpip/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -11,18 +11,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package tcpip
+package stack
-import "gvisor.dev/gvisor/pkg/tcpip/buffer"
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
// A PacketBuffer contains all the data of a network packet.
//
// As a PacketBuffer traverses up the stack, it may be necessary to pass it to
// multiple endpoints. Clone() should be called in such cases so that
// modifications to the Data field do not affect other copies.
-//
-// +stateify savable
type PacketBuffer struct {
+ // PacketBufferEntry is used to build an intrusive list of
+ // PacketBuffers.
+ PacketBufferEntry
+
// Data holds the payload of the packet. For inbound packets, it also
// holds the headers, which are consumed as the packet moves up the
// stack. Headers are guaranteed not to be split across views.
@@ -31,14 +36,6 @@ type PacketBuffer struct {
// or otherwise modified.
Data buffer.VectorisedView
- // DataOffset is used for GSO output. It is the offset into the Data
- // field where the payload of this packet starts.
- DataOffset int
-
- // DataSize is used for GSO output. It is the size of this packet's
- // payload.
- DataSize int
-
// Header holds the headers of outbound packets. As a packet is passed
// down the stack, each layer adds to Header.
Header buffer.Prependable
@@ -55,6 +52,14 @@ type PacketBuffer struct {
LinkHeader buffer.View
NetworkHeader buffer.View
TransportHeader buffer.View
+
+ // Hash is the transport layer hash of this packet. A value of zero
+ // indicates no valid hash has been set.
+ Hash uint32
+
+ // Owner is implemented by task to get the uid and gid.
+ // Only set for locally generated packets.
+ Owner tcpip.PacketOwner
}
// Clone makes a copy of pk. It clones the Data field, which creates a new
diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go
new file mode 100644
index 000000000..421fb5c15
--- /dev/null
+++ b/pkg/tcpip/stack/rand.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ mathrand "math/rand"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// lockedRandomSource provides a threadsafe rand.Source.
+type lockedRandomSource struct {
+ mu sync.Mutex
+ src mathrand.Source
+}
+
+func (r *lockedRandomSource) Int63() (n int64) {
+ r.mu.Lock()
+ n = r.src.Int63()
+ r.mu.Unlock()
+ return n
+}
+
+func (r *lockedRandomSource) Seed(seed int64) {
+ r.mu.Lock()
+ r.src.Seed(seed)
+ r.mu.Unlock()
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index f9fd8f18f..23ca9ee03 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -67,12 +67,12 @@ type TransportEndpoint interface {
// this transport endpoint. It sets pkt.TransportHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer)
+ HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer)
// HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
// HandleControlPacket takes ownership of pkt.
- HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer)
// Abort initiates an expedited endpoint teardown. It puts the endpoint
// in a closed state and frees all resources associated with it. This
@@ -100,7 +100,7 @@ type RawTransportEndpoint interface {
// layer up.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt tcpip.PacketBuffer)
+ HandlePacket(r *Route, pkt PacketBuffer)
}
// PacketEndpoint is the interface that needs to be implemented by packet
@@ -118,7 +118,7 @@ type PacketEndpoint interface {
// should construct its own ethernet header for applications.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
+ HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer)
}
// TransportProtocol is the interface that needs to be implemented by transport
@@ -150,7 +150,7 @@ type TransportProtocol interface {
// stats purposes only).
//
// HandleUnknownDestinationPacket takes ownership of pkt.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -180,7 +180,7 @@ type TransportDispatcher interface {
// pkt.NetworkHeader must be set before calling DeliverTransportPacket.
//
// DeliverTransportPacket takes ownership of pkt.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer)
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
@@ -189,7 +189,7 @@ type TransportDispatcher interface {
// DeliverTransportControlPacket.
//
// DeliverTransportControlPacket takes ownership of pkt.
- DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer)
}
// PacketLooping specifies where an outbound packet should be sent.
@@ -242,15 +242,15 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have
// already been set.
- WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error
// WritePackets writes packets to the given destination address and
// protocol. pkts must not be zero length.
- WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error)
+ WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address.
- WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error
+ WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
@@ -265,7 +265,7 @@ type NetworkEndpoint interface {
// this network endpoint. It sets pkt.NetworkHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt tcpip.PacketBuffer)
+ HandlePacket(r *Route, pkt PacketBuffer)
// Close is called when the endpoint is reomved from a stack.
Close()
@@ -322,7 +322,7 @@ type NetworkDispatcher interface {
// packets sent via loopback), and won't have the field set.
//
// DeliverNetworkPacket takes ownership of pkt.
- DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
+ DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -354,7 +354,7 @@ const (
// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
// ethernet, loopback, raw) and used by network layer protocols to send packets
// out through the implementer's data link endpoint. When a link header exists,
-// it sets each tcpip.PacketBuffer's LinkHeader field before passing it up the
+// it sets each PacketBuffer's LinkHeader field before passing it up the
// stack.
type LinkEndpoint interface {
// MTU is the maximum transmission unit for this endpoint. This is
@@ -385,7 +385,7 @@ type LinkEndpoint interface {
// To participate in transparent bridging, a LinkEndpoint implementation
// should call eth.Encode with header.EthernetFields.SrcAddr set to
// r.LocalLinkAddress if it is provided.
- WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error
// WritePackets writes packets with the given protocol through the
// given route. pkts must not be zero length.
@@ -393,7 +393,7 @@ type LinkEndpoint interface {
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, it may
// require to change syscall filters.
- WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+ WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
// WriteRawPacket writes a packet directly to the link. The packet
// should already have an ethernet header.
@@ -401,6 +401,9 @@ type LinkEndpoint interface {
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
+ //
+ // Attach will be called with a nil dispatcher if the receiver's associated
+ // NIC is being removed.
Attach(dispatcher NetworkDispatcher)
// IsAttached returns whether a NetworkDispatcher is attached to the
@@ -423,7 +426,7 @@ type InjectableLinkEndpoint interface {
LinkEndpoint
// InjectInbound injects an inbound packet.
- InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
+ InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer)
// InjectOutbound writes a fully formed outbound packet directly to the
// link.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index f565aafb2..a0e5e0300 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -153,7 +153,7 @@ func (r *Route) IsResolutionRequired() bool {
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
@@ -168,29 +168,32 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.Pack
return err
}
-// WritePackets writes the set of packets through the given route.
-func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) {
+// WritePackets writes a list of n packets through the given route and returns
+// the number of packets written.
+func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
if !r.ref.isValidForOutgoing() {
return 0, tcpip.ErrInvalidEndpointState
}
n, err := r.ref.ep.WritePackets(r, gso, pkts, params)
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
}
r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n))
- payloadSize := 0
- for i := 0; i < n; i++ {
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkts[i].Header.UsedLength()))
- payloadSize += pkts[i].DataSize
+
+ writtenBytes := 0
+ for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() {
+ writtenBytes += pb.Header.UsedLength()
+ writtenBytes += pb.Data.Size()
}
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize))
+
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
return n, err
}
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error {
+func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index ebb6c5e3b..41398a1b6 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -20,7 +20,9 @@
package stack
import (
+ "bytes"
"encoding/binary"
+ mathrand "math/rand"
"sync/atomic"
"time"
@@ -31,7 +33,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/waiter"
@@ -51,7 +52,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool
+ defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -428,7 +429,7 @@ type Stack struct {
// tables are the iptables packet filtering and manipulation rules. The are
// protected by tablesMu.`
- tables iptables.IPTables
+ tables IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
// stack is being restored.
@@ -462,6 +463,14 @@ type Stack struct {
// opaqueIIDOpts hold the options for generating opaque interface identifiers
// (IIDs) as outlined by RFC 7217.
opaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // forwarder holds the packets that wait for their link-address resolutions
+ // to complete, and forwards them when each resolution is done.
+ forwarder *forwardQueue
+
+ // randomGenerator is an injectable pseudo random generator that can be
+ // used when a random number is required.
+ randomGenerator *mathrand.Rand
}
// UniqueID is an abstract generator of unique identifiers.
@@ -522,9 +531,16 @@ type Options struct {
// this is non-nil.
RawFactory RawFactory
- // OpaqueIIDOpts hold the options for generating opaque interface identifiers
- // (IIDs) as outlined by RFC 7217.
+ // OpaqueIIDOpts hold the options for generating opaque interface
+ // identifiers (IIDs) as outlined by RFC 7217.
OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // RandSource is an optional source to use to generate random
+ // numbers. If omitted it defaults to a Source seeded by the data
+ // returned by rand.Read().
+ //
+ // RandSource must be thread-safe.
+ RandSource mathrand.Source
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -551,11 +567,13 @@ type TransportEndpointInfo struct {
RegisterNICID tcpip.NICID
}
-// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address
-// and returns the network protocol number to be used to communicate with the
-// specified address. It returns an error if the passed address is incompatible
-// with the receiver.
-func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+// AddrNetProtoLocked unwraps the specified address if it is a V4-mapped V6
+// address and returns the network protocol number to be used to communicate
+// with the specified address. It returns an error if the passed address is
+// incompatible with the receiver.
+//
+// Preconditon: the parent endpoint mu must be held while calling this method.
+func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
netProto := e.NetProto
switch len(addr.Addr) {
case header.IPv4AddressSize:
@@ -618,6 +636,13 @@ func New(opts Options) *Stack {
opts.UniqueID = new(uniqueIDGenerator)
}
+ randSrc := opts.RandSource
+ if randSrc == nil {
+ // Source provided by mathrand.NewSource is not thread-safe so
+ // we wrap it in a simple thread-safe version.
+ randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
+ }
+
// Make sure opts.NDPConfigs contains valid values only.
opts.NDPConfigs.validate()
@@ -639,6 +664,8 @@ func New(opts Options) *Stack {
uniqueIDGenerator: opts.UniqueID,
ndpDisp: opts.NDPDisp,
opaqueIIDOpts: opts.OpaqueIIDOpts,
+ forwarder: newForwardQueue(),
+ randomGenerator: mathrand.New(randSrc),
}
// Add specified network protocols.
@@ -731,7 +758,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, tcpip.PacketBuffer) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -1694,7 +1721,7 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool,
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() iptables.IPTables {
+func (s *Stack) IPTables() IPTables {
s.tablesMu.RLock()
t := s.tables
s.tablesMu.RUnlock()
@@ -1702,7 +1729,7 @@ func (s *Stack) IPTables() iptables.IPTables {
}
// SetIPTables sets the stack's iptables.
-func (s *Stack) SetIPTables(ipt iptables.IPTables) {
+func (s *Stack) SetIPTables(ipt IPTables) {
s.tablesMu.Lock()
s.tables = ipt
s.tablesMu.Unlock()
@@ -1812,6 +1839,12 @@ func (s *Stack) Seed() uint32 {
return s.seed
}
+// Rand returns a reference to a pseudo random generator that can be used
+// to generate random numbers as required.
+func (s *Stack) Rand() *mathrand.Rand {
+ return s.randomGenerator
+}
+
func generateRandUint32() uint32 {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
@@ -1819,3 +1852,16 @@ func generateRandUint32() uint32 {
}
return binary.LittleEndian.Uint32(b)
}
+
+func generateRandInt64() int64 {
+ b := make([]byte, 8)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ buf := bytes.NewReader(b)
+ var v int64
+ if err := binary.Read(buf, binary.LittleEndian, &v); err != nil {
+ panic(err)
+ }
+ return v
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index e15db40fb..c7634ceb1 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -90,7 +90,7 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
return &f.id
}
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
@@ -126,7 +126,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return f.ep.Capabilities()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
- f.HandlePacket(r, tcpip.PacketBuffer{
+ f.HandlePacket(r, stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
}
@@ -153,11 +153,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -255,7 +255,7 @@ type linkEPWithMockedAttach struct {
// Attach implements stack.LinkEndpoint.Attach.
func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) {
l.LinkEndpoint.Attach(d)
- l.attached = true
+ l.attached = d != nil
}
func (l *linkEPWithMockedAttach) isAttached() bool {
@@ -287,7 +287,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet with wrong address is not delivered.
buf[0] = 3
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 0 {
@@ -299,7 +299,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to first endpoint.
buf[0] = 1
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -311,7 +311,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to second endpoint.
buf[0] = 2
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -322,7 +322,7 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -334,7 +334,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -356,7 +356,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
func send(r stack.Route, payload buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
})
@@ -414,7 +414,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b
func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
t.Helper()
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if got := fakeNet.PacketCount(localAddrByte); got != want {
@@ -566,7 +566,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) {
t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err)
}
if !e.isAttached() {
- t.Fatalf("link endpoint not attached to a network disatcher")
+ t.Fatal("link endpoint not attached to a network dispatcher")
}
})
}
@@ -631,196 +631,240 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
checkNIC(false)
}
-func TestRoutesWithDisabledNIC(t *testing.T) {
- const unspecifiedNIC = 0
- const nicID1 = 1
- const nicID2 = 2
-
+func TestRemoveUnknownNIC(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
- ep1 := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
}
+}
- addr1 := tcpip.Address("\x01")
- if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
- }
+func TestRemoveNIC(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- ep2 := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID2, ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: loopback.New(),
+ }
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addr2 := tcpip.Address("\x02")
- if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ // NIC should be present in NICInfo and attached to a NetworkDispatcher.
+ allNICInfo := s.NICInfo()
+ if _, ok := allNICInfo[nicID]; !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ }
+ if !e.isAttached() {
+ t.Fatal("link endpoint not attached to a network dispatcher")
}
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- {
- subnet0, err := tcpip.NewSubnet("\x00", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
- {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
- })
+ // Removing a NIC should remove it from NICInfo and e should be detached from
+ // the NetworkDispatcher.
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
+ }
+ if nicInfo, ok := s.NICInfo()[nicID]; ok {
+ t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo)
+ }
+ if e.isAttached() {
+ t.Error("link endpoint for removed NIC still attached to a network dispatcher")
}
+}
- // Test routes to odd address.
- testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
- testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
- testRoute(t, s, nicID1, addr1, "\x05", addr1)
+func TestRouteWithDownNIC(t *testing.T) {
+ tests := []struct {
+ name string
+ downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
+ upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
+ }{
+ {
+ name: "Disabled NIC",
+ downFn: (*stack.Stack).DisableNIC,
+ upFn: (*stack.Stack).EnableNIC,
+ },
+
+ // Once a NIC is removed, it cannot be brought up.
+ {
+ name: "Removed NIC",
+ downFn: (*stack.Stack).RemoveNIC,
+ },
+ }
- // Test routes to even address.
- testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
- testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
- testRoute(t, s, nicID2, addr2, "\x06", addr2)
-
- // Disabling NIC1 should result in no routes to odd addresses. Routes to even
- // addresses should continue to be available as NIC2 is still enabled.
- if err := s.DisableNIC(nicID1); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
- }
- nic1Dst := tcpip.Address("\x05")
- testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
- testNoRoute(t, s, nicID1, addr1, nic1Dst)
- nic2Dst := tcpip.Address("\x06")
- testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
- testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
- testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
-
- // Disabling NIC2 should result in no routes to even addresses. No route
- // should be available to any address as routes to odd addresses were made
- // unavailable by disabling NIC1 above.
- if err := s.DisableNIC(nicID2); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
- }
- testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
- testNoRoute(t, s, nicID1, addr1, nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
- testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
- testNoRoute(t, s, nicID2, addr2, nic2Dst)
-
- // Enabling NIC1 should make routes to odd addresses available again. Routes
- // to even addresses should continue to be unavailable as NIC2 is still
- // disabled.
- if err := s.EnableNIC(nicID1); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
- }
- testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
- testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
- testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
- testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
- testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
- testNoRoute(t, s, nicID2, addr2, nic2Dst)
-}
-
-func TestRouteWritePacketWithDisabledNIC(t *testing.T) {
const unspecifiedNIC = 0
const nicID1 = 1
const nicID2 = 2
+ const addr1 = tcpip.Address("\x01")
+ const addr2 = tcpip.Address("\x02")
+ const nic1Dst = tcpip.Address("\x05")
+ const nic2Dst = tcpip.Address("\x06")
+
+ setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
- ep1 := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
- addr1 := tcpip.Address("\x01")
- if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
- }
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
- ep2 := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID2, ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
- addr2 := tcpip.Address("\x02")
- if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ return s, ep1, ep2
}
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- {
- subnet0, err := tcpip.NewSubnet("\x00", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\x01")
- if err != nil {
- t.Fatal(err)
+ // Tests that routes through a down NIC are not used when looking up a route
+ // for a destination.
+ t.Run("Find", func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, _, _ := setup(t)
+
+ // Test routes to odd address.
+ testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
+ testRoute(t, s, nicID1, addr1, "\x05", addr1)
+
+ // Test routes to even address.
+ testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
+ testRoute(t, s, nicID2, addr2, "\x06", addr2)
+
+ // Bringing NIC1 down should result in no routes to odd addresses. Routes to
+ // even addresses should continue to be available as NIC2 is still up.
+ if err := test.downFn(s, nicID1); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
+ testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
+
+ // Bringing NIC2 down should result in no routes to even addresses. No
+ // route should be available to any address as routes to odd addresses
+ // were made unavailable by bringing NIC1 down above.
+ if err := test.downFn(s, nicID2); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+
+ if upFn := test.upFn; upFn != nil {
+ // Bringing NIC1 up should make routes to odd addresses available
+ // again. Routes to even addresses should continue to be unavailable
+ // as NIC2 is still down.
+ if err := upFn(s, nicID1); err != nil {
+ t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
+ }
+ testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
+ testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+ }
+ })
}
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
- {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
- })
- }
+ })
- nic1Dst := tcpip.Address("\x05")
- r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
- }
- defer r1.Release()
+ // Tests that writing a packet using a Route through a down NIC fails.
+ t.Run("WritePacket", func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep1, ep2 := setup(t)
- nic2Dst := tcpip.Address("\x06")
- r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
- }
- defer r2.Release()
+ r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
+ }
+ defer r1.Release()
- // If we failed to get routes r1 or r2, we cannot proceed with the test.
- if t.Failed() {
- t.FailNow()
- }
+ r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
+ }
+ defer r2.Release()
- buf := buffer.View([]byte{1})
- testSend(t, r1, ep1, buf)
- testSend(t, r2, ep2, buf)
+ // If we failed to get routes r1 or r2, we cannot proceed with the test.
+ if t.Failed() {
+ t.FailNow()
+ }
- // Writes with Routes that use the disabled NIC1 should fail.
- if err := s.DisableNIC(nicID1); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
- }
- testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
- testSend(t, r2, ep2, buf)
+ buf := buffer.View([]byte{1})
+ testSend(t, r1, ep1, buf)
+ testSend(t, r2, ep2, buf)
- // Writes with Routes that use the disabled NIC2 should fail.
- if err := s.DisableNIC(nicID2); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
- }
- testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
- testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ // Writes with Routes that use NIC1 after being brought down should fail.
+ if err := test.downFn(s, nicID1); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testSend(t, r2, ep2, buf)
- // Writes with Routes that use the re-enabled NIC1 should succeed.
- // TODO(b/147015577): Should we instead completely invalidate all Routes that
- // were bound to a disabled NIC at some point?
- if err := s.EnableNIC(nicID1); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
- }
- testSend(t, r1, ep1, buf)
- testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ // Writes with Routes that use NIC2 after being brought down should fail.
+ if err := test.downFn(s, nicID2); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+
+ if upFn := test.upFn; upFn != nil {
+ // Writes with Routes that use NIC1 after being brought up should
+ // succeed.
+ //
+ // TODO(b/147015577): Should we instead completely invalidate all
+ // Routes that were bound to a NIC that was brought down at some
+ // point?
+ if err := upFn(s, nicID1); err != nil {
+ t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
+ }
+ testSend(t, r1, ep1, buf)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ }
+ })
+ }
+ })
}
func TestRoutes(t *testing.T) {
@@ -1401,19 +1445,19 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}}
if err := s.AddProtocolAddress(1, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err)
+ t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err)
}
r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
+ t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
- t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
+ t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
// If the NIC doesn't exist, it won't work.
if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
}
}
@@ -1439,12 +1483,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
}
nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr}
if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err)
+ t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err)
}
nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr}
if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
- t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err)
+ t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err)
}
// Set the initial route table.
@@ -1459,10 +1503,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
// When an interface is given, the route for a broadcast goes through it.
r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
+ t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
- t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
+ t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
// When an interface is not given, it consults the route table.
@@ -2213,7 +2257,7 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
@@ -2240,56 +2284,84 @@ func TestNICStats(t *testing.T) {
}
func TestNICForwarding(t *testing.T) {
- // Create a stack with the fake network protocol, two NICs, each with
- // an address.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- s.SetForwarding(true)
+ const nicID1 = 1
+ const nicID2 = 2
+ const dstAddr = tcpip.Address("\x03")
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC #1 failed:", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress #1 failed:", err)
+ tests := []struct {
+ name string
+ headerLen uint16
+ }{
+ {
+ name: "Zero header length",
+ },
+ {
+ name: "Non-zero header length",
+ headerLen: 16,
+ },
}
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatal("CreateNIC #2 failed:", err)
- }
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress #2 failed:", err)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ s.SetForwarding(true)
- // Route all packets to address 3 to NIC 2.
- {
- subnet, err := tcpip.NewSubnet("\x03", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}})
- }
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err)
+ }
- // Send a packet to address 3.
- buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
+ ep2 := channelLinkWithHeaderLength{
+ Endpoint: channel.New(10, defaultMTU, ""),
+ headerLength: test.headerLen,
+ }
+ if err := s.CreateNIC(nicID2, &ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err)
+ }
- if _, ok := ep2.Read(); !ok {
- t.Fatal("Packet not forwarded")
- }
+ // Route all packets to dstAddr to NIC 2.
+ {
+ subnet, err := tcpip.NewSubnet(dstAddr, "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}})
+ }
- // Test that forwarding increments Tx stats correctly.
- if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
- }
+ // Send a packet to dstAddr.
+ buf := buffer.NewView(30)
+ buf[0] = dstAddr[0]
+ ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
- if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
+ pkt, ok := ep2.Read()
+ if !ok {
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the link's MaxHeaderLength is honoured.
+ if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want {
+ t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want)
+ }
+
+ // Test that forwarding increments Tx stats correctly.
+ if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want {
+ t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
+ t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
+ }
+ })
}
}
@@ -2327,7 +2399,7 @@ func TestNICContextPreservation(t *testing.T) {
t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos)
}
if got, want := nicinfo.Context == test.want, true; got != want {
- t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want)
+ t.Fatalf("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want)
}
})
}
@@ -2696,7 +2768,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
if err != nil {
- t.Fatalf("NewSubnet failed:", err)
+ t.Fatalf("NewSubnet failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
@@ -2710,11 +2782,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// permanentExpired kind.
r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false)
if err != nil {
- t.Fatal("FindRoute failed:", err)
+ t.Fatalf("FindRoute failed: %v", err)
}
defer r.Release()
if err := s.RemoveAddress(1, "\x01"); err != nil {
- t.Fatalf("RemoveAddress failed:", err)
+ t.Fatalf("RemoveAddress failed: %v", err)
}
//
@@ -2726,7 +2798,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// Add some other address with peb set to
// FirstPrimaryEndpoint.
if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil {
- t.Fatal("AddAddressWithOptions failed:", err)
+ t.Fatalf("AddAddressWithOptions failed: %v", err)
}
@@ -2734,7 +2806,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// make sure the new peb was respected.
// (The address should just be promoted now).
if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil {
- t.Fatal("AddAddressWithOptions failed:", err)
+ t.Fatalf("AddAddressWithOptions failed: %v", err)
}
var primaryAddrs []tcpip.Address
for _, pa := range s.NICInfo()[1].ProtocolAddresses {
@@ -2767,11 +2839,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// GetMainNICAddress; else, our original address
// should be returned.
if err := s.RemoveAddress(1, "\x03"); err != nil {
- t.Fatalf("RemoveAddress failed:", err)
+ t.Fatalf("RemoveAddress failed: %v", err)
}
addr, err = s.GetMainNICAddress(1, fakeNetNumber)
if err != nil {
- t.Fatal("s.GetMainNICAddress failed:", err)
+ t.Fatalf("s.GetMainNICAddress failed: %v", err)
}
if ps == stack.NeverPrimaryEndpoint {
if want := (tcpip.AddressWithPrefix{}); addr != want {
@@ -3010,6 +3082,50 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
}
}
+// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6
+// address after leaving its solicited node multicast address does not result in
+// an error.
+func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ }
+
+ // The NIC should have joined addr1's solicited node multicast address.
+ snmc := header.SolicitedNodeAddr(addr1)
+ in, err := s.IsInGroup(nicID, snmc)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err)
+ }
+ if !in {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc)
+ }
+
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err)
+ }
+ in, err = s.IsInGroup(nicID, snmc)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err)
+ }
+ if in {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc)
+ }
+
+ if err := s.RemoveAddress(nicID, addr1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
+ }
+}
+
func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
const nicID = 1
@@ -3060,8 +3176,6 @@ func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC
// was disabled have DAD performed on them when the NIC is enabled.
func TestDoDADWhenNICEnabled(t *testing.T) {
- t.Parallel()
-
const dadTransmits = 1
const retransmitTimer = time.Second
const nicID = 1
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 778c0a4d6..9a33ed375 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -15,9 +15,9 @@
package stack
import (
+ "container/heap"
"fmt"
"math/rand"
- "sort"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -35,7 +35,7 @@ type protocolIDs struct {
type transportEndpoints struct {
// mu protects all fields of the transportEndpoints.
mu sync.RWMutex
- endpoints map[TransportEndpointID]*endpointsByNic
+ endpoints map[TransportEndpointID]*endpointsByNIC
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
rawEndpoints []RawTransportEndpoint
@@ -46,11 +46,11 @@ type transportEndpoints struct {
func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
eps.mu.Lock()
defer eps.mu.Unlock()
- epsByNic, ok := eps.endpoints[id]
+ epsByNIC, ok := eps.endpoints[id]
if !ok {
return
}
- if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
+ if !epsByNIC.unregisterEndpoint(bindToDevice, ep) {
return
}
delete(eps.endpoints, id)
@@ -66,18 +66,85 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
return es
}
-type endpointsByNic struct {
+// iterEndpointsLocked yields all endpointsByNIC in eps that match id, in
+// descending order of match quality. If a call to yield returns false,
+// iterEndpointsLocked stops iteration and returns immediately.
+//
+// Preconditions: eps.mu must be locked.
+func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
+ // Try to find a match with the id as provided.
+ if ep, ok := eps.endpoints[id]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+
+ // Try to find a match with the id minus the local address.
+ nid := id
+
+ nid.LocalAddress = ""
+ if ep, ok := eps.endpoints[nid]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+
+ // Try to find a match with the id minus the remote part.
+ nid.LocalAddress = id.LocalAddress
+ nid.RemoteAddress = ""
+ nid.RemotePort = 0
+ if ep, ok := eps.endpoints[nid]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+
+ // Try to find a match with only the local port.
+ nid.LocalAddress = ""
+ if ep, ok := eps.endpoints[nid]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+}
+
+// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
+// descending order of match quality.
+//
+// Preconditions: eps.mu must be locked.
+func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
+ var matchedEPs []*endpointsByNIC
+ eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
+ matchedEPs = append(matchedEPs, ep)
+ return true
+ })
+ return matchedEPs
+}
+
+// findEndpointLocked returns the endpoint that most closely matches the given id.
+//
+// Preconditions: eps.mu must be locked.
+func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
+ var matchedEP *endpointsByNIC
+ eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
+ matchedEP = ep
+ return false
+ })
+ return matchedEP
+}
+
+type endpointsByNIC struct {
mu sync.RWMutex
endpoints map[tcpip.NICID]*multiPortEndpoint
// seed is a random secret for a jenkins hash.
seed uint32
}
-func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint {
- epsByNic.mu.RLock()
- defer epsByNic.mu.RUnlock()
+func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
+ epsByNIC.mu.RLock()
+ defer epsByNIC.mu.RUnlock()
var eps []TransportEndpoint
- for _, ep := range epsByNic.endpoints {
+ for _, ep := range epsByNIC.endpoints {
eps = append(eps, ep.transportEndpoints()...)
}
return eps
@@ -85,13 +152,13 @@ func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) {
- epsByNic.mu.RLock()
+func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) {
+ epsByNIC.mu.RLock()
- mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
if !ok {
- if mpep, ok = epsByNic.endpoints[0]; !ok {
- epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ if mpep, ok = epsByNIC.endpoints[0]; !ok {
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
}
}
@@ -100,29 +167,29 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p
// endpoints bound to the right device.
if isMulticastOrBroadcast(id.LocalAddress) {
mpep.handlePacketAll(r, id, pkt)
- epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
}
// multiPortEndpoints are guaranteed to have at least one element.
- transEP := selectEndpoint(id, mpep, epsByNic.seed)
+ transEP := selectEndpoint(id, mpep, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
queuedProtocol.QueuePacket(r, transEP, id, pkt)
- epsByNic.mu.RUnlock()
+ epsByNIC.mu.RUnlock()
return
}
transEP.HandlePacket(r, id, pkt)
- epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) {
- epsByNic.mu.RLock()
- defer epsByNic.mu.RUnlock()
+func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) {
+ epsByNIC.mu.RLock()
+ defer epsByNIC.mu.RUnlock()
- mpep, ok := epsByNic.endpoints[n.ID()]
+ mpep, ok := epsByNIC.endpoints[n.ID()]
if !ok {
- mpep, ok = epsByNic.endpoints[0]
+ mpep, ok = epsByNIC.endpoints[0]
}
if !ok {
return
@@ -132,40 +199,41 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint
// broadcast like we are doing with handlePacket above?
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt)
+ selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt)
}
// registerEndpoint returns true if it succeeds. It fails and returns
// false if ep already has an element with the same key.
-func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
- epsByNic.mu.Lock()
- defer epsByNic.mu.Unlock()
+func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNIC.mu.Lock()
+ defer epsByNIC.mu.Unlock()
- if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok {
- // There was already a bind.
- return multiPortEp.singleRegisterEndpoint(t, reusePort)
+ multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
+ if !ok {
+ multiPortEp = &multiPortEndpoint{
+ demux: d,
+ netProto: netProto,
+ transProto: transProto,
+ reuse: reusePort,
+ }
+ epsByNIC.endpoints[bindToDevice] = multiPortEp
}
- // This is a new binding.
- multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto}
- multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
- multiPortEp.reuse = reusePort
- epsByNic.endpoints[bindToDevice] = multiPortEp
return multiPortEp.singleRegisterEndpoint(t, reusePort)
}
-// unregisterEndpoint returns true if endpointsByNic has to be unregistered.
-func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
- epsByNic.mu.Lock()
- defer epsByNic.mu.Unlock()
- multiPortEp, ok := epsByNic.endpoints[bindToDevice]
+// unregisterEndpoint returns true if endpointsByNIC has to be unregistered.
+func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
+ epsByNIC.mu.Lock()
+ defer epsByNIC.mu.Unlock()
+ multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
if !ok {
return false
}
if multiPortEp.unregisterEndpoint(t) {
- delete(epsByNic.endpoints, bindToDevice)
+ delete(epsByNIC.endpoints, bindToDevice)
}
- return len(epsByNic.endpoints) == 0
+ return len(epsByNIC.endpoints) == 0
}
// transportDemuxer demultiplexes packets targeted at a transport endpoint
@@ -183,7 +251,7 @@ type transportDemuxer struct {
// the dispatcher to delivery packets to the QueuePacket method instead of
// calling HandlePacket directly on the endpoint.
type queuedTransportProtocol interface {
- QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer)
+ QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
@@ -197,7 +265,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
for proto := range stack.transportProtocols {
protoIDs := protocolIDs{netProto, proto}
d.protocol[protoIDs] = &transportEndpoints{
- endpoints: make(map[TransportEndpointID]*endpointsByNic),
+ endpoints: make(map[TransportEndpointID]*endpointsByNIC),
}
qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol)
if isQueued {
@@ -222,6 +290,35 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
return nil
}
+type transportEndpointHeap []TransportEndpoint
+
+var _ heap.Interface = (*transportEndpointHeap)(nil)
+
+func (h *transportEndpointHeap) Len() int {
+ return len(*h)
+}
+
+func (h *transportEndpointHeap) Less(i, j int) bool {
+ return (*h)[i].UniqueID() < (*h)[j].UniqueID()
+}
+
+func (h *transportEndpointHeap) Swap(i, j int) {
+ (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
+}
+
+func (h *transportEndpointHeap) Push(x interface{}) {
+ *h = append(*h, x.(TransportEndpoint))
+}
+
+func (h *transportEndpointHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ old[n-1] = nil
+ *h = old[:n-1]
+ return x
+}
+
// multiPortEndpoint is a container for TransportEndpoints which are bound to
// the same pair of address and port. endpointsArr always has at least one
// element.
@@ -237,15 +334,14 @@ type multiPortEndpoint struct {
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
- endpointsArr []TransportEndpoint
- endpointsMap map[TransportEndpoint]int
+ endpoints transportEndpointHeap
// reuse indicates if more than one endpoint is allowed.
reuse bool
}
func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
ep.mu.RLock()
- eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+ eps := append([]TransportEndpoint(nil), ep.endpoints...)
ep.mu.RUnlock()
return eps
}
@@ -262,8 +358,8 @@ func reciprocalScale(val, n uint32) uint32 {
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
- if len(mpep.endpointsArr) == 1 {
- return mpep.endpointsArr[0]
+ if len(mpep.endpoints) == 1 {
+ return mpep.endpoints[0]
}
payload := []byte{
@@ -279,29 +375,26 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr)))
- return mpep.endpointsArr[idx]
+ idx := reciprocalScale(hash, uint32(len(mpep.endpoints)))
+ return mpep.endpoints[idx]
}
-func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) {
ep.mu.RLock()
queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
- for i, endpoint := range ep.endpointsArr {
- // HandlePacket takes ownership of pkt, so each endpoint needs
- // its own copy except for the final one.
- if i == len(ep.endpointsArr)-1 {
- if mustQueue {
- queuedProtocol.QueuePacket(r, endpoint, id, pkt)
- break
- }
- endpoint.HandlePacket(r, id, pkt)
- break
- }
+ // HandlePacket takes ownership of pkt, so each endpoint needs
+ // its own copy except for the final one.
+ for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] {
if mustQueue {
queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone())
- continue
+ } else {
+ endpoint.HandlePacket(r, id, pkt.Clone())
}
- endpoint.HandlePacket(r, id, pkt.Clone())
+ }
+ if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue {
+ queuedProtocol.QueuePacket(r, endpoint, id, pkt)
+ } else {
+ endpoint.HandlePacket(r, id, pkt)
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -312,26 +405,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo
ep.mu.Lock()
defer ep.mu.Unlock()
- if len(ep.endpointsArr) > 0 {
+ if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
if !ep.reuse || !reusePort {
return tcpip.ErrPortInUse
}
}
- // A new endpoint is added into endpointsArr and its index there is saved in
- // endpointsMap. This will allow us to remove endpoint from the array fast.
- ep.endpointsMap[t] = len(ep.endpointsArr)
- ep.endpointsArr = append(ep.endpointsArr, t)
+ heap.Push(&ep.endpoints, t)
- // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints
- // can be restored in the same order.
- sort.Slice(ep.endpointsArr, func(i, j int) bool {
- return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID()
- })
- for i, e := range ep.endpointsArr {
- ep.endpointsMap[e] = i
- }
return nil
}
@@ -340,21 +422,13 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
ep.mu.Lock()
defer ep.mu.Unlock()
- idx, ok := ep.endpointsMap[t]
- if !ok {
- return false
- }
- delete(ep.endpointsMap, t)
- l := len(ep.endpointsArr)
- if l > 1 {
- // The last endpoint in endpointsArr is moved instead of the deleted one.
- lastEp := ep.endpointsArr[l-1]
- ep.endpointsArr[idx] = lastEp
- ep.endpointsMap[lastEp] = idx
- ep.endpointsArr = ep.endpointsArr[0 : l-1]
- return false
+ for i, endpoint := range ep.endpoints {
+ if endpoint == t {
+ heap.Remove(&ep.endpoints, i)
+ break
+ }
}
- return true
+ return len(ep.endpoints) == 0
}
func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
@@ -371,19 +445,16 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
eps.mu.Lock()
defer eps.mu.Unlock()
- if epsByNic, ok := eps.endpoints[id]; ok {
- // There was already a binding.
- return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
- }
-
- // This is a new binding.
- epsByNic := &endpointsByNic{
- endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
- seed: rand.Uint32(),
+ epsByNIC, ok := eps.endpoints[id]
+ if !ok {
+ epsByNIC = &endpointsByNIC{
+ endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
+ seed: rand.Uint32(),
+ }
+ eps.endpoints[id] = epsByNIC
}
- eps.endpoints[id] = epsByNic
- return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
+ return epsByNIC.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
@@ -396,84 +467,60 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN
}
}
-var loopbackSubnet = func() tcpip.Subnet {
- sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
- if err != nil {
- panic(err)
- }
- return sn
-}()
-
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if
// the packet no longer needs to be handled.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
}
- eps.mu.RLock()
-
- // Determine which transport endpoint or endpoints to deliver this packet to.
// If the packet is a UDP broadcast or multicast, then find all matching
- // transport endpoints. If the packet is a TCP packet with a non-unicast
- // source or destination address, then do nothing further and instruct
- // the caller to do the same.
- var destEps []*endpointsByNic
- switch protocol {
- case header.UDPProtocolNumber:
- if isMulticastOrBroadcast(id.LocalAddress) {
- destEps = d.findAllEndpointsLocked(eps, id)
- break
- }
-
- if ep := d.findEndpointLocked(eps, id); ep != nil {
- destEps = append(destEps, ep)
+ // transport endpoints.
+ if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ eps.mu.RLock()
+ destEPs := eps.findAllEndpointsLocked(id)
+ eps.mu.RUnlock()
+ // Fail if we didn't find at least one matching transport endpoint.
+ if len(destEPs) == 0 {
+ r.Stats().UDP.UnknownPortErrors.Increment()
+ return false
}
-
- case header.TCPProtocolNumber:
- if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) {
- // TCP can only be used to communicate between a single
- // source and a single destination; the addresses must
- // be unicast.
- eps.mu.RUnlock()
- r.Stats().TCP.InvalidSegmentsReceived.Increment()
- return true
+ // handlePacket takes ownership of pkt, so each endpoint needs its own
+ // copy except for the final one.
+ for _, ep := range destEPs[:len(destEPs)-1] {
+ ep.handlePacket(r, id, pkt.Clone())
}
+ destEPs[len(destEPs)-1].handlePacket(r, id, pkt)
+ return true
+ }
- fallthrough
-
- default:
- if ep := d.findEndpointLocked(eps, id); ep != nil {
- destEps = append(destEps, ep)
- }
+ // If the packet is a TCP packet with a non-unicast source or destination
+ // address, then do nothing further and instruct the caller to do the same.
+ if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ // TCP can only be used to communicate between a single source and a
+ // single destination; the addresses must be unicast.
+ r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ return true
}
+ eps.mu.RLock()
+ ep := eps.findEndpointLocked(id)
eps.mu.RUnlock()
-
- // Fail if we didn't find at least one matching transport endpoint.
- if len(destEps) == 0 {
- // UDP packet could not be delivered to an unknown destination port.
+ if ep == nil {
if protocol == header.UDPProtocolNumber {
r.Stats().UDP.UnknownPortErrors.Increment()
}
return false
}
-
- // HandlePacket takes ownership of pkt, so each endpoint needs its own
- // copy except for the final one.
- for _, ep := range destEps[:len(destEps)-1] {
- ep.handlePacket(r, id, pkt.Clone())
- }
- destEps[len(destEps)-1].handlePacket(r, id, pkt)
-
+ ep.handlePacket(r, id, pkt)
return true
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool {
+func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -497,99 +544,53 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
}
- // Try to find the endpoint.
eps.mu.RLock()
- ep := d.findEndpointLocked(eps, id)
+ ep := eps.findEndpointLocked(id)
eps.mu.RUnlock()
-
- // Fail if we didn't find one.
if ep == nil {
return false
}
- // Deliver the packet.
ep.handleControlPacket(n, id, typ, extra, pkt)
-
return true
}
-func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic {
- var matchedEPs []*endpointsByNic
- // Try to find a match with the id as provided.
- if ep, ok := eps.endpoints[id]; ok {
- matchedEPs = append(matchedEPs, ep)
- }
-
- // Try to find a match with the id minus the local address.
- nid := id
-
- nid.LocalAddress = ""
- if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
- }
-
- // Try to find a match with the id minus the remote part.
- nid.LocalAddress = id.LocalAddress
- nid.RemoteAddress = ""
- nid.RemotePort = 0
- if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
- }
-
- // Try to find a match with only the local port.
- nid.LocalAddress = ""
- if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
- }
- return matchedEPs
-}
-
// findTransportEndpoint find a single endpoint that most closely matches the provided id.
func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
return nil
}
- // Try to find the endpoint.
+
eps.mu.RLock()
- epsByNic := d.findEndpointLocked(eps, id)
- // Fail if we didn't find one.
- if epsByNic == nil {
+ epsByNIC := eps.findEndpointLocked(id)
+ if epsByNIC == nil {
eps.mu.RUnlock()
return nil
}
- epsByNic.mu.RLock()
+ epsByNIC.mu.RLock()
eps.mu.RUnlock()
- mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
if !ok {
- if mpep, ok = epsByNic.endpoints[0]; !ok {
- epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ if mpep, ok = epsByNIC.endpoints[0]; !ok {
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return nil
}
}
- ep := selectEndpoint(id, mpep, epsByNic.seed)
- epsByNic.mu.RUnlock()
+ ep := selectEndpoint(id, mpep, epsByNIC.seed)
+ epsByNIC.mu.RUnlock()
return ep
}
-// findEndpointLocked returns the endpoint that most closely matches the given
-// id.
-func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic {
- if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 {
- return matchedEPs[0]
- }
- return nil
-}
-
// registerRawEndpoint registers the given endpoint with the dispatcher such
// that packets of the appropriate protocol are delivered to it. A single
// packet can be sent to one or more raw endpoints along with a non-raw
@@ -601,8 +602,8 @@ func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNum
}
eps.mu.Lock()
- defer eps.mu.Unlock()
eps.rawEndpoints = append(eps.rawEndpoints, ep)
+ eps.mu.Unlock()
return nil
}
@@ -616,13 +617,16 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
}
eps.mu.Lock()
- defer eps.mu.Unlock()
for i, rawEP := range eps.rawEndpoints {
if rawEP == ep {
- eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
- return
+ lastIdx := len(eps.rawEndpoints) - 1
+ eps.rawEndpoints[i] = eps.rawEndpoints[lastIdx]
+ eps.rawEndpoints[lastIdx] = nil
+ eps.rawEndpoints = eps.rawEndpoints[:lastIdx]
+ break
}
}
+ eps.mu.Unlock()
}
func isMulticastOrBroadcast(addr tcpip.Address) bool {
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 5e9237de9..2474a7db3 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -31,84 +31,58 @@ import (
)
const (
- stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackAddr = "\x0a\x00\x00\x01"
- stackPort = 1234
- testPort = 4096
+ testSrcAddrV4 = "\x0a\x00\x00\x01"
+ testDstAddrV4 = "\x0a\x00\x00\x02"
+
+ testDstPort = 1234
+ testSrcPort = 4096
)
type testContext struct {
- t *testing.T
linkEps map[tcpip.NICID]*channel.Endpoint
s *stack.Stack
-
- ep tcpip.Endpoint
- wq waiter.Queue
-}
-
-func (c *testContext) cleanup() {
- if c.ep != nil {
- c.ep.Close()
- }
-}
-
-func (c *testContext) createV6Endpoint(v6only bool) {
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
+ wq waiter.Queue
}
// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
linkEps := make(map[tcpip.NICID]*channel.Endpoint)
for _, linkEpID := range linkEpIDs {
channelEp := channel.New(256, mtu, "")
if err := s.CreateNIC(linkEpID, channelEp); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatalf("CreateNIC failed: %s", err)
}
linkEps[linkEpID] = channelEp
- if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress IPv4 failed: %v", err)
+ if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil {
+ t.Fatalf("AddAddress IPv4 failed: %s", err)
}
- if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress IPv6 failed: %v", err)
+ if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil {
+ t.Fatalf("AddAddress IPv6 failed: %s", err)
}
}
s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- },
+ {Destination: header.IPv4EmptySubnet, NIC: 1},
+ {Destination: header.IPv6EmptySubnet, NIC: 1},
})
return &testContext{
- t: t,
s: s,
linkEps: linkEps,
}
}
type headers struct {
- srcPort uint16
- dstPort uint16
+ srcPort, dstPort uint16
}
func newPayload() []byte {
@@ -119,6 +93,47 @@ func newPayload() []byte {
return b
}
+func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
+ payloadStart := len(buf) - len(payload)
+ copy(buf[payloadStart:], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: 0x80,
+ TotalLength: uint16(len(buf)),
+ TTL: 65,
+ Protocol: uint8(udp.ProtocolNumber),
+ SrcAddr: testSrcAddrV4,
+ DstAddr: testDstAddrV4,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ // Inject packet.
+ c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ NetworkHeader: buffer.View(ip),
+ TransportHeader: buffer.View(u),
+ })
+}
+
func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
@@ -130,8 +145,8 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
- SrcAddr: testV6Addr,
- DstAddr: stackV6Addr,
+ SrcAddr: testSrcAddrV6,
+ DstAddr: testDstAddrV6,
})
// Initialize the UDP header.
@@ -143,15 +158,17 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
})
// Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ NetworkHeader: buffer.View(ip),
+ TransportHeader: buffer.View(u),
})
}
@@ -167,38 +184,48 @@ func TestTransportDemuxerRegister(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want {
- t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tEP, ok := ep.(stack.TransportEndpoint)
+ if !ok {
+ t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
+ }
+ if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want {
+ t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
}
})
}
}
-// TestReuseBindToDevice injects varied packets on input devices and checks that
+// TestBindToDeviceDistribution injects varied packets on input devices and checks that
// the distribution of packets received matches expectations.
-func TestDistribution(t *testing.T) {
+func TestBindToDeviceDistribution(t *testing.T) {
type endpointSockopts struct {
- reuse int
+ reuse bool
bindToDevice tcpip.NICID
}
for _, test := range []struct {
name string
// endpoints will received the inject packets.
endpoints []endpointSockopts
- // wantedDistribution is the wanted ratio of packets received on each
+ // wantDistributions is the want ratio of packets received on each
// endpoint for each NIC on which packets are injected.
- wantedDistributions map[tcpip.NICID][]float64
+ wantDistributions map[tcpip.NICID][]float64
}{
{
"BindPortReuse",
// 5 endpoints that all have reuse set.
[]endpointSockopts{
- {1, 0},
- {1, 0},
- {1, 0},
- {1, 0},
- {1, 0},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed evenly.
@@ -209,9 +236,9 @@ func TestDistribution(t *testing.T) {
"BindToDevice",
// 3 endpoints with various bindings.
[]endpointSockopts{
- {0, 1},
- {0, 2},
- {0, 3},
+ {reuse: false, bindToDevice: 1},
+ {reuse: false, bindToDevice: 2},
+ {reuse: false, bindToDevice: 3},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 go only to the endpoint bound to dev0.
@@ -226,12 +253,12 @@ func TestDistribution(t *testing.T) {
"ReuseAndBindToDevice",
// 6 endpoints with various bindings.
[]endpointSockopts{
- {1, 1},
- {1, 1},
- {1, 2},
- {1, 2},
- {1, 2},
- {1, 0},
+ {reuse: true, bindToDevice: 1},
+ {reuse: true, bindToDevice: 1},
+ {reuse: true, bindToDevice: 2},
+ {reuse: true, bindToDevice: 2},
+ {reuse: true, bindToDevice: 2},
+ {reuse: true, bindToDevice: 0},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed among endpoints bound to
@@ -245,17 +272,17 @@ func TestDistribution(t *testing.T) {
},
},
} {
- t.Run(test.name, func(t *testing.T) {
- for device, wantedDistribution := range test.wantedDistributions {
- t.Run(string(device), func(t *testing.T) {
+ for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{
+ "IPv4": ipv4.ProtocolNumber,
+ "IPv6": ipv6.ProtocolNumber,
+ } {
+ for device, wantDistribution := range test.wantDistributions {
+ t.Run(test.name+protoName+string(device), func(t *testing.T) {
var devices []tcpip.NICID
- for d := range test.wantedDistributions {
+ for d := range test.wantDistributions {
devices = append(devices, d)
}
c := newDualTestContextMultiNIC(t, defaultMTU, devices)
- defer c.cleanup()
-
- c.createV6Endpoint(false)
eps := make(map[tcpip.Endpoint]int)
@@ -269,9 +296,9 @@ func TestDistribution(t *testing.T) {
defer close(ch)
var err *tcpip.Error
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq)
if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
eps[ep] = i
@@ -282,22 +309,31 @@ func TestDistribution(t *testing.T) {
}(ep)
defer ep.Close()
- reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
- if err := ep.SetSockOpt(reusePortOption); err != nil {
- c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
+ if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil {
+ t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err)
}
bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
- c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
+ t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err)
+ }
+
+ var dstAddr tcpip.Address
+ switch netProtoNum {
+ case ipv4.ProtocolNumber:
+ dstAddr = testDstAddrV4
+ case ipv6.ProtocolNumber:
+ dstAddr = testDstAddrV6
+ default:
+ t.Fatalf("unexpected protocol number: %d", netProtoNum)
}
- if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
- t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
+ if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil {
+ t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err)
}
}
npackets := 100000
nports := 10000
- if got, want := len(test.endpoints), len(wantedDistribution); got != want {
+ if got, want := len(test.endpoints), len(wantDistribution); got != want {
t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
}
ports := make(map[uint16]tcpip.Endpoint)
@@ -306,17 +342,22 @@ func TestDistribution(t *testing.T) {
// Send a packet.
port := uint16(i % nports)
payload := newPayload()
- c.sendV6Packet(payload,
- &headers{
- srcPort: testPort + port,
- dstPort: stackPort},
- device)
+ hdrs := &headers{
+ srcPort: testSrcPort + port,
+ dstPort: testDstPort,
+ }
+ switch netProtoNum {
+ case ipv4.ProtocolNumber:
+ c.sendV4Packet(payload, hdrs, device)
+ case ipv6.ProtocolNumber:
+ c.sendV6Packet(payload, hdrs, device)
+ default:
+ t.Fatalf("unexpected protocol number: %d", netProtoNum)
+ }
- var addr tcpip.FullAddress
ep := <-pollChannel
- _, _, err := ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
+ if _, _, err := ep.Read(nil); err != nil {
+ t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err)
}
stats[ep]++
if i < nports {
@@ -332,17 +373,17 @@ func TestDistribution(t *testing.T) {
// Check that a packet distribution is as expected.
for ep, i := range eps {
- wantedRatio := wantedDistribution[i]
- wantedRecv := wantedRatio * float64(npackets)
+ wantRatio := wantDistribution[i]
+ wantRecv := wantRatio * float64(npackets)
actualRecv := stats[ep]
actualRatio := float64(stats[ep]) / float64(npackets)
// The deviation is less than 10%.
- if math.Abs(actualRatio-wantedRatio) > 0.05 {
- t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
+ if math.Abs(actualRatio-wantRatio) > 0.05 {
+ t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets)
}
}
})
}
- })
+ }
}
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 5d1da2f8b..3084e6593 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -57,6 +56,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
return nil
}
+func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
+
func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
@@ -87,7 +88,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
if err != nil {
return 0, nil, err
}
- if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: buffer.View(v).ToVectorisedView(),
}); err != nil {
@@ -214,7 +215,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) {
// Increment the number of received packets.
f.proto.packetCount++
if f.acceptQueue != nil {
@@ -231,7 +232,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE
}
}
-func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) {
// Increment the number of received control packets.
f.proto.controlCount++
}
@@ -242,8 +243,8 @@ func (f *fakeTransportEndpoint) State() uint32 {
func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
-func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) {
- return iptables.IPTables{}, nil
+func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) {
+ return stack.IPTables{}, nil
}
func (f *fakeTransportEndpoint) Resume(*stack.Stack) {}
@@ -288,7 +289,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool {
return true
}
@@ -368,7 +369,7 @@ func TestTransportReceive(t *testing.T) {
// Make sure packet with wrong protocol is not delivered.
buf[0] = 1
buf[2] = 0
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 0 {
@@ -379,7 +380,7 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 3
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 0 {
@@ -390,7 +391,7 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 2
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 1 {
@@ -445,7 +446,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 0
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = 0
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 0 {
@@ -456,7 +457,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 3
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 0 {
@@ -467,7 +468,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 2
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 1 {
@@ -622,7 +623,7 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: req.ToVectorisedView(),
})
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 3dc5d87d6..1ca4088c9 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -336,6 +336,15 @@ type ControlMessages struct {
PacketInfo IPPacketInfo
}
+// PacketOwner is used to get UID and GID of the packet.
+type PacketOwner interface {
+ // UID returns UID of the packet.
+ UID() uint32
+
+ // GID returns GID of the packet.
+ GID() uint32
+}
+
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
// that exposes functionality like read, write, connect, etc. to users of the
// networking stack.
@@ -470,6 +479,9 @@ type Endpoint interface {
// Stats returns a reference to the endpoint stats.
Stats() EndpointStats
+
+ // SetOwner sets the task owner to the endpoint owner.
+ SetOwner(owner PacketOwner)
}
// EndpointInfo is the interface implemented by each endpoint info struct.
@@ -508,34 +520,90 @@ type WriteOptions struct {
type SockOptBool int
const (
+ // BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether
+ // datagram sockets are allowed to send packets to a broadcast address.
+ BroadcastOption SockOptBool = iota
+
+ // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be
+ // held until segments are full by the TCP transport protocol.
+ CorkOption
+
+ // DelayOption is used by SetSockOpt/GetSockOpt to specify if data
+ // should be sent out immediately by the transport protocol. For TCP,
+ // it determines if the Nagle algorithm is on or off.
+ DelayOption
+
+ // KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether
+ // TCP keepalive is enabled for this socket.
+ KeepaliveEnabledOption
+
+ // MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether
+ // multicast packets sent over a non-loopback interface will be looped back.
+ MulticastLoopOption
+
+ // PasscredOption is used by SetSockOpt/GetSockOpt to specify whether
+ // SCM_CREDENTIALS socket control messages are enabled.
+ //
+ // Only supported on Unix sockets.
+ PasscredOption
+
+ // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
+ QuickAckOption
+
// ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the
// IPV6_TCLASS ancillary message is passed with incoming packets.
- ReceiveTClassOption SockOptBool = iota
+ ReceiveTClassOption
// ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS
// ancillary message is passed with incoming packets.
ReceiveTOSOption
- // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
- // socket is to be restricted to sending and receiving IPv6 packets only.
- V6OnlyOption
-
// ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify
// if more inforamtion is provided with incoming packets such
// as interface index and address.
ReceiveIPPacketInfoOption
- // TODO(b/146901447): convert existing bool socket options to be handled via
- // Get/SetSockOptBool
+ // ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind()
+ // should allow reuse of local address.
+ ReuseAddressOption
+
+ // ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets
+ // to be bound to an identical socket address.
+ ReusePortOption
+
+ // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
+ // socket is to be restricted to sending and receiving IPv6 packets only.
+ V6OnlyOption
)
// SockOptInt represents socket options which values have the int type.
type SockOptInt int
const (
+ // KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number
+ // of un-ACKed TCP keepalives that will be sent before the connection is
+ // closed.
+ KeepaliveCountOption SockOptInt = iota
+
+ // IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS
+ // for all subsequent outgoing IPv4 packets from the endpoint.
+ IPv4TOSOption
+
+ // IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS
+ // for all subsequent outgoing IPv6 packets from the endpoint.
+ IPv6TrafficClassOption
+
+ // MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current
+ // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option.
+ MaxSegOption
+
+ // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
+ // TTL value for multicast messages. The default is 1.
+ MulticastTTLOption
+
// ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
// number of unread bytes in the input buffer should be returned.
- ReceiveQueueSizeOption SockOptInt = iota
+ ReceiveQueueSizeOption
// SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
// specify the send buffer size option.
@@ -549,44 +617,21 @@ const (
// number of unread bytes in the output buffer should be returned.
SendQueueSizeOption
- // DelayOption is used by SetSockOpt/GetSockOpt to specify if data
- // should be sent out immediately by the transport protocol. For TCP,
- // it determines if the Nagle algorithm is on or off.
- DelayOption
-
- // TODO(b/137664753): convert all int socket options to be handled via
- // GetSockOptInt.
+ // TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop
+ // limit value for unicast messages. The default is protocol specific.
+ //
+ // A zero value indicates the default.
+ TTLOption
)
// ErrorOption is used in GetSockOpt to specify that the last error reported by
// the endpoint should be cleared and returned.
type ErrorOption struct{}
-// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be
-// held until segments are full by the TCP transport protocol.
-type CorkOption int
-
-// ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind()
-// should allow reuse of local address.
-type ReuseAddressOption int
-
-// ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets
-// to be bound to an identical socket address.
-type ReusePortOption int
-
// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
// should bind only on a specific NIC.
type BindToDeviceOption NICID
-// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
-type QuickAckOption int
-
-// PasscredOption is used by SetSockOpt/GetSockOpt to specify whether
-// SCM_CREDENTIALS socket control messages are enabled.
-//
-// Only supported on Unix sockets.
-type PasscredOption int
-
// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
//
// TODO(b/64800844): Add and populate stat fields.
@@ -595,10 +640,6 @@ type TCPInfoOption struct {
RTTVar time.Duration
}
-// KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether
-// TCP keepalive is enabled for this socket.
-type KeepaliveEnabledOption int
-
// KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a
// connection must remain idle before the first TCP keepalive packet is sent.
// Once this time is reached, KeepaliveIntervalOption is used instead.
@@ -608,11 +649,6 @@ type KeepaliveIdleOption time.Duration
// interval between sending TCP keepalive packets.
type KeepaliveIntervalOption time.Duration
-// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number
-// of un-ACKed TCP keepalives that will be sent before the connection is
-// closed.
-type KeepaliveCountOption int
-
// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user
// specified timeout for a given TCP connection.
// See: RFC5482 for details.
@@ -626,20 +662,9 @@ type CongestionControlOption string
// control algorithms.
type AvailableCongestionControlOption string
-// ModerateReceiveBufferOption allows the caller to enable/disable TCP receive
// buffer moderation.
type ModerateReceiveBufferOption bool
-// MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current
-// Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option.
-type MaxSegOption int
-
-// TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop
-// limit value for unicast messages. The default is protocol specific.
-//
-// A zero value indicates the default.
-type TTLOption uint8
-
// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state
// before being marked closed.
@@ -656,9 +681,14 @@ type TCPTimeWaitTimeoutOption time.Duration
// for a handshake till the specified timeout until a segment with data arrives.
type TCPDeferAcceptOption time.Duration
-// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
-// TTL value for multicast messages. The default is 1.
-type MulticastTTLOption uint8
+// TCPMinRTOOption is use by SetSockOpt/GetSockOpt to allow overriding
+// default MinRTO used by the Stack.
+type TCPMinRTOOption time.Duration
+
+// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify
+// the number of endpoints that can be in SYN-RCVD state before the stack
+// switches to using SYN cookies.
+type TCPSynRcvdCountThresholdOption uint64
// MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a
// default interface for multicast.
@@ -667,10 +697,6 @@ type MulticastInterfaceOption struct {
InterfaceAddr Address
}
-// MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether
-// multicast packets sent over a non-loopback interface will be looped back.
-type MulticastLoopOption bool
-
// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to
// AddMembershipOption and RemoveMembershipOption.
type MembershipOption struct {
@@ -693,22 +719,10 @@ type RemoveMembershipOption MembershipOption
// TCP out-of-band data is delivered along with the normal in-band data.
type OutOfBandInlineOption int
-// BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether
-// datagram sockets are allowed to send packets to a broadcast address.
-type BroadcastOption int
-
// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify
// a default TTL.
type DefaultTTLOption uint8
-// IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS
-// for all subsequent outgoing IPv4 packets from the endpoint.
-type IPv4TOSOption uint8
-
-// IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS
-// for all subsequent outgoing IPv6 packets from the endpoint.
-type IPv6TrafficClassOption uint8
-
// IPPacketInfo is the message struture for IP_PKTINFO.
//
// +stateify savable
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index 8c0aacffa..1c8e2bc34 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -218,7 +218,7 @@ func TestAddressWithPrefixSubnet(t *testing.T) {
gotSubnet := ap.Subnet()
wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask)
if err != nil {
- t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err)
+ t.Errorf("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err)
continue
}
if gotSubnet != wantSubnet {
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index ac18ec5b1..9ce625c17 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -31,7 +31,6 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 426da1ee6..feef8dca0 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -74,6 +73,9 @@ type endpoint struct {
route stack.Route `state:"manual"`
ttl uint8
stats tcpip.TransportEndpointStats `state:"nosave"`
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
@@ -134,8 +136,12 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
+}
+
// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (iptables.IPTables, error) {
+func (e *endpoint) IPTables() (stack.IPTables, error) {
return e.stack.IPTables(), nil
}
@@ -291,15 +297,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
nicID = e.BindNICID
}
- toCopy := *to
- to = &toCopy
- netProto, err := e.checkV4Mapped(to)
+ dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
}
- // Find the enpoint.
- r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */)
+ // Find the endpoint.
+ r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */)
if err != nil {
return 0, nil, err
}
@@ -324,7 +328,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
switch e.NetProto {
case header.IPv4ProtocolNumber:
- err = send4(route, e.ID.LocalPort, v, e.ttl)
+ err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner)
case header.IPv6ProtocolNumber:
err = send6(route, e.ID.LocalPort, v, e.ttl)
@@ -344,13 +348,6 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(o)
- e.mu.Unlock()
- }
-
return nil
}
@@ -361,12 +358,25 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt sets a socket option. Currently not supported.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+
+ }
return nil
}
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -393,32 +403,29 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
+ case tcpip.TTLOption:
+ e.rcvMu.Lock()
+ v := int(e.ttl)
+ e.rcvMu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
}
- return -1, tcpip.ErrUnknownProtocolOption
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
+ switch opt.(type) {
case tcpip.ErrorOption:
return nil
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
- case *tcpip.TTLOption:
- e.rcvMu.Lock()
- *o = tcpip.TTLOption(e.ttl)
- e.rcvMu.Unlock()
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
}
-func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
+func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error {
if len(data) < header.ICMPv4MinimumSize {
return tcpip.ErrInvalidEndpointState
}
@@ -443,10 +450,11 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: data.ToVectorisedView(),
TransportHeader: buffer.View(icmpv4),
+ Owner: owner,
})
}
@@ -473,20 +481,21 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: dataVV,
TransportHeader: buffer.View(icmpv6),
})
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, false /* v6only */)
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */)
if err != nil {
- return 0, err
+ return tcpip.FullAddress{}, 0, err
}
- *addr = unwrapped
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -517,7 +526,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -630,7 +639,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -734,7 +743,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
@@ -796,7 +805,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
}
// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 113d92901..3c47692b2 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool {
return true
}
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
index d22de6b26..b989b1209 100644
--- a/pkg/tcpip/transport/packet/BUILD
+++ b/pkg/tcpip/transport/packet/BUILD
@@ -31,7 +31,6 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 5722815e9..23158173d 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -29,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -76,6 +75,7 @@ type endpoint struct {
sndBufSize int
closed bool
stats tcpip.TransportEndpointStats `state:"nosave"`
+ bound bool
}
// NewEndpoint returns a new packet endpoint.
@@ -99,8 +99,8 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
}
// Abort implements stack.TransportEndpoint.Abort.
-func (e *endpoint) Abort() {
- e.Close()
+func (ep *endpoint) Abort() {
+ ep.Close()
}
// Close implements tcpip.Endpoint.Close.
@@ -125,6 +125,7 @@ func (ep *endpoint) Close() {
}
ep.closed = true
+ ep.bound = false
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -132,7 +133,7 @@ func (ep *endpoint) Close() {
func (ep *endpoint) ModerateRecvBuf(copied int) {}
// IPTables implements tcpip.Endpoint.IPTables.
-func (ep *endpoint) IPTables() (iptables.IPTables, error) {
+func (ep *endpoint) IPTables() (stack.IPTables, error) {
return ep.stack.IPTables(), nil
}
@@ -216,7 +217,24 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
// sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex."
// - packet(7).
- return tcpip.ErrNotSupported
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.bound {
+ return tcpip.ErrAlreadyBound
+ }
+
+ // Unregister endpoint with all the nics.
+ ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+
+ // Bind endpoint to receive packets from specific interface.
+ if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
+ return err
+ }
+
+ ep.bound = true
+
+ return nil
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
@@ -280,7 +298,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
ep.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -374,3 +392,5 @@ func (ep *endpoint) Info() tcpip.EndpointInfo {
func (ep *endpoint) Stats() tcpip.EndpointStats {
return &ep.stats
}
+
+func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index c9baf4600..2eab09088 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -32,7 +32,6 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/packet",
"//pkg/waiter",
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 2ef5fac76..eee754a5a 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -30,7 +30,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -81,6 +80,9 @@ type endpoint struct {
// Connect(), and is valid only when conneted is true.
route stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
}
// NewEndpoint returns a raw endpoint for the given protocols.
@@ -160,8 +162,12 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
+}
+
// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (iptables.IPTables, error) {
+func (e *endpoint) IPTables() (stack.IPTables, error) {
return e.stack.IPTables(), nil
}
@@ -342,17 +348,19 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
switch e.NetProto {
case header.IPv4ProtocolNumber:
if !e.associated {
- if err := route.WriteHeaderIncludedPacket(tcpip.PacketBuffer{
+ if err := route.WriteHeaderIncludedPacket(stack.PacketBuffer{
Data: buffer.View(payloadBytes).ToVectorisedView(),
}); err != nil {
return 0, nil, err
}
break
}
+
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
- if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: buffer.View(payloadBytes).ToVectorisedView(),
+ Owner: e.owner,
}); err != nil {
return 0, nil, err
}
@@ -525,14 +533,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
+ switch opt.(type) {
case tcpip.ErrorOption:
return nil
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -540,7 +544,13 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -568,13 +578,13 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
}
-
- return -1, tcpip.ErrUnknownProtocolOption
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
-func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) {
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 272e8f570..edb7718a6 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -32,6 +32,7 @@ go_library(
srcs = [
"accept.go",
"connect.go",
+ "connect_unsafe.go",
"cubic.go",
"cubic_state.go",
"dispatcher.go",
@@ -65,12 +66,10 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
- "//pkg/tmutex",
"//pkg/waiter",
"@com_github_google_btree//:go_default_library",
],
@@ -88,7 +87,9 @@ go_test(
"tcp_timestamp_test.go",
],
# FIXME(b/68809571)
- tags = ["flaky"],
+ tags = [
+ "flaky",
+ ],
deps = [
":tcp",
"//pkg/sync",
@@ -105,5 +106,6 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/tcp/testing/context",
"//pkg/waiter",
+ "//runsc/testutil",
],
)
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 13e383ffc..5bb243e3b 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -17,6 +17,7 @@ package tcp
import (
"crypto/sha1"
"encoding/binary"
+ "fmt"
"hash"
"io"
"time"
@@ -25,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -49,17 +49,14 @@ const (
// timestamp and the current timestamp. If the difference is greater
// than maxTSDiff, the cookie is expired.
maxTSDiff = 2
-)
-var (
- // SynRcvdCountThreshold is the global maximum number of connections
- // that are allowed to be in SYN-RCVD state before TCP starts using SYN
- // cookies to accept connections.
- //
- // It is an exported variable only for testing, and should not otherwise
- // be used by importers of this package.
+ // SynRcvdCountThreshold is the default global maximum number of
+ // connections that are allowed to be in SYN-RCVD state before TCP
+ // starts using SYN cookies to accept connections.
SynRcvdCountThreshold uint64 = 1000
+)
+var (
// mssTable is a slice containing the possible MSS values that we
// encode in the SYN cookie with two bits.
mssTable = []uint16{536, 1300, 1440, 1460}
@@ -74,29 +71,42 @@ func encodeMSS(mss uint16) uint32 {
return 0
}
-// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is
-// protected by a mutex so that we can increment only when it's guaranteed not
-// to go above a threshold.
-var synRcvdCount struct {
- sync.Mutex
- value uint64
- pending sync.WaitGroup
-}
-
// listenContext is used by a listening endpoint to store state used while
// listening for connections. This struct is allocated by the listen goroutine
// and must not be accessed or have its methods called concurrently as they
// may mutate the stored objects.
type listenContext struct {
- stack *stack.Stack
- rcvWnd seqnum.Size
- nonce [2][sha1.BlockSize]byte
+ stack *stack.Stack
+
+ // synRcvdCount is a reference to the stack level synRcvdCount.
+ synRcvdCount *synRcvdCounter
+
+ // rcvWnd is the receive window that is sent by this listening context
+ // in the initial SYN-ACK.
+ rcvWnd seqnum.Size
+
+ // nonce are random bytes that are initialized once when the context
+ // is created and used to seed the hash function when generating
+ // the SYN cookie.
+ nonce [2][sha1.BlockSize]byte
+
+ // listenEP is a reference to the listening endpoint associated with
+ // this context. Can be nil if the context is created by the forwarder.
listenEP *endpoint
+ // hasherMu protects hasher.
hasherMu sync.Mutex
- hasher hash.Hash
- v6only bool
+ // hasher is the hash function used to generate a SYN cookie.
+ hasher hash.Hash
+
+ // v6Only is true if listenEP is a dual stack socket and has the
+ // IPV6_V6ONLY option set.
+ v6only bool
+
+ // netProto indicates the network protocol(IPv4/v6) for the listening
+ // endpoint.
netProto tcpip.NetworkProtocolNumber
+
// pendingMu protects pendingEndpoints. This should only be accessed
// by the listening endpoint's worker goroutine.
//
@@ -115,44 +125,6 @@ func timeStamp() uint32 {
return uint32(time.Now().Unix()>>6) & tsMask
}
-// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD
-// state. It succeeds if the increment doesn't make the count go beyond the
-// threshold, and fails otherwise.
-func incSynRcvdCount() bool {
- synRcvdCount.Lock()
-
- if synRcvdCount.value >= SynRcvdCountThreshold {
- synRcvdCount.Unlock()
- return false
- }
-
- synRcvdCount.pending.Add(1)
- synRcvdCount.value++
-
- synRcvdCount.Unlock()
- return true
-}
-
-// decSynRcvdCount atomically decrements the global number of endpoints in
-// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount
-// succeeded.
-func decSynRcvdCount() {
- synRcvdCount.Lock()
-
- synRcvdCount.value--
- synRcvdCount.pending.Done()
- synRcvdCount.Unlock()
-}
-
-// synCookiesInUse() returns true if the synRcvdCount is greater than
-// SynRcvdCountThreshold.
-func synCookiesInUse() bool {
- synRcvdCount.Lock()
- v := synRcvdCount.value
- synRcvdCount.Unlock()
- return v >= SynRcvdCountThreshold
-}
-
// newListenContext creates a new listen context.
func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
@@ -164,6 +136,11 @@ func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size,
listenEP: listenEP,
pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
}
+ p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol)
+ if !ok {
+ panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk))
+ }
+ l.synRcvdCount = p.SynRcvdCounter()
rand.Read(l.nonce[0][:])
rand.Read(l.nonce[1][:])
@@ -221,7 +198,8 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
}
// createConnectingEndpoint creates a new endpoint in a connecting state, with
-// the connection parameters given by the arguments.
+// the connection parameters given by the arguments. The endpoint is returned
+// with n.mu held.
func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
@@ -236,27 +214,13 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
n.rcvBufSize = int(l.rcvWnd)
n.amss = mssForRoute(&n.route)
+ n.setEndpointState(StateConnecting)
n.maybeEnableTimestamp(rcvdSynOpts)
n.maybeEnableSACKPermitted(rcvdSynOpts)
n.initGSO()
- // Now inherit any socket options that should be inherited from the
- // listening endpoint.
- // In case of Forwarder listenEP will be nil and hence this check.
- if l.listenEP != nil {
- l.listenEP.propagateInheritableOptions(n)
- }
-
- // Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil {
- n.Close()
- return nil, err
- }
-
- n.isRegistered = true
-
// Create sender and receiver.
//
// The receiver at least temporarily has a zero receive window scale,
@@ -268,12 +232,28 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// window to grow to a really large value.
n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
+ // Lock the endpoint before registering to ensure that no out of
+ // band changes are possible due to incoming packets etc till
+ // the endpoint is done initializing.
+ n.mu.Lock()
+
+ // Register new endpoint so that packets are routed to it.
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil {
+ n.mu.Unlock()
+ n.Close()
+ return nil, err
+ }
+
+ n.isRegistered = true
+
return n, nil
}
// createEndpointAndPerformHandshake creates a new endpoint in connected state
// and then performs the TCP 3-way handshake.
-func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
+//
+// The new endpoint is returned with e.mu held.
+func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
@@ -281,6 +261,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if err != nil {
return nil, err
}
+ ep.owner = owner
// listenEP is nil when listenContext is used by tcp.Forwarder.
deferAccept := time.Duration(0)
@@ -288,9 +269,25 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
l.listenEP.mu.Unlock()
+ // Ensure we release any registrations done by the newly
+ // created endpoint.
+ ep.mu.Unlock()
+ ep.Close()
+
+ // Wake up any waiters. This is strictly not required normally
+ // as a socket that was never accepted can't really have any
+ // registered waiters except when stack.Wait() is called which
+ // waits for all registered endpoints to stop and expects an
+ // EventHUp.
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
return nil, tcpip.ErrConnectionAborted
}
l.addPendingEndpoint(ep)
+
+ // Propagate any inheritable options from the listening endpoint
+ // to the newly created endpoint.
+ l.listenEP.propagateInheritableOptionsLocked(ep)
+
deferAccept = l.listenEP.deferAccept
l.listenEP.mu.Unlock()
}
@@ -298,6 +295,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Perform the 3-way handshake.
h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
+ ep.mu.Unlock()
ep.Close()
// Wake up any waiters. This is strictly not required normally
// as a socket that was never accepted can't really have any
@@ -309,11 +307,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if l.listenEP != nil {
l.removePendingEndpoint(ep)
}
+
+ ep.drainClosingSegmentQueue()
+
return nil, err
}
- ep.mu.Lock()
ep.isConnectNotified = true
- ep.mu.Unlock()
// Update the receive window scaling. We can't do it before the
// handshake because it's possible that the peer doesn't support window
@@ -347,30 +346,38 @@ func (l *listenContext) closeAllPendingEndpoints() {
}
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
-// endpoint has transitioned out of the listen state, the new endpoint is closed
-// instead.
+// endpoint has transitioned out of the listen state (acceptedChan is nil),
+// the new endpoint is closed instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
e.mu.Lock()
- state := e.EndpointState()
e.pendingAccepted.Add(1)
- defer e.pendingAccepted.Done()
- acceptedChan := e.acceptedChan
e.mu.Unlock()
+ defer e.pendingAccepted.Done()
- if state == StateListen {
- acceptedChan <- n
- e.waiterQueue.Notify(waiter.EventIn)
- } else {
- n.Close()
+ e.acceptMu.Lock()
+ for {
+ if e.acceptedChan == nil {
+ e.acceptMu.Unlock()
+ n.notifyProtocolGoroutine(notifyReset)
+ return
+ }
+ select {
+ case e.acceptedChan <- n:
+ e.acceptMu.Unlock()
+ e.waiterQueue.Notify(waiter.EventIn)
+ return
+ default:
+ e.acceptCond.Wait()
+ }
}
}
-// propagateInheritableOptions propagates any options set on the listening
+// propagateInheritableOptionsLocked propagates any options set on the listening
// endpoint to the newly created endpoint.
-func (e *endpoint) propagateInheritableOptions(n *endpoint) {
- e.mu.Lock()
+//
+// Precondition: e.mu and n.mu must be held.
+func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
n.userTimeout = e.userTimeout
- e.mu.Unlock()
}
// handleSynSegment is called in its own goroutine once the listening endpoint
@@ -380,11 +387,15 @@ func (e *endpoint) propagateInheritableOptions(n *endpoint) {
// A limited number of these goroutines are allowed before TCP starts using SYN
// cookies to accept connections.
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
- defer decSynRcvdCount()
- defer e.decSynRcvdCount()
+ defer ctx.synRcvdCount.dec()
+ defer func() {
+ e.mu.Lock()
+ e.decSynRcvdCount()
+ e.mu.Unlock()
+ }()
defer s.decRef()
- n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{})
+ n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
@@ -398,40 +409,39 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
}
func (e *endpoint) incSynRcvdCount() bool {
- e.mu.Lock()
- if e.synRcvdCount >= cap(e.acceptedChan) {
- e.mu.Unlock()
- return false
+ e.acceptMu.Lock()
+ canInc := e.synRcvdCount < cap(e.acceptedChan)
+ e.acceptMu.Unlock()
+ if canInc {
+ e.synRcvdCount++
}
- e.synRcvdCount++
- e.mu.Unlock()
- return true
+ return canInc
}
func (e *endpoint) decSynRcvdCount() {
- e.mu.Lock()
e.synRcvdCount--
- e.mu.Unlock()
}
func (e *endpoint) acceptQueueIsFull() bool {
- e.mu.Lock()
- if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c {
- e.mu.Unlock()
- return true
- }
- e.mu.Unlock()
- return false
+ e.acceptMu.Lock()
+ full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
+ e.acceptMu.Unlock()
+ return full
}
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
- if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) {
+ e.rcvListMu.Lock()
+ rcvClosed := e.rcvClosed
+ e.rcvListMu.Unlock()
+ if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) {
+ // If the endpoint is shutdown, reply with reset.
+ //
// RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
// must be sent in response to a SYN-ACK while in the listen
// state to prevent completing a handshake from an old SYN.
- e.sendTCP(&s.route, s.id, buffer.VectorisedView{}, e.ttl, e.sendTOS, header.TCPFlagRst, s.ackNumber, 0, 0, nil, nil)
+ replyWithReset(s, e.sendTOS, e.ttl)
return
}
@@ -441,7 +451,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
switch {
case s.flags == header.TCPFlagSyn:
opts := parseSynSegmentOptions(s)
- if incSynRcvdCount() {
+ if ctx.synRcvdCount.inc() {
// Only handle the syn if the following conditions hold
// - accept queue is not full.
// - number of connections in synRcvd state is less than the
@@ -451,7 +461,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
return
}
- decSynRcvdCount()
+ ctx.synRcvdCount.dec()
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
@@ -479,7 +489,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
TSEcr: opts.TSVal,
MSS: mssForRoute(&s.route),
}
- e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
+ e.sendSynTCP(&s.route, tcpFields{
+ id: s.id,
+ ttl: e.ttl,
+ tos: e.sendTOS,
+ flags: header.TCPFlagSyn | header.TCPFlagAck,
+ seq: cookie,
+ ack: s.sequenceNumber + 1,
+ rcvWnd: ctx.rcvWnd,
+ }, synOpts)
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
}
@@ -496,7 +514,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
- if !synCookiesInUse() {
+ if !ctx.synRcvdCount.synCookiesInUse() {
// When not using SYN cookies, as per RFC 793, section 3.9, page 64:
// Any acknowledgment is bad if it arrives on a connection still in
// the LISTEN state. An acceptable reset segment should be formed
@@ -512,7 +530,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// The only time we should reach here when a connection
// was opened and closed really quickly and a delayed
// ACK was received from the sender.
- replyWithReset(s)
+ replyWithReset(s, e.sendTOS, e.ttl)
return
}
@@ -558,6 +576,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
+ // Propagate any inheritable options from the listening endpoint
+ // to the newly created endpoint.
+ e.propagateInheritableOptionsLocked(n)
+
// clear the tsOffset for the newly created
// endpoint as the Timestamp was already
// randomly offset when the original SYN-ACK was
@@ -592,14 +614,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
v6only := e.v6only
- e.mu.Unlock()
ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto)
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
// handleSynSegment() from attempting to queue new connections
// to the endpoint.
- e.mu.Lock()
e.setEndpointState(StateClose)
// close any endpoints in SYN-RCVD state.
@@ -613,6 +633,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
}
e.mu.Unlock()
+ e.drainClosingSegmentQueue()
+
// Notify waiters that the endpoint is shutdown.
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
@@ -621,7 +643,10 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
for {
- switch index, _ := s.Fetch(true); index {
+ e.mu.Unlock()
+ index, _ := s.Fetch(true)
+ e.mu.Lock()
+ switch index {
case wakerForNotification:
n := e.fetchNotifications()
if n&notifyClose != 0 {
@@ -634,7 +659,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
s.decRef()
}
close(e.drainDone)
+ e.mu.Unlock()
<-e.undrain
+ e.mu.Lock()
}
case wakerForNewSegment:
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 7730e6445..368865911 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -61,6 +61,9 @@ const (
)
// handshake holds the state used during a TCP 3-way handshake.
+//
+// NOTE: handshake.ep.mu is held during handshake processing. It is released if
+// we are going to block and reacquired when we start processing an event.
type handshake struct {
ep *endpoint
state handshakeState
@@ -209,9 +212,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
h.mss = opts.MSS
h.sndWndScale = opts.WS
h.deferAccept = deferAccept
- h.ep.mu.Lock()
h.ep.setEndpointState(StateSynRecv)
- h.ep.mu.Unlock()
}
// checkAck checks if the ACK number, if present, of a segment received during
@@ -241,9 +242,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// RFC 793, page 67, states that "If the RST bit is set [and] If the ACK
// was acceptable then signal the user "error: connection reset", drop
// the segment, enter CLOSED state, delete TCB, and return."
- h.ep.mu.Lock()
h.ep.workerCleanup = true
- h.ep.mu.Unlock()
// Although the RFC above calls out ECONNRESET, Linux actually returns
// ECONNREFUSED here so we do as well.
return tcpip.ErrConnectionRefused
@@ -281,9 +280,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
if s.flagIsSet(header.TCPFlagAck) {
h.state = handshakeCompleted
- h.ep.mu.Lock()
h.ep.transitionToStateEstablishedLocked(h)
- h.ep.mu.Unlock()
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
return nil
@@ -293,10 +290,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// but resend our own SYN and wait for it to be acknowledged in the
// SYN-RCVD state.
h.state = handshakeSynRcvd
- h.ep.mu.Lock()
ttl := h.ep.ttl
+ amss := h.ep.amss
h.ep.setEndpointState(StateSynRecv)
- h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
TS: rcvSynOpts.TS,
@@ -307,12 +303,20 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// permits SACK. This is not explicitly defined in the RFC but
// this is the behaviour implemented by Linux.
SACKPermitted: rcvSynOpts.SACKPermitted,
- MSS: h.ep.amss,
+ MSS: amss,
}
if ttl == 0 {
ttl = s.route.DefaultTTL()
}
- h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&s.route, tcpFields{
+ id: h.ep.ID,
+ ttl: ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
return nil
}
@@ -365,7 +369,15 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&s.route, tcpFields{
+ id: h.ep.ID,
+ ttl: h.ep.ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
return nil
}
@@ -394,15 +406,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
}
h.state = handshakeCompleted
- h.ep.mu.Lock()
h.ep.transitionToStateEstablishedLocked(h)
+
// If the segment has data then requeue it for the receiver
// to process it again once main loop is started.
if s.data.Size() > 0 {
s.incRef()
h.ep.enqueueSegment(s)
}
- h.ep.mu.Unlock()
return nil
}
@@ -488,7 +499,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
}
if n&notifyDrain != 0 {
close(h.ep.drainDone)
+ h.ep.mu.Unlock()
<-h.ep.undrain
+ h.ep.mu.Lock()
}
}
@@ -553,10 +566,23 @@ func (h *handshake) execute() *tcpip.Error {
synOpts.WS = -1
}
}
- h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ id: h.ep.ID,
+ ttl: h.ep.ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
for h.state != handshakeCompleted {
- switch index, _ := s.Fetch(true); index {
+ h.ep.mu.Unlock()
+ index, _ := s.Fetch(true)
+ h.ep.mu.Lock()
+ switch index {
+
case wakerForResend:
timeOut *= 2
if timeOut > MaxRTO {
@@ -572,12 +598,20 @@ func (h *handshake) execute() *tcpip.Error {
// the connection with another ACK or data (as ACKs are never
// retransmitted on their own).
if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
- h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ id: h.ep.ID,
+ ttl: h.ep.ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
}
case wakerForNotification:
n := h.ep.fetchNotifications()
- if n&notifyClose != 0 {
+ if (n&notifyClose)|(n&notifyAbort) != 0 {
return tcpip.ErrAborted
}
if n&notifyDrain != 0 {
@@ -593,7 +627,9 @@ func (h *handshake) execute() *tcpip.Error {
}
}
close(h.ep.drainDone)
+ h.ep.mu.Unlock()
<-h.ep.undrain
+ h.ep.mu.Lock()
}
case wakerForNewSegment:
@@ -617,17 +653,17 @@ func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
var optionPool = sync.Pool{
New: func() interface{} {
- return make([]byte, maxOptionSize)
+ return &[maxOptionSize]byte{}
},
}
func getOptions() []byte {
- return optionPool.Get().([]byte)
+ return (*optionPool.Get().(*[maxOptionSize]byte))[:]
}
func putOptions(options []byte) {
// Reslice to full capacity.
- optionPool.Put(options[0:cap(options)])
+ optionPool.Put(optionsToArray(options))
}
func makeSynOptions(opts header.TCPSynOptions) []byte {
@@ -683,18 +719,33 @@ func makeSynOptions(opts header.TCPSynOptions) []byte {
return options[:offset]
}
-func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
- options := makeSynOptions(opts)
+// tcpFields is a struct to carry different parameters required by the
+// send*TCP variant functions below.
+type tcpFields struct {
+ id stack.TransportEndpointID
+ ttl uint8
+ tos uint8
+ flags byte
+ seq seqnum.Value
+ ack seqnum.Value
+ rcvWnd seqnum.Size
+ opts []byte
+ txHash uint32
+}
+
+func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) *tcpip.Error {
+ tf.opts = makeSynOptions(opts)
// We ignore SYN send errors and let the callers re-attempt send.
- if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil {
+ if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil {
e.stats.SendErrors.SynSendToNetworkFailed.Increment()
}
- putOptions(options)
+ putOptions(tf.opts)
return nil
}
-func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
- if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil {
+func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error {
+ tf.txHash = e.txHash
+ if err := sendTCP(r, tf, data, gso, e.owner); err != nil {
e.stats.SendErrors.SegmentSendToNetworkFailed.Increment()
return err
}
@@ -702,24 +753,23 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu
return nil
}
-func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *tcpip.PacketBuffer, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) {
- optLen := len(opts)
+func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) {
+ optLen := len(tf.opts)
hdr := &pkt.Header
- packetSize := pkt.DataSize
- off := pkt.DataOffset
+ packetSize := pkt.Data.Size()
// Initialize the header.
tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
pkt.TransportHeader = buffer.View(tcp)
tcp.Encode(&header.TCPFields{
- SrcPort: id.LocalPort,
- DstPort: id.RemotePort,
- SeqNum: uint32(seq),
- AckNum: uint32(ack),
+ SrcPort: tf.id.LocalPort,
+ DstPort: tf.id.RemotePort,
+ SeqNum: uint32(tf.seq),
+ AckNum: uint32(tf.ack),
DataOffset: uint8(header.TCPMinimumSize + optLen),
- Flags: flags,
- WindowSize: uint16(rcvWnd),
+ Flags: tf.flags,
+ WindowSize: uint16(tf.rcvWnd),
})
- copy(tcp[header.TCPMinimumSize:], opts)
+ copy(tcp[header.TCPMinimumSize:], tf.opts)
length := uint16(hdr.UsedLength() + packetSize)
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
@@ -731,48 +781,49 @@ func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *tcpip.Packet
// header and data and get the right sum of the TCP packet.
tcp.SetChecksum(xsum)
} else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
- xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize)
+ xsum = header.ChecksumVV(pkt.Data, xsum)
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
-
}
-func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
- optLen := len(opts)
- if rcvWnd > 0xffff {
- rcvWnd = 0xffff
+func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
+ // We need to shallow clone the VectorisedView here as ReadToView will
+ // split the VectorisedView and Trim underlying views as it splits. Not
+ // doing the clone here will cause the underlying views of data itself
+ // to be altered.
+ data = data.Clone(nil)
+
+ optLen := len(tf.opts)
+ if tf.rcvWnd > 0xffff {
+ tf.rcvWnd = 0xffff
}
mss := int(gso.MSS)
n := (data.Size() + mss - 1) / mss
- // Allocate one big slice for all the headers.
- hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen
- buf := make([]byte, n*hdrSize)
- pkts := make([]tcpip.PacketBuffer, n)
- for i := range pkts {
- pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize])
- }
-
size := data.Size()
- off := 0
+ hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen
+ var pkts stack.PacketBufferList
for i := 0; i < n; i++ {
packetSize := mss
if packetSize > size {
packetSize = size
}
size -= packetSize
- pkts[i].DataOffset = off
- pkts[i].DataSize = packetSize
- pkts[i].Data = data
- buildTCPHdr(r, id, &pkts[i], flags, seq, ack, rcvWnd, opts, gso)
- off += packetSize
- seq = seq.Add(seqnum.Size(packetSize))
+ var pkt stack.PacketBuffer
+ pkt.Header = buffer.NewPrependable(hdrSize)
+ pkt.Hash = tf.txHash
+ pkt.Owner = owner
+ data.ReadToVV(&pkt.Data, packetSize)
+ buildTCPHdr(r, tf, &pkt, gso)
+ tf.seq = tf.seq.Add(seqnum.Size(packetSize))
+ pkts.PushBack(&pkt)
}
- if ttl == 0 {
- ttl = r.DefaultTTL()
+
+ if tf.ttl == 0 {
+ tf.ttl = r.DefaultTTL()
}
- sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos})
+ sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos})
if err != nil {
r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent))
}
@@ -782,33 +833,33 @@ func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.Vect
// sendTCP sends a TCP segment with the provided options via the provided
// network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
- optLen := len(opts)
- if rcvWnd > 0xffff {
- rcvWnd = 0xffff
+func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
+ optLen := len(tf.opts)
+ if tf.rcvWnd > 0xffff {
+ tf.rcvWnd = 0xffff
}
if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
- return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso)
+ return sendTCPBatch(r, tf, data, gso, owner)
}
- pkt := tcpip.PacketBuffer{
- Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
- DataOffset: 0,
- DataSize: data.Size(),
- Data: data,
+ pkt := stack.PacketBuffer{
+ Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
+ Data: data,
+ Hash: tf.txHash,
+ Owner: owner,
}
- buildTCPHdr(r, id, &pkt, flags, seq, ack, rcvWnd, opts, gso)
+ buildTCPHdr(r, tf, &pkt, gso)
- if ttl == 0 {
- ttl = r.DefaultTTL()
+ if tf.ttl == 0 {
+ tf.ttl = r.DefaultTTL()
}
- if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, pkt); err != nil {
+ if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil {
r.Stats().TCP.SegmentSendErrors.Increment()
return err
}
r.Stats().TCP.SegmentsSent.Increment()
- if (flags & header.TCPFlagRst) != 0 {
+ if (tf.flags & header.TCPFlagRst) != 0 {
r.Stats().TCP.ResetsSent.Increment()
}
return nil
@@ -860,7 +911,16 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
- err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso)
+ err := e.sendTCP(&e.route, tcpFields{
+ id: e.ID,
+ ttl: e.ttl,
+ tos: e.sendTOS,
+ flags: flags,
+ seq: seq,
+ ack: ack,
+ rcvWnd: rcvWnd,
+ opts: options,
+ }, data, e.gso)
putOptions(options)
return err
}
@@ -875,7 +935,6 @@ func (e *endpoint) handleWrite() *tcpip.Error {
first := e.sndQueue.Front()
if first != nil {
e.snd.writeList.PushBackList(&e.sndQueue)
- e.snd.sndNxtList.UpdateForward(e.sndBufInQueue)
e.sndBufInQueue = 0
}
@@ -994,22 +1053,40 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route)
}
if ep == nil {
- replyWithReset(s)
+ replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
s.decRef()
return
}
+
+ if e == ep {
+ panic("current endpoint not removed from demuxer, enqueing segments to itself")
+ }
+
if ep.(*endpoint).enqueueSegment(s) {
ep.(*endpoint).newSegmentWaker.Assert()
}
}
+// Drain segment queue from the endpoint and try to re-match the segment to a
+// different endpoint. This is used when the current endpoint is transitioned to
+// StateClose and has been unregistered from the transport demuxer.
+func (e *endpoint) drainClosingSegmentQueue() {
+ for {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ break
+ }
+
+ e.tryDeliverSegmentFromClosedEndpoint(s)
+ }
+}
+
func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
if e.rcv.acceptable(s.sequenceNumber, 0) {
// RFC 793, page 37 states that "in all states
// except SYN-SENT, all reset (RST) segments are
// validated by checking their SEQ-fields." So
// we only process it if it's acceptable.
- e.mu.Lock()
switch e.EndpointState() {
// In case of a RST in CLOSE-WAIT linux moves
// the socket to closed state with an error set
@@ -1033,11 +1110,9 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
case StateCloseWait:
e.transitionToStateCloseLocked()
e.HardError = tcpip.ErrAborted
- e.mu.Unlock()
e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
- e.mu.Unlock()
// RFC 793, page 37 states that "in all states
// except SYN-SENT, all reset (RST) segments are
// validated by checking their SEQ-fields." So
@@ -1150,9 +1225,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
// Now check if the received segment has caused us to transition
// to a CLOSED state, if yes then terminate processing and do
// not invoke the sender.
- e.mu.RLock()
state := e.state
- e.mu.RUnlock()
if state == StateClose {
// When we get into StateClose while processing from the queue,
// return immediately and let the protocolMainloop handle it.
@@ -1175,9 +1248,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
// keepalive packets periodically when the connection is idle. If we don't hear
// from the other side after a number of tries, we terminate the connection.
func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
- e.mu.RLock()
userTimeout := e.userTimeout
- e.mu.RUnlock()
e.keepalive.Lock()
if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
@@ -1241,6 +1312,7 @@ func (e *endpoint) disableKeepaliveTimer() {
// goroutine and is responsible for sending segments and handling received
// segments.
func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error {
+ e.mu.Lock()
var closeTimer *time.Timer
var closeWaker sleep.Waker
@@ -1262,7 +1334,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
e.mu.Unlock()
- e.workMu.Unlock()
+
+ e.drainClosingSegmentQueue()
+
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -1273,16 +1347,13 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// completion.
initialRcvWnd := e.initialReceiveWindow()
h := newHandshake(e, seqnum.Size(initialRcvWnd))
- e.mu.Lock()
h.ep.setEndpointState(StateSynSent)
- e.mu.Unlock()
if err := h.execute(); err != nil {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
- e.mu.Lock()
e.setEndpointState(StateError)
e.HardError = err
@@ -1295,9 +1366,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.keepalive.timer.init(&e.keepalive.waker)
defer e.keepalive.timer.cleanup()
- e.mu.Lock()
drained := e.drainDone != nil
- e.mu.Unlock()
if drained {
close(e.drainDone)
<-e.undrain
@@ -1323,10 +1392,8 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// This means the socket is being closed due
// to the TCP-FIN-WAIT2 timeout was hit. Just
// mark the socket as closed.
- e.mu.Lock()
e.transitionToStateCloseLocked()
e.workerCleanup = true
- e.mu.Unlock()
return nil
},
},
@@ -1381,7 +1448,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
if n&notifyClose != 0 && closeTimer == nil {
- e.mu.Lock()
if e.EndpointState() == StateFinWait2 && e.closed {
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
@@ -1390,7 +1456,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
})
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
- e.mu.Unlock()
}
if n&notifyKeepaliveChanged != 0 {
@@ -1410,7 +1475,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
+ e.mu.Unlock()
<-e.undrain
+ e.mu.Lock()
}
}
@@ -1453,7 +1520,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
e.rcvListMu.Unlock()
- e.mu.Lock()
if e.workerCleanup {
e.notifyProtocolGoroutine(notifyClose)
}
@@ -1461,7 +1527,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
cleanupOnError := func(err *tcpip.Error) {
- e.mu.Lock()
e.workerCleanup = true
if err != nil {
e.resetConnectionLocked(err)
@@ -1473,16 +1538,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
loop:
for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError {
e.mu.Unlock()
- e.workMu.Unlock()
v, _ := s.Fetch(true)
- e.workMu.Lock()
+ e.mu.Lock()
- // We need to double check here because the notification maybe
+ // We need to double check here because the notification may be
// stale by the time we got around to processing it.
- //
- // NOTE: since we now hold the workMu the processors cannot
- // change the state of the endpoint so it's safe to proceed
- // after this check.
switch e.EndpointState() {
case StateError:
// If the endpoint has already transitioned to an ERROR
@@ -1495,21 +1555,17 @@ loop:
case StateTimeWait:
fallthrough
case StateClose:
- e.mu.Lock()
break loop
default:
if err := funcs[v].f(); err != nil {
cleanupOnError(err)
return nil
}
- e.mu.Lock()
}
}
- state := e.EndpointState()
- e.mu.Unlock()
var reuseTW func()
- if state == StateTimeWait {
+ if e.EndpointState() == StateTimeWait {
// Disable close timer as we now entering real TIME_WAIT.
if closeTimer != nil {
closeTimer.Stop()
@@ -1519,14 +1575,11 @@ loop:
s.Done()
// Wake up any waiters before we enter TIME_WAIT.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
- e.mu.Lock()
e.workerCleanup = true
- e.mu.Unlock()
reuseTW = e.doTimeWait()
}
// Mark endpoint as closed.
- e.mu.Lock()
if e.EndpointState() != StateError {
e.transitionToStateCloseLocked()
}
@@ -1534,19 +1587,6 @@ loop:
// Lock released below.
epilogue()
- // epilogue removes the endpoint from the transport-demuxer and
- // unlocks e.mu. Now that no new segments can get enqueued to this
- // endpoint, try to re-match the segment to a different endpoint
- // as the current endpoint is closed.
- for {
- s := e.segmentQueue.dequeue()
- if s == nil {
- break
- }
-
- e.tryDeliverSegmentFromClosedEndpoint(s)
- }
-
// A new SYN was received during TIME_WAIT and we need to abort
// the timewait and redirect the segment to the listener queue
if reuseTW != nil {
@@ -1632,6 +1672,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
const timeWaitDone = 3
s := sleep.Sleeper{}
+ defer s.Done()
s.AddWaker(&e.newSegmentWaker, newSegment)
s.AddWaker(&e.notificationWaker, notification)
@@ -1641,9 +1682,9 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
defer timeWaitTimer.Stop()
for {
- e.workMu.Unlock()
+ e.mu.Unlock()
v, _ := s.Fetch(true)
- e.workMu.Lock()
+ e.mu.Lock()
switch v {
case newSegment:
extendTimeWait, reuseTW := e.handleTimeWaitSegments()
@@ -1666,7 +1707,9 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
e.handleTimeWaitSegments()
}
close(e.drainDone)
+ e.mu.Unlock()
<-e.undrain
+ e.mu.Lock()
return nil
}
case timeWaitDone:
diff --git a/pkg/tcpip/transport/tcp/connect_unsafe.go b/pkg/tcpip/transport/tcp/connect_unsafe.go
new file mode 100644
index 000000000..cfc304616
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/connect_unsafe.go
@@ -0,0 +1,30 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// optionsToArray converts a slice of capacity >-= maxOptionSize to an array.
+//
+// optionsToArray panics if the capacity of options is smaller than
+// maxOptionSize.
+func optionsToArray(options []byte) *[maxOptionSize]byte {
+ // Reslice to full capacity.
+ options = options[0:maxOptionSize]
+ return (*[maxOptionSize]byte)(unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&options)).Data))
+}
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index d792b07d6..6062ca916 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -18,7 +18,6 @@ import (
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -128,7 +127,7 @@ func (p *processor) handleSegments() {
continue
}
- if !ep.workMu.TryLock() {
+ if !ep.mu.TryLock() {
ep.newSegmentWaker.Assert()
continue
}
@@ -138,12 +137,10 @@ func (p *processor) handleSegments() {
if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose {
// Send any active resets if required.
if err != nil {
- ep.mu.Lock()
ep.resetConnectionLocked(err)
- ep.mu.Unlock()
}
ep.notifyProtocolGoroutine(notifyTickleWorker)
- ep.workMu.Unlock()
+ ep.mu.Unlock()
continue
}
@@ -151,7 +148,7 @@ func (p *processor) handleSegments() {
p.epQ.enqueue(ep)
}
- ep.workMu.Unlock()
+ ep.mu.Unlock()
}
}
}
@@ -189,7 +186,7 @@ func (d *dispatcher) wait() {
}
}
-func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
ep := stackEP.(*endpoint)
s := newSegment(r, id, pkt)
if !s.parse() {
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 4f361b226..804e95aea 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -568,11 +568,10 @@ func TestV4AcceptOnV4(t *testing.T) {
func testV4ListenClose(t *testing.T, c *context.Context) {
// Set the SynRcvd threshold to zero to force a syn cookie based accept
// to happen.
- saved := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = saved
- }()
- tcp.SynRcvdCountThreshold = 0
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err)
+ }
+
const n = uint16(32)
// Start listening.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index f1ad19dac..7e8def82d 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -18,6 +18,7 @@ import (
"encoding/binary"
"fmt"
"math"
+ "runtime"
"strings"
"sync/atomic"
"time"
@@ -29,11 +30,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tmutex"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -283,6 +282,38 @@ func (*EndpointInfo) IsEndpointInfo() {}
// synchronized. The protocol implementation, however, runs in a single
// goroutine.
//
+// Each endpoint has a few mutexes:
+//
+// e.mu -> Primary mutex for an endpoint must be held for all operations except
+// in e.Readiness where acquiring it will result in a deadlock in epoll
+// implementation.
+//
+// The following three mutexes can be acquired independent of e.mu but if
+// acquired with e.mu then e.mu must be acquired first.
+//
+// e.acceptMu -> protects acceptedChan.
+// e.rcvListMu -> Protects the rcvList and associated fields.
+// e.sndBufMu -> Protects the sndQueue and associated fields.
+// e.lastErrorMu -> Protects the lastError field.
+//
+// LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different
+// based on the context in which the lock is acquired. In the syscall context
+// e.LockUser/e.UnlockUser should be used and when doing background processing
+// e.mu.Lock/e.mu.Unlock should be used. The distinction is described below
+// in brief.
+//
+// The reason for this locking behaviour is to avoid wakeups to handle packets.
+// In cases where the endpoint is already locked the background processor can
+// queue the packet up and go its merry way and the lock owner will eventually
+// process the backlog when releasing the lock. Similarly when acquiring the
+// lock from say a syscall goroutine we can implement a bit of spinning if we
+// know that the lock is not held by another syscall goroutine. Background
+// processors should never hold the lock for long and we can avoid an expensive
+// sleep/wakeup by spinning for a shortwhile.
+//
+// For more details please see the detailed documentation on
+// e.LockUser/e.UnlockUser methods.
+//
// +stateify savable
type endpoint struct {
EndpointInfo
@@ -299,12 +330,6 @@ type endpoint struct {
// Precondition: epQueue.mu must be held to read/write this field..
pendingProcessing bool `state:"nosave"`
- // workMu is used to arbitrate which goroutine may perform protocol
- // work. Only the main protocol goroutine is expected to call Lock() on
- // it, but other goroutines (e.g., send) may call TryLock() to eagerly
- // perform work without having to wait for the main one to wake up.
- workMu tmutex.Mutex `state:"nosave"`
-
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
@@ -330,15 +355,11 @@ type endpoint struct {
rcvBufSize int
rcvBufUsed int
rcvAutoParams rcvBufAutoTuneParams
- // zeroWindow indicates that the window was closed due to receive buffer
- // space being filled up. This is set by the worker goroutine before
- // moving a segment to the rcvList. This setting is cleared by the
- // endpoint when a Read() call reads enough data for the new window to
- // be non-zero.
- zeroWindow bool
- // The following fields are protected by the mutex.
- mu sync.RWMutex `state:"nosave"`
+ // mu protects all endpoint fields unless documented otherwise. mu must
+ // be acquired before interacting with the endpoint fields.
+ mu sync.Mutex `state:"nosave"`
+ ownedByUser uint32
// state must be read/set using the EndpointState()/setEndpointState() methods.
state EndpointState `state:".(EndpointState)"`
@@ -513,6 +534,23 @@ type endpoint struct {
// to the acceptedChan below terminate before we close acceptedChan.
pendingAccepted sync.WaitGroup `state:"nosave"`
+ // acceptMu protects acceptedChan.
+ acceptMu sync.Mutex `state:"nosave"`
+
+ // acceptCond is a condition variable that can be used to block on when
+ // acceptedChan is full and an endpoint is ready to be delivered.
+ //
+ // This condition variable is required because just blocking on sending
+ // to acceptedChan does not work in cases where endpoint.Listen is
+ // called twice with different backlog values. In such cases the channel
+ // is closed and a new one created. Any pending goroutines blocking on
+ // the write to the channel will panic.
+ //
+ // We use this condition variable to block/unblock goroutines which
+ // tried to deliver an endpoint but couldn't because accept backlog was
+ // full ( See: endpoint.deliverAccepted ).
+ acceptCond *sync.Cond `state:"nosave"`
+
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
@@ -561,6 +599,13 @@ type endpoint struct {
// endpoint and at this point the endpoint is only around
// to complete the TCP shutdown.
closed bool
+
+ // txHash is the transport layer hash to be set on outbound packets
+ // emitted by this endpoint.
+ txHash uint32
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -583,14 +628,93 @@ func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
return maxMSS
}
+// LockUser tries to lock e.mu and if it fails it will check if the lock is held
+// by another syscall goroutine. If yes, then it will goto sleep waiting for the
+// lock to be released, if not then it will spin till it acquires the lock or
+// another syscall goroutine acquires it in which case it will goto sleep as
+// described above.
+//
+// The assumption behind spinning here being that background packet processing
+// should not be holding the lock for long and spinning reduces latency as we
+// avoid an expensive sleep/wakeup of of the syscall goroutine).
+func (e *endpoint) LockUser() {
+ for {
+ // Try first if the sock is locked then check if it's owned
+ // by another user goroutine if not then we spin, otherwise
+ // we just goto sleep on the Lock() and wait.
+ if !e.mu.TryLock() {
+ // If socket is owned by the user then just goto sleep
+ // as the lock could be held for a reasonably long time.
+ if atomic.LoadUint32(&e.ownedByUser) == 1 {
+ e.mu.Lock()
+ atomic.StoreUint32(&e.ownedByUser, 1)
+ return
+ }
+ // Spin but yield the processor since the lower half
+ // should yield the lock soon.
+ runtime.Gosched()
+ continue
+ }
+ atomic.StoreUint32(&e.ownedByUser, 1)
+ return
+ }
+}
+
+// UnlockUser will check if there are any segments already queued for processing
+// and process any such segments before unlocking e.mu. This is required because
+// we when packets arrive and endpoint lock is already held then such packets
+// are queued up to be processed. If the lock is held by the endpoint goroutine
+// then it will process these packets but if the lock is instead held by the
+// syscall goroutine then we can have the syscall goroutine process the backlog
+// before unlocking.
+//
+// This avoids an unnecessary wakeup of the endpoint protocol goroutine for the
+// endpoint. It's also required eventually when we get rid of the endpoint
+// protocol goroutine altogether.
+//
+// Precondition: e.LockUser() must have been called before calling e.UnlockUser()
+func (e *endpoint) UnlockUser() {
+ // Lock segment queue before checking so that we avoid a race where
+ // segments can be queued between the time we check if queue is empty
+ // and actually unlock the endpoint mutex.
+ for {
+ e.segmentQueue.mu.Lock()
+ if e.segmentQueue.emptyLocked() {
+ if atomic.SwapUint32(&e.ownedByUser, 0) != 1 {
+ panic("e.UnlockUser() called without calling e.LockUser()")
+ }
+ e.mu.Unlock()
+ e.segmentQueue.mu.Unlock()
+ return
+ }
+ e.segmentQueue.mu.Unlock()
+
+ switch e.EndpointState() {
+ case StateEstablished:
+ if err := e.handleSegments(true /* fastPath */); err != nil {
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ }
+ default:
+ // Since we are waking the endpoint goroutine here just unlock
+ // and let it process the queued segments.
+ e.newSegmentWaker.Assert()
+ if atomic.SwapUint32(&e.ownedByUser, 0) != 1 {
+ panic("e.UnlockUser() called without calling e.LockUser()")
+ }
+ e.mu.Unlock()
+ return
+ }
+ }
+}
+
// StopWork halts packet processing. Only to be used in tests.
func (e *endpoint) StopWork() {
- e.workMu.Lock()
+ e.mu.Lock()
}
// ResumeWork resumes packet processing. Only to be used in tests.
func (e *endpoint) ResumeWork() {
- e.workMu.Unlock()
+ e.mu.Unlock()
}
// setEndpointState updates the state of the endpoint to state atomically. This
@@ -672,6 +796,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
count: 9,
},
uniqueID: s.UniqueID(),
+ txHash: s.Rand().Uint32(),
}
var ss SendBufferSizeOption
@@ -696,7 +821,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
var de DelayEnabled
if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
- e.SetSockOptInt(tcpip.DelayOption, 1)
+ e.SetSockOptBool(tcpip.DelayOption, true)
}
var tcpLT tcpip.TCPLingerTimeoutOption
@@ -709,9 +834,8 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
}
e.segmentQueue.setLimit(MaxUnprocessedSegments)
- e.workMu.Init()
- e.workMu.Lock()
e.tsOffset = timeStampOffset()
+ e.acceptCond = sync.NewCond(&e.acceptMu)
return e
}
@@ -721,9 +845,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
result := waiter.EventMask(0)
- e.mu.RLock()
- defer e.mu.RUnlock()
-
switch e.EndpointState() {
case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
@@ -735,9 +856,11 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
case StateListen:
// Check if there's anything in the accepted channel.
if (mask & waiter.EventIn) != 0 {
+ e.acceptMu.Lock()
if len(e.acceptedChan) > 0 {
result |= waiter.EventIn
}
+ e.acceptMu.Unlock()
}
}
if e.EndpointState().connected() {
@@ -798,7 +921,21 @@ func (e *endpoint) Abort() {
// If the endpoint disconnected after the check, nothing needs to be
// done, so sending a notification which will potentially be ignored is
// fine.
- if e.EndpointState().connected() {
+ //
+ // If the endpoint connecting finishes after the check, the endpoint
+ // is either in a connected state (where we would notifyAbort anyway),
+ // SYN-RECV (where we would also notifyAbort anyway), or in an error
+ // state where nothing is required and the notification can be safely
+ // ignored.
+ //
+ // Endpoints where a Close during connecting or SYN-RECV state would be
+ // problematic are set to state connecting before being registered (and
+ // thus possible to be Aborted). They are never available in initial
+ // state.
+ //
+ // Endpoints transitioning from initial to connecting state may be
+ // safely either closed or sent notifyAbort.
+ if s := e.EndpointState(); s == StateConnecting || s == StateSynRecv || s.connected() {
e.notifyProtocolGoroutine(notifyAbort)
return
}
@@ -809,25 +946,22 @@ func (e *endpoint) Abort() {
// with it. It must be called only once and with no other concurrent calls to
// the endpoint.
func (e *endpoint) Close() {
- e.mu.Lock()
- closed := e.closed
- e.mu.Unlock()
- if closed {
+ e.LockUser()
+ defer e.UnlockUser()
+ if e.closed {
return
}
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
- e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
- e.closeNoShutdown()
+ e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ e.closeNoShutdownLocked()
}
// closeNoShutdown closes the endpoint without doing a full shutdown. This is
// used when a connection needs to be aborted with a RST and we want to skip
// a full 4 way TCP shutdown.
-func (e *endpoint) closeNoShutdown() {
- e.mu.Lock()
-
+func (e *endpoint) closeNoShutdownLocked() {
// For listening sockets, we always release ports inline so that they
// are immediately available for reuse after Close() is called. If also
// registered, we unregister as well otherwise the next user would fail
@@ -846,66 +980,55 @@ func (e *endpoint) closeNoShutdown() {
// Mark endpoint as closed.
e.closed = true
+
+ switch e.EndpointState() {
+ case StateClose, StateError:
+ return
+ }
+
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
- tcpip.AddDanglingEndpoint(e)
- switch e.EndpointState() {
- // Sockets in StateSynRecv state(passive connections) are closed when
- // the handshake fails or if the listening socket is closed while
- // handshake was in progress. In such cases the handshake goroutine
- // is already gone by the time Close is called and we need to cleanup
- // here.
- case StateInitial, StateBound, StateSynRecv:
- e.cleanupLocked()
- e.setEndpointState(StateClose)
- case StateError, StateClose:
- // do nothing.
- default:
+ if e.workerRunning {
e.workerCleanup = true
+ tcpip.AddDanglingEndpoint(e)
+ // Worker will remove the dangling endpoint when the endpoint
+ // goroutine terminates.
e.notifyProtocolGoroutine(notifyClose)
+ } else {
+ e.transitionToStateCloseLocked()
}
-
- e.mu.Unlock()
}
// closePendingAcceptableConnections closes all connections that have completed
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
- done := make(chan struct{})
- // Spin a goroutine up as ranging on e.acceptedChan will just block when
- // there are no more connections in the channel. Using a non-blocking
- // select does not work as it can potentially select the default case
- // even when there are pending writes but that are not yet written to
- // the channel.
- go func() {
- defer close(done)
- for n := range e.acceptedChan {
- n.notifyProtocolGoroutine(notifyReset)
- // close all connections that have completed but
- // not accepted by the application.
- n.Close()
- }
- }()
- // pendingAccepted(see endpoint.deliverAccepted) tracks the number of
- // endpoints which have completed handshake but are not yet written to
- // the e.acceptedChan. We wait here till the goroutine above can drain
- // all such connections from e.acceptedChan.
- e.pendingAccepted.Wait()
+ e.acceptMu.Lock()
+ if e.acceptedChan == nil {
+ e.acceptMu.Unlock()
+ return
+ }
close(e.acceptedChan)
- <-done
+ ch := e.acceptedChan
e.acceptedChan = nil
+ e.acceptCond.Broadcast()
+ e.acceptMu.Unlock()
+
+ // Reset all connections that are waiting to be accepted.
+ for n := range ch {
+ n.notifyProtocolGoroutine(notifyReset)
+ }
+ // Wait for reset of all endpoints that are still waiting to be delivered to
+ // the now closed acceptedChan.
+ e.pendingAccepted.Wait()
}
// cleanupLocked frees all resources associated with the endpoint. It is called
// after Close() is called and the worker goroutine (if any) is done with its
// work.
func (e *endpoint) cleanupLocked() {
-
// Close all endpoints that might have been accepted by TCP but not by
// the client.
- if e.acceptedChan != nil {
- e.closePendingAcceptableConnectionsLocked()
- }
+ e.closePendingAcceptableConnectionsLocked()
e.workerCleanup = false
@@ -945,6 +1068,9 @@ func (e *endpoint) initialReceiveWindow() int {
// ModerateRecvBuf adjusts the receive buffer and the advertised window
// based on the number of bytes copied to user space.
func (e *endpoint) ModerateRecvBuf(copied int) {
+ e.LockUser()
+ defer e.UnlockUser()
+
e.rcvListMu.Lock()
if e.rcvAutoParams.disabled {
e.rcvListMu.Unlock()
@@ -994,7 +1120,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvBufSize = rcvWnd
availAfter := e.receiveBufferAvailableLocked()
mask := uint32(notifyReceiveWindowChanged)
- if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
e.notifyProtocolGoroutine(mask)
@@ -1011,14 +1137,18 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvListMu.Unlock()
}
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
+}
+
// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (iptables.IPTables, error) {
+func (e *endpoint) IPTables() (stack.IPTables, error) {
return e.stack.IPTables(), nil
}
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- e.mu.RLock()
+ e.LockUser()
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data. Also note that a RST being received
// would cause the state to become StateError so we should allow the
@@ -1028,7 +1158,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.HardError
- e.mu.RUnlock()
+ e.UnlockUser()
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
@@ -1038,8 +1168,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
v, err := e.readLocked()
e.rcvListMu.Unlock()
-
- e.mu.RUnlock()
+ e.UnlockUser()
if err == tcpip.ErrClosedForReceive {
e.stats.ReadErrors.ReadClosed.Increment()
@@ -1071,7 +1200,7 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
- if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -1112,13 +1241,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
// and opts.EndOfRecord are also ignored.
- e.mu.RLock()
+ e.LockUser()
e.sndBufMu.Lock()
avail, err := e.isEndpointWritableLocked()
if err != nil {
e.sndBufMu.Unlock()
- e.mu.RUnlock()
+ e.UnlockUser()
e.stats.WriteErrors.WriteClosed.Increment()
return 0, nil, err
}
@@ -1130,113 +1259,68 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// are copying data in.
if !opts.Atomic {
e.sndBufMu.Unlock()
- e.mu.RUnlock()
+ e.UnlockUser()
}
// Fetch data.
v, perr := p.Payload(avail)
if perr != nil || len(v) == 0 {
- if opts.Atomic { // See above.
+ // Note that perr may be nil if len(v) == 0.
+ if opts.Atomic {
e.sndBufMu.Unlock()
- e.mu.RUnlock()
+ e.UnlockUser()
}
- // Note that perr may be nil if len(v) == 0.
return 0, nil, perr
}
- if opts.Atomic {
+ queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) {
// Add data to the send queue.
s := newSegmentFromView(&e.route, e.ID, v)
e.sndBufUsed += len(v)
e.sndBufInQueue += seqnum.Size(len(v))
e.sndQueue.PushBack(s)
e.sndBufMu.Unlock()
- // Release the endpoint lock to prevent deadlocks due to lock
- // order inversion when acquiring workMu.
- e.mu.RUnlock()
- }
-
- if e.workMu.TryLock() {
- // Since we released locks in between it's possible that the
- // endpoint transitioned to a CLOSED/ERROR states so make
- // sure endpoint is still writable before trying to write.
- if !opts.Atomic { // See above.
- e.mu.RLock()
- e.sndBufMu.Lock()
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
- e.stats.WriteErrors.WriteClosed.Increment()
- return 0, nil, err
- }
-
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
- }
- // Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
- e.sndBufMu.Unlock()
- // Release the endpoint lock to prevent deadlocks due to lock
- // order inversion when acquiring workMu.
- e.mu.RUnlock()
-
- }
// Do the work inline.
e.handleWrite()
- e.workMu.Unlock()
- } else {
- if !opts.Atomic { // See above.
- e.mu.RLock()
- e.sndBufMu.Lock()
+ e.UnlockUser()
+ return int64(len(v)), nil, nil
+ }
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
- e.stats.WriteErrors.WriteClosed.Increment()
- return 0, nil, err
- }
+ if opts.Atomic {
+ // Locks released in queueAndSend()
+ return queueAndSend()
+ }
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
- }
- // Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
- e.sndBufMu.Unlock()
- // Release the endpoint lock to prevent deadlocks due to lock
- // order inversion when acquiring workMu.
- e.mu.RUnlock()
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
+ e.LockUser()
+ e.sndBufMu.Lock()
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.UnlockUser()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
- }
- // Let the protocol goroutine do the work.
- e.sndWaker.Assert()
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
}
- return int64(len(v)), nil, nil
+ // Locks released in queueAndSend()
+ return queueAndSend()
}
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data.
@@ -1289,9 +1373,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
return num, tcpip.ControlMessages{}, nil
}
-// windowCrossedACKThreshold checks if the receive window to be announced now
-// would be under aMSS or under half receive buffer, whichever smaller. This is
-// useful as a receive side silly window syndrome prevention mechanism. If
+// windowCrossedACKThresholdLocked checks if the receive window to be announced
+// now would be under aMSS or under half receive buffer, whichever smaller. This
+// is useful as a receive side silly window syndrome prevention mechanism. If
// window grows to reasonable value, we should send ACK to the sender to inform
// the rx space is now large. We also want ensure a series of small read()'s
// won't trigger a flood of spurious tiny ACK's.
@@ -1302,7 +1386,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// crossed will be true if the window size crossed the ACK threshold.
// above will be true if the new window is >= ACK threshold and false
// otherwise.
-func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) {
+//
+// Precondition: e.mu and e.rcvListMu must be held.
+func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
newAvail := e.receiveBufferAvailableLocked()
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
@@ -1326,21 +1412,71 @@ func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, abo
// SetSockOptBool sets a socket option.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
switch opt {
+
+ case tcpip.BroadcastOption:
+ e.LockUser()
+ e.broadcast = v
+ e.UnlockUser()
+
+ case tcpip.CorkOption:
+ e.LockUser()
+ if !v {
+ atomic.StoreUint32(&e.cork, 0)
+
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.cork, 1)
+ }
+ e.UnlockUser()
+
+ case tcpip.DelayOption:
+ if v {
+ atomic.StoreUint32(&e.delay, 1)
+ } else {
+ atomic.StoreUint32(&e.delay, 0)
+
+ // Handle delayed data.
+ e.sndWaker.Assert()
+ }
+
+ case tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ e.keepalive.enabled = v
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+
+ case tcpip.QuickAckOption:
+ o := uint32(1)
+ if v {
+ o = 0
+ }
+ atomic.StoreUint32(&e.slowAck, o)
+
+ case tcpip.ReuseAddressOption:
+ e.LockUser()
+ e.reuseAddr = v
+ e.UnlockUser()
+
+ case tcpip.ReusePortOption:
+ e.LockUser()
+ e.reusePort = v
+ e.UnlockUser()
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
return tcpip.ErrInvalidEndpointState
}
- e.mu.Lock()
- defer e.mu.Unlock()
-
// We only allow this to be set when we're in the initial state.
if e.EndpointState() != StateInitial {
return tcpip.ErrInvalidEndpointState
}
+ e.LockUser()
e.v6only = v
+ e.UnlockUser()
}
return nil
@@ -1348,23 +1484,56 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt sets a socket option.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ // Lower 2 bits represents ECN bits. RFC 3168, section 23.1
+ const inetECNMask = 3
+
switch opt {
+ case tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ e.keepalive.count = v
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+
+ case tcpip.IPv4TOSOption:
+ e.LockUser()
+ // TODO(gvisor.dev/issue/995): ECN is not currently supported,
+ // ignore the bits for now.
+ e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+ e.UnlockUser()
+
+ case tcpip.IPv6TrafficClassOption:
+ e.LockUser()
+ // TODO(gvisor.dev/issue/995): ECN is not currently supported,
+ // ignore the bits for now.
+ e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+ e.UnlockUser()
+
+ case tcpip.MaxSegOption:
+ userMSS := v
+ if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
+ return tcpip.ErrInvalidOptionValue
+ }
+ e.LockUser()
+ e.userMSS = uint16(userMSS)
+ e.UnlockUser()
+ e.notifyProtocolGoroutine(notifyMSSChanged)
+
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
var rs ReceiveBufferSizeOption
- size := int(v)
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
- if size < rs.Min {
- size = rs.Min
+ if v < rs.Min {
+ v = rs.Min
}
- if size > rs.Max {
- size = rs.Max
+ if v > rs.Max {
+ v = rs.Max
}
}
mask := uint32(notifyReceiveWindowChanged)
+ e.LockUser()
e.rcvListMu.Lock()
// Make sure the receive buffer size allows us to send a
@@ -1373,17 +1542,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
if e.rcv != nil {
scale = e.rcv.rcvWndScale
}
- if size>>scale == 0 {
- size = 1 << scale
+ if v>>scale == 0 {
+ v = 1 << scale
}
// Make sure 2*size doesn't overflow.
- if size > math.MaxInt32/2 {
- size = math.MaxInt32 / 2
+ if v > math.MaxInt32/2 {
+ v = math.MaxInt32 / 2
}
availBefore := e.receiveBufferAvailableLocked()
- e.rcvBufSize = size
+ e.rcvBufSize = v
availAfter := e.receiveBufferAvailableLocked()
e.rcvAutoParams.disabled = true
@@ -1391,151 +1560,71 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// Immediately send an ACK to uncork the sender silly window
// syndrome prevetion, when our available space grows above aMSS
// or half receive buffer, whichever smaller.
- if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
- e.rcvListMu.Unlock()
+ e.rcvListMu.Unlock()
+ e.UnlockUser()
e.notifyProtocolGoroutine(mask)
- return nil
case tcpip.SendBufferSizeOption:
// Make sure the send buffer size is within the min and max
// allowed.
- size := int(v)
var ss SendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
- if size < ss.Min {
- size = ss.Min
+ if v < ss.Min {
+ v = ss.Min
}
- if size > ss.Max {
- size = ss.Max
+ if v > ss.Max {
+ v = ss.Max
}
}
e.sndBufMu.Lock()
- e.sndBufSize = size
+ e.sndBufSize = v
e.sndBufMu.Unlock()
- return nil
-
- case tcpip.DelayOption:
- if v == 0 {
- atomic.StoreUint32(&e.delay, 0)
- // Handle delayed data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.delay, 1)
- }
- return nil
+ case tcpip.TTLOption:
+ e.LockUser()
+ e.ttl = uint8(v)
+ e.UnlockUser()
- default:
- return nil
}
+ return nil
}
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- // Lower 2 bits represents ECN bits. RFC 3168, section 23.1
- const inetECNMask = 3
switch v := opt.(type) {
- case tcpip.CorkOption:
- if v == 0 {
- atomic.StoreUint32(&e.cork, 0)
-
- // Handle the corked data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.cork, 1)
- }
- return nil
-
- case tcpip.ReuseAddressOption:
- e.mu.Lock()
- e.reuseAddr = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.reusePort = v != 0
- e.mu.Unlock()
- return nil
-
case tcpip.BindToDeviceOption:
id := tcpip.NICID(v)
if id != 0 && !e.stack.HasNIC(id) {
return tcpip.ErrUnknownDevice
}
- e.mu.Lock()
+ e.LockUser()
e.bindToDevice = id
- e.mu.Unlock()
- return nil
-
- case tcpip.QuickAckOption:
- if v == 0 {
- atomic.StoreUint32(&e.slowAck, 1)
- } else {
- atomic.StoreUint32(&e.slowAck, 0)
- }
- return nil
-
- case tcpip.MaxSegOption:
- userMSS := v
- if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
- return tcpip.ErrInvalidOptionValue
- }
- e.mu.Lock()
- e.userMSS = uint16(userMSS)
- e.mu.Unlock()
- e.notifyProtocolGoroutine(notifyMSSChanged)
- return nil
-
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(v)
- e.mu.Unlock()
- return nil
-
- case tcpip.KeepaliveEnabledOption:
- e.keepalive.Lock()
- e.keepalive.enabled = v != 0
- e.keepalive.Unlock()
- e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- return nil
+ e.UnlockUser()
case tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
e.keepalive.idle = time.Duration(v)
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- return nil
case tcpip.KeepaliveIntervalOption:
e.keepalive.Lock()
e.keepalive.interval = time.Duration(v)
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- return nil
- case tcpip.KeepaliveCountOption:
- e.keepalive.Lock()
- e.keepalive.count = int(v)
- e.keepalive.Unlock()
- e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- return nil
+ case tcpip.OutOfBandInlineOption:
+ // We don't currently support disabling this option.
case tcpip.TCPUserTimeoutOption:
- e.mu.Lock()
+ e.LockUser()
e.userTimeout = time.Duration(v)
- e.mu.Unlock()
- return nil
-
- case tcpip.BroadcastOption:
- e.mu.Lock()
- e.broadcast = v != 0
- e.mu.Unlock()
- return nil
+ e.UnlockUser()
case tcpip.CongestionControlOption:
// Query the available cc algorithms in the stack and
@@ -1548,22 +1637,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
availCC := strings.Split(string(avail), " ")
for _, cc := range availCC {
if v == tcpip.CongestionControlOption(cc) {
- // Acquire the work mutex as we may need to
- // reinitialize the congestion control state.
- e.mu.Lock()
+ e.LockUser()
state := e.EndpointState()
e.cc = v
- e.mu.Unlock()
switch state {
case StateEstablished:
- e.workMu.Lock()
- e.mu.Lock()
if e.EndpointState() == state {
e.snd.cc = e.snd.initCongestionControl(e.cc)
}
- e.mu.Unlock()
- e.workMu.Unlock()
}
+ e.UnlockUser()
return nil
}
}
@@ -1572,24 +1655,8 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// control algorithm is specified.
return tcpip.ErrNoSuchFile
- case tcpip.IPv4TOSOption:
- e.mu.Lock()
- // TODO(gvisor.dev/issue/995): ECN is not currently supported,
- // ignore the bits for now.
- e.sendTOS = uint8(v) & ^uint8(inetECNMask)
- e.mu.Unlock()
- return nil
-
- case tcpip.IPv6TrafficClassOption:
- e.mu.Lock()
- // TODO(gvisor.dev/issue/995): ECN is not currently supported,
- // ignore the bits for now.
- e.sendTOS = uint8(v) & ^uint8(inetECNMask)
- e.mu.Unlock()
- return nil
-
case tcpip.TCPLingerTimeoutOption:
- e.mu.Lock()
+ e.LockUser()
if v < 0 {
// Same as effectively disabling TCPLinger timeout.
v = 0
@@ -1607,27 +1674,26 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
v = stkTCPLingerTimeout
}
e.tcpLingerTimeout = time.Duration(v)
- e.mu.Unlock()
- return nil
+ e.UnlockUser()
case tcpip.TCPDeferAcceptOption:
- e.mu.Lock()
+ e.LockUser()
if time.Duration(v) > MaxRTO {
v = tcpip.TCPDeferAcceptOption(MaxRTO)
}
e.deferAccept = time.Duration(v)
- e.mu.Unlock()
- return nil
+ e.UnlockUser()
default:
return nil
}
+ return nil
}
// readyReceiveSize returns the number of bytes ready to be received.
func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
// The endpoint cannot be in listen state.
if e.EndpointState() == StateListen {
@@ -1643,25 +1709,89 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
+ case tcpip.BroadcastOption:
+ e.LockUser()
+ v := e.broadcast
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.CorkOption:
+ return atomic.LoadUint32(&e.cork) != 0, nil
+
+ case tcpip.DelayOption:
+ return atomic.LoadUint32(&e.delay) != 0, nil
+
+ case tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ v := e.keepalive.enabled
+ e.keepalive.Unlock()
+
+ return v, nil
+
+ case tcpip.QuickAckOption:
+ v := atomic.LoadUint32(&e.slowAck) == 0
+ return v, nil
+
+ case tcpip.ReuseAddressOption:
+ e.LockUser()
+ v := e.reuseAddr
+ e.UnlockUser()
+
+ return v, nil
+
+ case tcpip.ReusePortOption:
+ e.LockUser()
+ v := e.reusePort
+ e.UnlockUser()
+
+ return v, nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
return false, tcpip.ErrUnknownProtocolOption
}
- e.mu.Lock()
+ e.LockUser()
v := e.v6only
- e.mu.Unlock()
+ e.UnlockUser()
return v, nil
- }
- return false, tcpip.ErrUnknownProtocolOption
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
+ case tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ v := e.keepalive.count
+ e.keepalive.Unlock()
+ return v, nil
+
+ case tcpip.IPv4TOSOption:
+ e.LockUser()
+ v := int(e.sendTOS)
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.LockUser()
+ v := int(e.sendTOS)
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.MaxSegOption:
+ // This is just stubbed out. Linux never returns the user_mss
+ // value as it either returns the defaultMSS or returns the
+ // actual current MSS. Netstack just returns the defaultMSS
+ // always for now.
+ v := header.TCPDefaultMSS
+ return v, nil
+
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
@@ -1677,12 +1807,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.rcvListMu.Unlock()
return v, nil
- case tcpip.DelayOption:
- var o int
- if v := atomic.LoadUint32(&e.delay); v != 0 {
- o = 1
- }
- return o, nil
+ case tcpip.TTLOption:
+ e.LockUser()
+ v := int(e.ttl)
+ e.UnlockUser()
+ return v, nil
default:
return -1, tcpip.ErrUnknownProtocolOption
@@ -1699,168 +1828,71 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.lastErrorMu.Unlock()
return err
- case *tcpip.MaxSegOption:
- // This is just stubbed out. Linux never returns the user_mss
- // value as it either returns the defaultMSS or returns the
- // actual current MSS. Netstack just returns the defaultMSS
- // always for now.
- *o = header.TCPDefaultMSS
- return nil
-
- case *tcpip.CorkOption:
- *o = 0
- if v := atomic.LoadUint32(&e.cork); v != 0 {
- *o = 1
- }
- return nil
-
- case *tcpip.ReuseAddressOption:
- e.mu.RLock()
- v := e.reuseAddr
- e.mu.RUnlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
- case *tcpip.ReusePortOption:
- e.mu.RLock()
- v := e.reusePort
- e.mu.RUnlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
case *tcpip.BindToDeviceOption:
- e.mu.RLock()
+ e.LockUser()
*o = tcpip.BindToDeviceOption(e.bindToDevice)
- e.mu.RUnlock()
- return nil
-
- case *tcpip.QuickAckOption:
- *o = 1
- if v := atomic.LoadUint32(&e.slowAck); v != 0 {
- *o = 0
- }
- return nil
-
- case *tcpip.TTLOption:
- e.mu.Lock()
- *o = tcpip.TTLOption(e.ttl)
- e.mu.Unlock()
- return nil
+ e.UnlockUser()
case *tcpip.TCPInfoOption:
*o = tcpip.TCPInfoOption{}
- e.mu.RLock()
+ e.LockUser()
snd := e.snd
- e.mu.RUnlock()
+ e.UnlockUser()
if snd != nil {
snd.rtt.Lock()
o.RTT = snd.rtt.srtt
o.RTTVar = snd.rtt.rttvar
snd.rtt.Unlock()
}
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- e.keepalive.Lock()
- v := e.keepalive.enabled
- e.keepalive.Unlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
case *tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
*o = tcpip.KeepaliveIdleOption(e.keepalive.idle)
e.keepalive.Unlock()
- return nil
case *tcpip.KeepaliveIntervalOption:
e.keepalive.Lock()
*o = tcpip.KeepaliveIntervalOption(e.keepalive.interval)
e.keepalive.Unlock()
- return nil
-
- case *tcpip.KeepaliveCountOption:
- e.keepalive.Lock()
- *o = tcpip.KeepaliveCountOption(e.keepalive.count)
- e.keepalive.Unlock()
- return nil
case *tcpip.TCPUserTimeoutOption:
- e.mu.Lock()
+ e.LockUser()
*o = tcpip.TCPUserTimeoutOption(e.userTimeout)
- e.mu.Unlock()
- return nil
+ e.UnlockUser()
case *tcpip.OutOfBandInlineOption:
// We don't currently support disabling this option.
*o = 1
- return nil
-
- case *tcpip.BroadcastOption:
- e.mu.Lock()
- v := e.broadcast
- e.mu.Unlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
case *tcpip.CongestionControlOption:
- e.mu.Lock()
+ e.LockUser()
*o = e.cc
- e.mu.Unlock()
- return nil
-
- case *tcpip.IPv4TOSOption:
- e.mu.RLock()
- *o = tcpip.IPv4TOSOption(e.sendTOS)
- e.mu.RUnlock()
- return nil
-
- case *tcpip.IPv6TrafficClassOption:
- e.mu.RLock()
- *o = tcpip.IPv6TrafficClassOption(e.sendTOS)
- e.mu.RUnlock()
- return nil
+ e.UnlockUser()
case *tcpip.TCPLingerTimeoutOption:
- e.mu.Lock()
+ e.LockUser()
*o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout)
- e.mu.Unlock()
- return nil
+ e.UnlockUser()
case *tcpip.TCPDeferAcceptOption:
- e.mu.Lock()
+ e.LockUser()
*o = tcpip.TCPDeferAcceptOption(e.deferAccept)
- e.mu.Unlock()
- return nil
+ e.UnlockUser()
default:
return tcpip.ErrUnknownProtocolOption
}
+ return nil
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only)
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
if err != nil {
- return 0, err
+ return tcpip.FullAddress{}, 0, err
}
- *addr = unwrapped
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -1885,12 +1917,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// yet accepted by the app, they are restored without running the main goroutine
// here.
func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.LockUser()
+ defer e.UnlockUser()
connectingAddr := addr.Addr
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -2055,13 +2087,17 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
// Shutdown closes the read and/or write end of the endpoint connection to its
// peer.
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
- e.mu.Lock()
+ e.LockUser()
+ defer e.UnlockUser()
+ return e.shutdownLocked(flags)
+}
+
+func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
e.shutdownFlags |= flags
- finQueued := false
switch {
case e.EndpointState().connected():
// Close for read.
- if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
+ if e.shutdownFlags&tcpip.ShutdownRead != 0 {
// Mark read side as closed.
e.rcvListMu.Lock()
e.rcvClosed = true
@@ -2070,69 +2106,58 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// If we're fully closed and we have unread data we need to abort
// the connection with a RST.
- if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
- e.mu.Unlock()
- // Try to send an active reset immediately if the
- // work mutex is available.
- if e.workMu.TryLock() {
- e.mu.Lock()
- // We need to double check here to make
- // sure worker has not transitioned the
- // endpoint out of a connected state
- // before trying to send a reset.
- if e.EndpointState().connected() {
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
- e.notifyProtocolGoroutine(notifyTickleWorker)
- }
- e.mu.Unlock()
- e.workMu.Unlock()
- } else {
- e.notifyProtocolGoroutine(notifyReset)
- }
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 {
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ // Wake up worker to terminate loop.
+ e.notifyProtocolGoroutine(notifyTickleWorker)
return nil
}
}
// Close for write.
- if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 {
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
e.sndBufMu.Lock()
-
if e.sndClosed {
// Already closed.
e.sndBufMu.Unlock()
- break
+ if e.EndpointState() == StateTimeWait {
+ return tcpip.ErrNotConnected
+ }
+ return nil
}
// Queue fin segment.
s := newSegmentFromView(&e.route, e.ID, nil)
e.sndQueue.PushBack(s)
e.sndBufInQueue++
- finQueued = true
// Mark endpoint as closed.
e.sndClosed = true
e.sndBufMu.Unlock()
+ e.handleClose()
}
+ return nil
case e.EndpointState() == StateListen:
- // Tell protocolListenLoop to stop.
- if flags&tcpip.ShutdownRead != 0 {
- e.notifyProtocolGoroutine(notifyClose)
+ if e.shutdownFlags&tcpip.ShutdownRead != 0 {
+ // Reset all connections from the accept queue and keep the
+ // worker running so that it can continue handling incoming
+ // segments by replying with RST.
+ //
+ // By not removing this endpoint from the demuxer mapping, we
+ // ensure that any other bind to the same port fails, as on Linux.
+ // TODO(gvisor.dev/issue/2468): We need to enable applications to
+ // start listening on this endpoint again similar to Linux.
+ e.rcvListMu.Lock()
+ e.rcvClosed = true
+ e.rcvListMu.Unlock()
+ e.closePendingAcceptableConnectionsLocked()
+ // Notify waiters that the endpoint is shutdown.
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}
+ return nil
default:
- e.mu.Unlock()
return tcpip.ErrNotConnected
}
- e.mu.Unlock()
- if finQueued {
- if e.workMu.TryLock() {
- e.handleClose()
- e.workMu.Unlock()
- } else {
- // Tell protocol goroutine to close.
- e.sndCloseWaker.Assert()
- }
- }
- return nil
}
// Listen puts the endpoint in "listen" mode, which allows it to accept
@@ -2147,8 +2172,8 @@ func (e *endpoint) Listen(backlog int) *tcpip.Error {
}
func (e *endpoint) listen(backlog int) *tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.LockUser()
+ defer e.UnlockUser()
// Allow the backlog to be adjusted if the endpoint is not shutting down.
// When the endpoint shuts down, it sets workerCleanup to true, and from
@@ -2157,6 +2182,8 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
if e.EndpointState() == StateListen && !e.workerCleanup {
// Adjust the size of the channel iff we can fix existing
// pending connections into the new one.
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
if len(e.acceptedChan) > backlog {
return tcpip.ErrInvalidEndpointState
}
@@ -2169,6 +2196,11 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
for ep := range origChan {
e.acceptedChan <- ep
}
+
+ // Notify any blocked goroutines that they can attempt to
+ // deliver endpoints again.
+ e.acceptCond.Broadcast()
+
return nil
}
@@ -2198,9 +2230,12 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// The channel may be non-nil when we're restoring the endpoint, and it
// may be pre-populated with some previously accepted (but not Accepted)
// endpoints.
+ e.acceptMu.Lock()
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
+ e.acceptMu.Unlock()
+
e.workerRunning = true
go e.protocolListenLoop( // S/R-SAFE: drained on save.
seqnum.Size(e.receiveBufferAvailable()))
@@ -2210,7 +2245,6 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
func (e *endpoint) startAcceptedLoop() {
- e.mu.Lock()
e.workerRunning = true
e.mu.Unlock()
wakerInitDone := make(chan struct{})
@@ -2221,18 +2255,24 @@ func (e *endpoint) startAcceptedLoop() {
// Accept returns a new endpoint if a peer has established a connection
// to an endpoint previously set to listen mode.
func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
+ e.rcvListMu.Lock()
+ rcvClosed := e.rcvClosed
+ e.rcvListMu.Unlock()
// Endpoint must be in listen state before it can accept connections.
- if e.EndpointState() != StateListen {
+ if rcvClosed || e.EndpointState() != StateListen {
return nil, nil, tcpip.ErrInvalidEndpointState
}
// Get the new accepted endpoint.
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
var n *endpoint
select {
case n = <-e.acceptedChan:
+ e.acceptCond.Signal()
default:
return nil, nil, tcpip.ErrWouldBlock
}
@@ -2241,8 +2281,8 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
// Bind binds the endpoint to a specific local port and optionally address.
func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.LockUser()
+ defer e.UnlockUser()
return e.bindLocked(addr)
}
@@ -2256,7 +2296,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
e.BindAddr = addr.Addr
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -2320,8 +2360,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// GetLocalAddress returns the address to which the endpoint is bound.
func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
return tcpip.FullAddress{
Addr: e.ID.LocalAddress,
@@ -2332,8 +2372,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
// GetRemoteAddress returns the address to which the endpoint is connected.
func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
if !e.EndpointState().connected() {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
@@ -2346,7 +2386,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}, nil
}
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
// TCP HandlePacket is not required anymore as inbound packets first
// land at the Dispatcher which then can either delivery using the
// worker go routine or directly do the invoke the tcp processing inline
@@ -2365,7 +2405,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool {
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
switch typ {
case stack.ControlPacketTooBig:
e.sndBufMu.Lock()
@@ -2406,7 +2446,7 @@ func (e *endpoint) readyToRead(s *segment) {
e.rcvBufUsed += s.data.Size()
// Increase counter if the receive window falls down below MSS
// or half receive buffer size, whichever smaller.
- if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
}
e.rcvList.PushBack(s)
@@ -2414,7 +2454,6 @@ func (e *endpoint) readyToRead(s *segment) {
e.rcvClosed = true
}
e.rcvListMu.Unlock()
-
e.waiterQueue.Notify(waiter.EventIn)
}
@@ -2558,9 +2597,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
s.SegTime = time.Now()
// Copy EndpointID.
- e.mu.Lock()
s.ID = stack.TCPEndpointID(e.ID)
- e.mu.Unlock()
// Copy endpoint rcv state.
e.rcvListMu.Lock()
@@ -2690,10 +2727,10 @@ func (e *endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
- e.mu.RLock()
+ e.LockUser()
// Make a copy of the endpoint info.
ret := e.EndpointInfo
- e.mu.RUnlock()
+ e.UnlockUser()
return &ret
}
@@ -2708,9 +2745,9 @@ func (e *endpoint) Wait() {
e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp)
defer e.waiterQueue.EventUnregister(&waitEntry)
for {
- e.mu.Lock()
+ e.LockUser()
running := e.workerRunning
- e.mu.Unlock()
+ e.UnlockUser()
if !running {
break
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 4a46f0ec5..c3c692555 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -162,8 +162,8 @@ func (e *endpoint) loadState(state EndpointState) {
connectingLoading.Add(1)
}
// Directly update the state here rather than using e.setEndpointState
- // as the endpoint is still being loaded and the stack reference to increment
- // metrics is not yet initialized.
+ // as the endpoint is still being loaded and the stack reference is not
+ // yet initialized.
atomic.StoreUint32((*uint32)(&e.state), uint32(state))
}
@@ -173,6 +173,9 @@ func (e *endpoint) afterLoad() {
// Restore the endpoint to InitialState as it will be moved to
// its origEndpointState during Resume.
e.state = StateInitial
+ // Condition variables and mutexs are not S/R'ed so reinitialize
+ // acceptCond with e.acceptMu.
+ e.acceptCond = sync.NewCond(&e.acceptMu)
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
@@ -180,7 +183,6 @@ func (e *endpoint) afterLoad() {
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
e.segmentQueue.setLimit(MaxUnprocessedSegments)
- e.workMu.Init()
state := e.origEndpointState
switch state {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index c9ee5bf06..704d01c64 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -61,7 +61,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
s := newSegment(r, id, pkt)
defer s.decRef()
@@ -130,7 +130,7 @@ func (r *ForwarderRequest) Complete(sendReset bool) {
// If the caller requested, send a reset.
if sendReset {
- replyWithReset(r.segment)
+ replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL())
}
// Release all resources.
@@ -157,7 +157,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
TSVal: r.synOptions.TSVal,
TSEcr: r.synOptions.TSEcr,
SACKPermitted: r.synOptions.SACKPermitted,
- }, queue)
+ }, queue, nil)
if err != nil {
return nil, err
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 73098d904..cfd9a4e8e 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -94,8 +94,65 @@ const (
ccCubic = "cubic"
)
+// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The
+// value is protected by a mutex so that we can increment only when it's
+// guaranteed not to go above a threshold.
+type synRcvdCounter struct {
+ sync.Mutex
+ value uint64
+ pending sync.WaitGroup
+ threshold uint64
+}
+
+// inc tries to increment the global number of endpoints in SYN-RCVD state. It
+// succeeds if the increment doesn't make the count go beyond the threshold, and
+// fails otherwise.
+func (s *synRcvdCounter) inc() bool {
+ s.Lock()
+ defer s.Unlock()
+ if s.value >= s.threshold {
+ return false
+ }
+
+ s.pending.Add(1)
+ s.value++
+
+ return true
+}
+
+// dec atomically decrements the global number of endpoints in SYN-RCVD
+// state. It must only be called if a previous call to inc succeeded.
+func (s *synRcvdCounter) dec() {
+ s.Lock()
+ defer s.Unlock()
+ s.value--
+ s.pending.Done()
+}
+
+// synCookiesInUse returns true if the synRcvdCount is greater than
+// SynRcvdCountThreshold.
+func (s *synRcvdCounter) synCookiesInUse() bool {
+ s.Lock()
+ defer s.Unlock()
+ return s.value >= s.threshold
+}
+
+// SetThreshold sets synRcvdCounter.Threshold to ths new threshold.
+func (s *synRcvdCounter) SetThreshold(threshold uint64) {
+ s.Lock()
+ defer s.Unlock()
+ s.threshold = threshold
+}
+
+// Threshold returns the current value of synRcvdCounter.Threhsold.
+func (s *synRcvdCounter) Threshold() uint64 {
+ s.Lock()
+ defer s.Unlock()
+ return s.threshold
+}
+
type protocol struct {
- mu sync.Mutex
+ mu sync.RWMutex
sackEnabled bool
delayEnabled bool
sendBufferSize SendBufferSizeOption
@@ -105,6 +162,8 @@ type protocol struct {
moderateReceiveBuffer bool
tcpLingerTimeout time.Duration
tcpTimeWaitTimeout time.Duration
+ minRTO time.Duration
+ synRcvdCount synRcvdCounter
dispatcher *dispatcher
}
@@ -140,7 +199,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// to a specific processing queue. Each queue is serviced by its own processor
// goroutine which is responsible for dequeuing and doing full TCP dispatch of
// the packet.
-func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
p.dispatcher.queuePacket(r, ep, id, pkt)
}
@@ -151,7 +210,7 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
s := newSegment(r, id, pkt)
defer s.decRef()
@@ -164,12 +223,12 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo
return true
}
- replyWithReset(s)
+ replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
return true
}
// replyWithReset replies to the given segment with a reset segment.
-func replyWithReset(s *segment) {
+func replyWithReset(s *segment, tos, ttl uint8) {
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
ack := seqnum.Value(0)
@@ -191,7 +250,15 @@ func replyWithReset(s *segment) {
flags |= header.TCPFlagAck
ack = s.sequenceNumber.Add(s.logicalLen())
}
- sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
+ sendTCP(&s.route, tcpFields{
+ id: s.id,
+ ttl: ttl,
+ tos: tos,
+ flags: flags,
+ seq: seq,
+ ack: ack,
+ rcvWnd: 0,
+ }, buffer.VectorisedView{}, nil /* gso */, nil /* PacketOwner */)
}
// SetOption implements stack.TransportProtocol.SetOption.
@@ -264,6 +331,21 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case tcpip.TCPMinRTOOption:
+ if v < 0 {
+ v = tcpip.TCPMinRTOOption(MinRTO)
+ }
+ p.mu.Lock()
+ p.minRTO = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPSynRcvdCountThresholdOption:
+ p.mu.Lock()
+ p.synRcvdCount.SetThreshold(uint64(v))
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -273,57 +355,69 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
func (p *protocol) Option(option interface{}) *tcpip.Error {
switch v := option.(type) {
case *SACKEnabled:
- p.mu.Lock()
+ p.mu.RLock()
*v = SACKEnabled(p.sackEnabled)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *DelayEnabled:
- p.mu.Lock()
+ p.mu.RLock()
*v = DelayEnabled(p.delayEnabled)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *SendBufferSizeOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = p.sendBufferSize
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *ReceiveBufferSizeOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = p.recvBufferSize
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.CongestionControlOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.CongestionControlOption(p.congestionControl)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.AvailableCongestionControlOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.ModerateReceiveBufferOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.TCPLingerTimeoutOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.TCPTimeWaitTimeoutOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout)
- p.mu.Unlock()
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPMinRTOOption:
+ p.mu.RLock()
+ *v = tcpip.TCPMinRTOOption(p.minRTO)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPSynRcvdCountThresholdOption:
+ p.mu.RLock()
+ *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold())
+ p.mu.RUnlock()
return nil
default:
@@ -341,6 +435,12 @@ func (p *protocol) Wait() {
p.dispatcher.wait()
}
+// SynRcvdCounter returns a reference to the synRcvdCount for this protocol
+// instance.
+func (p *protocol) SynRcvdCounter() *synRcvdCounter {
+ return &p.synRcvdCount
+}
+
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{
@@ -350,6 +450,8 @@ func NewProtocol() stack.TransportProtocol {
availableCongestionControl: []string{ccReno, ccCubic},
tcpLingerTimeout: DefaultTCPLingerTimeout,
tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
dispatcher: newDispatcher(runtime.GOMAXPROCS(0)),
+ minRTO: MinRTO,
}
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 958f03ac1..caf8977b3 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -168,7 +168,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// We just received a FIN, our next state depends on whether we sent a
// FIN already or not.
- r.ep.mu.Lock()
switch r.ep.EndpointState() {
case StateEstablished:
r.ep.setEndpointState(StateCloseWait)
@@ -183,7 +182,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
case StateFinWait2:
r.ep.setEndpointState(StateTimeWait)
}
- r.ep.mu.Unlock()
// Flush out any pending segments, except the very first one if
// it happens to be the one we're handling now because the
@@ -195,6 +193,10 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
for i := first; i < len(r.pendingRcvdSegments); i++ {
r.pendingRcvdSegments[i].decRef()
+ // Note that slice truncation does not allow garbage collection of
+ // truncated items, thus truncated items must be set to nil to avoid
+ // memory leaks.
+ r.pendingRcvdSegments[i] = nil
}
r.pendingRcvdSegments = r.pendingRcvdSegments[:first]
@@ -204,7 +206,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Handle ACK (not FIN-ACK, which we handled above) during one of the
// shutdown states.
if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
- r.ep.mu.Lock()
switch r.ep.EndpointState() {
case StateFinWait1:
r.ep.setEndpointState(StateFinWait2)
@@ -218,7 +219,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
case StateLastAck:
r.ep.transitionToStateCloseLocked()
}
- r.ep.mu.Unlock()
}
return true
@@ -332,10 +332,8 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// handleRcvdSegment handles TCP segments directed at the connection managed by
// r as they arrive. It is called by the protocol main loop.
func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
- r.ep.mu.RLock()
state := r.ep.EndpointState()
closed := r.ep.closed
- r.ep.mu.RUnlock()
if state != StateEstablished {
drop, err := r.handleRcvdSegmentClosing(s, state, closed)
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 1c10da5ca..40461fd31 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -18,7 +18,6 @@ import (
"sync/atomic"
"time"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -56,12 +55,12 @@ type segment struct {
options []byte `state:".([]byte)"`
hasNewSACKInfo bool
rcvdTime time.Time `state:".(unixTime)"`
- // xmitTime is the last transmit time of this segment. A zero value
- // indicates that the segment has yet to be transmitted.
- xmitTime time.Time `state:".(unixTime)"`
+ // xmitTime is the last transmit time of this segment.
+ xmitTime time.Time `state:".(unixTime)"`
+ xmitCount uint32
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment {
+func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) *segment {
s := &segment{
refCnt: 1,
id: id,
@@ -78,9 +77,11 @@ func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.V
id: id,
route: r.Clone(),
}
- s.views[0] = v
- s.data = buffer.NewVectorisedView(len(v), s.views[:1])
s.rcvdTime = time.Now()
+ if len(v) != 0 {
+ s.views[0] = v
+ s.data = buffer.NewVectorisedView(len(v), s.views[:1])
+ }
return s
}
diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go
index 9fd061d7d..8d3ddce4b 100644
--- a/pkg/tcpip/transport/tcp/segment_heap.go
+++ b/pkg/tcpip/transport/tcp/segment_heap.go
@@ -14,21 +14,25 @@
package tcp
+import "container/heap"
+
type segmentHeap []*segment
+var _ heap.Interface = (*segmentHeap)(nil)
+
// Len returns the length of h.
-func (h segmentHeap) Len() int {
- return len(h)
+func (h *segmentHeap) Len() int {
+ return len(*h)
}
// Less determines whether the i-th element of h is less than the j-th element.
-func (h segmentHeap) Less(i, j int) bool {
- return h[i].sequenceNumber.LessThan(h[j].sequenceNumber)
+func (h *segmentHeap) Less(i, j int) bool {
+ return (*h)[i].sequenceNumber.LessThan((*h)[j].sequenceNumber)
}
// Swap swaps the i-th and j-th elements of h.
-func (h segmentHeap) Swap(i, j int) {
- h[i], h[j] = h[j], h[i]
+func (h *segmentHeap) Swap(i, j int) {
+ (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
}
// Push adds x as the last element of h.
@@ -41,6 +45,7 @@ func (h *segmentHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
+ old[n-1] = nil
*h = old[:n-1]
return x
}
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index bd20a7ee9..48a257137 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -28,10 +28,16 @@ type segmentQueue struct {
used int
}
+// emptyLocked determines if the queue is empty.
+// Preconditions: q.mu must be held.
+func (q *segmentQueue) emptyLocked() bool {
+ return q.used == 0
+}
+
// empty determines if the queue is empty.
func (q *segmentQueue) empty() bool {
q.mu.Lock()
- r := q.used == 0
+ r := q.emptyLocked()
q.mu.Unlock()
return r
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index b74b61e7d..d8cfe3115 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "fmt"
"math"
"sync/atomic"
"time"
@@ -126,10 +127,6 @@ type sender struct {
// sndNxt is the sequence number of the next segment to be sent.
sndNxt seqnum.Value
- // sndNxtList is the sequence number of the next segment to be added to
- // the send list.
- sndNxtList seqnum.Value
-
// rttMeasureSeqNum is the sequence number being used for the latest RTT
// measurement.
rttMeasureSeqNum seqnum.Value
@@ -153,6 +150,9 @@ type sender struct {
rtt rtt
rto time.Duration
+ // minRTO is the minimum permitted value for sender.rto.
+ minRTO time.Duration
+
// maxPayloadSize is the maximum size of the payload of a given segment.
// It is initialized on demand.
maxPayloadSize int
@@ -229,7 +229,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
sndWnd: sndWnd,
sndUna: iss + 1,
sndNxt: iss + 1,
- sndNxtList: iss + 1,
rto: 1 * time.Second,
rttMeasureSeqNum: iss + 1,
lastSendTime: time.Now(),
@@ -265,6 +264,13 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
// etc.
s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss)
+ // Get Stack wide minRTO.
+ var v tcpip.TCPMinRTOOption
+ if err := ep.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
+ panic(fmt.Sprintf("unable to get minRTO from stack: %s", err))
+ }
+ s.minRTO = time.Duration(v)
+
return s
}
@@ -399,8 +405,8 @@ func (s *sender) updateRTO(rtt time.Duration) {
s.rto = s.rtt.srtt + 4*s.rtt.rttvar
s.rtt.Unlock()
- if s.rto < MinRTO {
- s.rto = MinRTO
+ if s.rto < s.minRTO {
+ s.rto = s.minRTO
}
}
@@ -455,9 +461,7 @@ func (s *sender) retransmitTimerExpired() bool {
// Give up if we've waited more than a minute since the last resend or
// if a user time out is set and we have exceeded the user specified
// timeout since the first retransmission.
- s.ep.mu.RLock()
uto := s.ep.userTimeout
- s.ep.mu.RUnlock()
if s.firstRetransmittedSegXmitTime.IsZero() {
// We store the original xmitTime of the segment that we are
@@ -713,7 +717,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
default:
s.ep.setEndpointState(StateFinWait1)
}
-
} else {
// We're sending a non-FIN segment.
if seg.flags&header.TCPFlagFin != 0 {
@@ -1229,7 +1232,7 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// sendSegment sends the specified segment.
func (s *sender) sendSegment(seg *segment) *tcpip.Error {
- if !seg.xmitTime.IsZero() {
+ if seg.xmitCount > 0 {
s.ep.stack.Stats().TCP.Retransmits.Increment()
s.ep.stats.SendErrors.Retransmits.Increment()
if s.sndCwnd < s.sndSsthresh {
@@ -1237,6 +1240,7 @@ func (s *sender) sendSegment(seg *segment) *tcpip.Error {
}
}
seg.xmitTime = time.Now()
+ seg.xmitCount++
return s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 782d7b42c..359a75e73 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/runsc/testutil"
)
func TestFastRecovery(t *testing.T) {
@@ -40,7 +41,7 @@ func TestFastRecovery(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
@@ -86,16 +87,23 @@ func TestFastRecovery(t *testing.T) {
// Receive the retransmitted packet.
c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
- }
+ // Wait before checking metrics.
+ metricPollFn := func() error {
+ if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ }
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ }
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want)
+ }
+ return nil
}
- if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want)
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
// Now send 7 mode duplicate acks. Each of these should cause a window
@@ -117,12 +125,18 @@ func TestFastRecovery(t *testing.T) {
// Receive the retransmit due to partial ack.
c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
- t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ // Wait before checking metrics.
+ metricPollFn = func() error {
+ if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ }
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ }
+ return nil
}
-
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
- t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
// Receive the 10 extra packets that should have been released due to
@@ -192,7 +206,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
@@ -234,7 +248,7 @@ func TestCongestionAvoidance(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
@@ -338,7 +352,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
@@ -447,7 +461,7 @@ func TestRetransmit(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
@@ -492,24 +506,33 @@ func TestRetransmit(t *testing.T) {
rtxOffset := bytesRead - maxPayload*expected
c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
- if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want)
- }
+ metricPollFn := func() error {
+ if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want)
+ }
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
- }
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
+ }
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
- t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want)
- }
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
+ return fmt.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want)
+ }
+
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want)
+ }
+
+ if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
+ }
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want)
+ return nil
}
- if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
+ // Poll when checking metrics.
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
// Acknowledge half of the pending data.
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index afea124ec..1dd63dd61 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -149,21 +149,22 @@ func TestSackPermittedAccept(t *testing.T) {
{true, false, -1, 0xffff}, // When cookie is used window scaling is disabled.
{false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
}
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
+
for _, tc := range testCases {
t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
- if tc.cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
- } else {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }
for _, sackEnabled := range []bool{false, true} {
t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+
+ if tc.cookieEnabled {
+ // Set the SynRcvd threshold to
+ // zero to force a syn cookie
+ // based accept to happen.
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
+ }
setStackSACKPermitted(t, c, sackEnabled)
rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
@@ -222,21 +223,23 @@ func TestSackDisabledAccept(t *testing.T) {
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
{false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
}
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
+
for _, tc := range testCases {
t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
- if tc.cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
- } else {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }
for _, sackEnabled := range []bool{false, true} {
t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+
+ if tc.cookieEnabled {
+ // Set the SynRcvd threshold to
+ // zero to force a syn cookie
+ // based accept to happen.
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
+ }
+
setStackSACKPermitted(t, c, sackEnabled)
rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
@@ -387,7 +390,7 @@ func TestSACKRecovery(t *testing.T) {
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 5b2b16afa..ab1014c7f 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -284,7 +284,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
// are released instantly on Close.
tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil {
- t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err)
+ t.Fatalf("e.stack.SetTransportProtocolOption(%d, %v) = %v", tcp.ProtocolNumber, tcpTW, err)
}
c.EP.Close()
@@ -590,6 +590,10 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
),
)
+ // Give the stack a few ms to transition the endpoint out of ESTABLISHED
+ // state.
+ time.Sleep(10 * time.Millisecond)
+
if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
@@ -728,7 +732,7 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
tests := []struct {
name string
- setMSS uint16
+ setMSS int
expMSS uint16
}{
{
@@ -756,15 +760,14 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
c.Create(-1)
// Set the MSS socket option.
- opt := tcpip.MaxSegOption(test.setMSS)
- if err := c.EP.SetSockOpt(opt); err != nil {
- t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err)
+ if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, test.setMSS); err != nil {
+ t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err)
}
// Get expected window size.
rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
if err != nil {
- t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err)
+ t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
}
ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
@@ -818,15 +821,14 @@ func TestUserSuppliedMSSOnConnectV6(t *testing.T) {
c.CreateV6Endpoint(true)
// Set the MSS socket option.
- opt := tcpip.MaxSegOption(test.setMSS)
- if err := c.EP.SetSockOpt(opt); err != nil {
- t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err)
+ if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
+ t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err)
}
// Get expected window size.
rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
if err != nil {
- t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err)
+ t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
}
ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
@@ -1032,8 +1034,8 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) {
checker.SeqNum(200)))
}
-// TestListenShutdown tests for the listening endpoint not processing
-// any receive when it is on read shutdown.
+// TestListenShutdown tests for the listening endpoint replying with RST
+// on read shutdown.
func TestListenShutdown(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -1044,7 +1046,7 @@ func TestListenShutdown(t *testing.T) {
t.Fatal("Bind failed:", err)
}
- if err := c.EP.Listen(10 /* backlog */); err != nil {
+ if err := c.EP.Listen(1 /* backlog */); err != nil {
t.Fatal("Listen failed:", err)
}
@@ -1052,9 +1054,6 @@ func TestListenShutdown(t *testing.T) {
t.Fatal("Shutdown failed:", err)
}
- // Wait for the endpoint state to be propagated.
- time.Sleep(10 * time.Millisecond)
-
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -1063,7 +1062,49 @@ func TestListenShutdown(t *testing.T) {
AckNum: 200,
})
- c.CheckNoPacket("Packet received when listening socket was shutdown")
+ // Expect the listening endpoint to reset the connection.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ))
+}
+
+// TestListenCloseWhileConnect tests for the listening endpoint to
+// drain the accept-queue when closed. This should reset all of the
+// pending connections that are waiting to be accepted.
+func TestListenCloseWhileConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(1 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventIn)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+ // Wait for the new endpoint created because of handshake to be delivered
+ // to the listening endpoint's accept queue.
+ <-notifyCh
+
+ // Close the listening endpoint.
+ c.EP.Close()
+
+ // Expect the listening endpoint to reset the connection.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ))
}
func TestTOSV4(t *testing.T) {
@@ -1077,17 +1118,17 @@ func TestTOSV4(t *testing.T) {
c.EP = ep
const tos = 0xC0
- if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
- t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+ if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
+ t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err)
}
- var v tcpip.IPv4TOSOption
- if err := c.EP.GetSockOpt(&v); err != nil {
- t.Errorf("GetSockopt failed: %s", err)
+ v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
+ t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err)
}
- if want := tcpip.IPv4TOSOption(tos); v != want {
- t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ if v != tos {
+ t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos)
}
testV4Connect(t, c, checker.TOS(tos, 0))
@@ -1125,17 +1166,17 @@ func TestTrafficClassV6(t *testing.T) {
c.CreateV6Endpoint(false)
const tos = 0xC0
- if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
- t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err)
+ if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil {
+ t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err)
}
- var v tcpip.IPv6TrafficClassOption
- if err := c.EP.GetSockOpt(&v); err != nil {
- t.Fatalf("GetSockopt failed: %s", err)
+ v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
+ t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err)
}
- if want := tcpip.IPv6TrafficClassOption(tos); v != want {
- t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ if v != tos {
+ t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos)
}
// Test the connection request.
@@ -1711,7 +1752,7 @@ func TestNoWindowShrinking(t *testing.T) {
c.CreateConnected(789, 30000, 10)
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
- t.Fatalf("SetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %v", err)
}
we, ch := waiter.NewChannelEntry(nil)
@@ -1984,7 +2025,7 @@ func TestScaledWindowAccept(t *testing.T) {
// Set the window size greater than the maximum non-scaled window.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
@@ -2057,7 +2098,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
// Set the window size greater than the maximum non-scaled window.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
@@ -2221,10 +2262,10 @@ func TestSegmentMerging(t *testing.T) {
{
"cork",
func(ep tcpip.Endpoint) {
- ep.SetSockOpt(tcpip.CorkOption(1))
+ ep.SetSockOptBool(tcpip.CorkOption, true)
},
func(ep tcpip.Endpoint) {
- ep.SetSockOpt(tcpip.CorkOption(0))
+ ep.SetSockOptBool(tcpip.CorkOption, false)
},
},
}
@@ -2236,9 +2277,18 @@ func TestSegmentMerging(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- // Prevent the endpoint from processing packets.
- test.stop(c.EP)
+ // Send tcp.InitialCwnd number of segments to fill up
+ // InitialWindow but don't ACK. That should prevent
+ // anymore packets from going out.
+ for i := 0; i < tcp.InitialCwnd; i++ {
+ view := buffer.NewViewFromBytes([]byte{0})
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write #%d failed: %s", i+1, err)
+ }
+ }
+ // Now send the segments that should get merged as the congestion
+ // window is full and we won't be able to send any more packets.
var allData []byte
for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
allData = append(allData, data...)
@@ -2248,8 +2298,29 @@ func TestSegmentMerging(t *testing.T) {
}
}
- // Let the endpoint process the segments that we just sent.
- test.resume(c.EP)
+ // Check that we get tcp.InitialCwnd packets.
+ for i := 0; i < tcp.InitialCwnd; i++ {
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(header.TCPMinimumSize+1),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+uint32(i)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ }
+
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload.
+ RcvWnd: 30000,
+ })
// Check that data is received.
b := c.GetPacket()
@@ -2257,7 +2328,7 @@ func TestSegmentMerging(t *testing.T) {
checker.PayloadLen(len(allData)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.SeqNum(uint32(c.IRS)+11),
checker.AckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
@@ -2273,7 +2344,7 @@ func TestSegmentMerging(t *testing.T) {
DstPort: c.Port,
Flags: header.TCPFlagAck,
SeqNum: 790,
- AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))),
+ AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))),
RcvWnd: 30000,
})
})
@@ -2286,7 +2357,7 @@ func TestDelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOptInt(tcpip.DelayOption, 1)
+ c.EP.SetSockOptBool(tcpip.DelayOption, true)
var allData []byte
for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
@@ -2334,7 +2405,7 @@ func TestUndelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOptInt(tcpip.DelayOption, 1)
+ c.EP.SetSockOptBool(tcpip.DelayOption, true)
allData := [][]byte{{0}, {1, 2, 3}}
for i, data := range allData {
@@ -2367,7 +2438,7 @@ func TestUndelay(t *testing.T) {
// Check that we don't get the second packet yet.
c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
- c.EP.SetSockOptInt(tcpip.DelayOption, 0)
+ c.EP.SetSockOptBool(tcpip.DelayOption, false)
// Check that data is received.
second := c.GetPacket()
@@ -2404,8 +2475,8 @@ func TestMSSNotDelayed(t *testing.T) {
fn func(tcpip.Endpoint)
}{
{"no-op", func(tcpip.Endpoint) {}},
- {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }},
- {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }},
+ {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }},
+ {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }},
}
for _, test := range tests {
@@ -2546,12 +2617,12 @@ func TestSetTTL(t *testing.T) {
t.Fatalf("NewEndpoint failed: %v", err)
}
- if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
- t.Fatalf("SetSockOpt failed: %v", err)
+ if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
+ t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
}
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ t.Fatalf("Unexpected return value from Connect: %s", err)
}
// Receive SYN packet.
@@ -2591,7 +2662,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// window scaling option.
const rcvBufferSize = 0x20000
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
@@ -2637,26 +2708,24 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
// Set the SynRcvd threshold to zero to force a syn cookie based accept
// to happen.
- saved := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = saved
- }()
- tcp.SynRcvdCountThreshold = 0
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
// Create EP and start listening.
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake.
@@ -2674,7 +2743,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2735,7 +2804,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
const rcvBufferSize = 0x20000
const wndScale = 2
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
}
// Start connection attempt.
@@ -3852,26 +3921,26 @@ func TestMinMaxBufferSizes(t *testing.T) {
// Set values below the min.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err)
}
checkRecvBufferSize(t, ep, 200)
if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err)
}
checkSendBufferSize(t, ep, 300)
// Set values above the max.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
}
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
@@ -4117,11 +4186,11 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
case "ipv4":
case "ipv6":
if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err)
+ t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err)
}
case "dual":
if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil {
- t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err)
+ t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err)
}
default:
t.Fatalf("unknown network: '%s'", network)
@@ -4444,11 +4513,11 @@ func TestKeepalive(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const keepAliveInterval = 10 * time.Millisecond
- c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
+ const keepAliveInterval = 3 * time.Second
+ c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond))
c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
- c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5))
- c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1))
+ c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5)
+ c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true)
// 5 unacked keepalives are sent. ACK each one, and check that the
// connection stays alive after 5.
@@ -4539,7 +4608,7 @@ func TestKeepalive(t *testing.T) {
// Sleep for a litte over the KeepAlive interval to make sure
// the timer has time to fire after the last ACK and close the
// close the socket.
- time.Sleep(keepAliveInterval + 5*time.Millisecond)
+ time.Sleep(keepAliveInterval + keepAliveInterval/2)
// The connection should be terminated after 5 unacked keepalives.
// Send an ACK to trigger a RST from the stack as the endpoint should
@@ -5074,25 +5143,23 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
}
func TestListenBacklogFullSynCookieInUse(t *testing.T) {
- saved := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = saved
- }()
- tcp.SynRcvdCountThreshold = 1
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err)
+ }
+
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -5100,7 +5167,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
listenBacklog := 1
portOffset := uint16(0)
if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
executeHandshake(t, c, context.TestPort+portOffset, false)
@@ -5579,7 +5646,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
return
}
if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) {
- t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w, wantRcvWnd)
+ t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w)
}
},
))
@@ -5740,14 +5807,14 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
func TestDelayEnabled(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- checkDelayOption(t, c, false, 0) // Delay is disabled by default.
+ checkDelayOption(t, c, false, false) // Delay is disabled by default.
for _, v := range []struct {
delayEnabled tcp.DelayEnabled
- wantDelayOption int
+ wantDelayOption bool
}{
- {delayEnabled: false, wantDelayOption: 0},
- {delayEnabled: true, wantDelayOption: 1},
+ {delayEnabled: false, wantDelayOption: false},
+ {delayEnabled: true, wantDelayOption: true},
} {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -5758,7 +5825,7 @@ func TestDelayEnabled(t *testing.T) {
}
}
-func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) {
+func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) {
t.Helper()
var gotDelayEnabled tcp.DelayEnabled
@@ -5773,12 +5840,12 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del
if err != nil {
t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err)
}
- gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption)
+ gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption)
if err != nil {
- t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err)
+ t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err)
}
if gotDelayOption != wantDelayOption {
- t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption)
+ t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption)
}
}
@@ -6587,14 +6654,17 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
- const keepAliveInterval = 10 * time.Millisecond
- c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
+ const keepAliveInterval = 3 * time.Second
+ c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond))
c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
- c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10))
- c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1))
-
- // Set userTimeout to be the duration for 3 keepalive probes.
- userTimeout := 30 * time.Millisecond
+ c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10)
+ c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true)
+
+ // Set userTimeout to be the duration to be 1 keepalive
+ // probes. Which means that after the first probe is sent
+ // the second one should cause the connection to be
+ // closed due to userTimeout being hit.
+ userTimeout := 1 * keepAliveInterval
c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
// Check that the connection is still alive.
@@ -6602,28 +6672,23 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
}
- // Now receive 2 keepalives, but don't ACK them. The connection should
- // be reset when the 3rd one should be sent due to userTimeout being
- // 30ms and each keepalive probe should be sent 10ms apart as set above after
- // the connection has been idle for 10ms.
- for i := 0; i < 2; i++ {
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(uint32(790)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
+ // Now receive 1 keepalives, but don't ACK it.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)),
+ checker.AckNum(uint32(790)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
// Sleep for a litte over the KeepAlive interval to make sure
// the timer has time to fire after the last ACK and close the
// close the socket.
- time.Sleep(keepAliveInterval + 5*time.Millisecond)
+ time.Sleep(keepAliveInterval + keepAliveInterval/2)
- // The connection should be terminated after 30ms.
+ // The connection should be closed with a timeout.
// Send an ACK to trigger a RST from the stack as the endpoint should
// be dead.
c.SendPacket(nil, &context.Headers{
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index a641e953d..8edbff964 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -127,16 +127,14 @@ func TestTimeStampDisabledConnect(t *testing.T) {
}
func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
if cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
}
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
tsVal := rand.Uint32()
@@ -148,7 +146,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ t.Fatalf("Unexpected error from Write: %s", err)
}
// Check that data is received and that the timestamp option TSEcr field
@@ -190,17 +188,15 @@ func TestTimeStampEnabledAccept(t *testing.T) {
}
func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
- if cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
- }
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ if cookieEnabled {
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
+ }
+
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
@@ -211,7 +207,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ t.Fatalf("Unexpected error from Write: %s", err)
}
// Check that data is received and that the timestamp option is disabled
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 1e9a0dea3..7b1d72cf4 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -152,6 +152,13 @@ func New(t *testing.T, mtu uint32) *Context {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
+ // Increase minimum RTO in tests to avoid test flakes due to early
+ // retransmit in case the test executors are overloaded and cause timers
+ // to fire earlier than expected.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil {
+ t.Fatalf("failed to set stack-wide minRTO: %s", err)
+ }
+
// Some of the congestion control tests send up to 640 packets, we so
// set the channel size to 1000.
ep := channel.New(1000, mtu, "")
@@ -204,6 +211,7 @@ func (c *Context) Cleanup() {
if c.EP != nil {
c.EP.Close()
}
+ c.Stack().Close()
}
// Stack returns a reference to the stack in the Context.
@@ -216,7 +224,8 @@ func (c *Context) Stack() *stack.Stack {
func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
c.t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), wait)
+ ctx, cancel := context.WithTimeout(context.Background(), wait)
+ defer cancel()
if _, ok := c.linkEP.ReadContext(ctx); ok {
c.t.Fatal(errMsg)
}
@@ -234,7 +243,8 @@ func (c *Context) CheckNoPacket(errMsg string) {
func (c *Context) GetPacket() []byte {
c.t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
p, ok := c.linkEP.ReadContext(ctx)
if !ok {
c.t.Fatalf("Packet wasn't written out")
@@ -306,7 +316,7 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -362,7 +372,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp
// SendSegment sends a TCP segment that has already been built and written to a
// buffer.VectorisedView.
func (c *Context) SendSegment(s buffer.VectorisedView) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
Data: s,
})
}
@@ -370,7 +380,7 @@ func (c *Context) SendSegment(s buffer.VectorisedView) {
// SendPacket builds and sends a TCP segment(with the provided payload & TCP
// headers) in an IPv4 packet via the link layer endpoint.
func (c *Context) SendPacket(payload []byte, h *Headers) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
Data: c.BuildSegment(payload, h),
})
}
@@ -379,7 +389,7 @@ func (c *Context) SendPacket(payload []byte, h *Headers) {
// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
// provided source and destination IPv4 addresses.
func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
})
}
@@ -414,6 +424,8 @@ func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlock
// verifies that the packet packet payload of packet matches the slice
// of data indicated by offset & size.
func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
+ c.t.Helper()
+
c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0)
}
@@ -422,6 +434,8 @@ func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
// data indicated by offset & size and skips optlen bytes in addition to the IP
// TCP headers when comparing the data.
func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) {
+ c.t.Helper()
+
b := c.GetPacket()
checker.IPv4(c.t, b,
checker.PayloadLen(size+header.TCPMinimumSize+optlen),
@@ -444,6 +458,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op
// data indicated by offset & size. It returns true if a packet was received and
// processed.
func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool {
+ c.t.Helper()
+
b := c.GetPacketNonBlocking()
if b == nil {
return false
@@ -485,7 +501,8 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
func (c *Context) GetV6Packet() []byte {
c.t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
p, ok := c.linkEP.ReadContext(ctx)
if !ok {
c.t.Fatalf("Packet wasn't written out")
@@ -547,7 +564,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp
t.SetChecksum(^t.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -566,6 +583,8 @@ func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf
//
// PreCondition: c.EP must already be created.
func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) {
+ c.t.Helper()
+
// Start connection attempt.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index adc908e24..b5d2d0ba6 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -32,7 +32,6 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 1c6a600b8..edb54f0be 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
@@ -144,6 +143,9 @@ type endpoint struct {
// TODO(b/142022063): Add ability to save and restore per endpoint stats.
stats tcpip.TransportEndpointStats `state:"nosave"`
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
}
// +stateify savable
@@ -234,7 +236,7 @@ func (e *endpoint) Close() {
func (e *endpoint) ModerateRecvBuf(copied int) {}
// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (iptables.IPTables, error) {
+func (e *endpoint) IPTables() (stack.IPTables, error) {
return e.stack.IPTables(), nil
}
@@ -443,19 +445,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrBroadcastDisabled
}
- netProto, err := e.checkV4Mapped(to)
+ dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
}
- r, _, err := e.connectRoute(nicID, *to, netProto)
+ r, _, err := e.connectRoute(nicID, dst, netProto)
if err != nil {
return 0, nil, err
}
defer r.Release()
route = &r
- dstPort = to.Port
+ dstPort = dst.Port
}
if route.IsResolutionRequired() {
@@ -485,7 +487,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil {
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
@@ -499,11 +501,20 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
switch opt {
+ case tcpip.BroadcastOption:
+ e.mu.Lock()
+ e.broadcast = v
+ e.mu.Unlock()
+
+ case tcpip.MulticastLoopOption:
+ e.mu.Lock()
+ e.multicastLoop = v
+ e.mu.Unlock()
+
case tcpip.ReceiveTOSOption:
e.mu.Lock()
e.receiveTOS = v
e.mu.Unlock()
- return nil
case tcpip.ReceiveTClassOption:
// We only support this option on v6 endpoints.
@@ -514,7 +525,18 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.mu.Lock()
e.receiveTClass = v
e.mu.Unlock()
- return nil
+
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.Lock()
+ e.receiveIPPacketInfo = v
+ e.mu.Unlock()
+
+ case tcpip.ReuseAddressOption:
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v
+ e.mu.Unlock()
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
@@ -531,13 +553,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
}
e.v6only = v
- return nil
-
- case tcpip.ReceiveIPPacketInfoOption:
- e.mu.Lock()
- e.receiveIPPacketInfo = v
- e.mu.Unlock()
- return nil
}
return nil
@@ -545,28 +560,44 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
- return nil
-}
+ switch opt {
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ e.multicastTTL = uint8(v)
+ e.mu.Unlock()
-// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
case tcpip.TTLOption:
e.mu.Lock()
e.ttl = uint8(v)
e.mu.Unlock()
- case tcpip.MulticastTTLOption:
+ case tcpip.IPv4TOSOption:
e.mu.Lock()
- e.multicastTTL = uint8(v)
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
e.mu.Unlock()
+ case tcpip.ReceiveBufferSizeOption:
+ case tcpip.SendBufferSizeOption:
+
+ }
+
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
case tcpip.MulticastInterfaceOption:
e.mu.Lock()
defer e.mu.Unlock()
fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- netProto, err := e.checkV4Mapped(&fa)
+ fa, netProto, err := e.checkV4MappedLocked(fa)
if err != nil {
return err
}
@@ -684,16 +715,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
- case tcpip.MulticastLoopOption:
- e.mu.Lock()
- e.multicastLoop = bool(v)
- e.mu.Unlock()
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.reusePort = v != 0
- e.mu.Unlock()
-
case tcpip.BindToDeviceOption:
id := tcpip.NICID(v)
if id != 0 && !e.stack.HasNIC(id) {
@@ -702,26 +723,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Lock()
e.bindToDevice = id
e.mu.Unlock()
- return nil
-
- case tcpip.BroadcastOption:
- e.mu.Lock()
- e.broadcast = v != 0
- e.mu.Unlock()
-
- return nil
-
- case tcpip.IPv4TOSOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
- return nil
-
- case tcpip.IPv6TrafficClassOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
- return nil
}
return nil
}
@@ -729,6 +730,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
+ case tcpip.BroadcastOption:
+ e.mu.RLock()
+ v := e.broadcast
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ case tcpip.MulticastLoopOption:
+ e.mu.RLock()
+ v := e.multicastLoop
+ e.mu.RUnlock()
+ return v, nil
+
case tcpip.ReceiveTOSOption:
e.mu.RLock()
v := e.receiveTOS
@@ -746,6 +762,22 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.RLock()
+ v := e.receiveIPPacketInfo
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.ReuseAddressOption:
+ return false, nil
+
+ case tcpip.ReusePortOption:
+ e.mu.RLock()
+ v := e.reusePort
+ e.mu.RUnlock()
+
+ return v, nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -758,19 +790,32 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return v, nil
- case tcpip.ReceiveIPPacketInfoOption:
- e.mu.RLock()
- v := e.receiveIPPacketInfo
- e.mu.RUnlock()
- return v, nil
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
}
-
- return false, tcpip.ErrUnknownProtocolOption
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
+ case tcpip.IPv4TOSOption:
+ e.mu.RLock()
+ v := int(e.sendTOS)
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ v := int(e.sendTOS)
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ v := int(e.multicastTTL)
+ e.mu.Unlock()
+ return v, nil
+
case tcpip.ReceiveQueueSizeOption:
v := 0
e.rcvMu.Lock()
@@ -792,29 +837,22 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
v := e.rcvBufSizeMax
e.rcvMu.Unlock()
return v, nil
- }
- return -1, tcpip.ErrUnknownProtocolOption
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ v := int(e.ttl)
+ e.mu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
- return nil
-
- case *tcpip.TTLOption:
- e.mu.Lock()
- *o = tcpip.TTLOption(e.ttl)
- e.mu.Unlock()
- return nil
-
- case *tcpip.MulticastTTLOption:
- e.mu.Lock()
- *o = tcpip.MulticastTTLOption(e.multicastTTL)
- e.mu.Unlock()
- return nil
-
case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
*o = tcpip.MulticastInterfaceOption{
@@ -822,72 +860,21 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.multicastAddr,
}
e.mu.Unlock()
- return nil
-
- case *tcpip.MulticastLoopOption:
- e.mu.RLock()
- v := e.multicastLoop
- e.mu.RUnlock()
-
- *o = tcpip.MulticastLoopOption(v)
- return nil
-
- case *tcpip.ReuseAddressOption:
- *o = 0
- return nil
-
- case *tcpip.ReusePortOption:
- e.mu.RLock()
- v := e.reusePort
- e.mu.RUnlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
case *tcpip.BindToDeviceOption:
e.mu.RLock()
*o = tcpip.BindToDeviceOption(e.bindToDevice)
e.mu.RUnlock()
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
- case *tcpip.BroadcastOption:
- e.mu.RLock()
- v := e.broadcast
- e.mu.RUnlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
- case *tcpip.IPv4TOSOption:
- e.mu.RLock()
- *o = tcpip.IPv4TOSOption(e.sendTOS)
- e.mu.RUnlock()
- return nil
-
- case *tcpip.IPv6TrafficClassOption:
- e.mu.RLock()
- *o = tcpip.IPv6TrafficClassOption(e.sendTOS)
- e.mu.RUnlock()
- return nil
default:
return tcpip.ErrUnknownProtocolOption
}
+ return nil
}
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
@@ -913,10 +900,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
if useDefaultTTL {
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, stack.PacketBuffer{
Header: hdr,
Data: data,
TransportHeader: buffer.View(udp),
+ Owner: owner,
}); err != nil {
r.Stats().UDP.PacketSendErrors.Increment()
return err
@@ -927,13 +915,14 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
return nil
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only)
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
if err != nil {
- return 0, err
+ return tcpip.FullAddress{}, 0, err
}
- *addr = unwrapped
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -981,10 +970,6 @@ func (e *endpoint) Disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- netProto, err := e.checkV4Mapped(&addr)
- if err != nil {
- return err
- }
if addr.Port == 0 {
// We don't support connecting to port zero.
return tcpip.ErrInvalidEndpointState
@@ -1012,6 +997,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
r, nicID, err := e.connectRoute(nicID, addr, netProto)
if err != nil {
return err
@@ -1139,7 +1129,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -1258,7 +1248,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
// Get the header then trim it from the view.
hdr := header.UDP(pkt.Data.First())
if int(hdr.Length()) > pkt.Data.Size() {
@@ -1325,7 +1315,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
}
// State implements tcpip.Endpoint.State.
@@ -1355,3 +1345,7 @@ func (*endpoint) Wait() {}
func isBroadcastOrMulticast(a tcpip.Address) bool {
return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
}
+
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 43fb047ed..466bd9381 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -69,6 +69,9 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
e.stack = s
for _, m := range e.multicastMemberships {
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index fc706ede2..a674ceb68 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -43,7 +43,7 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
f.handler(&ForwarderRequest{
stack: f.stack,
route: r,
@@ -61,7 +61,7 @@ type ForwarderRequest struct {
stack *stack.Stack
route *stack.Route
id stack.TransportEndpointID
- pkt tcpip.PacketBuffer
+ pkt stack.PacketBuffer
}
// ID returns the 4-tuple (src address, src port, dst address, dst port) that
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 8df089d22..6e31a9bac 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -66,7 +66,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
// Get the header then trim it from the view.
hdr := header.UDP(pkt.Data.First())
if int(hdr.Length()) > pkt.Data.Size() {
@@ -135,7 +135,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv4DstUnreachable)
pkt.SetCode(header.ICMPv4PortUnreachable)
pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload,
})
@@ -172,7 +172,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv6DstUnreachable)
pkt.SetCode(header.ICMPv6PortUnreachable)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload,
})
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 34b7c2360..8acaa607a 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -343,11 +343,11 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
c.createEndpoint(flow.sockProto())
if flow.isV6Only() {
if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ c.t.Fatalf("SetSockOptBool failed: %s", err)
}
} else if flow.isBroadcast() {
- if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil {
- c.t.Fatal("SetSockOpt failed:", err)
+ if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ c.t.Fatalf("SetSockOptBool failed: %s", err)
}
}
}
@@ -358,7 +358,8 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
c.t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
p, ok := c.linkEP.ReadContext(ctx)
if !ok {
c.t.Fatalf("Packet wasn't written out")
@@ -439,7 +440,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
NetworkHeader: buffer.View(ip),
TransportHeader: buffer.View(u),
@@ -486,7 +487,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
// Inject packet.
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
NetworkHeader: buffer.View(ip),
TransportHeader: buffer.View(u),
@@ -607,7 +608,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
// Check the peer address.
h := flow.header4Tuple(incoming)
if addr.Addr != h.srcAddr.Addr {
- c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr)
+ c.t.Fatalf("unexpected remote address: got %s, want %v", addr.Addr, h.srcAddr)
}
// Check the payload.
@@ -1271,8 +1272,8 @@ func TestTTL(t *testing.T) {
c.createEndpointForFlow(flow)
const multicastTTL = 42
- if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil {
+ c.t.Fatalf("SetSockOptInt failed: %s", err)
}
var wantTTL uint8
@@ -1311,8 +1312,8 @@ func TestSetTTL(t *testing.T) {
c.createEndpointForFlow(flow)
- if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
+ c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
}
var p stack.NetworkProtocol
@@ -1346,25 +1347,26 @@ func TestSetTOS(t *testing.T) {
c.createEndpointForFlow(flow)
const tos = testTOS
- var v tcpip.IPv4TOSOption
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
+ v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0)
+ c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
}
- if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
- c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err)
+ if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
+ c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err)
}
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
+ v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
}
- if want := tcpip.IPv4TOSOption(tos); v != want {
- c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want)
+ if v != tos {
+ c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos)
}
testWrite(c, flow, checker.TOS(tos, 0))
@@ -1381,25 +1383,26 @@ func TestSetTClass(t *testing.T) {
c.createEndpointForFlow(flow)
const tClass = testTOS
- var v tcpip.IPv6TrafficClassOption
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
+ v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0)
+ c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0)
}
- if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil {
- c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err)
+ if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil {
+ c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err)
}
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
+ v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
}
- if want := tcpip.IPv6TrafficClassOption(tClass); v != want {
- c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want)
+ if v != tClass {
+ c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass)
}
// The header getter for TClass is called TOS, so use that checker.
@@ -1430,7 +1433,7 @@ func TestReceiveTosTClass(t *testing.T) {
// Verify that setting and reading the option works.
v, err := c.ep.GetSockOptBool(option)
if err != nil {
- c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err)
+ c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
}
// Test for expected default value.
if v != false {
@@ -1444,7 +1447,7 @@ func TestReceiveTosTClass(t *testing.T) {
got, err := c.ep.GetSockOptBool(option)
if err != nil {
- c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err)
+ c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
}
if got != want {
@@ -1563,7 +1566,8 @@ func TestV4UnknownDestination(t *testing.T) {
}
c.injectPacket(tc.flow, payload)
if !tc.icmpRequired {
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
if p, ok := c.linkEP.ReadContext(ctx); ok {
t.Fatalf("unexpected packet received: %+v", p)
}
@@ -1571,7 +1575,8 @@ func TestV4UnknownDestination(t *testing.T) {
}
// ICMP required.
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
p, ok := c.linkEP.ReadContext(ctx)
if !ok {
t.Fatalf("packet wasn't written out")
@@ -1639,7 +1644,8 @@ func TestV6UnknownDestination(t *testing.T) {
}
c.injectPacket(tc.flow, payload)
if !tc.icmpRequired {
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
if p, ok := c.linkEP.ReadContext(ctx); ok {
t.Fatalf("unexpected packet received: %+v", p)
}
@@ -1647,7 +1653,8 @@ func TestV6UnknownDestination(t *testing.T) {
}
// ICMP required.
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
p, ok := c.linkEP.ReadContext(ctx)
if !ok {
t.Fatalf("packet wasn't written out")
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index d2f4403b0..cd6a0ea6b 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -29,9 +29,6 @@ import (
)
// IO provides access to the contents of a virtual memory space.
-//
-// FIXME(b/38173783): Implementations of IO cannot expect ctx to contain any
-// meaningful data.
type IO interface {
// CopyOut copies len(src) bytes from src to the memory mapped at addr. It
// returns the number of bytes copied. If the number of bytes copied is <
diff --git a/pkg/usermem/usermem_x86.go b/pkg/usermem/usermem_x86.go
index 8059b72d2..d96f829fb 100644
--- a/pkg/usermem/usermem_x86.go
+++ b/pkg/usermem/usermem_x86.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 i386
+// +build amd64 386
package usermem
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 26f68fe3d..5451f1eba 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -21,6 +21,7 @@ go_library(
"network.go",
"strace.go",
"user.go",
+ "vfs.go",
],
visibility = [
"//runsc:__subpackages__",
@@ -33,6 +34,7 @@ go_library(
"//pkg/control/server",
"//pkg/cpuid",
"//pkg/eventchannel",
+ "//pkg/fspath",
"//pkg/log",
"//pkg/memutil",
"//pkg/rand",
@@ -40,6 +42,7 @@ go_library(
"//pkg/sentry/arch",
"//pkg/sentry/arch:registers_go_proto",
"//pkg/sentry/control",
+ "//pkg/sentry/devices/memdev",
"//pkg/sentry/fs",
"//pkg/sentry/fs/dev",
"//pkg/sentry/fs/gofer",
@@ -49,6 +52,12 @@ go_library(
"//pkg/sentry/fs/sys",
"//pkg/sentry/fs/tmpfs",
"//pkg/sentry/fs/tty",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/fsimpl/gofer",
+ "//pkg/sentry/fsimpl/host",
+ "//pkg/sentry/fsimpl/proc",
+ "//pkg/sentry/fsimpl/sys",
+ "//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel:uncaught_signal_go_proto",
@@ -71,6 +80,7 @@ go_library(
"//pkg/sentry/time",
"//pkg/sentry/unimpl:unimplemented_syscall_go_proto",
"//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
"//pkg/sentry/watchdog",
"//pkg/sync",
"//pkg/syserror",
@@ -114,10 +124,12 @@ go_test(
"//pkg/p9",
"//pkg/sentry/contexttest",
"//pkg/sentry/fs",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sync",
"//pkg/unet",
"//runsc/fsgofer",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go
index 8995d678e..b7cfb35bf 100644
--- a/runsc/boot/compat.go
+++ b/runsc/boot/compat.go
@@ -65,7 +65,7 @@ func newCompatEmitter(logFD int) (*compatEmitter, error) {
if logFD > 0 {
f := os.NewFile(uintptr(logFD), "user log file")
- target := &log.MultiEmitter{c.sink, &log.K8sJSONEmitter{log.Writer{Next: f}}}
+ target := &log.MultiEmitter{c.sink, log.K8sJSONEmitter{&log.Writer{Next: f}}}
c.sink = &log.BasicLogger{Level: log.Info, Emitter: target}
}
return c, nil
diff --git a/runsc/boot/config.go b/runsc/boot/config.go
index 35391030f..715a19112 100644
--- a/runsc/boot/config.go
+++ b/runsc/boot/config.go
@@ -158,6 +158,9 @@ type Config struct {
// DebugLog is the path to log debug information to, if not empty.
DebugLog string
+ // PanicLog is the path to log GO's runtime messages, if not empty.
+ PanicLog string
+
// DebugLogFormat is the log format for debug.
DebugLogFormat string
@@ -269,6 +272,7 @@ func (c *Config) ToFlags() []string {
"--log=" + c.LogFilename,
"--log-format=" + c.LogFormat,
"--debug-log=" + c.DebugLog,
+ "--panic-log=" + c.PanicLog,
"--debug-log-format=" + c.DebugLogFormat,
"--file-access=" + c.FileAccess.String(),
"--overlay=" + strconv.FormatBool(c.Overlay),
@@ -301,5 +305,10 @@ func (c *Config) ToFlags() []string {
if len(c.TestOnlyTestNameEnv) != 0 {
f = append(f, "--TESTONLY-test-name-env="+c.TestOnlyTestNameEnv)
}
+
+ if c.VFS2 {
+ f = append(f, "--vfs2=true")
+ }
+
return f
}
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 17e774e0c..8125d5061 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -101,11 +101,14 @@ const (
// Profiling related commands (see pprof.go for more details).
const (
- StartCPUProfile = "Profile.StartCPUProfile"
- StopCPUProfile = "Profile.StopCPUProfile"
- HeapProfile = "Profile.HeapProfile"
- StartTrace = "Profile.StartTrace"
- StopTrace = "Profile.StopTrace"
+ StartCPUProfile = "Profile.StartCPUProfile"
+ StopCPUProfile = "Profile.StopCPUProfile"
+ HeapProfile = "Profile.HeapProfile"
+ GoroutineProfile = "Profile.GoroutineProfile"
+ BlockProfile = "Profile.BlockProfile"
+ MutexProfile = "Profile.MutexProfile"
+ StartTrace = "Profile.StartTrace"
+ StopTrace = "Profile.StopTrace"
)
// Logging related commands (see logging.go for more details).
diff --git a/runsc/boot/fds.go b/runsc/boot/fds.go
index 417d2d5fb..7e49f6f9f 100644
--- a/runsc/boot/fds.go
+++ b/runsc/boot/fds.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/host"
+ vfshost "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
"gvisor.dev/gvisor/pkg/sentry/kernel"
)
@@ -31,10 +32,13 @@ func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.F
return nil, fmt.Errorf("stdioFDs should contain exactly 3 FDs (stdin, stdout, and stderr), but %d FDs received", len(stdioFDs))
}
+ if kernel.VFS2Enabled {
+ return createFDTableVFS2(ctx, console, stdioFDs)
+ }
+
k := kernel.KernelFromContext(ctx)
fdTable := k.NewFDTable()
defer fdTable.DecRef()
- mounter := fs.FileOwnerFromContext(ctx)
var ttyFile *fs.File
for appFD, hostFD := range stdioFDs {
@@ -44,7 +48,7 @@ func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.F
// Import the file as a host TTY file.
if ttyFile == nil {
var err error
- appFile, err = host.ImportFile(ctx, hostFD, mounter, true /* isTTY */)
+ appFile, err = host.ImportFile(ctx, hostFD, true /* isTTY */)
if err != nil {
return nil, err
}
@@ -63,7 +67,7 @@ func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.F
} else {
// Import the file as a regular host file.
var err error
- appFile, err = host.ImportFile(ctx, hostFD, mounter, false /* isTTY */)
+ appFile, err = host.ImportFile(ctx, hostFD, false /* isTTY */)
if err != nil {
return nil, err
}
@@ -79,3 +83,31 @@ func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.F
fdTable.IncRef()
return fdTable, nil
}
+
+func createFDTableVFS2(ctx context.Context, console bool, stdioFDs []int) (*kernel.FDTable, error) {
+ k := kernel.KernelFromContext(ctx)
+ fdTable := k.NewFDTable()
+ defer fdTable.DecRef()
+
+ hostMount, err := vfshost.NewMount(k.VFS())
+ if err != nil {
+ return nil, fmt.Errorf("creating host mount: %w", err)
+ }
+
+ for appFD, hostFD := range stdioFDs {
+ // TODO(gvisor.dev/issue/1482): Add TTY support.
+ appFile, err := vfshost.ImportFD(hostMount, hostFD, false)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := fdTable.NewFDAtVFS2(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil {
+ appFile.DecRef()
+ return nil, err
+ }
+ appFile.DecRef()
+ }
+
+ fdTable.IncRef()
+ return fdTable, nil
+}
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index a4627905e..1828d116a 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -44,7 +44,7 @@ var allowedSyscalls = seccomp.SyscallRules{
{
seccomp.AllowAny{},
seccomp.AllowAny{},
- seccomp.AllowValue(0),
+ seccomp.AllowValue(syscall.O_CLOEXEC),
},
},
syscall.SYS_EPOLL_CREATE1: {},
@@ -284,12 +284,21 @@ var allowedSyscalls = seccomp.SyscallRules{
{seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)},
},
syscall.SYS_SIGALTSTACK: {},
+ unix.SYS_STATX: {},
syscall.SYS_SYNC_FILE_RANGE: {},
syscall.SYS_TGKILL: []seccomp.Rule{
{
seccomp.AllowValue(uint64(os.Getpid())),
},
},
+ syscall.SYS_UTIMENSAT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0), /* null pathname */
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0), /* flags */
+ },
+ },
syscall.SYS_WRITE: {},
// The only user in rawfile.NonBlockingWrite3 always passes iovcnt with
// values 2 or 3. Three iovec-s are passed, when the PACKET_VNET_HDR
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index 0f62842ea..98cce60af 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -278,6 +278,9 @@ func subtargets(root string, mnts []specs.Mount) []string {
}
func setupContainerFS(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error {
+ if conf.VFS2 {
+ return setupContainerVFS2(ctx, conf, mntr, procArgs)
+ }
mns, err := mntr.setupFS(conf, procArgs)
if err != nil {
return err
@@ -573,6 +576,9 @@ func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hin
// should be mounted (e.g. a volume shared between containers). It must be
// called for the root container only.
func (c *containerMounter) processHints(conf *Config) error {
+ if conf.VFS2 {
+ return nil
+ }
ctx := c.k.SupervisorContext()
for _, hint := range c.hints.mounts {
// TODO(b/142076984): Only support tmpfs for now. Bind mounts require a
@@ -781,9 +787,6 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) (
useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
default:
- // TODO(nlacasse): Support all the mount types and make this a fatal error.
- // Most applications will "just work" without them, so this is a warning
- // for now.
log.Warningf("ignoring unknown filesystem type %q", m.Type)
}
return fsName, opts, useOverlay, nil
@@ -824,7 +827,20 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns
inode, err := filesystem.Mount(ctx, mountDevice(m), mf, strings.Join(opts, ","), nil)
if err != nil {
- return fmt.Errorf("creating mount with source %q: %v", m.Source, err)
+ err := fmt.Errorf("creating mount with source %q: %v", m.Source, err)
+ // Check to see if this is a common error due to a Linux bug.
+ // This error is generated here in order to cause it to be
+ // printed to the user using Docker via 'runsc create' etc. rather
+ // than simply printed to the logs for the 'runsc boot' command.
+ //
+ // We check the error message string rather than type because the
+ // actual error types (syscall.EIO, syscall.EPIPE) are lost by file system
+ // implementation (e.g. p9).
+ // TODO(gvisor.dev/issue/1765): Remove message when bug is resolved.
+ if strings.Contains(err.Error(), syscall.EIO.Error()) || strings.Contains(err.Error(), syscall.EPIPE.Error()) {
+ return fmt.Errorf("%v: %s", err, specutils.FaqErrorMsg("memlock", "you may be encountering a Linux kernel bug"))
+ }
+ return err
}
// If there are submounts, we need to overlay the mount on top of a ramfs
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index e7ca98134..cf1f47bc7 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -26,7 +26,6 @@ import (
specs "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
@@ -73,6 +72,8 @@ import (
_ "gvisor.dev/gvisor/pkg/sentry/socket/unix"
)
+var syscallTable *kernel.SyscallTable
+
// Loader keeps state needed to start the kernel and run the container..
type Loader struct {
// k is the kernel.
@@ -156,13 +157,17 @@ type Args struct {
Spec *specs.Spec
// Conf is the system configuration.
Conf *Config
- // ControllerFD is the FD to the URPC controller.
+ // ControllerFD is the FD to the URPC controller. The Loader takes ownership
+ // of this FD and may close it at any time.
ControllerFD int
- // Device is an optional argument that is passed to the platform.
+ // Device is an optional argument that is passed to the platform. The Loader
+ // takes ownership of this file and may close it at any time.
Device *os.File
- // GoferFDs is an array of FDs used to connect with the Gofer.
+ // GoferFDs is an array of FDs used to connect with the Gofer. The Loader
+ // takes ownership of these FDs and may close them at any time.
GoferFDs []int
- // StdioFDs is the stdio for the application.
+ // StdioFDs is the stdio for the application. The Loader takes ownership of
+ // these FDs and may close them at any time.
StdioFDs []int
// Console is set to true if using TTY.
Console bool
@@ -175,6 +180,9 @@ type Args struct {
UserLogFD int
}
+// make sure stdioFDs are always the same on initial start and on restore
+const startingStdioFD = 64
+
// New initializes a new kernel loader configured by spec.
// New also handles setting up a kernel for restoring a container.
func New(args Args) (*Loader, error) {
@@ -188,13 +196,14 @@ func New(args Args) (*Loader, error) {
return nil, fmt.Errorf("setting up memory usage: %v", err)
}
- if args.Conf.VFS2 {
- st, ok := kernel.LookupSyscallTable(abi.Linux, arch.Host)
- if ok {
- vfs2.Override(st.Table)
- }
+ // Patch the syscall table.
+ kernel.VFS2Enabled = args.Conf.VFS2
+ if kernel.VFS2Enabled {
+ vfs2.Override(syscallTable.Table)
}
+ kernel.RegisterSyscallTable(syscallTable)
+
// Create kernel and platform.
p, err := createPlatform(args.Conf, args.Device)
if err != nil {
@@ -319,6 +328,24 @@ func New(args Args) (*Loader, error) {
return nil, fmt.Errorf("creating pod mount hints: %v", err)
}
+ // Make host FDs stable between invocations. Host FDs must map to the exact
+ // same number when the sandbox is restored. Otherwise the wrong FD will be
+ // used.
+ var stdioFDs []int
+ newfd := startingStdioFD
+ for _, fd := range args.StdioFDs {
+ err := syscall.Dup3(fd, newfd, syscall.O_CLOEXEC)
+ if err != nil {
+ return nil, fmt.Errorf("dup3 of stdioFDs failed: %v", err)
+ }
+ stdioFDs = append(stdioFDs, newfd)
+ err = syscall.Close(fd)
+ if err != nil {
+ return nil, fmt.Errorf("close original stdioFDs failed: %v", err)
+ }
+ newfd++
+ }
+
eid := execID{cid: args.ID}
l := &Loader{
k: k,
@@ -327,7 +354,7 @@ func New(args Args) (*Loader, error) {
watchdog: dog,
spec: args.Spec,
goferFDs: args.GoferFDs,
- stdioFDs: args.StdioFDs,
+ stdioFDs: stdioFDs,
rootProcArgs: procArgs,
sandboxID: args.ID,
processes: map[execID]*execProcess{eid: {}},
@@ -367,11 +394,16 @@ func newProcess(id string, spec *specs.Spec, creds *auth.Credentials, k *kernel.
return kernel.CreateProcessArgs{}, fmt.Errorf("creating limits: %v", err)
}
+ wd := spec.Process.Cwd
+ if wd == "" {
+ wd = "/"
+ }
+
// Create the process arguments.
procArgs := kernel.CreateProcessArgs{
Argv: spec.Process.Args,
Envv: spec.Process.Env,
- WorkingDirectory: spec.Process.Cwd, // Defaults to '/' if empty.
+ WorkingDirectory: wd,
Credentials: creds,
Umask: 0022,
Limits: ls,
@@ -516,7 +548,15 @@ func (l *Loader) run() error {
}
// Add the HOME enviroment variable if it is not already set.
- envv, err := maybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace, l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv)
+ var envv []string
+ if kernel.VFS2Enabled {
+ envv, err = maybeAddExecUserHomeVFS2(ctx, l.rootProcArgs.MountNamespaceVFS2,
+ l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv)
+
+ } else {
+ envv, err = maybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace,
+ l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv)
+ }
if err != nil {
return err
}
@@ -569,6 +609,16 @@ func (l *Loader) run() error {
}
})
+ // l.stdioFDs are derived from dup() in boot.New() and they are now dup()ed again
+ // either in createFDTable() during initial start or in descriptor.initAfterLoad()
+ // during restore, we can release l.stdioFDs now.
+ for _, fd := range l.stdioFDs {
+ err := syscall.Close(fd)
+ if err != nil {
+ return fmt.Errorf("close dup()ed stdioFDs: %v", err)
+ }
+ }
+
log.Infof("Process should have started...")
l.watchdog.Start()
return l.k.Start()
diff --git a/runsc/boot/loader_amd64.go b/runsc/boot/loader_amd64.go
index b9669f2ac..78df86611 100644
--- a/runsc/boot/loader_amd64.go
+++ b/runsc/boot/loader_amd64.go
@@ -17,11 +17,10 @@
package boot
import (
- "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
)
func init() {
- // Register the global syscall table.
- kernel.RegisterSyscallTable(linux.AMD64)
+ // Set the global syscall table.
+ syscallTable = linux.AMD64
}
diff --git a/runsc/boot/loader_arm64.go b/runsc/boot/loader_arm64.go
index cf64d28c8..250785010 100644
--- a/runsc/boot/loader_arm64.go
+++ b/runsc/boot/loader_arm64.go
@@ -17,11 +17,10 @@
package boot
import (
- "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
)
func init() {
- // Register the global syscall table.
- kernel.RegisterSyscallTable(linux.ARM64)
+ // Set the global syscall table.
+ syscallTable = linux.ARM64
}
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index 44aa63196..e7c71734f 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -24,11 +24,13 @@ import (
"time"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/control/server"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/runsc/fsgofer"
@@ -65,6 +67,11 @@ func testSpec() *specs.Spec {
}
}
+func resetSyscallTable() {
+ kernel.VFS2Enabled = false
+ kernel.FlushSyscallTablesTestOnly()
+}
+
// startGofer starts a new gofer routine serving 'root' path. It returns the
// sandbox side of the connection, and a function that when called will stop the
// gofer.
@@ -100,7 +107,7 @@ func startGofer(root string) (int, func(), error) {
return sandboxEnd, cleanup, nil
}
-func createLoader() (*Loader, func(), error) {
+func createLoader(vfsEnabled bool) (*Loader, func(), error) {
fd, err := server.CreateSocket(ControlSocketAddr(fmt.Sprintf("%010d", rand.Int())[:10]))
if err != nil {
return nil, nil, err
@@ -108,12 +115,23 @@ func createLoader() (*Loader, func(), error) {
conf := testConfig()
spec := testSpec()
+ conf.VFS2 = vfsEnabled
+
sandEnd, cleanup, err := startGofer(spec.Root.Path)
if err != nil {
return nil, nil, err
}
- stdio := []int{int(os.Stdin.Fd()), int(os.Stdout.Fd()), int(os.Stderr.Fd())}
+ // Loader takes ownership of stdio.
+ var stdio []int
+ for _, f := range []*os.File{os.Stdin, os.Stdout, os.Stderr} {
+ newFd, err := unix.Dup(int(f.Fd()))
+ if err != nil {
+ return nil, nil, err
+ }
+ stdio = append(stdio, newFd)
+ }
+
args := Args{
ID: "foo",
Spec: spec,
@@ -132,10 +150,22 @@ func createLoader() (*Loader, func(), error) {
// TestRun runs a simple application in a sandbox and checks that it succeeds.
func TestRun(t *testing.T) {
- l, cleanup, err := createLoader()
+ defer resetSyscallTable()
+ doRun(t, false)
+}
+
+// TestRunVFS2 runs TestRun in VFSv2.
+func TestRunVFS2(t *testing.T) {
+ defer resetSyscallTable()
+ doRun(t, true)
+}
+
+func doRun(t *testing.T, vfsEnabled bool) {
+ l, cleanup, err := createLoader(vfsEnabled)
if err != nil {
t.Fatalf("error creating loader: %v", err)
}
+
defer l.Destroy()
defer cleanup()
@@ -169,7 +199,18 @@ func TestRun(t *testing.T) {
// TestStartSignal tests that the controller Start message will cause
// WaitForStartSignal to return.
func TestStartSignal(t *testing.T) {
- l, cleanup, err := createLoader()
+ defer resetSyscallTable()
+ doStartSignal(t, false)
+}
+
+// TestStartSignalVFS2 does TestStartSignal with VFS2.
+func TestStartSignalVFS2(t *testing.T) {
+ defer resetSyscallTable()
+ doStartSignal(t, true)
+}
+
+func doStartSignal(t *testing.T, vfsEnabled bool) {
+ l, cleanup, err := createLoader(vfsEnabled)
if err != nil {
t.Fatalf("error creating loader: %v", err)
}
diff --git a/runsc/boot/user.go b/runsc/boot/user.go
index f0aa52135..332e4fce5 100644
--- a/runsc/boot/user.go
+++ b/runsc/boot/user.go
@@ -23,8 +23,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -84,6 +86,48 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.K
File: f,
}
+ return findHomeInPasswd(uint32(uid), r, defaultHome)
+}
+
+type fileReaderVFS2 struct {
+ ctx context.Context
+ fd *vfs.FileDescription
+}
+
+func (r *fileReaderVFS2) Read(buf []byte) (int, error) {
+ n, err := r.fd.Read(r.ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ return int(n), err
+}
+
+func getExecUserHomeVFS2(ctx context.Context, mns *vfs.MountNamespace, uid auth.KUID) (string, error) {
+ const defaultHome = "/"
+
+ root := mns.Root()
+ defer root.DecRef()
+
+ creds := auth.CredentialsFromContext(ctx)
+
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse("/etc/passwd"),
+ }
+
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }
+
+ fd, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, target, opts)
+ if err != nil {
+ return defaultHome, nil
+ }
+ defer fd.DecRef()
+
+ r := &fileReaderVFS2{
+ ctx: ctx,
+ fd: fd,
+ }
+
homeDir, err := findHomeInPasswd(uint32(uid), r, defaultHome)
if err != nil {
return "", err
@@ -111,6 +155,26 @@ func maybeAddExecUserHome(ctx context.Context, mns *fs.MountNamespace, uid auth.
if err != nil {
return nil, fmt.Errorf("error reading exec user: %v", err)
}
+
+ return append(envv, "HOME="+homeDir), nil
+}
+
+func maybeAddExecUserHomeVFS2(ctx context.Context, vmns *vfs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) {
+ // Check if the envv already contains HOME.
+ for _, env := range envv {
+ if strings.HasPrefix(env, "HOME=") {
+ // We have it. Return the original slice unmodified.
+ return envv, nil
+ }
+ }
+
+ // Read /etc/passwd for the user's HOME directory and set the HOME
+ // environment variable as required by POSIX if it is not overridden by
+ // the user.
+ homeDir, err := getExecUserHomeVFS2(ctx, vmns, uid)
+ if err != nil {
+ return nil, fmt.Errorf("error reading exec user: %v", err)
+ }
return append(envv, "HOME="+homeDir), nil
}
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
new file mode 100644
index 000000000..82083c57d
--- /dev/null
+++ b/runsc/boot/vfs.go
@@ -0,0 +1,310 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "fmt"
+ "path"
+ "strconv"
+ "strings"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/devices/memdev"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ devtmpfsimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ goferimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/gofer"
+ procimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc"
+ sysimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys"
+ tmpfsimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func registerFilesystems(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) error {
+
+ vfsObj.MustRegisterFilesystemType(rootFsName, &goferimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserList: true,
+ })
+
+ vfsObj.MustRegisterFilesystemType(bind, &goferimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserList: true,
+ })
+
+ vfsObj.MustRegisterFilesystemType(devpts, &devtmpfsimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+
+ vfsObj.MustRegisterFilesystemType(devtmpfs, &devtmpfsimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(proc, &procimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(sysfs, &sysimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(tmpfs, &tmpfsimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(nonefs, &sysimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+
+ // Setup files in devtmpfs.
+ if err := memdev.Register(vfsObj); err != nil {
+ return fmt.Errorf("registering memdev: %w", err)
+ }
+ a, err := devtmpfsimpl.NewAccessor(ctx, vfsObj, creds, devtmpfsimpl.Name)
+ if err != nil {
+ return fmt.Errorf("creating devtmpfs accessor: %w", err)
+ }
+ defer a.Release()
+
+ if err := a.UserspaceInit(ctx); err != nil {
+ return fmt.Errorf("initializing userspace: %w", err)
+ }
+ if err := memdev.CreateDevtmpfsFiles(ctx, a); err != nil {
+ return fmt.Errorf("creating devtmpfs files: %w", err)
+ }
+ return nil
+}
+
+func setupContainerVFS2(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error {
+ if err := mntr.k.VFS().Init(); err != nil {
+ return fmt.Errorf("failed to initialize VFS: %w", err)
+ }
+ mns, err := mntr.setupVFS2(ctx, conf, procArgs)
+ if err != nil {
+ return fmt.Errorf("failed to setupFS: %w", err)
+ }
+ procArgs.MountNamespaceVFS2 = mns
+ return setExecutablePathVFS2(ctx, procArgs)
+}
+
+func setExecutablePathVFS2(ctx context.Context, procArgs *kernel.CreateProcessArgs) error {
+
+ exe := procArgs.Argv[0]
+
+ // Absolute paths can be used directly.
+ if path.IsAbs(exe) {
+ procArgs.Filename = exe
+ return nil
+ }
+
+ // Paths with '/' in them should be joined to the working directory, or
+ // to the root if working directory is not set.
+ if strings.IndexByte(exe, '/') > 0 {
+
+ if !path.IsAbs(procArgs.WorkingDirectory) {
+ return fmt.Errorf("working directory %q must be absolute", procArgs.WorkingDirectory)
+ }
+
+ procArgs.Filename = path.Join(procArgs.WorkingDirectory, exe)
+ return nil
+ }
+
+ // Paths with a '/' are relative to the CWD.
+ if strings.IndexByte(exe, '/') > 0 {
+ procArgs.Filename = path.Join(procArgs.WorkingDirectory, exe)
+ return nil
+ }
+
+ // Otherwise, We must lookup the name in the paths, starting from the
+ // root directory.
+ root := procArgs.MountNamespaceVFS2.Root()
+ defer root.DecRef()
+
+ paths := fs.GetPath(procArgs.Envv)
+ creds := procArgs.Credentials
+
+ for _, p := range paths {
+
+ binPath := path.Join(p, exe)
+
+ pop := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(binPath),
+ FollowFinalSymlink: true,
+ }
+
+ opts := &vfs.OpenOptions{
+ FileExec: true,
+ Flags: linux.O_RDONLY,
+ }
+
+ dentry, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, pop, opts)
+ if err == syserror.ENOENT || err == syserror.EACCES {
+ // Didn't find it here.
+ continue
+ }
+ if err != nil {
+ return err
+ }
+ dentry.DecRef()
+
+ procArgs.Filename = binPath
+ return nil
+ }
+
+ return fmt.Errorf("executable %q not found in $PATH=%q", exe, strings.Join(paths, ":"))
+}
+
+func (c *containerMounter) setupVFS2(ctx context.Context, conf *Config, procArgs *kernel.CreateProcessArgs) (*vfs.MountNamespace, error) {
+ log.Infof("Configuring container's file system with VFS2")
+
+ // Create context with root credentials to mount the filesystem (the current
+ // user may not be privileged enough).
+ rootProcArgs := *procArgs
+ rootProcArgs.WorkingDirectory = "/"
+ rootProcArgs.Credentials = auth.NewRootCredentials(procArgs.Credentials.UserNamespace)
+ rootProcArgs.Umask = 0022
+ rootProcArgs.MaxSymlinkTraversals = linux.MaxSymlinkTraversals
+ rootCtx := procArgs.NewContext(c.k)
+
+ creds := procArgs.Credentials
+ if err := registerFilesystems(rootCtx, c.k.VFS(), creds); err != nil {
+ return nil, fmt.Errorf("register filesystems: %w", err)
+ }
+
+ fd := c.fds.remove()
+
+ opts := strings.Join(p9MountOptionsVFS2(fd, conf.FileAccess), ",")
+
+ log.Infof("Mounting root over 9P, ioFD: %d", fd)
+ mns, err := c.k.VFS().NewMountNamespace(ctx, creds, "", rootFsName, &vfs.GetFilesystemOptions{Data: opts})
+ if err != nil {
+ return nil, fmt.Errorf("setting up mountnamespace: %w", err)
+ }
+
+ rootProcArgs.MountNamespaceVFS2 = mns
+
+ // Mount submounts.
+ if err := c.mountSubmountsVFS2(rootCtx, conf, mns, creds); err != nil {
+ return nil, fmt.Errorf("mounting submounts vfs2: %w", err)
+ }
+
+ return mns, nil
+}
+
+func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials) error {
+
+ for _, submount := range c.mounts {
+ log.Debugf("Mounting %q to %q, type: %s, options: %s", submount.Source, submount.Destination, submount.Type, submount.Options)
+ if err := c.mountSubmountVFS2(ctx, conf, mns, creds, &submount); err != nil {
+ return err
+ }
+ }
+
+ // TODO(gvisor.dev/issue/1487): implement mountTmp from fs.go.
+
+ return c.checkDispenser()
+}
+
+// TODO(gvisor.dev/issue/1487): Implement submount options similar to the VFS1 version.
+func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials, submount *specs.Mount) error {
+ root := mns.Root()
+ defer root.DecRef()
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(submount.Destination),
+ }
+
+ _, options, useOverlay, err := c.getMountNameAndOptionsVFS2(conf, *submount)
+ if err != nil {
+ return fmt.Errorf("mountOptions failed: %w", err)
+ }
+
+ opts := &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ Data: strings.Join(options, ","),
+ },
+ InternalMount: true,
+ }
+
+ // All writes go to upper, be paranoid and make lower readonly.
+ opts.ReadOnly = useOverlay
+
+ if err := c.k.VFS().MountAt(ctx, creds, "", target, submount.Type, opts); err != nil {
+ return fmt.Errorf("failed to mount %q (type: %s): %w, opts: %v", submount.Destination, submount.Type, err, opts)
+ }
+ log.Infof("Mounted %q to %q type: %s, internal-options: %q", submount.Source, submount.Destination, submount.Type, opts)
+ return nil
+}
+
+// getMountNameAndOptionsVFS2 retrieves the fsName, opts, and useOverlay values
+// used for mounts.
+func (c *containerMounter) getMountNameAndOptionsVFS2(conf *Config, m specs.Mount) (string, []string, bool, error) {
+ var (
+ fsName string
+ opts []string
+ useOverlay bool
+ )
+
+ switch m.Type {
+ case devpts, devtmpfs, proc, sysfs:
+ fsName = m.Type
+ case nonefs:
+ fsName = sysfs
+ case tmpfs:
+ fsName = m.Type
+
+ var err error
+ opts, err = parseAndFilterOptions(m.Options, tmpfsAllowedOptions...)
+ if err != nil {
+ return "", nil, false, err
+ }
+
+ case bind:
+ fd := c.fds.remove()
+ fsName = "9p"
+ opts = p9MountOptionsVFS2(fd, c.getMountAccessType(m))
+ // If configured, add overlay to all writable mounts.
+ useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
+
+ default:
+ log.Warningf("ignoring unknown filesystem type %q", m.Type)
+ }
+ return fsName, opts, useOverlay, nil
+}
+
+// p9MountOptions creates a slice of options for a p9 mount.
+// TODO(gvisor.dev/issue/1200): Remove this version in favor of the one in
+// fs.go when privateunixsocket lands.
+func p9MountOptionsVFS2(fd int, fa FileAccessType) []string {
+ opts := []string{
+ "trans=fd",
+ "rfdno=" + strconv.Itoa(fd),
+ "wfdno=" + strconv.Itoa(fd),
+ }
+ if fa == FileAccessShared {
+ opts = append(opts, "cache=remote_revalidating")
+ }
+ return opts
+}
diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go
index 0f3da69a0..0938944a6 100644
--- a/runsc/cmd/boot.go
+++ b/runsc/cmd/boot.go
@@ -23,6 +23,7 @@ import (
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/boot/platforms"
@@ -82,8 +83,13 @@ type Boot struct {
// sandbox (e.g. gofer) and sent through this FD.
mountsFD int
- // pidns is set if the sanadbox is in its own pid namespace.
+ // pidns is set if the sandbox is in its own pid namespace.
pidns bool
+
+ // attached is set to true to kill the sandbox process when the parent process
+ // terminates. This flag is set when the command execve's itself because
+ // parent death signal doesn't propagate through execve when uid/gid changes.
+ attached bool
}
// Name implements subcommands.Command.Name.
@@ -118,6 +124,7 @@ func (b *Boot) SetFlags(f *flag.FlagSet) {
f.IntVar(&b.userLogFD, "user-log-fd", 0, "file descriptor to write user logs to. 0 means no logging.")
f.IntVar(&b.startSyncFD, "start-sync-fd", -1, "required FD to used to synchronize sandbox startup")
f.IntVar(&b.mountsFD, "mounts-fd", -1, "mountsFD is the file descriptor to read list of mounts after they have been resolved (direct paths, no symlinks).")
+ f.BoolVar(&b.attached, "attached", false, "if attached is true, kills the sandbox process when the parent process terminates")
}
// Execute implements subcommands.Command.Execute. It starts a sandbox in a
@@ -133,29 +140,32 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
conf := args[0].(*boot.Config)
+ if b.attached {
+ // Ensure this process is killed after parent process terminates when
+ // attached mode is enabled. In the unfortunate event that the parent
+ // terminates before this point, this process leaks.
+ if err := unix.Prctl(unix.PR_SET_PDEATHSIG, uintptr(unix.SIGKILL), 0, 0, 0); err != nil {
+ Fatalf("error setting parent death signal: %v", err)
+ }
+ }
+
if b.setUpRoot {
if err := setUpChroot(b.pidns); err != nil {
Fatalf("error setting up chroot: %v", err)
}
- if !b.applyCaps {
- // Remove --setup-root arg to call myself.
- var args []string
- for _, arg := range os.Args {
- if !strings.Contains(arg, "setup-root") {
- args = append(args, arg)
- }
- }
- if !conf.Rootless {
- // Note that we've already read the spec from the spec FD, and
- // we will read it again after the exec call. This works
- // because the ReadSpecFromFile function seeks to the beginning
- // of the file before reading.
- if err := callSelfAsNobody(args); err != nil {
- Fatalf("%v", err)
- }
- panic("callSelfAsNobody must never return success")
+ if !b.applyCaps && !conf.Rootless {
+ // Remove --apply-caps arg to call myself. It has already been done.
+ args := prepareArgs(b.attached, "setup-root")
+
+ // Note that we've already read the spec from the spec FD, and
+ // we will read it again after the exec call. This works
+ // because the ReadSpecFromFile function seeks to the beginning
+ // of the file before reading.
+ if err := callSelfAsNobody(args); err != nil {
+ Fatalf("%v", err)
}
+ panic("callSelfAsNobody must never return success")
}
}
@@ -181,13 +191,9 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
caps.Permitted = append(caps.Permitted, c)
}
- // Remove --apply-caps arg to call myself.
- var args []string
- for _, arg := range os.Args {
- if !strings.Contains(arg, "setup-root") && !strings.Contains(arg, "apply-caps") {
- args = append(args, arg)
- }
- }
+ // Remove --apply-caps and --setup-root arg to call myself. Both have
+ // already been done.
+ args := prepareArgs(b.attached, "setup-root", "apply-caps")
// Note that we've already read the spec from the spec FD, and
// we will read it again after the exec call. This works
@@ -258,3 +264,22 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
l.Destroy()
return subcommands.ExitSuccess
}
+
+func prepareArgs(attached bool, exclude ...string) []string {
+ var args []string
+ for _, arg := range os.Args {
+ for _, excl := range exclude {
+ if strings.Contains(arg, excl) {
+ goto skip
+ }
+ }
+ args = append(args, arg)
+ if attached && arg == "boot" {
+ // Strategicaly place "--attached" after the command. This is needed
+ // to ensure the new process is killed when the parent process terminates.
+ args = append(args, "--attached")
+ }
+ skip:
+ }
+ return args
+}
diff --git a/runsc/cmd/chroot.go b/runsc/cmd/chroot.go
index b5a0ce17d..189244765 100644
--- a/runsc/cmd/chroot.go
+++ b/runsc/cmd/chroot.go
@@ -50,7 +50,7 @@ func pivotRoot(root string) error {
// new_root, so after umounting the old_root, we will see only
// the new_root in "/".
if err := syscall.PivotRoot(".", "."); err != nil {
- return fmt.Errorf("error changing root filesystem: %v", err)
+ return fmt.Errorf("pivot_root failed, make sure that the root mount has a parent: %v", err)
}
if err := syscall.Unmount(".", syscall.MNT_DETACH); err != nil {
diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go
index 79965460e..b5de2588b 100644
--- a/runsc/cmd/debug.go
+++ b/runsc/cmd/debug.go
@@ -32,17 +32,20 @@ import (
// Debug implements subcommands.Command for the "debug" command.
type Debug struct {
- pid int
- stacks bool
- signal int
- profileHeap string
- profileCPU string
- trace string
- strace string
- logLevel string
- logPackets string
- duration time.Duration
- ps bool
+ pid int
+ stacks bool
+ signal int
+ profileHeap string
+ profileCPU string
+ profileGoroutine string
+ profileBlock string
+ profileMutex string
+ trace string
+ strace string
+ logLevel string
+ logPackets string
+ duration time.Duration
+ ps bool
}
// Name implements subcommands.Command.
@@ -66,6 +69,9 @@ func (d *Debug) SetFlags(f *flag.FlagSet) {
f.BoolVar(&d.stacks, "stacks", false, "if true, dumps all sandbox stacks to the log")
f.StringVar(&d.profileHeap, "profile-heap", "", "writes heap profile to the given file.")
f.StringVar(&d.profileCPU, "profile-cpu", "", "writes CPU profile to the given file.")
+ f.StringVar(&d.profileGoroutine, "profile-goroutine", "", "writes goroutine profile to the given file.")
+ f.StringVar(&d.profileBlock, "profile-block", "", "writes block profile to the given file.")
+ f.StringVar(&d.profileMutex, "profile-mutex", "", "writes mutex profile to the given file.")
f.DurationVar(&d.duration, "duration", time.Second, "amount of time to wait for CPU and trace profiles")
f.StringVar(&d.trace, "trace", "", "writes an execution trace to the given file.")
f.IntVar(&d.signal, "signal", -1, "sends signal to the sandbox")
@@ -147,6 +153,42 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
}
log.Infof("Heap profile written to %q", d.profileHeap)
}
+ if d.profileGoroutine != "" {
+ f, err := os.Create(d.profileGoroutine)
+ if err != nil {
+ return Errorf(err.Error())
+ }
+ defer f.Close()
+
+ if err := c.Sandbox.GoroutineProfile(f); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("Goroutine profile written to %q", d.profileGoroutine)
+ }
+ if d.profileBlock != "" {
+ f, err := os.Create(d.profileBlock)
+ if err != nil {
+ return Errorf(err.Error())
+ }
+ defer f.Close()
+
+ if err := c.Sandbox.BlockProfile(f); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("Block profile written to %q", d.profileBlock)
+ }
+ if d.profileMutex != "" {
+ f, err := os.Create(d.profileMutex)
+ if err != nil {
+ return Errorf(err.Error())
+ }
+ defer f.Close()
+
+ if err := c.Sandbox.MutexProfile(f); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("Mutex profile written to %q", d.profileMutex)
+ }
delay := false
if d.profileCPU != "" {
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 6e06f3c0f..28f0d54b9 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -272,9 +272,8 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error {
root := spec.Root.Path
if !conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
- // FIXME: runsc can't be re-executed without
- // /proc, so we create a tmpfs mount, mount ./proc and ./root
- // there, then move this mount to the root and after
+ // runsc can't be re-executed without /proc, so we create a tmpfs mount,
+ // mount ./proc and ./root there, then move this mount to the root and after
// setCapsAndCallSelf, runsc will chroot into /root.
//
// We need a directory to construct a new root and we know that
@@ -335,7 +334,7 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error {
if !conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
if err := pivotRoot("/proc"); err != nil {
- Fatalf("faild to change the root file system: %v", err)
+ Fatalf("failed to change the root file system: %v", err)
}
if err := os.Chdir("/"); err != nil {
Fatalf("failed to change working directory")
diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go
index c2518d52b..651615d4c 100644
--- a/runsc/container/console_test.go
+++ b/runsc/container/console_test.go
@@ -333,13 +333,13 @@ func TestJobControlSignalRootContainer(t *testing.T) {
// file. Writes after a certain point will block unless we drain the
// PTY, so we must continually copy from it.
//
- // We log the output to stdout for debugabilitly, and also to a buffer,
+ // We log the output to stderr for debugabilitly, and also to a buffer,
// since we wait on particular output from bash below. We use a custom
// blockingBuffer which is thread-safe and also blocks on Read calls,
// which makes this a suitable Reader for WaitUntilRead.
ptyBuf := newBlockingBuffer()
tee := io.TeeReader(ptyMaster, ptyBuf)
- go io.Copy(os.Stdout, tee)
+ go io.Copy(os.Stderr, tee)
// Start the container.
if err := c.Start(conf); err != nil {
diff --git a/runsc/container/container.go b/runsc/container/container.go
index 68782c4be..7233659b1 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -17,6 +17,7 @@ package container
import (
"context"
+ "errors"
"fmt"
"io/ioutil"
"os"
@@ -1066,27 +1067,19 @@ func runInCgroup(cg *cgroup.Cgroup, fn func() error) error {
// adjustGoferOOMScoreAdj sets the oom_store_adj for the container's gofer.
func (c *Container) adjustGoferOOMScoreAdj() error {
- if c.GoferPid != 0 && c.Spec.Process.OOMScoreAdj != nil {
- if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil {
- // Ignore NotExist error because it can be returned when the sandbox
- // exited while OOM score was being adjusted.
- if !os.IsNotExist(err) {
- return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err)
- }
- log.Warningf("Gofer process (%d) not found setting oom_score_adj", c.GoferPid)
- }
+ if c.GoferPid == 0 || c.Spec.Process.OOMScoreAdj == nil {
+ return nil
}
-
- return nil
+ return setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj)
}
// adjustSandboxOOMScoreAdj sets the oom_score_adj for the sandbox.
// oom_score_adj is set to the lowest oom_score_adj among the containers
// running in the sandbox.
//
-// TODO(gvisor.dev/issue/512): This call could race with other containers being
+// TODO(gvisor.dev/issue/238): This call could race with other containers being
// created at the same time and end up setting the wrong oom_score_adj to the
-// sandbox.
+// sandbox. Use rpc client to synchronize.
func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) error {
containers, err := loadSandbox(rootDir, s.ID)
if err != nil {
@@ -1154,29 +1147,29 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool)
}
// Set the lowest of all containers oom_score_adj to the sandbox.
- if err := setOOMScoreAdj(s.Pid, lowScore); err != nil {
- // Ignore NotExist error because it can be returned when the sandbox
- // exited while OOM score was being adjusted.
- if !os.IsNotExist(err) {
- return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err)
- }
- log.Warningf("Sandbox process (%d) not found setting oom_score_adj", s.Pid)
- }
-
- return nil
+ return setOOMScoreAdj(s.Pid, lowScore)
}
// setOOMScoreAdj sets oom_score_adj to the given value for the given PID.
// /proc must be available and mounted read-write. scoreAdj should be between
-// -1000 and 1000.
+// -1000 and 1000. It's a noop if the process has already exited.
func setOOMScoreAdj(pid int, scoreAdj int) error {
f, err := os.OpenFile(fmt.Sprintf("/proc/%d/oom_score_adj", pid), os.O_WRONLY, 0644)
if err != nil {
+ // Ignore NotExist errors because it can race with process exit.
+ if os.IsNotExist(err) {
+ log.Warningf("Process (%d) not found setting oom_score_adj", pid)
+ return nil
+ }
return err
}
defer f.Close()
if _, err := f.WriteString(strconv.Itoa(scoreAdj)); err != nil {
- return err
+ if errors.Is(err, syscall.ESRCH) {
+ log.Warningf("Process (%d) exited while setting oom_score_adj", pid)
+ return nil
+ }
+ return fmt.Errorf("setting oom_score_adj to %q: %v", scoreAdj, err)
}
return nil
}
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index bdd65b498..24f9ecc35 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -124,23 +124,6 @@ func procListsEqual(got, want []*control.Process) (bool, error) {
return true, nil
}
-// getAndCheckProcLists is similar to waitForProcessList, but does not wait and retry the
-// test for equality. This is because we already confirmed that exec occurred.
-func getAndCheckProcLists(cont *Container, want []*control.Process) error {
- got, err := cont.Processes()
- if err != nil {
- return fmt.Errorf("error getting process data from container: %v", err)
- }
- equal, err := procListsEqual(got, want)
- if err != nil {
- return err
- }
- if equal {
- return nil
- }
- return fmt.Errorf("container got process list: %s, want: %s", procListToString(got), procListToString(want))
-}
-
func procListToString(pl []*control.Process) string {
strs := make([]string, 0, len(pl))
for _, p := range pl {
@@ -538,9 +521,21 @@ func TestExePath(t *testing.T) {
// Test the we can retrieve the application exit status from the container.
func TestAppExitStatus(t *testing.T) {
+ conf := testutil.TestConfig()
+ conf.VFS2 = false
+ doAppExitStatus(t, conf)
+}
+
+// This is TestAppExitStatus for VFSv2.
+func TestAppExitStatusVFS2(t *testing.T) {
+ conf := testutil.TestConfig()
+ conf.VFS2 = true
+ doAppExitStatus(t, conf)
+}
+
+func doAppExitStatus(t *testing.T, conf *boot.Config) {
// First container will succeed.
succSpec := testutil.NewSpecWithArgs("true")
- conf := testutil.TestConfig()
rootDir, bundleDir, err := testutil.SetupContainer(succSpec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -2092,7 +2087,7 @@ func TestOverlayfsStaleRead(t *testing.T) {
defer out.Close()
const want = "foobar"
- cmd := fmt.Sprintf("cat %q && echo %q> %q && cp %q %q", in.Name(), want, in.Name(), in.Name(), out.Name())
+ cmd := fmt.Sprintf("cat %q >&2 && echo %q> %q && cp %q %q", in.Name(), want, in.Name(), in.Name(), out.Name())
spec := testutil.NewSpecWithArgs("/bin/bash", "-c", cmd)
if err := run(spec, conf); err != nil {
t.Fatalf("Error running container: %v", err)
diff --git a/runsc/container/test_app/test_app.go b/runsc/container/test_app/test_app.go
index 01c47c79f..5f1c4b7d6 100644
--- a/runsc/container/test_app/test_app.go
+++ b/runsc/container/test_app/test_app.go
@@ -96,7 +96,7 @@ func (c *uds) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{})
listener, err := net.Listen("unix", c.socketPath)
if err != nil {
- log.Fatal("error listening on socket %q:", c.socketPath, err)
+ log.Fatalf("error listening on socket %q: %v", c.socketPath, err)
}
go server(listener, outputFile)
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index cadd83273..1942f50d7 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -767,22 +767,18 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
return err
}
-// TODO(b/127675828): support getxattr.
func (*localFile) GetXattr(string, uint64) (string, error) {
return "", syscall.EOPNOTSUPP
}
-// TODO(b/127675828): support setxattr.
func (*localFile) SetXattr(string, string, uint32) error {
return syscall.EOPNOTSUPP
}
-// TODO(b/148303075): support listxattr.
func (*localFile) ListXattr(uint64) (map[string]struct{}, error) {
return nil, syscall.EOPNOTSUPP
}
-// TODO(b/148303075): support removexattr.
func (*localFile) RemoveXattr(string) error {
return syscall.EOPNOTSUPP
}
diff --git a/runsc/main.go b/runsc/main.go
index af73bed97..9d52f3006 100644
--- a/runsc/main.go
+++ b/runsc/main.go
@@ -54,9 +54,11 @@ var (
// Debugging flags.
debugLog = flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.")
+ panicLog = flag.String("panic-log", "", "file path were panic reports and other Go's runtime messages are written.")
logPackets = flag.Bool("log-packets", false, "enable network packet logging.")
logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.")
debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.")
+ panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.")
debugLogFormat = flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.")
alsoLogToStderr = flag.Bool("alsologtostderr", false, "send log messages to stderr.")
@@ -82,6 +84,7 @@ var (
rootless = flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.")
referenceLeakMode = flag.String("ref-leak-mode", "disabled", "sets reference leak check mode: disabled (default), log-names, log-traces.")
cpuNumFromQuota = flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)")
+ vfs2Enabled = flag.Bool("vfs2", false, "TEST ONLY; use while VFSv2 is landing. This uses the new experimental VFS layer.")
// Test flags, not to be used outside tests, ever.
testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.")
@@ -206,6 +209,7 @@ func main() {
LogFilename: *logFilename,
LogFormat: *logFormat,
DebugLog: *debugLog,
+ PanicLog: *panicLog,
DebugLogFormat: *debugLogFormat,
FileAccess: fsAccess,
FSGoferHostUDS: *fsGoferHostUDS,
@@ -227,6 +231,7 @@ func main() {
ReferenceLeakMode: refsLeakMode,
OverlayfsStaleRead: *overlayfsStaleRead,
CPUNumFromQuota: *cpuNumFromQuota,
+ VFS2: *vfs2Enabled,
TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot,
TestOnlyTestNameEnv: *testOnlyTestNameEnv,
@@ -258,20 +263,6 @@ func main() {
if *debugLogFD > -1 {
f := os.NewFile(uintptr(*debugLogFD), "debug log file")
- // Quick sanity check to make sure no other commands get passed
- // a log fd (they should use log dir instead).
- if subcommand != "boot" && subcommand != "gofer" {
- cmd.Fatalf("flag --debug-log-fd should only be passed to 'boot' and 'gofer' command, but was passed to %q", subcommand)
- }
-
- // If we are the boot process, then we own our stdio FDs and can do what we
- // want with them. Since Docker and Containerd both eat boot's stderr, we
- // dup our stderr to the provided log FD so that panics will appear in the
- // logs, rather than just disappear.
- if err := syscall.Dup3(int(f.Fd()), int(os.Stderr.Fd()), 0); err != nil {
- cmd.Fatalf("error dup'ing fd %d to stderr: %v", f.Fd(), err)
- }
-
e = newEmitter(*debugLogFormat, f)
} else if *debugLog != "" {
@@ -287,6 +278,26 @@ func main() {
e = newEmitter("text", ioutil.Discard)
}
+ if *panicLogFD > -1 || *debugLogFD > -1 {
+ fd := *panicLogFD
+ if fd < 0 {
+ fd = *debugLogFD
+ }
+ // Quick sanity check to make sure no other commands get passed
+ // a log fd (they should use log dir instead).
+ if subcommand != "boot" && subcommand != "gofer" {
+ cmd.Fatalf("flags --debug-log-fd and --panic-log-fd should only be passed to 'boot' and 'gofer' command, but was passed to %q", subcommand)
+ }
+
+ // If we are the boot process, then we own our stdio FDs and can do what we
+ // want with them. Since Docker and Containerd both eat boot's stderr, we
+ // dup our stderr to the provided log FD so that panics will appear in the
+ // logs, rather than just disappear.
+ if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil {
+ cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err)
+ }
+ }
+
if *alsoLogToStderr {
e = &log.MultiEmitter{e, newEmitter(*debugLogFormat, os.Stderr)}
}
@@ -304,6 +315,7 @@ func main() {
log.Infof("\t\tFileAccess: %v, overlay: %t", conf.FileAccess, conf.Overlay)
log.Infof("\t\tNetwork: %v, logging: %t", conf.Network, conf.LogPackets)
log.Infof("\t\tStrace: %t, max size: %d, syscalls: %s", conf.Strace, conf.StraceLogSize, conf.StraceSyscalls)
+ log.Infof("\t\tVFS2 enabled: %v", conf.VFS2)
log.Infof("***************************")
if *testOnlyAllowRunAsCurrentUserWithoutChroot {
@@ -333,11 +345,11 @@ func main() {
func newEmitter(format string, logFile io.Writer) log.Emitter {
switch format {
case "text":
- return &log.GoogleEmitter{log.Writer{Next: logFile}}
+ return log.GoogleEmitter{&log.Writer{Next: logFile}}
case "json":
- return &log.JSONEmitter{log.Writer{Next: logFile}}
+ return log.JSONEmitter{&log.Writer{Next: logFile}}
case "json-k8s":
- return &log.K8sJSONEmitter{log.Writer{Next: logFile}}
+ return log.K8sJSONEmitter{&log.Writer{Next: logFile}}
}
cmd.Fatalf("invalid log format %q, must be 'text', 'json', or 'json-k8s'", format)
panic("unreachable")
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index ec72bdbfd..e82bcef6f 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -18,10 +18,12 @@ package sandbox
import (
"context"
"fmt"
+ "io"
"math"
"os"
"os/exec"
"strconv"
+ "strings"
"syscall"
"time"
@@ -142,7 +144,19 @@ func New(conf *boot.Config, args *Args) (*Sandbox, error) {
// Wait until the sandbox has booted.
b := make([]byte, 1)
if l, err := clientSyncFile.Read(b); err != nil || l != 1 {
- return nil, fmt.Errorf("waiting for sandbox to start: %v", err)
+ err := fmt.Errorf("waiting for sandbox to start: %v", err)
+ // If the sandbox failed to start, it may be because the binary
+ // permissions were incorrect. Check the bits and return a more helpful
+ // error message.
+ //
+ // NOTE: The error message is checked because error types are lost over
+ // rpc calls.
+ if strings.Contains(err.Error(), io.EOF.Error()) {
+ if permsErr := checkBinaryPermissions(conf); permsErr != nil {
+ return nil, fmt.Errorf("%v: %v", err, permsErr)
+ }
+ }
+ return nil, err
}
c.Release()
@@ -369,8 +383,24 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
cmd.Args = append(cmd.Args, "--debug-log-fd="+strconv.Itoa(nextFD))
nextFD++
}
+ if conf.PanicLog != "" {
+ test := ""
+ if len(conf.TestOnlyTestNameEnv) != 0 {
+ // Fetch test name if one is provided and the test only flag was set.
+ if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
+ test = t
+ }
+ }
- cmd.Args = append(cmd.Args, "--panic-signal="+strconv.Itoa(int(syscall.SIGTERM)))
+ panicLogFile, err := specutils.DebugLogFile(conf.PanicLog, "panic", test)
+ if err != nil {
+ return fmt.Errorf("opening debug log file in %q: %v", conf.PanicLog, err)
+ }
+ defer panicLogFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, panicLogFile)
+ cmd.Args = append(cmd.Args, "--panic-log-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
// Add the "boot" command to the args.
//
@@ -426,6 +456,12 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nextFD++
}
+ // TODO(b/151157106): syscall tests fail by timeout if asyncpreemptoff
+ // isn't set.
+ if conf.Platform == "kvm" {
+ cmd.Env = append(cmd.Env, "GODEBUG=asyncpreemptoff=1")
+ }
+
// The current process' stdio must be passed to the application via the
// --stdio-fds flag. The stdio of the sandbox process itself must not
// be connected to the same FDs, otherwise we risk leaking sandbox
@@ -564,45 +600,32 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nss = append(nss, specs.LinuxNamespace{Type: specs.UserNamespace})
cmd.Args = append(cmd.Args, "--setup-root")
+ const nobody = 65534
if conf.Rootless {
- log.Infof("Rootless mode: sandbox will run as root inside user namespace, mapped to the current user, uid: %d, gid: %d", os.Getuid(), os.Getgid())
+ log.Infof("Rootless mode: sandbox will run as nobody inside user namespace, mapped to the current user, uid: %d, gid: %d", os.Getuid(), os.Getgid())
cmd.SysProcAttr.UidMappings = []syscall.SysProcIDMap{
{
- ContainerID: 0,
+ ContainerID: nobody,
HostID: os.Getuid(),
Size: 1,
},
}
cmd.SysProcAttr.GidMappings = []syscall.SysProcIDMap{
{
- ContainerID: 0,
+ ContainerID: nobody,
HostID: os.Getgid(),
Size: 1,
},
}
- cmd.SysProcAttr.Credential = &syscall.Credential{Uid: 0, Gid: 0}
} else {
// Map nobody in the new namespace to nobody in the parent namespace.
//
// A sandbox process will construct an empty
- // root for itself, so it has to have the CAP_SYS_ADMIN
- // capability.
- //
- // FIXME(b/122554829): The current implementations of
- // os/exec doesn't allow to set ambient capabilities if
- // a process is started in a new user namespace. As a
- // workaround, we start the sandbox process with the 0
- // UID and then it constructs a chroot and sets UID to
- // nobody. https://github.com/golang/go/issues/2315
- const nobody = 65534
+ // root for itself, so it has to have
+ // CAP_SYS_ADMIN and CAP_SYS_CHROOT capabilities.
cmd.SysProcAttr.UidMappings = []syscall.SysProcIDMap{
{
- ContainerID: 0,
- HostID: nobody - 1,
- Size: 1,
- },
- {
ContainerID: nobody,
HostID: nobody,
Size: 1,
@@ -615,11 +638,11 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
Size: 1,
},
}
-
- // Set credentials to run as user and group nobody.
- cmd.SysProcAttr.Credential = &syscall.Credential{Uid: 0, Gid: nobody}
}
+ // Set credentials to run as user and group nobody.
+ cmd.SysProcAttr.Credential = &syscall.Credential{Uid: nobody, Gid: nobody}
+ cmd.SysProcAttr.AmbientCaps = append(cmd.SysProcAttr.AmbientCaps, uintptr(capability.CAP_SYS_ADMIN), uintptr(capability.CAP_SYS_CHROOT))
} else {
return fmt.Errorf("can't run sandbox process as user nobody since we don't have CAP_SETUID or CAP_SETGID")
}
@@ -677,6 +700,13 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nextFD++
}
+ if args.Attached {
+ // Kill sandbox if parent process exits in attached mode.
+ cmd.SysProcAttr.Pdeathsig = syscall.SIGKILL
+ // Tells boot that any process it creates must have pdeathsig set.
+ cmd.Args = append(cmd.Args, "--attached")
+ }
+
// Add container as the last argument.
cmd.Args = append(cmd.Args, s.ID)
@@ -685,15 +715,22 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
log.Debugf("Donating FD %d: %q", i+3, f.Name())
}
- if args.Attached {
- // Kill sandbox if parent process exits in attached mode.
- cmd.SysProcAttr.Pdeathsig = syscall.SIGKILL
- }
-
log.Debugf("Starting sandbox: %s %v", binPath, cmd.Args)
log.Debugf("SysProcAttr: %+v", cmd.SysProcAttr)
if err := specutils.StartInNS(cmd, nss); err != nil {
- return fmt.Errorf("Sandbox: %v", err)
+ err := fmt.Errorf("starting sandbox: %v", err)
+ // If the sandbox failed to start, it may be because the binary
+ // permissions were incorrect. Check the bits and return a more helpful
+ // error message.
+ //
+ // NOTE: The error message is checked because error types are lost over
+ // rpc calls.
+ if strings.Contains(err.Error(), syscall.EACCES.Error()) {
+ if permsErr := checkBinaryPermissions(conf); permsErr != nil {
+ return fmt.Errorf("%v: %v", err, permsErr)
+ }
+ }
+ return err
}
s.child = true
s.Pid = cmd.Process.Pid
@@ -972,6 +1009,66 @@ func (s *Sandbox) StopCPUProfile() error {
return nil
}
+// GoroutineProfile writes a goroutine profile to the given file.
+func (s *Sandbox) GoroutineProfile(f *os.File) error {
+ log.Debugf("Goroutine profile %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opts := control.ProfileOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+ if err := conn.Call(boot.GoroutineProfile, &opts, nil); err != nil {
+ return fmt.Errorf("getting sandbox %q goroutine profile: %v", s.ID, err)
+ }
+ return nil
+}
+
+// BlockProfile writes a block profile to the given file.
+func (s *Sandbox) BlockProfile(f *os.File) error {
+ log.Debugf("Block profile %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opts := control.ProfileOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+ if err := conn.Call(boot.BlockProfile, &opts, nil); err != nil {
+ return fmt.Errorf("getting sandbox %q block profile: %v", s.ID, err)
+ }
+ return nil
+}
+
+// MutexProfile writes a mutex profile to the given file.
+func (s *Sandbox) MutexProfile(f *os.File) error {
+ log.Debugf("Mutex profile %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opts := control.ProfileOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+ if err := conn.Call(boot.MutexProfile, &opts, nil); err != nil {
+ return fmt.Errorf("getting sandbox %q mutex profile: %v", s.ID, err)
+ }
+ return nil
+}
+
// StartTrace start trace writing to the given file.
func (s *Sandbox) StartTrace(f *os.File) error {
log.Debugf("Trace start %q", s.ID)
@@ -1096,3 +1193,31 @@ func deviceFileForPlatform(name string) (*os.File, error) {
}
return f, nil
}
+
+// checkBinaryPermissions verifies that the required binary bits are set on
+// the runsc executable.
+func checkBinaryPermissions(conf *boot.Config) error {
+ // All platforms need the other exe bit
+ neededBits := os.FileMode(0001)
+ if conf.Platform == platforms.Ptrace {
+ // Ptrace needs the other read bit
+ neededBits |= os.FileMode(0004)
+ }
+
+ exePath, err := os.Executable()
+ if err != nil {
+ return fmt.Errorf("getting exe path: %v", err)
+ }
+
+ // Check the permissions of the runsc binary and print an error if it
+ // doesn't match expectations.
+ info, err := os.Stat(exePath)
+ if err != nil {
+ return fmt.Errorf("stat file: %v", err)
+ }
+
+ if info.Mode().Perm()&neededBits != neededBits {
+ return fmt.Errorf(specutils.FaqErrorMsg("runsc-perms", fmt.Sprintf("%s does not have the correct permissions", exePath)))
+ }
+ return nil
+}
diff --git a/runsc/specutils/namespace.go b/runsc/specutils/namespace.go
index c7dd3051c..60bb7b7ee 100644
--- a/runsc/specutils/namespace.go
+++ b/runsc/specutils/namespace.go
@@ -252,6 +252,9 @@ func MaybeRunAsRoot() error {
},
Credential: &syscall.Credential{Uid: 0, Gid: 0},
GidMappingsEnableSetgroups: false,
+
+ // Make sure child is killed when the parent terminates.
+ Pdeathsig: syscall.SIGKILL,
}
cmd.Env = os.Environ()
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index d3c2e4e78..837d5e238 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -92,6 +92,12 @@ func ValidateSpec(spec *specs.Spec) error {
log.Warningf("AppArmor profile %q is being ignored", spec.Process.ApparmorProfile)
}
+ // PR_SET_NO_NEW_PRIVS is assumed to always be set.
+ // See kernel.Task.updateCredsForExecLocked.
+ if !spec.Process.NoNewPrivileges {
+ log.Warningf("noNewPrivileges ignored. PR_SET_NO_NEW_PRIVS is assumed to always be set.")
+ }
+
// TODO(gvisor.dev/issue/510): Apply seccomp to application inside sandbox.
if spec.Linux != nil && spec.Linux.Seccomp != nil {
log.Warningf("Seccomp spec is being ignored")
@@ -528,3 +534,8 @@ func EnvVar(env []string, name string) (string, bool) {
}
return "", false
}
+
+// FaqErrorMsg returns an error message pointing to the FAQ.
+func FaqErrorMsg(anchor, msg string) string {
+ return fmt.Sprintf("%s; see https://gvisor.dev/faq#%s for more details", msg, anchor)
+}
diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go
index 92d677e71..51e487715 100644
--- a/runsc/testutil/testutil.go
+++ b/runsc/testutil/testutil.go
@@ -87,18 +87,19 @@ func TestConfig() *boot.Config {
logDir = dir + "/"
}
return &boot.Config{
- Debug: true,
- DebugLog: logDir,
- LogFormat: "text",
- DebugLogFormat: "text",
- AlsoLogToStderr: true,
- LogPackets: true,
- Network: boot.NetworkNone,
- Strace: true,
- Platform: "ptrace",
- FileAccess: boot.FileAccessExclusive,
+ Debug: true,
+ DebugLog: logDir,
+ LogFormat: "text",
+ DebugLogFormat: "text",
+ AlsoLogToStderr: true,
+ LogPackets: true,
+ Network: boot.NetworkNone,
+ Strace: true,
+ Platform: "ptrace",
+ FileAccess: boot.FileAccessExclusive,
+ NumNetworkChannels: 1,
+
TestOnlyAllowRunAsCurrentUserWithoutChroot: true,
- NumNetworkChannels: 1,
}
}
diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh
index a0317db02..e0f6df438 100644..100755
--- a/scripts/benchmark.sh
+++ b/scripts/benchmark.sh
@@ -14,12 +14,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Run in the root of the repo.
-cd "$(dirname "$0")"
+source $(dirname $0)/common.sh
-KEY_PATH=${KEY_PATH:-"${KOKORO_KEYSTORE_DIR}/${KOKORO_SERVICE_ACCOUNT}"}
+# gcloud may be installed as a "snap". If it is, include it in PATH.
+declare -r snap="/snap/bin"
+if [[ -d "${snap}" ]]; then
+ export PATH="${PATH}:${snap}"
+fi
-gcloud auth activate-service-account --key-file "${KEY_PATH}"
+# Make sure we can find gcloud and exit if not.
+which gcloud
-gcloud compute instances list
+# Exporting for subprocesses as GCP APIs and tools check this environmental
+# variable for authentication.
+export GOOGLE_APPLICATION_CREDENTIALS="${KOKORO_KEYSTORE_DIR}/${GCLOUD_CREDENTIALS}"
+gcloud auth activate-service-account \
+ --key-file "${GOOGLE_APPLICATION_CREDENTIALS}"
+
+gcloud config set project ${PROJECT}
+gcloud config set compute/zone ${ZONE}
+
+bazel run //benchmarks:benchmarks -- \
+ --verbose \
+ run-gcp \
+ "(startup|absl)" \
+ --internal \
+ --runtime=runc \
+ --runtime=runsc \
+ --installers=head
diff --git a/scripts/build.sh b/scripts/build.sh
index 4c042af6c..7c9c99800 100755
--- a/scripts/build.sh
+++ b/scripts/build.sh
@@ -17,7 +17,7 @@
source $(dirname $0)/common.sh
# Install required packages for make_repository.sh et al.
-sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils xz-utils
+apt_install dpkg-sig coreutils apt-utils xz-utils
# Build runsc.
runsc=$(build -c opt //runsc)
diff --git a/scripts/common.sh b/scripts/common.sh
index 3ca699e4a..bc6ba71e8 100755
--- a/scripts/common.sh
+++ b/scripts/common.sh
@@ -84,3 +84,25 @@ function install_runsc() {
# Restart docker to pick up the new runtime configuration.
sudo systemctl restart docker
}
+
+# Installs the given packages. Note that the package names should be verified to
+# be correct, otherwise this may result in a loop that spins until time out.
+function apt_install() {
+ while true; do
+ sudo apt-get update &&
+ sudo apt-get install -y "$@" &&
+ true
+ result="${?}"
+ case $result in
+ 0)
+ break
+ ;;
+ 100)
+ # 100 is the error code that apt-get returns.
+ ;;
+ *)
+ exit $result
+ ;;
+ esac
+ done
+}
diff --git a/scripts/dev.sh b/scripts/dev.sh
index 6238b4d0b..a9107f33e 100755
--- a/scripts/dev.sh
+++ b/scripts/dev.sh
@@ -66,6 +66,7 @@ if [[ ${REFRESH} -eq 0 ]]; then
else
mkdir -p "$(dirname ${RUNSC_BIN})"
cp -f ${OUTPUT} "${RUNSC_BIN}"
+ chmod a+rx "${RUNSC_BIN}"
echo
echo "Runtime ${RUNTIME} refreshed."
diff --git a/scripts/iptables_tests.sh b/scripts/iptables_tests.sh
index 3069d8628..0f46909ac 100755
--- a/scripts/iptables_tests.sh
+++ b/scripts/iptables_tests.sh
@@ -16,12 +16,15 @@
source $(dirname $0)/common.sh
-install_runsc_for_test iptables
+install_runsc_for_test iptables --net-raw
# Build the docker image for the test.
-run //test/iptables/runner-image --norun
+run //test/iptables/runner:runner-image --norun
-# TODO(gvisor.dev/issue/170): Also test this on runsc once iptables are better
-# supported
-test //test/iptables:iptables_test "--test_arg=--runtime=runc" \
+test //test/iptables:iptables_test \
+ "--test_arg=--runtime=runc" \
+ "--test_arg=--image=bazel/test/iptables/runner:runner-image"
+
+test //test/iptables:iptables_test \
+ "--test_arg=--runtime=${RUNTIME}" \
"--test_arg=--image=bazel/test/iptables/runner:runner-image"
diff --git a/scripts/packetimpact_tests.sh b/scripts/packetimpact_tests.sh
new file mode 100755
index 000000000..027d11e64
--- /dev/null
+++ b/scripts/packetimpact_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+install_runsc_for_test runsc-d
+test_runsc $(bazel query "attr(tags, packetimpact, tests(//test/packetimpact/...))")
diff --git a/scripts/release.sh b/scripts/release.sh
index 091abf87f..e14ba04a7 100755
--- a/scripts/release.sh
+++ b/scripts/release.sh
@@ -25,6 +25,14 @@ if ! [[ -v KOKORO_RELEASE_TAG ]]; then
echo "No KOKORO_RELEASE_TAG provided." >&2
exit 1
fi
+if ! [[ -v KOKORO_RELNOTES ]]; then
+ echo "No KOKORO_RELNOTES provided." >&2
+ exit 1
+fi
+if ! [[ -r "${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}" ]]; then
+ echo "The file '${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}' is not readable." >&2
+ exit 1
+fi
# Unless an explicit releaser is provided, use the bot e-mail.
declare -r KOKORO_RELEASE_AUTHOR=${KOKORO_RELEASE_AUTHOR:-gvisor-bot}
@@ -46,4 +54,7 @@ EOF
fi
# Run the release tool, which pushes to the origin repository.
-tools/tag_release.sh "${KOKORO_RELEASE_COMMIT}" "${KOKORO_RELEASE_TAG}"
+tools/tag_release.sh \
+ "${KOKORO_RELEASE_COMMIT}" \
+ "${KOKORO_RELEASE_TAG}" \
+ "${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}"
diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go
index 4074d2285..594c8e752 100644
--- a/test/e2e/exec_test.go
+++ b/test/e2e/exec_test.go
@@ -240,17 +240,7 @@ func TestExecEnvHasHome(t *testing.T) {
}
d := dockerutil.MakeDocker("exec-env-home-test")
- // We will check that HOME is set for root user, and also for a new
- // non-root user we will create.
- newUID := 1234
- newHome := "/foo/bar"
-
- // Create a new user with a home directory, and then sleep.
- script := fmt.Sprintf(`
- mkdir -p -m 777 %s && \
- adduser foo -D -u %d -h %s && \
- sleep 1000`, newHome, newUID, newHome)
- if err := d.Run("alpine", "/bin/sh", "-c", script); err != nil {
+ if err := d.Run("alpine", "sleep", "1000"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
defer d.CleanUp()
@@ -264,7 +254,15 @@ func TestExecEnvHasHome(t *testing.T) {
t.Errorf("wanted exec output to contain %q, got %q", want, got)
}
- // Execute the same as uid 123 and expect newHome.
+ // Create a new user with a home directory.
+ newUID := 1234
+ newHome := "/foo/bar"
+ cmd := fmt.Sprintf("mkdir -p -m 777 %q && adduser foo -D -u %d -h %q", newHome, newUID, newHome)
+ if _, err := d.Exec("/bin/sh", "-c", cmd); err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+
+ // Execute the same as the new user and expect newHome.
got, err = d.ExecAsUser(strconv.Itoa(newUID), "/bin/sh", "-c", "echo $HOME")
if err != nil {
t.Fatalf("docker exec failed: %v", err)
diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go
index 28064e557..cc4fbbaed 100644
--- a/test/e2e/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -175,10 +175,8 @@ func TestCheckpointRestore(t *testing.T) {
t.Fatal(err)
}
- // TODO(b/143498576): Remove after github.com/moby/moby/issues/38963 is fixed.
- time.Sleep(1 * time.Second)
-
- if err := d.Restore("test"); err != nil {
+ // TODO(b/143498576): Remove Poll after github.com/moby/moby/issues/38963 is fixed.
+ if err := testutil.Poll(func() error { return d.Restore("test") }, 15*time.Second); err != nil {
t.Fatal("docker restore failed:", err)
}
diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go
index b2fb6401a..41e0cfa8d 100644
--- a/test/iptables/filter_input.go
+++ b/test/iptables/filter_input.go
@@ -47,6 +47,8 @@ func init() {
RegisterTestCase(FilterInputJumpReturnDrop{})
RegisterTestCase(FilterInputJumpBuiltin{})
RegisterTestCase(FilterInputJumpTwice{})
+ RegisterTestCase(FilterInputDestination{})
+ RegisterTestCase(FilterInputInvertDestination{})
}
// FilterInputDropUDP tests that we can drop UDP traffic.
@@ -106,7 +108,7 @@ func (FilterInputDropOnlyUDP) ContainerAction(ip net.IP) error {
func (FilterInputDropOnlyUDP) LocalAction(ip net.IP) error {
// Try to establish a TCP connection with the container, which should
// succeed.
- return connectTCP(ip, acceptPort, dropPort, sendloopDuration)
+ return connectTCP(ip, acceptPort, sendloopDuration)
}
// FilterInputDropUDPPort tests that we can drop UDP traffic by port.
@@ -192,8 +194,11 @@ func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDropTCPDestPort) LocalAction(ip net.IP) error {
- if err := connectTCP(ip, dropPort, acceptPort, sendloopDuration); err == nil {
- return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
+ // Ensure we cannot connect to the container.
+ for start := time.Now(); time.Since(start) < sendloopDuration; {
+ if err := connectTCP(ip, dropPort, sendloopDuration-time.Since(start)); err == nil {
+ return fmt.Errorf("expected not to connect, but was able to connect on port %d", dropPort)
+ }
}
return nil
@@ -209,13 +214,14 @@ func (FilterInputDropTCPSrcPort) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error {
- if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ // Drop anything from an ephemeral port.
+ if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", "1024:65535", "-j", "DROP"); err != nil {
return err
}
// Listen for TCP packets on accept port.
if err := listenTCP(acceptPort, sendloopDuration); err == nil {
- return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
+ return fmt.Errorf("connection destined to port %d should not be accepted, but was", dropPort)
}
return nil
@@ -223,8 +229,11 @@ func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDropTCPSrcPort) LocalAction(ip net.IP) error {
- if err := connectTCP(ip, acceptPort, dropPort, sendloopDuration); err == nil {
- return fmt.Errorf("connection on port %d should not be acceptedi, but got accepted", dropPort)
+ // Ensure we cannot connect to the container.
+ for start := time.Now(); time.Since(start) < sendloopDuration; {
+ if err := connectTCP(ip, acceptPort, sendloopDuration-time.Since(start)); err == nil {
+ return fmt.Errorf("expected not to connect, but was able to connect on port %d", acceptPort)
+ }
}
return nil
@@ -595,3 +604,66 @@ func (FilterInputJumpTwice) ContainerAction(ip net.IP) error {
func (FilterInputJumpTwice) LocalAction(ip net.IP) error {
return sendUDPLoop(ip, acceptPort, sendloopDuration)
}
+
+// FilterInputDestination verifies that we can filter packets via `-d
+// <ipaddr>`.
+type FilterInputDestination struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDestination) Name() string {
+ return "FilterInputDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDestination) ContainerAction(ip net.IP) error {
+ addrs, err := localAddrs()
+ if err != nil {
+ return err
+ }
+
+ // Make INPUT's default action DROP, then ACCEPT all packets bound for
+ // this machine.
+ rules := [][]string{{"-P", "INPUT", "DROP"}}
+ for _, addr := range addrs {
+ rules = append(rules, []string{"-A", "INPUT", "-d", addr, "-j", "ACCEPT"})
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDestination) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputInvertDestination verifies that we can filter packets via `! -d
+// <ipaddr>`.
+type FilterInputInvertDestination struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputInvertDestination) Name() string {
+ return "FilterInputInvertDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputInvertDestination) ContainerAction(ip net.IP) error {
+ // Make INPUT's default action DROP, then ACCEPT all packets not bound
+ // for 127.0.0.1.
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-A", "INPUT", "!", "-d", localIP, "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputInvertDestination) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go
index ee2c49f9a..f6d974b85 100644
--- a/test/iptables/filter_output.go
+++ b/test/iptables/filter_output.go
@@ -22,9 +22,17 @@ import (
func init() {
RegisterTestCase(FilterOutputDropTCPDestPort{})
RegisterTestCase(FilterOutputDropTCPSrcPort{})
+ RegisterTestCase(FilterOutputDestination{})
+ RegisterTestCase(FilterOutputInvertDestination{})
+ RegisterTestCase(FilterOutputAcceptTCPOwner{})
+ RegisterTestCase(FilterOutputDropTCPOwner{})
+ RegisterTestCase(FilterOutputAcceptUDPOwner{})
+ RegisterTestCase(FilterOutputDropUDPOwner{})
+ RegisterTestCase(FilterOutputOwnerFail{})
}
-// FilterOutputDropTCPDestPort tests that connections are not accepted on specified source ports.
+// FilterOutputDropTCPDestPort tests that connections are not accepted on
+// specified source ports.
type FilterOutputDropTCPDestPort struct{}
// Name implements TestCase.Name.
@@ -48,14 +56,15 @@ func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP) error {
- if err := connectTCP(ip, acceptPort, dropPort, sendloopDuration); err == nil {
+ if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil {
return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort)
}
return nil
}
-// FilterOutputDropTCPSrcPort tests that connections are not accepted on specified source ports.
+// FilterOutputDropTCPSrcPort tests that connections are not accepted on
+// specified source ports.
type FilterOutputDropTCPSrcPort struct{}
// Name implements TestCase.Name.
@@ -79,9 +88,201 @@ func (FilterOutputDropTCPSrcPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP) error {
- if err := connectTCP(ip, dropPort, acceptPort, sendloopDuration); err == nil {
+ if err := connectTCP(ip, dropPort, sendloopDuration); err == nil {
return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
}
return nil
}
+
+// FilterOutputAcceptTCPOwner tests that TCP connections from uid owner are accepted.
+type FilterOutputAcceptTCPOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputAcceptTCPOwner) Name() string {
+ return "FilterOutputAcceptTCPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputAcceptTCPOwner) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ if err := listenTCP(acceptPort, sendloopDuration); err != nil {
+ return fmt.Errorf("connection on port %d should be accepted, but got dropped", acceptPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputAcceptTCPOwner) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, acceptPort, sendloopDuration); err != nil {
+ return fmt.Errorf("connection destined to port %d should be accepted, but got dropped", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputDropTCPOwner tests that TCP connections from uid owner are dropped.
+type FilterOutputDropTCPOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputDropTCPOwner) Name() string {
+ return "FilterOutputDropTCPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropTCPOwner) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ if err := listenTCP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should be dropped, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropTCPOwner) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should be dropped, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputAcceptUDPOwner tests that UDP packets from uid owner are accepted.
+type FilterOutputAcceptUDPOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputAcceptUDPOwner) Name() string {
+ return "FilterOutputAcceptUDPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputAcceptUDPOwner) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ // Send UDP packets on acceptPort.
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputAcceptUDPOwner) LocalAction(ip net.IP) error {
+ // Listen for UDP packets on acceptPort.
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// FilterOutputDropUDPOwner tests that UDP packets from uid owner are dropped.
+type FilterOutputDropUDPOwner struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputDropUDPOwner) Name() string {
+ return "FilterOutputDropUDPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropUDPOwner) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Send UDP packets on dropPort.
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropUDPOwner) LocalAction(ip net.IP) error {
+ // Listen for UDP packets on dropPort.
+ if err := listenUDP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets should not be received")
+ }
+
+ return nil
+}
+
+// FilterOutputOwnerFail tests that without uid/gid option, owner rule
+// will fail.
+type FilterOutputOwnerFail struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputOwnerFail) Name() string {
+ return "FilterOutputOwnerFail"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputOwnerFail) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil {
+ return fmt.Errorf("Invalid argument")
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputOwnerFail) LocalAction(ip net.IP) error {
+ // no-op.
+ return nil
+}
+
+// FilterOutputDestination tests that we can selectively allow packets to
+// certain destinations.
+type FilterOutputDestination struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputDestination) Name() string {
+ return "FilterOutputDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDestination) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"},
+ {"-P", "OUTPUT", "DROP"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDestination) LocalAction(ip net.IP) error {
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// FilterOutputInvertDestination tests that we can selectively allow packets
+// not headed for a particular destination.
+type FilterOutputInvertDestination struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInvertDestination) Name() string {
+ return "FilterOutputInvertDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInvertDestination) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-A", "OUTPUT", "!", "-d", localIP, "-j", "ACCEPT"},
+ {"-P", "OUTPUT", "DROP"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInvertDestination) LocalAction(ip net.IP) error {
+ return listenUDP(acceptPort, sendloopDuration)
+}
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index 0621861eb..493d69052 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -191,17 +191,37 @@ func TestFilterInputDropOnlyUDP(t *testing.T) {
}
func TestNATRedirectUDPPort(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
if err := singleTest(NATRedirectUDPPort{}); err != nil {
t.Fatal(err)
}
}
+func TestNATRedirectTCPPort(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATRedirectTCPPort{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
func TestNATDropUDP(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
if err := singleTest(NATDropUDP{}); err != nil {
t.Fatal(err)
}
}
+func TestNATAcceptAll(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATAcceptAll{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
func TestFilterInputDropTCPDestPort(t *testing.T) {
if err := singleTest(FilterInputDropTCPDestPort{}); err != nil {
t.Fatal(err)
@@ -239,17 +259,51 @@ func TestFilterInputReturnUnderflow(t *testing.T) {
}
func TestFilterOutputDropTCPDestPort(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("filter OUTPUT isn't supported yet (gvisor.dev/issue/170).")
if err := singleTest(FilterOutputDropTCPDestPort{}); err != nil {
t.Fatal(err)
}
}
func TestFilterOutputDropTCPSrcPort(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("filter OUTPUT isn't supported yet (gvisor.dev/issue/170).")
if err := singleTest(FilterOutputDropTCPSrcPort{}); err != nil {
t.Fatal(err)
}
}
+func TestFilterOutputAcceptTCPOwner(t *testing.T) {
+ if err := singleTest(FilterOutputAcceptTCPOwner{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterOutputDropTCPOwner(t *testing.T) {
+ if err := singleTest(FilterOutputDropTCPOwner{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterOutputAcceptUDPOwner(t *testing.T) {
+ if err := singleTest(FilterOutputAcceptUDPOwner{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterOutputDropUDPOwner(t *testing.T) {
+ if err := singleTest(FilterOutputDropUDPOwner{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterOutputOwnerFail(t *testing.T) {
+ if err := singleTest(FilterOutputOwnerFail{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
func TestJumpSerialize(t *testing.T) {
if err := singleTest(FilterInputSerializeJump{}); err != nil {
t.Fatal(err)
@@ -285,3 +339,83 @@ func TestJumpTwice(t *testing.T) {
t.Fatal(err)
}
}
+
+func TestInputDestination(t *testing.T) {
+ if err := singleTest(FilterInputDestination{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestInputInvertDestination(t *testing.T) {
+ if err := singleTest(FilterInputInvertDestination{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestOutputDestination(t *testing.T) {
+ if err := singleTest(FilterOutputDestination{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestOutputInvertDestination(t *testing.T) {
+ if err := singleTest(FilterOutputInvertDestination{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATOutRedirectIP(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATOutRedirectIP{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATOutDontRedirectIP(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATOutDontRedirectIP{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATOutRedirectInvert(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATOutRedirectInvert{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATPreRedirectIP(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATPreRedirectIP{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATPreDontRedirectIP(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATPreDontRedirectIP{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATPreRedirectInvert(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATPreRedirectInvert{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNATRedirectRequiresProtocol(t *testing.T) {
+ // TODO(gvisor.dev/issue/170): Enable when supported.
+ t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
+ if err := singleTest(NATRedirectRequiresProtocol{}); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
index 32cf5a417..134391e8d 100644
--- a/test/iptables/iptables_util.go
+++ b/test/iptables/iptables_util.go
@@ -24,6 +24,7 @@ import (
)
const iptablesBinary = "iptables"
+const localIP = "127.0.0.1"
// filterTable calls `iptables -t filter` with the given args.
func filterTable(args ...string) error {
@@ -46,8 +47,17 @@ func tableCmd(table string, args []string) error {
// filterTableRules is like filterTable, but runs multiple iptables commands.
func filterTableRules(argsList [][]string) error {
+ return tableRules("filter", argsList)
+}
+
+// natTableRules is like natTable, but runs multiple iptables commands.
+func natTableRules(argsList [][]string) error {
+ return tableRules("nat", argsList)
+}
+
+func tableRules(table string, argsList [][]string) error {
for _, args := range argsList {
- if err := filterTable(args...); err != nil {
+ if err := tableCmd(table, args); err != nil {
return err
}
}
@@ -125,27 +135,37 @@ func listenTCP(port int, timeout time.Duration) error {
return nil
}
-// connectTCP connects the TCP server over specified local port, server IP and remote/server port.
-func connectTCP(ip net.IP, remotePort, localPort int, timeout time.Duration) error {
+// connectTCP connects to the given IP and port from an ephemeral local address.
+func connectTCP(ip net.IP, port int, timeout time.Duration) error {
contAddr := net.TCPAddr{
IP: ip,
- Port: remotePort,
+ Port: port,
}
// The container may not be listening when we first connect, so retry
// upon error.
callback := func() error {
- localAddr := net.TCPAddr{
- Port: localPort,
- }
- conn, err := net.DialTCP("tcp4", &localAddr, &contAddr)
+ conn, err := net.DialTimeout("tcp", contAddr.String(), timeout)
if conn != nil {
conn.Close()
}
return err
}
if err := testutil.Poll(callback, timeout); err != nil {
- return fmt.Errorf("timed out waiting to send IP, most recent error: %v", err)
+ return fmt.Errorf("timed out waiting to connect IP, most recent error: %v", err)
}
return nil
}
+
+// localAddrs returns a list of local network interface addresses.
+func localAddrs() ([]string, error) {
+ addrs, err := net.InterfaceAddrs()
+ if err != nil {
+ return nil, err
+ }
+ addrStrs := make([]string, 0, len(addrs))
+ for _, addr := range addrs {
+ addrStrs = append(addrStrs, addr.String())
+ }
+ return addrStrs, nil
+}
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index a01117ec8..40096901c 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -15,8 +15,10 @@
package iptables
import (
+ "errors"
"fmt"
"net"
+ "time"
)
const (
@@ -25,7 +27,16 @@ const (
func init() {
RegisterTestCase(NATRedirectUDPPort{})
+ RegisterTestCase(NATRedirectTCPPort{})
RegisterTestCase(NATDropUDP{})
+ RegisterTestCase(NATAcceptAll{})
+ RegisterTestCase(NATPreRedirectIP{})
+ RegisterTestCase(NATPreDontRedirectIP{})
+ RegisterTestCase(NATPreRedirectInvert{})
+ RegisterTestCase(NATOutRedirectIP{})
+ RegisterTestCase(NATOutDontRedirectIP{})
+ RegisterTestCase(NATOutRedirectInvert{})
+ RegisterTestCase(NATRedirectRequiresProtocol{})
}
// NATRedirectUDPPort tests that packets are redirected to different port.
@@ -45,6 +56,7 @@ func (NATRedirectUDPPort) ContainerAction(ip net.IP) error {
if err := listenUDP(redirectPort, sendloopDuration); err != nil {
return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", redirectPort, err)
}
+
return nil
}
@@ -53,7 +65,31 @@ func (NATRedirectUDPPort) LocalAction(ip net.IP) error {
return sendUDPLoop(ip, acceptPort, sendloopDuration)
}
-// NATDropUDP tests that packets are not received in ports other than redirect port.
+// NATRedirectTCPPort tests that connections are redirected on specified ports.
+type NATRedirectTCPPort struct{}
+
+// Name implements TestCase.Name.
+func (NATRedirectTCPPort) Name() string {
+ return "NATRedirectTCPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATRedirectTCPPort) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on redirect port.
+ return listenTCP(redirectPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATRedirectTCPPort) LocalAction(ip net.IP) error {
+ return connectTCP(ip, dropPort, sendloopDuration)
+}
+
+// NATDropUDP tests that packets are not received in ports other than redirect
+// port.
type NATDropUDP struct{}
// Name implements TestCase.Name.
@@ -78,3 +114,218 @@ func (NATDropUDP) ContainerAction(ip net.IP) error {
func (NATDropUDP) LocalAction(ip net.IP) error {
return sendUDPLoop(ip, acceptPort, sendloopDuration)
}
+
+// NATAcceptAll tests that all UDP packets are accepted.
+type NATAcceptAll struct{}
+
+// Name implements TestCase.Name.
+func (NATAcceptAll) Name() string {
+ return "NATAcceptAll"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATAcceptAll) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ if err := listenUDP(acceptPort, sendloopDuration); err != nil {
+ return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATAcceptAll) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// NATOutRedirectIP uses iptables to select packets based on destination IP and
+// redirects them.
+type NATOutRedirectIP struct{}
+
+// Name implements TestCase.Name.
+func (NATOutRedirectIP) Name() string {
+ return "NATOutRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectIP) ContainerAction(ip net.IP) error {
+ // Redirect OUTPUT packets to a listening localhost port.
+ dest := net.IP([]byte{200, 0, 0, 2})
+ return loopbackTest(dest, "-A", "OUTPUT", "-d", dest.String(), "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort))
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectIP) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// NATOutDontRedirectIP tests that iptables matching with "-d" does not match
+// packets it shouldn't.
+type NATOutDontRedirectIP struct{}
+
+// Name implements TestCase.Name.
+func (NATOutDontRedirectIP) Name() string {
+ return "NATOutDontRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutDontRedirectIP) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "OUTPUT", "-d", localIP, "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil {
+ return err
+ }
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutDontRedirectIP) LocalAction(ip net.IP) error {
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// NATOutRedirectInvert tests that iptables can match with "! -d".
+type NATOutRedirectInvert struct{}
+
+// Name implements TestCase.Name.
+func (NATOutRedirectInvert) Name() string {
+ return "NATOutRedirectInvert"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectInvert) ContainerAction(ip net.IP) error {
+ // Redirect OUTPUT packets to a listening localhost port.
+ dest := []byte{200, 0, 0, 3}
+ destStr := "200.0.0.2"
+ return loopbackTest(dest, "-A", "OUTPUT", "!", "-d", destStr, "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort))
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectInvert) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// NATPreRedirectIP tests that we can use iptables to select packets based on
+// destination IP and redirect them.
+type NATPreRedirectIP struct{}
+
+// Name implements TestCase.Name.
+func (NATPreRedirectIP) Name() string {
+ return "NATPreRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRedirectIP) ContainerAction(ip net.IP) error {
+ addrs, err := localAddrs()
+ if err != nil {
+ return err
+ }
+
+ var rules [][]string
+ for _, addr := range addrs {
+ rules = append(rules, []string{"-A", "PREROUTING", "-p", "udp", "-d", addr, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)})
+ }
+ if err := natTableRules(rules); err != nil {
+ return err
+ }
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRedirectIP) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// NATPreDontRedirectIP tests that iptables matching with "-d" does not match
+// packets it shouldn't.
+type NATPreDontRedirectIP struct{}
+
+// Name implements TestCase.Name.
+func (NATPreDontRedirectIP) Name() string {
+ return "NATPreDontRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreDontRedirectIP) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "udp", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil {
+ return err
+ }
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreDontRedirectIP) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// NATPreRedirectInvert tests that iptables can match with "! -d".
+type NATPreRedirectInvert struct{}
+
+// Name implements TestCase.Name.
+func (NATPreRedirectInvert) Name() string {
+ return "NATPreRedirectInvert"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRedirectInvert) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "udp", "!", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil {
+ return err
+ }
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRedirectInvert) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// NATRedirectRequiresProtocol tests that use of the --to-ports flag requires a
+// protocol to be specified with -p.
+type NATRedirectRequiresProtocol struct{}
+
+// Name implements TestCase.Name.
+func (NATRedirectRequiresProtocol) Name() string {
+ return "NATRedirectRequiresProtocol"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATRedirectRequiresProtocol) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err == nil {
+ return errors.New("expected an error using REDIRECT --to-ports without a protocol")
+ }
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATRedirectRequiresProtocol) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// loopbackTests runs an iptables rule and ensures that packets sent to
+// dest:dropPort are received by localhost:acceptPort.
+func loopbackTest(dest net.IP, args ...string) error {
+ if err := natTable(args...); err != nil {
+ return err
+ }
+ sendCh := make(chan error)
+ listenCh := make(chan error)
+ go func() {
+ sendCh <- sendUDPLoop(dest, dropPort, sendloopDuration)
+ }()
+ go func() {
+ listenCh <- listenUDP(acceptPort, sendloopDuration)
+ }()
+ select {
+ case err := <-listenCh:
+ if err != nil {
+ return err
+ }
+ case <-time.After(sendloopDuration):
+ return errors.New("timed out")
+ }
+ // sendCh will always take the full sendloop time.
+ return <-sendCh
+}
diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD
index d113555b1..fb0b2db41 100644
--- a/test/packetdrill/BUILD
+++ b/test/packetdrill/BUILD
@@ -1,8 +1,48 @@
-load("defs.bzl", "packetdrill_test")
+load("defs.bzl", "packetdrill_linux_test", "packetdrill_netstack_test", "packetdrill_test")
package(licenses = ["notice"])
packetdrill_test(
- name = "fin_wait2_timeout",
+ name = "packetdrill_sanity_test",
+ scripts = ["sanity_test.pkt"],
+)
+
+packetdrill_test(
+ name = "accept_ack_drop_test",
+ scripts = ["accept_ack_drop.pkt"],
+)
+
+packetdrill_test(
+ name = "fin_wait2_timeout_test",
scripts = ["fin_wait2_timeout.pkt"],
)
+
+packetdrill_linux_test(
+ name = "tcp_user_timeout_test_linux_test",
+ scripts = ["linux/tcp_user_timeout.pkt"],
+)
+
+packetdrill_netstack_test(
+ name = "tcp_user_timeout_test_netstack_test",
+ scripts = ["netstack/tcp_user_timeout.pkt"],
+)
+
+packetdrill_test(
+ name = "listen_close_before_handshake_complete_test",
+ scripts = ["listen_close_before_handshake_complete.pkt"],
+)
+
+packetdrill_test(
+ name = "no_rst_to_rst_test",
+ scripts = ["no_rst_to_rst.pkt"],
+)
+
+packetdrill_test(
+ name = "tcp_defer_accept_test",
+ scripts = ["tcp_defer_accept.pkt"],
+)
+
+packetdrill_test(
+ name = "tcp_defer_accept_timeout_test",
+ scripts = ["tcp_defer_accept_timeout.pkt"],
+)
diff --git a/test/packetdrill/Dockerfile b/test/packetdrill/Dockerfile
index bd4451355..4b75e9527 100644
--- a/test/packetdrill/Dockerfile
+++ b/test/packetdrill/Dockerfile
@@ -1,9 +1,9 @@
FROM ubuntu:bionic
-RUN apt-get update
-RUN apt-get install -y net-tools git iptables iputils-ping netcat tcpdump jq tar
+RUN apt-get update && apt-get install -y net-tools git iptables iputils-ping \
+ netcat tcpdump jq tar bison flex make
RUN hash -r
RUN git clone --branch packetdrill-v2.0 \
https://github.com/google/packetdrill.git
-RUN cd packetdrill/gtests/net/packetdrill && ./configure && \
- apt-get install -y bison flex make && make
+RUN cd packetdrill/gtests/net/packetdrill && ./configure && make
+CMD /bin/bash
diff --git a/test/packetdrill/accept_ack_drop.pkt b/test/packetdrill/accept_ack_drop.pkt
new file mode 100644
index 000000000..76e638fd4
--- /dev/null
+++ b/test/packetdrill/accept_ack_drop.pkt
@@ -0,0 +1,27 @@
+// Test that the accept works if the final ACK is dropped and an ack with data
+// follows the dropped ack.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0.0 > S. 0:0(0) ack 1 <...>
+
++0.0 < . 1:5(4) ack 1 win 257
++0.0 > . 1:1(0) ack 5 <...>
+
+// This should cause connection to transition to connected state.
++0.000 accept(3, ..., ...) = 4
++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0
+
+// Now read the data and we should get 4 bytes.
++0.000 read(4,..., 4) = 4
++0.000 close(4) = 0
+
++0.0 > F. 1:1(0) ack 5 <...>
++0.0 < F. 5:5(0) ack 2 win 257
++0.01 > . 2:2(0) ack 6 <...> \ No newline at end of file
diff --git a/test/packetdrill/defs.bzl b/test/packetdrill/defs.bzl
index 8623ce7b1..f499c177b 100644
--- a/test/packetdrill/defs.bzl
+++ b/test/packetdrill/defs.bzl
@@ -66,7 +66,7 @@ def packetdrill_linux_test(name, **kwargs):
if "tags" not in kwargs:
kwargs["tags"] = _PACKETDRILL_TAGS
_packetdrill_test(
- name = name + "_linux_test",
+ name = name,
flags = ["--dut_platform", "linux"],
**kwargs
)
@@ -75,13 +75,13 @@ def packetdrill_netstack_test(name, **kwargs):
if "tags" not in kwargs:
kwargs["tags"] = _PACKETDRILL_TAGS
_packetdrill_test(
- name = name + "_netstack_test",
+ name = name,
# This is the default runtime unless
# "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value.
flags = ["--dut_platform", "netstack", "--runtime", "runsc-d"],
**kwargs
)
-def packetdrill_test(**kwargs):
- packetdrill_linux_test(**kwargs)
- packetdrill_netstack_test(**kwargs)
+def packetdrill_test(name, **kwargs):
+ packetdrill_linux_test(name + "_linux_test", **kwargs)
+ packetdrill_netstack_test(name + "_netstack_test", **kwargs)
diff --git a/test/packetdrill/fin_wait2_timeout.pkt b/test/packetdrill/fin_wait2_timeout.pkt
index 613f0bec9..93ab08575 100644
--- a/test/packetdrill/fin_wait2_timeout.pkt
+++ b/test/packetdrill/fin_wait2_timeout.pkt
@@ -19,5 +19,5 @@
+0 > F. 1:1(0) ack 1 <...>
+0 < . 1:1(0) ack 2 win 257
-+1.1 < . 1:1(0) ack 2 win 257
++2 < . 1:1(0) ack 2 win 257
+0 > R 2:2(0) win 0
diff --git a/test/packetdrill/linux/tcp_user_timeout.pkt b/test/packetdrill/linux/tcp_user_timeout.pkt
new file mode 100644
index 000000000..38018cb42
--- /dev/null
+++ b/test/packetdrill/linux/tcp_user_timeout.pkt
@@ -0,0 +1,39 @@
+// Test that a socket w/ TCP_USER_TIMEOUT set aborts the connection
+// if there is pending unacked data after the user specified timeout.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
++0.1 < . 1:1(0) ack 1 win 32792
+
++0.100 accept(3, ..., ...) = 4
+
+// Okay, we received nothing, and decide to close this idle socket.
+// We set TCP_USER_TIMEOUT to 3 seconds because really it is not worth
+// trying hard to cleanly close this flow, at the price of keeping
+// a TCP structure in kernel for about 1 minute!
++2 setsockopt(4, SOL_TCP, TCP_USER_TIMEOUT, [3000], 4) = 0
+
+// The write/ack is required mainly for netstack as netstack does
+// not update its RTO during the handshake.
++0 write(4, ..., 100) = 100
++0 > P. 1:101(100) ack 1 <...>
++0 < . 1:1(0) ack 101 win 32792
+
++0 close(4) = 0
+
++0 > F. 101:101(0) ack 1 <...>
++.3~+.400 > F. 101:101(0) ack 1 <...>
++.3~+.400 > F. 101:101(0) ack 1 <...>
++.6~+.800 > F. 101:101(0) ack 1 <...>
++1.2~+1.300 > F. 101:101(0) ack 1 <...>
+
+// We finally receive something from the peer, but it is way too late
+// Our socket vanished because TCP_USER_TIMEOUT was really small.
++.1 < . 1:2(1) ack 102 win 32792
++0 > R 102:102(0) win 0
diff --git a/test/packetdrill/listen_close_before_handshake_complete.pkt b/test/packetdrill/listen_close_before_handshake_complete.pkt
new file mode 100644
index 000000000..51c3f1a32
--- /dev/null
+++ b/test/packetdrill/listen_close_before_handshake_complete.pkt
@@ -0,0 +1,31 @@
+// Test that closing a listening socket closes any connections in SYN-RCVD
+// state and any packets bound for these connections generate a RESET.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
+
++0.100 close(3) = 0
++0.1 < P. 1:1(0) ack 1 win 257
+
+// Linux generates a reset with no ack number/bit set. This is contradictory to
+// what is specified in Rule 1 under Reset Generation in
+// https://tools.ietf.org/html/rfc793#section-3.4.
+// "1. If the connection does not exist (CLOSED) then a reset is sent
+// in response to any incoming segment except another reset. In
+// particular, SYNs addressed to a non-existent connection are rejected
+// by this means.
+//
+// If the incoming segment has an ACK field, the reset takes its
+// sequence number from the ACK field of the segment, otherwise the
+// reset has sequence number zero and the ACK field is set to the sum
+// of the sequence number and segment length of the incoming segment.
+// The connection remains in the CLOSED state."
+
++0.0 > R 1:1(0) win 0 \ No newline at end of file
diff --git a/test/packetdrill/netstack/tcp_user_timeout.pkt b/test/packetdrill/netstack/tcp_user_timeout.pkt
new file mode 100644
index 000000000..60103adba
--- /dev/null
+++ b/test/packetdrill/netstack/tcp_user_timeout.pkt
@@ -0,0 +1,38 @@
+// Test that a socket w/ TCP_USER_TIMEOUT set aborts the connection
+// if there is pending unacked data after the user specified timeout.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
++0.1 < . 1:1(0) ack 1 win 32792
+
++0.100 accept(3, ..., ...) = 4
+
+// Okay, we received nothing, and decide to close this idle socket.
+// We set TCP_USER_TIMEOUT to 3 seconds because really it is not worth
+// trying hard to cleanly close this flow, at the price of keeping
+// a TCP structure in kernel for about 1 minute!
++2 setsockopt(4, SOL_TCP, TCP_USER_TIMEOUT, [3000], 4) = 0
+
+// The write/ack is required mainly for netstack as netstack does
+// not update its RTO during the handshake.
++0 write(4, ..., 100) = 100
++0 > P. 1:101(100) ack 1 <...>
++0 < . 1:1(0) ack 101 win 32792
+
++0 close(4) = 0
+
++0 > F. 101:101(0) ack 1 <...>
++.2~+.300 > F. 101:101(0) ack 1 <...>
++.4~+.500 > F. 101:101(0) ack 1 <...>
++.8~+.900 > F. 101:101(0) ack 1 <...>
+
+// We finally receive something from the peer, but it is way too late
+// Our socket vanished because TCP_USER_TIMEOUT was really small.
++1.61 < . 1:2(1) ack 102 win 32792
++0 > R 102:102(0) win 0
diff --git a/test/packetdrill/no_rst_to_rst.pkt b/test/packetdrill/no_rst_to_rst.pkt
new file mode 100644
index 000000000..612747827
--- /dev/null
+++ b/test/packetdrill/no_rst_to_rst.pkt
@@ -0,0 +1,36 @@
+// Test a RST is not generated in response to a RST and a RST is correctly
+// generated when an accepted endpoint is RST due to an incoming RST.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
++0 < P. 1:1(0) ack 1 win 257
+
++0.100 accept(3, ..., ...) = 4
+
++0.200 < R 1:1(0) win 0
+
++0.300 read(4,..., 4) = -1 ECONNRESET (Connection Reset by Peer)
+
++0.00 < . 1:1(0) ack 1 win 257
+
+// Linux generates a reset with no ack number/bit set. This is contradictory to
+// what is specified in Rule 1 under Reset Generation in
+// https://tools.ietf.org/html/rfc793#section-3.4.
+// "1. If the connection does not exist (CLOSED) then a reset is sent
+// in response to any incoming segment except another reset. In
+// particular, SYNs addressed to a non-existent connection are rejected
+// by this means.
+//
+// If the incoming segment has an ACK field, the reset takes its
+// sequence number from the ACK field of the segment, otherwise the
+// reset has sequence number zero and the ACK field is set to the sum
+// of the sequence number and segment length of the incoming segment.
+// The connection remains in the CLOSED state."
+
++0.00 > R 1:1(0) win 0 \ No newline at end of file
diff --git a/test/packetdrill/packetdrill_test.sh b/test/packetdrill/packetdrill_test.sh
index 0b22dfd5c..c8268170f 100755
--- a/test/packetdrill/packetdrill_test.sh
+++ b/test/packetdrill/packetdrill_test.sh
@@ -91,8 +91,8 @@ fi
# Variables specific to the test runner start with TEST_RUNNER_.
declare -r PACKETDRILL="/packetdrill/gtests/net/packetdrill/packetdrill"
# Use random numbers so that test networks don't collide.
-declare -r CTRL_NET="ctrl_net-${RANDOM}${RANDOM}"
-declare -r TEST_NET="test_net-${RANDOM}${RANDOM}"
+declare -r CTRL_NET="ctrl_net-$(shuf -i 0-99999999 -n 1)"
+declare -r TEST_NET="test_net-$(shuf -i 0-99999999 -n 1)"
declare -r tolerance_usecs=100000
# On both DUT and test runner, testing packets are on the eth2 interface.
declare -r TEST_DEVICE="eth2"
diff --git a/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt
new file mode 100644
index 000000000..a86b90ce6
--- /dev/null
+++ b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt
@@ -0,0 +1,9 @@
+// Test that a listening socket generates a RST when it receives an
+// ACK and syn cookies are not in use.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
++0.1 < . 1:1(0) ack 1 win 32792
++0 > R 1:1(0) ack 0 win 0 \ No newline at end of file
diff --git a/test/packetdrill/sanity_test.pkt b/test/packetdrill/sanity_test.pkt
new file mode 100644
index 000000000..b3b58c366
--- /dev/null
+++ b/test/packetdrill/sanity_test.pkt
@@ -0,0 +1,7 @@
+// Basic sanity test. One system call.
+//
+// All of the plumbing has to be working however, and the packetdrill wire
+// client needs to be able to connect to the wire server and send the script,
+// probe local interfaces, run through the test w/ timings, etc.
+
+0.000 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
diff --git a/test/packetdrill/tcp_defer_accept.pkt b/test/packetdrill/tcp_defer_accept.pkt
new file mode 100644
index 000000000..a17f946db
--- /dev/null
+++ b/test/packetdrill/tcp_defer_accept.pkt
@@ -0,0 +1,48 @@
+// Test that a bare ACK does not complete a connection when TCP_DEFER_ACCEPT
+// timeout is not hit but an ACK w/ data does complete and deliver the
+// connection to the accept queue.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [5], 4) = 0
++0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0.0 > S. 0:0(0) ack 1 <...>
+
+// Send a bare ACK this should not complete the connection as we
+// set the TCP_DEFER_ACCEPT above.
++0.0 < . 1:1(0) ack 1 win 257
+
+// The bare ACK should be dropped and no connection should be delivered
+// to the accept queue.
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT
+// to 5 seconds above.
++2.5 < . 1:1(0) ack 1 win 257
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// set accept socket back to blocking.
++0.000 fcntl(3, F_SETFL, O_RDWR) = 0
+
+// Now send an ACK w/ data. This should complete the connection
+// and deliver the socket to the accept queue.
++0.1 < . 1:5(4) ack 1 win 257
++0.0 > . 1:1(0) ack 5 <...>
+
+// This should cause connection to transition to connected state.
++0.000 accept(3, ..., ...) = 4
++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0
+
+// Now read the data and we should get 4 bytes.
++0.000 read(4,..., 4) = 4
++0.000 close(4) = 0
+
++0.0 > F. 1:1(0) ack 5 <...>
++0.0 < F. 5:5(0) ack 2 win 257
++0.01 > . 2:2(0) ack 6 <...> \ No newline at end of file
diff --git a/test/packetdrill/tcp_defer_accept_timeout.pkt b/test/packetdrill/tcp_defer_accept_timeout.pkt
new file mode 100644
index 000000000..201fdeb14
--- /dev/null
+++ b/test/packetdrill/tcp_defer_accept_timeout.pkt
@@ -0,0 +1,48 @@
+// Test that a bare ACK is accepted after TCP_DEFER_ACCEPT timeout
+// is hit and a connection is delivered.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [3], 4) = 0
++0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0.0 > S. 0:0(0) ack 1 <...>
+
+// Send a bare ACK this should not complete the connection as we
+// set the TCP_DEFER_ACCEPT above.
++0.0 < . 1:1(0) ack 1 win 257
+
+// The bare ACK should be dropped and no connection should be delivered
+// to the accept queue.
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT
+// to 5 seconds above.
++2.5 < . 1:1(0) ack 1 win 257
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// set accept socket back to blocking.
++0.000 fcntl(3, F_SETFL, O_RDWR) = 0
+
+// We should see one more retransmit of the SYN-ACK as a last ditch
+// attempt when TCP_DEFER_ACCEPT timeout is hit to trigger another
+// ACK or a packet with data.
++.35~+2.35 > S. 0:0(0) ack 1 <...>
+
+// Now send another bare ACK after TCP_DEFER_ACCEPT time has been passed.
++0.0 < . 1:1(0) ack 1 win 257
+
+// The ACK above should cause connection to transition to connected state.
++0.000 accept(3, ..., ...) = 4
++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0
+
++0.000 close(4) = 0
+
++0.0 > F. 1:1(0) ack 1 <...>
++0.0 < F. 1:1(0) ack 2 win 257
++0.01 > . 2:2(0) ack 2 <...>
diff --git a/test/packetimpact/README.md b/test/packetimpact/README.md
new file mode 100644
index 000000000..ece4dedc6
--- /dev/null
+++ b/test/packetimpact/README.md
@@ -0,0 +1,531 @@
+# Packetimpact
+
+## What is packetimpact?
+
+Packetimpact is a tool for platform-independent network testing. It is heavily
+inspired by [packetdrill](https://github.com/google/packetdrill). It creates two
+docker containers connected by a network. One is for the test bench, which
+operates the test. The other is for the device-under-test (DUT), which is the
+software being tested. The test bench communicates over the network with the DUT
+to check correctness of the network.
+
+### Goals
+
+Packetimpact aims to provide:
+
+* A **multi-platform** solution that can test both Linux and gVisor.
+* **Conciseness** on par with packetdrill scripts.
+* **Control-flow** like for loops, conditionals, and variables.
+* **Flexibilty** to specify every byte in a packet or use multiple sockets.
+
+## When to use packetimpact?
+
+There are a few ways to write networking tests for gVisor currently:
+
+* [Go unit tests](https://github.com/google/gvisor/tree/master/pkg/tcpip)
+* [syscall tests](https://github.com/google/gvisor/tree/master/test/syscalls/linux)
+* [packetdrill tests](https://github.com/google/gvisor/tree/master/test/packetdrill)
+* packetimpact tests
+
+The right choice depends on the needs of the test.
+
+Feature | Go unit test | syscall test | packetdrill | packetimpact
+------------- | ------------ | ------------ | ----------- | ------------
+Multiplatform | no | **YES** | **YES** | **YES**
+Concise | no | somewhat | somewhat | **VERY**
+Control-flow | **YES** | **YES** | no | **YES**
+Flexible | **VERY** | no | somewhat | **VERY**
+
+### Go unit tests
+
+If the test depends on the internals of gVisor and doesn't need to run on Linux
+or other platforms for comparison purposes, a Go unit test can be appropriate.
+They can observe internals of gVisor networking. The downside is that they are
+**not concise** and **not multiplatform**. If you require insight on gVisor
+internals, this is the right choice.
+
+### Syscall tests
+
+Syscall tests are **multiplatform** but cannot examine the internals of gVisor
+networking. They are **concise**. They can use **control-flow** structures like
+conditionals, for loops, and variables. However, they are limited to only what
+the POSIX interface provides so they are **not flexible**. For example, you
+would have difficulty writing a syscall test that intentionally sends a bad IP
+checksum. Or if you did write that test with raw sockets, it would be very
+**verbose** to write a test that intentionally send wrong checksums, wrong
+protocols, wrong sequence numbers, etc.
+
+### Packetdrill tests
+
+Packetdrill tests are **multiplatform** and can run against both Linux and
+gVisor. They are **concise** and use a special packetdrill scripting language.
+They are **more flexible** than a syscall test in that they can send packets
+that a syscall test would have difficulty sending, like a packet with a
+calcuated ACK number. But they are also somewhat limimted in flexibiilty in that
+they can't do tests with multiple sockets. They have **no control-flow** ability
+like variables or conditionals. For example, it isn't possible to send a packet
+that depends on the window size of a previous packet because the packetdrill
+language can't express that. Nor could you branch based on whether or not the
+other side supports window scaling, for example.
+
+### Packetimpact tests
+
+Packetimpact tests are similar to Packetdrill tests except that they are written
+in Go instead of the packetdrill scripting language. That gives them all the
+**control-flow** abilities of Go (loops, functions, variables, etc). They are
+**multiplatform** in the same way as packetdrill tests but even more
+**flexible** because Go is more expressive than the scripting language of
+packetdrill. However, Go is **not as concise** as the packetdrill language. Many
+design decisions below are made to mitigate that.
+
+## How it works
+
+```
+ +--------------+ +--------------+
+ | | TEST NET | |
+ | | <===========> | Device |
+ | Test | | Under |
+ | Bench | | Test |
+ | | <===========> | (DUT) |
+ | | CONTROL NET | |
+ +--------------+ +--------------+
+```
+
+Two docker containers are created by a script, one for the test bench and the
+other for the device under test (DUT). The script connects the two containers
+with a control network and test network. It also does some other tasks like
+waiting until the DUT is ready before starting the test and disabling Linux
+networking that would interfere with the test bench.
+
+### DUT
+
+The DUT container runs a program called the "posix_server". The posix_server is
+written in c++ for maximum portability. It is compiled on the host. The script
+that starts the containers copies it into the DUT's container and runs it. It's
+job is to receive directions from the test bench on what actions to take. For
+this, the posix_server does three steps in a loop:
+
+1. Listen for a request from the test bench.
+2. Execute a command.
+3. Send the response back to the test bench.
+
+The requests and responses are
+[protobufs](https://developers.google.com/protocol-buffers) and the
+communication is done with [gRPC](https://grpc.io/). The commands run are
+[POSIX socket commands](https://en.wikipedia.org/wiki/Berkeley_sockets#Socket_API_functions),
+with the inputs and outputs converted into protobuf requests and responses. All
+communication is on the control network, so that the test network is unaffected
+by extra packets.
+
+For example, this is the request and response pair to call
+[`socket()`](http://man7.org/linux/man-pages/man2/socket.2.html):
+
+```protocol-buffer
+message SocketRequest {
+ int32 domain = 1;
+ int32 type = 2;
+ int32 protocol = 3;
+}
+
+message SocketResponse {
+ int32 fd = 1;
+ int32 errno_ = 2;
+}
+```
+
+##### Alternatives considered
+
+* We could have use JSON for communication instead. It would have been a
+ lighter-touch than protobuf but protobuf handles all the data type and has
+ strict typing to prevent a class of errors. The test bench could be written
+ in other languages, too.
+* Instead of mimicking the POSIX interfaces, arguments could have had a more
+ natural form, like the `bind()` getting a string IP address instead of bytes
+ in a `sockaddr_t`. However, conforming to the existing structures keeps more
+ of the complexity in Go and keeps the posix_server simpler and thus more
+ likely to compile everywhere.
+
+### Test Bench
+
+The test bench does most of the work in a test. It is a Go program that compiles
+on the host and is copied by the script into test bench's container. It is a
+regular [go unit test](https://golang.org/pkg/testing/) that imports the test
+bench framework. The test bench framwork is based on three basic utilities:
+
+* Commanding the DUT to run POSIX commands and return responses.
+* Sending raw packets to the DUT on the test network.
+* Listening for raw packets from the DUT on the test network.
+
+#### DUT commands
+
+To keep the interface to the DUT consistent and easy-to-use, each POSIX command
+supported by the posix_server is wrapped in functions with signatures similar to
+the ones in the [Go unix package](https://godoc.org/golang.org/x/sys/unix). This
+way all the details of endianess and (un)marshalling of go structs such as
+[unix.Timeval](https://godoc.org/golang.org/x/sys/unix#Timeval) is handled in
+one place. This also makes it straight-forward to convert tests that use `unix.`
+or `syscall.` calls to `dut.` calls.
+
+For example, creating a connection to the DUT and commanding it to make a socket
+looks like this:
+
+```go
+dut := testbench.NewDut(t)
+fd, err := dut.SocketWithErrno(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP)
+if fd < 0 {
+ t.Fatalf(...)
+}
+```
+
+Because the usual case is to fail the test when the DUT fails to create a
+socket, there is a concise version of each of the `...WithErrno` functions that
+does that:
+
+```go
+dut := testbench.NewDut(t)
+fd := dut.Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP)
+```
+
+The DUT and other structs in the code store a `*testing.T` so that they can
+provide versions of functions that call `t.Fatalf(...)`. This helps keep tests
+concise.
+
+##### Alternatives considered
+
+* Instead of mimicking the `unix.` go interface, we could have invented a more
+ natural one, like using `float64` instead of `Timeval`. However, using the
+ same function signatures that `unix.` has makes it easier to convert code to
+ `dut.`. Also, using an existing interface ensures that we don't invent an
+ interface that isn't extensible. For example, if we invented a function for
+ `bind()` that didn't support IPv6 and later we had to add a second `bind6()`
+ function.
+
+#### Sending/Receiving Raw Packets
+
+The framework wraps POSIX sockets for sending and receiving raw frames. Both
+send and receive are synchronous commands.
+[SO_RCVTIMEO](http://man7.org/linux/man-pages/man7/socket.7.html) is used to set
+a timeout on the receive commands. For ease of use, these are wrapped in an
+`Injector` and a `Sniffer`. They have functions:
+
+```go
+func (s *Sniffer) Recv(timeout time.Duration) []byte {...}
+func (i *Injector) Send(b []byte) {...}
+```
+
+##### Alternatives considered
+
+* [gopacket](https://github.com/google/gopacket) pcap has raw socket support
+ but requires cgo. cgo is not guaranteed to be portable from the host to the
+ container and in practice, the container doesn't recognize binaries built on
+ the host if they use cgo.
+* Both gVisor and gopacket have the ability to read and write pcap files
+ without cgo but that is insufficient here.
+* The sniffer and injector can't share a socket because they need to be bound
+ differently.
+* Sniffing could have been done asynchronously with channels, obviating the
+ need for `SO_RCVTIMEO`. But that would introduce asynchronous complication.
+ `SO_RCVTIMEO` is well supported on the test bench.
+
+#### `Layer` struct
+
+A large part of packetimpact tests is creating packets to send and comparing
+received packets against expectations. To keep tests concise, it is useful to be
+able to specify just the important parts of packets that need to be set. For
+example, sending a packet with default values except for TCP Flags. And for
+packets received, it's useful to be able to compare just the necessary parts of
+received packets and ignore the rest.
+
+To aid in both of those, Go structs with optional fields are created for each
+encapsulation type, such as IPv4, TCP, and Ethernet. This is inspired by
+[scapy](https://scapy.readthedocs.io/en/latest/). For example, here is the
+struct for Ethernet:
+
+```go
+type Ether struct {
+ LayerBase
+ SrcAddr *tcpip.LinkAddress
+ DstAddr *tcpip.LinkAddress
+ Type *tcpip.NetworkProtocolNumber
+}
+```
+
+Each struct has the same fields as those in the
+[gVisor headers](https://github.com/google/gvisor/tree/master/pkg/tcpip/header)
+but with a pointer for each field that may be `nil`.
+
+##### Alternatives considered
+
+* Just use []byte like gVisor headers do. The drawback is that it makes the
+ tests more verbose.
+ * For example, there would be no way to call `Send(myBytes)` concisely and
+ indicate if the checksum should be calculated automatically versus
+ overridden. The only way would be to add lines to the test to calculate
+ it before each Send, which is wordy. Or make multiple versions of Send:
+ one that checksums IP, one that doesn't, one that checksums TCP, one
+ that does both, etc. That would be many combinations.
+ * Filtering inputs would become verbose. Either:
+ * large conditionals that need to be repeated many places:
+ `h[FlagOffset] == SYN && h[LengthOffset:LengthOffset+2] == ...` or
+ * Many functions, one per field, like: `filterByFlag(myBytes, SYN)`,
+ `filterByLength(myBytes, 20)`, `filterByNextProto(myBytes, 0x8000)`,
+ etc.
+ * Using pointers allows us to combine `Layer`s with a one-line call to
+ `mergo.Merge(...)`. So the default `Layers` can be overridden by a
+ `Layers` with just the TCP conection's src/dst which can be overridden
+ by one with just a test specific TCP window size. Each override is
+ specified as just one call to `mergo.Merge`.
+ * It's a proven way to separate the details of a packet from the byte
+ format as shown by scapy's success.
+* Use packetgo. It's more general than parsing packets with gVisor. However:
+ * packetgo doesn't have optional fields so many of the above problems
+ still apply.
+ * It would be yet another dependency.
+ * It's not as well known to engineers that are already writing gVisor
+ code.
+ * It might be a good candidate for replacing the parsing of packets into
+ `Layer`s if all that parsing turns out to be more work than parsing by
+ packetgo and converting *that* to `Layer`. packetgo has easier to use
+ getters for the layers. This could be done later in a way that doesn't
+ break tests.
+
+#### `Layer` methods
+
+The `Layer` structs provide a way to partially specify an encapsulation. They
+also need methods for using those partially specified encapsulation, for example
+to marshal them to bytes or compare them. For those, each encapsulation
+implements the `Layer` interface:
+
+```go
+// Layer is the interface that all encapsulations must implement.
+//
+// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A
+// Layer contains all the fields of the encapsulation. Each field is a pointer
+// and may be nil.
+type Layer interface {
+ // toBytes converts the Layer into bytes. In places where the Layer's field
+ // isn't nil, the value that is pointed to is used. When the field is nil, a
+ // reasonable default for the Layer is used. For example, "64" for IPv4 TTL
+ // and a calculated checksum for TCP or IP. Some layers require information
+ // from the previous or next layers in order to compute a default, such as
+ // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list
+ // to the layer's neighbors.
+ toBytes() ([]byte, error)
+
+ // match checks if the current Layer matches the provided Layer. If either
+ // Layer has a nil in a given field, that field is considered matching.
+ // Otherwise, the values pointed to by the fields must match.
+ match(Layer) bool
+
+ // length in bytes of the current encapsulation
+ length() int
+
+ // next gets a pointer to the encapsulated Layer.
+ next() Layer
+
+ // prev gets a pointer to the Layer encapsulating this one.
+ prev() Layer
+
+ // setNext sets the pointer to the encapsulated Layer.
+ setNext(Layer)
+
+ // setPrev sets the pointer to the Layer encapsulating this one.
+ setPrev(Layer)
+}
+```
+
+For each `Layer` there is also a parsing function. For example, this one is for
+Ethernet:
+
+```
+func ParseEther(b []byte) (Layers, error)
+```
+
+The parsing function converts bytes received on the wire into a `Layer`
+(actually `Layers`, see below) which has no `nil`s in it. By using
+`match(Layer)` to compare against another `Layer` that *does* have `nil`s in it,
+the received bytes can be partially compared. The `nil`s behave as
+"don't-cares".
+
+##### Alternatives considered
+
+* Matching against `[]byte` instead of converting to `Layer` first.
+ * The downside is that it precludes the use of a `cmp.Equal` one-liner to
+ do comparisons.
+ * It creates confusion in the code to deal with both representations at
+ different times. For example, is the checksum calculated on `[]byte` or
+ `Layer` when sending? What about when checking received packets?
+
+#### `Layers`
+
+```
+type Layers []Layer
+
+func (ls *Layers) match(other Layers) bool {...}
+func (ls *Layers) toBytes() ([]byte, error) {...}
+```
+
+`Layers` is an array of `Layer`. It represents a stack of encapsulations, such
+as `Layers{Ether{},IPv4{},TCP{},Payload{}}`. It also has `toBytes()` and
+`match(Layers)`, like `Layer`. The parse functions above actually return
+`Layers` and not `Layer` because they know about the headers below and
+sequentially call each parser on the remaining, encapsulated bytes.
+
+All this leads to the ability to write concise packet processing. For example:
+
+```go
+etherType := 0x8000
+flags = uint8(header.TCPFlagSyn|header.TCPFlagAck)
+toMatch := Layers{Ether{Type: &etherType}, IPv4{}, TCP{Flags: &flags}}
+for {
+ recvBytes := sniffer.Recv(time.Second)
+ if recvBytes == nil {
+ println("Got no packet for 1 second")
+ }
+ gotPacket, err := ParseEther(recvBytes)
+ if err == nil && toMatch.match(gotPacket) {
+ println("Got a TCP/IPv4/Eth packet with SYNACK")
+ }
+}
+```
+
+##### Alternatives considered
+
+* Don't use previous and next pointers.
+ * Each layer may need to be able to interrogate the layers aroung it, like
+ for computing the next protocol number or total length. So *some*
+ mechanism is needed for a `Layer` to see neighboring layers.
+ * We could pass the entire array `Layers` to the `toBytes()` function.
+ Passing an array to a method that includes in the array the function
+ receiver itself seems wrong.
+
+#### Connections
+
+Using `Layers` above, we can create connection structures to maintain state
+about connections. For example, here is the `TCPIPv4` struct:
+
+```
+type TCPIPv4 struct {
+ outgoing Layers
+ incoming Layers
+ localSeqNum uint32
+ remoteSeqNum uint32
+ sniffer Sniffer
+ injector Injector
+ t *testing.T
+}
+```
+
+`TCPIPv4` contains an `outgoing Layers` which holds the defaults for the
+connection, such as the source and destination MACs, IPs, and ports. When
+`outgoing.toBytes()` is called a valid packet for this TCPIPv4 flow is built.
+
+It also contains `incoming Layers` which holds filter for incoming packets that
+belong to this flow. `incoming.match(Layers)` is used on received bytes to check
+if they are part of the flow.
+
+The `sniffer` and `injector` are for receiving and sending raw packet bytes. The
+`localSeqNum` and `remoteSeqNum` are updated by `Send` and `Recv` so that
+outgoing packets will have, by default, the correct sequence number and ack
+number.
+
+TCPIPv4 provides some functions:
+
+```
+func (conn *TCPIPv4) Send(tcp TCP) {...}
+func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP {...}
+```
+
+`Send(tcp TCP)` uses [mergo](https://github.com/imdario/mergo) to merge the
+provided `TCP` (a `Layer`) into `outgoing`. This way the user can specify
+concisely just which fields of `outgoing` to modify. The packet is sent using
+the `injector`.
+
+`Recv(timeout time.Duration)` reads packets from the sniffer until either the
+timeout has elapsed or a packet that matches `incoming` arrives.
+
+Using those, we can perform a TCP 3-way handshake without too much code:
+
+```go
+func (conn *TCPIPv4) Handshake() {
+ syn := uint8(header.TCPFlagSyn)
+ synack := uint8(header.TCPFlagSyn)
+ ack := uint8(header.TCPFlagAck)
+ conn.Send(TCP{Flags: &syn}) // Send a packet with all defaults but set TCP-SYN.
+
+ // Wait for the SYN-ACK response.
+ for {
+ newTCP := conn.Recv(time.Second) // This already filters by MAC, IP, and ports.
+ if TCP{Flags: &synack}.match(newTCP) {
+ break // Only if it's a SYN-ACK proceed.
+ }
+ }
+
+ conn.Send(TCP{Flags: &ack}) // Send an ACK. The seq and ack numbers are set correctly.
+}
+```
+
+The handshake code is part of the testbench utilities so tests can share this
+common sequence, making tests even more concise.
+
+##### Alternatives considered
+
+* Instead of storing `outgoing` and `incoming`, store values.
+ * There would be many more things to store instead, like `localMac`,
+ `remoteMac`, `localIP`, `remoteIP`, `localPort`, and `remotePort`.
+ * Construction of a packet would be many lines to copy each of these
+ values into a `[]byte`. And there would be slight variations needed for
+ each encapsulation stack, like TCPIPv6 and ARP.
+ * Filtering incoming packets would be a long sequence:
+ * Compare the MACs, then
+ * Parse the next header, then
+ * Compare the IPs, then
+ * Parse the next header, then
+ * Compare the TCP ports. Instead it's all just one call to
+ `cmp.Equal(...)`, for all sequences.
+ * A TCPIPv6 connection could share most of the code. Only the type of the
+ IP addresses are different. The types of `outgoing` and `incoming` would
+ be remain `Layers`.
+ * An ARP connection could share all the Ethernet parts. The IP `Layer`
+ could be factored out of `outgoing`. After that, the IPv4 and IPv6
+ connections could implement one interface and a single TCP struct could
+ have either network protocol through composition.
+
+## Putting it all together
+
+Here's what te start of a packetimpact unit test looks like. This test creates a
+TCP connection with the DUT. There are added comments for explanation in this
+document but a real test might not include them in order to stay even more
+concise.
+
+```go
+func TestMyTcpTest(t *testing.T) {
+ // Prepare a DUT for communication.
+ dut := testbench.NewDUT(t)
+
+ // This does:
+ // dut.Socket()
+ // dut.Bind()
+ // dut.Getsockname() to learn the new port number
+ // dut.Listen()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD) // Tell the DUT to close the socket at the end of the test.
+
+ // Monitor a new TCP connection with sniffer, injector, sequence number tracking,
+ // and reasonable outgoing and incoming packet field default IPs, MACs, and port numbers.
+ conn := testbench.NewTCPIPv4(t, dut, remotePort)
+
+ // Perform a 3-way handshake: send SYN, expect SYNACK, send ACK.
+ conn.Handshake()
+
+ // Tell the DUT to accept the new connection.
+ acceptFD := dut.Accept(acceptFd)
+}
+```
+
+## Other notes
+
+* The time between receiving a SYN-ACK and replying with an ACK in `Handshake`
+ is about 3ms. This is much slower than the native unix response, which is
+ about 0.3ms. Packetdrill gets closer to 0.3ms. For tests where timing is
+ crucial, packetdrill is faster and more precise.
diff --git a/test/packetimpact/dut/BUILD b/test/packetimpact/dut/BUILD
new file mode 100644
index 000000000..3ce63c2c6
--- /dev/null
+++ b/test/packetimpact/dut/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "cc_binary", "grpcpp")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+cc_binary(
+ name = "posix_server",
+ srcs = ["posix_server.cc"],
+ linkstatic = 1,
+ static = True, # This is needed for running in a docker container.
+ deps = [
+ grpcpp,
+ "//test/packetimpact/proto:posix_server_cc_grpc_proto",
+ "//test/packetimpact/proto:posix_server_cc_proto",
+ ],
+)
diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc
new file mode 100644
index 000000000..86e580c6f
--- /dev/null
+++ b/test/packetimpact/dut/posix_server.cc
@@ -0,0 +1,260 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <getopt.h>
+#include <netdb.h>
+#include <netinet/in.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <iostream>
+#include <unordered_map>
+
+#include "arpa/inet.h"
+#include "include/grpcpp/security/server_credentials.h"
+#include "include/grpcpp/server_builder.h"
+#include "test/packetimpact/proto/posix_server.grpc.pb.h"
+#include "test/packetimpact/proto/posix_server.pb.h"
+
+// Converts a sockaddr_storage to a Sockaddr message.
+::grpc::Status sockaddr_to_proto(const sockaddr_storage &addr,
+ socklen_t addrlen,
+ posix_server::Sockaddr *sockaddr_proto) {
+ switch (addr.ss_family) {
+ case AF_INET: {
+ auto addr_in = reinterpret_cast<const sockaddr_in *>(&addr);
+ auto response_in = sockaddr_proto->mutable_in();
+ response_in->set_family(addr_in->sin_family);
+ response_in->set_port(ntohs(addr_in->sin_port));
+ response_in->mutable_addr()->assign(
+ reinterpret_cast<const char *>(&addr_in->sin_addr.s_addr), 4);
+ return ::grpc::Status::OK;
+ }
+ case AF_INET6: {
+ auto addr_in6 = reinterpret_cast<const sockaddr_in6 *>(&addr);
+ auto response_in6 = sockaddr_proto->mutable_in6();
+ response_in6->set_family(addr_in6->sin6_family);
+ response_in6->set_port(ntohs(addr_in6->sin6_port));
+ response_in6->set_flowinfo(ntohl(addr_in6->sin6_flowinfo));
+ response_in6->mutable_addr()->assign(
+ reinterpret_cast<const char *>(&addr_in6->sin6_addr.s6_addr), 16);
+ response_in6->set_scope_id(ntohl(addr_in6->sin6_scope_id));
+ return ::grpc::Status::OK;
+ }
+ }
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown Sockaddr");
+}
+
+class PosixImpl final : public posix_server::Posix::Service {
+ ::grpc::Status Accept(grpc_impl::ServerContext *context,
+ const ::posix_server::AcceptRequest *request,
+ ::posix_server::AcceptResponse *response) override {
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ response->set_fd(accept(request->sockfd(),
+ reinterpret_cast<sockaddr *>(&addr), &addrlen));
+ response->set_errno_(errno);
+ return sockaddr_to_proto(addr, addrlen, response->mutable_addr());
+ }
+
+ ::grpc::Status Bind(grpc_impl::ServerContext *context,
+ const ::posix_server::BindRequest *request,
+ ::posix_server::BindResponse *response) override {
+ if (!request->has_addr()) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Missing address");
+ }
+ sockaddr_storage addr;
+
+ switch (request->addr().sockaddr_case()) {
+ case posix_server::Sockaddr::SockaddrCase::kIn: {
+ auto request_in = request->addr().in();
+ if (request_in.addr().size() != 4) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "IPv4 address must be 4 bytes");
+ }
+ auto addr_in = reinterpret_cast<sockaddr_in *>(&addr);
+ addr_in->sin_family = request_in.family();
+ addr_in->sin_port = htons(request_in.port());
+ request_in.addr().copy(
+ reinterpret_cast<char *>(&addr_in->sin_addr.s_addr), 4);
+ break;
+ }
+ case posix_server::Sockaddr::SockaddrCase::kIn6: {
+ auto request_in6 = request->addr().in6();
+ if (request_in6.addr().size() != 16) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "IPv6 address must be 16 bytes");
+ }
+ auto addr_in6 = reinterpret_cast<sockaddr_in6 *>(&addr);
+ addr_in6->sin6_family = request_in6.family();
+ addr_in6->sin6_port = htons(request_in6.port());
+ addr_in6->sin6_flowinfo = htonl(request_in6.flowinfo());
+ request_in6.addr().copy(
+ reinterpret_cast<char *>(&addr_in6->sin6_addr.s6_addr), 16);
+ addr_in6->sin6_scope_id = htonl(request_in6.scope_id());
+ break;
+ }
+ case posix_server::Sockaddr::SockaddrCase::SOCKADDR_NOT_SET:
+ default:
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Unknown Sockaddr");
+ }
+ response->set_ret(bind(request->sockfd(),
+ reinterpret_cast<sockaddr *>(&addr), sizeof(addr)));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Close(grpc_impl::ServerContext *context,
+ const ::posix_server::CloseRequest *request,
+ ::posix_server::CloseResponse *response) override {
+ response->set_ret(close(request->fd()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status GetSockName(
+ grpc_impl::ServerContext *context,
+ const ::posix_server::GetSockNameRequest *request,
+ ::posix_server::GetSockNameResponse *response) override {
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ response->set_ret(getsockname(
+ request->sockfd(), reinterpret_cast<sockaddr *>(&addr), &addrlen));
+ response->set_errno_(errno);
+ return sockaddr_to_proto(addr, addrlen, response->mutable_addr());
+ }
+
+ ::grpc::Status Listen(grpc_impl::ServerContext *context,
+ const ::posix_server::ListenRequest *request,
+ ::posix_server::ListenResponse *response) override {
+ response->set_ret(listen(request->sockfd(), request->backlog()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Send(::grpc::ServerContext *context,
+ const ::posix_server::SendRequest *request,
+ ::posix_server::SendResponse *response) override {
+ response->set_ret(::send(request->sockfd(), request->buf().data(),
+ request->buf().size(), request->flags()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status SetSockOpt(
+ grpc_impl::ServerContext *context,
+ const ::posix_server::SetSockOptRequest *request,
+ ::posix_server::SetSockOptResponse *response) override {
+ response->set_ret(setsockopt(request->sockfd(), request->level(),
+ request->optname(), request->optval().c_str(),
+ request->optval().size()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status SetSockOptInt(
+ ::grpc::ServerContext *context,
+ const ::posix_server::SetSockOptIntRequest *request,
+ ::posix_server::SetSockOptIntResponse *response) override {
+ int opt = request->intval();
+ response->set_ret(::setsockopt(request->sockfd(), request->level(),
+ request->optname(), &opt, sizeof(opt)));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status SetSockOptTimeval(
+ ::grpc::ServerContext *context,
+ const ::posix_server::SetSockOptTimevalRequest *request,
+ ::posix_server::SetSockOptTimevalResponse *response) override {
+ timeval tv = {.tv_sec = static_cast<__time_t>(request->timeval().seconds()),
+ .tv_usec = static_cast<__suseconds_t>(
+ request->timeval().microseconds())};
+ response->set_ret(setsockopt(request->sockfd(), request->level(),
+ request->optname(), &tv, sizeof(tv)));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Socket(grpc_impl::ServerContext *context,
+ const ::posix_server::SocketRequest *request,
+ ::posix_server::SocketResponse *response) override {
+ response->set_fd(
+ socket(request->domain(), request->type(), request->protocol()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Recv(::grpc::ServerContext *context,
+ const ::posix_server::RecvRequest *request,
+ ::posix_server::RecvResponse *response) override {
+ std::vector<char> buf(request->len());
+ response->set_ret(
+ recv(request->sockfd(), buf.data(), buf.size(), request->flags()));
+ response->set_errno_(errno);
+ response->set_buf(buf.data(), response->ret());
+ return ::grpc::Status::OK;
+ }
+};
+
+// Parse command line options. Returns a pointer to the first argument beyond
+// the options.
+void parse_command_line_options(int argc, char *argv[], std::string *ip,
+ int *port) {
+ static struct option options[] = {{"ip", required_argument, NULL, 1},
+ {"port", required_argument, NULL, 2},
+ {0, 0, 0, 0}};
+
+ // Parse the arguments.
+ int c;
+ while ((c = getopt_long(argc, argv, "", options, NULL)) > 0) {
+ if (c == 1) {
+ *ip = optarg;
+ } else if (c == 2) {
+ *port = std::stoi(std::string(optarg));
+ }
+ }
+}
+
+void run_server(const std::string &ip, int port) {
+ PosixImpl posix_service;
+ grpc::ServerBuilder builder;
+ std::string server_address = ip + ":" + std::to_string(port);
+ // Set the authentication mechanism.
+ std::shared_ptr<grpc::ServerCredentials> creds =
+ grpc::InsecureServerCredentials();
+ builder.AddListeningPort(server_address, creds);
+ builder.RegisterService(&posix_service);
+
+ std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
+ std::cerr << "Server listening on " << server_address << std::endl;
+ server->Wait();
+ std::cerr << "posix_server is finished." << std::endl;
+}
+
+int main(int argc, char *argv[]) {
+ std::cerr << "posix_server is starting." << std::endl;
+ std::string ip;
+ int port;
+ parse_command_line_options(argc, argv, &ip, &port);
+
+ std::cerr << "Got IP " << ip << " and port " << port << "." << std::endl;
+ run_server(ip, port);
+}
diff --git a/test/packetimpact/proto/BUILD b/test/packetimpact/proto/BUILD
new file mode 100644
index 000000000..4a4370f42
--- /dev/null
+++ b/test/packetimpact/proto/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "proto_library")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+proto_library(
+ name = "posix_server",
+ srcs = ["posix_server.proto"],
+ has_services = 1,
+)
diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto
new file mode 100644
index 000000000..4035e1ee6
--- /dev/null
+++ b/test/packetimpact/proto/posix_server.proto
@@ -0,0 +1,193 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package posix_server;
+
+message SockaddrIn {
+ int32 family = 1;
+ uint32 port = 2;
+ bytes addr = 3;
+}
+
+message SockaddrIn6 {
+ uint32 family = 1;
+ uint32 port = 2;
+ uint32 flowinfo = 3;
+ bytes addr = 4;
+ uint32 scope_id = 5;
+}
+
+message Sockaddr {
+ oneof sockaddr {
+ SockaddrIn in = 1;
+ SockaddrIn6 in6 = 2;
+ }
+}
+
+message Timeval {
+ int64 seconds = 1;
+ int64 microseconds = 2;
+}
+
+// Request and Response pairs for each Posix service RPC call, sorted.
+
+message AcceptRequest {
+ int32 sockfd = 1;
+}
+
+message AcceptResponse {
+ int32 fd = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ Sockaddr addr = 3;
+}
+
+message BindRequest {
+ int32 sockfd = 1;
+ Sockaddr addr = 2;
+}
+
+message BindResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message CloseRequest {
+ int32 fd = 1;
+}
+
+message CloseResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message GetSockNameRequest {
+ int32 sockfd = 1;
+}
+
+message GetSockNameResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ Sockaddr addr = 3;
+}
+
+message ListenRequest {
+ int32 sockfd = 1;
+ int32 backlog = 2;
+}
+
+message ListenResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message SendRequest {
+ int32 sockfd = 1;
+ bytes buf = 2;
+ int32 flags = 3;
+}
+
+message SendResponse {
+ int32 ret = 1;
+ int32 errno_ = 2;
+}
+
+message SetSockOptRequest {
+ int32 sockfd = 1;
+ int32 level = 2;
+ int32 optname = 3;
+ bytes optval = 4;
+}
+
+message SetSockOptResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message SetSockOptIntRequest {
+ int32 sockfd = 1;
+ int32 level = 2;
+ int32 optname = 3;
+ int32 intval = 4;
+}
+
+message SetSockOptIntResponse {
+ int32 ret = 1;
+ int32 errno_ = 2;
+}
+
+message SetSockOptTimevalRequest {
+ int32 sockfd = 1;
+ int32 level = 2;
+ int32 optname = 3;
+ Timeval timeval = 4;
+}
+
+message SetSockOptTimevalResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message SocketRequest {
+ int32 domain = 1;
+ int32 type = 2;
+ int32 protocol = 3;
+}
+
+message SocketResponse {
+ int32 fd = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message RecvRequest {
+ int32 sockfd = 1;
+ int32 len = 2;
+ int32 flags = 3;
+}
+
+message RecvResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ bytes buf = 3;
+}
+
+service Posix {
+ // Call accept() on the DUT.
+ rpc Accept(AcceptRequest) returns (AcceptResponse);
+ // Call bind() on the DUT.
+ rpc Bind(BindRequest) returns (BindResponse);
+ // Call close() on the DUT.
+ rpc Close(CloseRequest) returns (CloseResponse);
+ // Call getsockname() on the DUT.
+ rpc GetSockName(GetSockNameRequest) returns (GetSockNameResponse);
+ // Call listen() on the DUT.
+ rpc Listen(ListenRequest) returns (ListenResponse);
+ // Call send() on the DUT.
+ rpc Send(SendRequest) returns (SendResponse);
+ // Call setsockopt() on the DUT. You should prefer one of the other
+ // SetSockOpt* functions with a more structured optval or else you may get the
+ // encoding wrong, such as making a bad assumption about the server's word
+ // sizes or endianness.
+ rpc SetSockOpt(SetSockOptRequest) returns (SetSockOptResponse);
+ // Call setsockopt() on the DUT with an int optval.
+ rpc SetSockOptInt(SetSockOptIntRequest) returns (SetSockOptIntResponse);
+ // Call setsockopt() on the DUT with a Timeval optval.
+ rpc SetSockOptTimeval(SetSockOptTimevalRequest)
+ returns (SetSockOptTimevalResponse);
+ // Call socket() on the DUT.
+ rpc Socket(SocketRequest) returns (SocketResponse);
+ // Call recv() on the DUT.
+ rpc Recv(RecvRequest) returns (RecvResponse);
+}
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
new file mode 100644
index 000000000..b6a254882
--- /dev/null
+++ b/test/packetimpact/testbench/BUILD
@@ -0,0 +1,41 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "testbench",
+ srcs = [
+ "connections.go",
+ "dut.go",
+ "dut_client.go",
+ "layers.go",
+ "rawsockets.go",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//pkg/usermem",
+ "//test/packetimpact/proto:posix_server_go_proto",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
+ "@com_github_imdario_mergo//:go_default_library",
+ "@com_github_mohae_deepcopy//:go_default_library",
+ "@org_golang_google_grpc//:go_default_library",
+ "@org_golang_google_grpc//keepalive:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ "@org_uber_go_multierr//:go_default_library",
+ ],
+)
+
+go_test(
+ name = "testbench_test",
+ size = "small",
+ srcs = ["layers_test.go"],
+ library = ":testbench",
+ deps = ["//pkg/tcpip"],
+)
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
new file mode 100644
index 000000000..f84fd8ba7
--- /dev/null
+++ b/test/packetimpact/testbench/connections.go
@@ -0,0 +1,670 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testbench has utilities to send and receive packets and also command
+// the DUT to run POSIX functions.
+package testbench
+
+import (
+ "flag"
+ "fmt"
+ "math/rand"
+ "net"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/mohae/deepcopy"
+ "go.uber.org/multierr"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+var localIPv4 = flag.String("local_ipv4", "", "local IPv4 address for test packets")
+var remoteIPv4 = flag.String("remote_ipv4", "", "remote IPv4 address for test packets")
+var localMAC = flag.String("local_mac", "", "local mac address for test packets")
+var remoteMAC = flag.String("remote_mac", "", "remote mac address for test packets")
+
+// pickPort makes a new socket and returns the socket FD and port. The caller
+// must close the FD when done with the port if there is no error.
+func pickPort() (int, uint16, error) {
+ fd, err := unix.Socket(unix.AF_INET, unix.SOCK_STREAM, 0)
+ if err != nil {
+ return -1, 0, err
+ }
+ var sa unix.SockaddrInet4
+ copy(sa.Addr[0:4], net.ParseIP(*localIPv4).To4())
+ if err := unix.Bind(fd, &sa); err != nil {
+ unix.Close(fd)
+ return -1, 0, err
+ }
+ newSockAddr, err := unix.Getsockname(fd)
+ if err != nil {
+ unix.Close(fd)
+ return -1, 0, err
+ }
+ newSockAddrInet4, ok := newSockAddr.(*unix.SockaddrInet4)
+ if !ok {
+ unix.Close(fd)
+ return -1, 0, fmt.Errorf("can't cast Getsockname result to SockaddrInet4")
+ }
+ return fd, uint16(newSockAddrInet4.Port), nil
+}
+
+// layerState stores the state of a layer of a connection.
+type layerState interface {
+ // outgoing returns an outgoing layer to be sent in a frame.
+ outgoing() Layer
+
+ // incoming creates an expected Layer for comparing against a received Layer.
+ // Because the expectation can depend on values in the received Layer, it is
+ // an input to incoming. For example, the ACK number needs to be checked in a
+ // TCP packet but only if the ACK flag is set in the received packet.
+ incoming(received Layer) Layer
+
+ // sent updates the layerState based on the Layer that was sent. The input is
+ // a Layer with all prev and next pointers populated so that the entire frame
+ // as it was sent is available.
+ sent(sent Layer) error
+
+ // received updates the layerState based on a Layer that is receieved. The
+ // input is a Layer with all prev and next pointers populated so that the
+ // entire frame as it was receieved is available.
+ received(received Layer) error
+
+ // close frees associated resources held by the LayerState.
+ close() error
+}
+
+// etherState maintains state about an Ethernet connection.
+type etherState struct {
+ out, in Ether
+}
+
+var _ layerState = (*etherState)(nil)
+
+// newEtherState creates a new etherState.
+func newEtherState(out, in Ether) (*etherState, error) {
+ lMAC, err := tcpip.ParseMACAddress(*localMAC)
+ if err != nil {
+ return nil, err
+ }
+
+ rMAC, err := tcpip.ParseMACAddress(*remoteMAC)
+ if err != nil {
+ return nil, err
+ }
+ s := etherState{
+ out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
+ in: Ether{SrcAddr: &rMAC, DstAddr: &lMAC},
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *etherState) outgoing() Layer {
+ return &s.out
+}
+
+func (s *etherState) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*etherState) sent(Layer) error {
+ return nil
+}
+
+func (*etherState) received(Layer) error {
+ return nil
+}
+
+func (*etherState) close() error {
+ return nil
+}
+
+// ipv4State maintains state about an IPv4 connection.
+type ipv4State struct {
+ out, in IPv4
+}
+
+var _ layerState = (*ipv4State)(nil)
+
+// newIPv4State creates a new ipv4State.
+func newIPv4State(out, in IPv4) (*ipv4State, error) {
+ lIP := tcpip.Address(net.ParseIP(*localIPv4).To4())
+ rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4())
+ s := ipv4State{
+ out: IPv4{SrcAddr: &lIP, DstAddr: &rIP},
+ in: IPv4{SrcAddr: &rIP, DstAddr: &lIP},
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *ipv4State) outgoing() Layer {
+ return &s.out
+}
+
+func (s *ipv4State) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*ipv4State) sent(Layer) error {
+ return nil
+}
+
+func (*ipv4State) received(Layer) error {
+ return nil
+}
+
+func (*ipv4State) close() error {
+ return nil
+}
+
+// tcpState maintains state about a TCP connection.
+type tcpState struct {
+ out, in TCP
+ localSeqNum, remoteSeqNum *seqnum.Value
+ synAck *TCP
+ portPickerFD int
+ finSent bool
+}
+
+var _ layerState = (*tcpState)(nil)
+
+// SeqNumValue is a helper routine that allocates a new seqnum.Value value to
+// store v and returns a pointer to it.
+func SeqNumValue(v seqnum.Value) *seqnum.Value {
+ return &v
+}
+
+// newTCPState creates a new TCPState.
+func newTCPState(out, in TCP) (*tcpState, error) {
+ portPickerFD, localPort, err := pickPort()
+ if err != nil {
+ return nil, err
+ }
+ s := tcpState{
+ out: TCP{SrcPort: &localPort},
+ in: TCP{DstPort: &localPort},
+ localSeqNum: SeqNumValue(seqnum.Value(rand.Uint32())),
+ portPickerFD: portPickerFD,
+ finSent: false,
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *tcpState) outgoing() Layer {
+ newOutgoing := deepcopy.Copy(s.out).(TCP)
+ if s.localSeqNum != nil {
+ newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum))
+ }
+ if s.remoteSeqNum != nil {
+ newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum))
+ }
+ return &newOutgoing
+}
+
+func (s *tcpState) incoming(received Layer) Layer {
+ tcpReceived, ok := received.(*TCP)
+ if !ok {
+ return nil
+ }
+ newIn := deepcopy.Copy(s.in).(TCP)
+ if s.remoteSeqNum != nil {
+ newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum))
+ }
+ if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 {
+ // The caller didn't specify an AckNum so we'll expect the calculated one,
+ // but only if the ACK flag is set because the AckNum is not valid in a
+ // header if ACK is not set.
+ newIn.AckNum = Uint32(uint32(*s.localSeqNum))
+ }
+ return &newIn
+}
+
+func (s *tcpState) sent(sent Layer) error {
+ tcp, ok := sent.(*TCP)
+ if !ok {
+ return fmt.Errorf("can't update tcpState with %T Layer", sent)
+ }
+ if !s.finSent {
+ // update localSeqNum by the payload only when FIN is not yet sent by us
+ for current := tcp.next(); current != nil; current = current.next() {
+ s.localSeqNum.UpdateForward(seqnum.Size(current.length()))
+ }
+ }
+ if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
+ s.localSeqNum.UpdateForward(1)
+ }
+ if *tcp.Flags&(header.TCPFlagFin) != 0 {
+ s.finSent = true
+ }
+ return nil
+}
+
+func (s *tcpState) received(l Layer) error {
+ tcp, ok := l.(*TCP)
+ if !ok {
+ return fmt.Errorf("can't update tcpState with %T Layer", l)
+ }
+ s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum))
+ if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
+ s.remoteSeqNum.UpdateForward(1)
+ }
+ for current := tcp.next(); current != nil; current = current.next() {
+ s.remoteSeqNum.UpdateForward(seqnum.Size(current.length()))
+ }
+ return nil
+}
+
+// close frees the port associated with this connection.
+func (s *tcpState) close() error {
+ if err := unix.Close(s.portPickerFD); err != nil {
+ return err
+ }
+ s.portPickerFD = -1
+ return nil
+}
+
+// udpState maintains state about a UDP connection.
+type udpState struct {
+ out, in UDP
+ portPickerFD int
+}
+
+var _ layerState = (*udpState)(nil)
+
+// newUDPState creates a new udpState.
+func newUDPState(out, in UDP) (*udpState, error) {
+ portPickerFD, localPort, err := pickPort()
+ if err != nil {
+ return nil, err
+ }
+ s := udpState{
+ out: UDP{SrcPort: &localPort},
+ in: UDP{DstPort: &localPort},
+ portPickerFD: portPickerFD,
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *udpState) outgoing() Layer {
+ return &s.out
+}
+
+func (s *udpState) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*udpState) sent(l Layer) error {
+ return nil
+}
+
+func (*udpState) received(l Layer) error {
+ return nil
+}
+
+// close frees the port associated with this connection.
+func (s *udpState) close() error {
+ if err := unix.Close(s.portPickerFD); err != nil {
+ return err
+ }
+ s.portPickerFD = -1
+ return nil
+}
+
+// Connection holds a collection of layer states for maintaining a connection
+// along with sockets for sniffer and injecting packets.
+type Connection struct {
+ layerStates []layerState
+ injector Injector
+ sniffer Sniffer
+ t *testing.T
+}
+
+// match tries to match each Layer in received against the incoming filter. If
+// received is longer than layerStates then that may still count as a match. The
+// reverse is never a match. override overrides the default matchers for each
+// Layer.
+func (conn *Connection) match(override, received Layers) bool {
+ if len(received) < len(conn.layerStates) {
+ return false
+ }
+ for i, s := range conn.layerStates {
+ toMatch := s.incoming(received[i])
+ if toMatch == nil {
+ return false
+ }
+ if i < len(override) {
+ toMatch.merge(override[i])
+ }
+ if !toMatch.match(received[i]) {
+ return false
+ }
+ }
+ return true
+}
+
+// Close frees associated resources held by the Connection.
+func (conn *Connection) Close() {
+ errs := multierr.Combine(conn.sniffer.close(), conn.injector.close())
+ for _, s := range conn.layerStates {
+ if err := s.close(); err != nil {
+ errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err))
+ }
+ }
+ if errs != nil {
+ conn.t.Fatalf("unable to close %+v: %s", conn, errs)
+ }
+}
+
+// CreateFrame builds a frame for the connection with layer overriding defaults
+// of the innermost layer and additionalLayers added after it.
+func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
+ var layersToSend Layers
+ for _, s := range conn.layerStates {
+ layersToSend = append(layersToSend, s.outgoing())
+ }
+ if err := layersToSend[len(layersToSend)-1].merge(layer); err != nil {
+ conn.t.Fatalf("can't merge %+v into %+v: %s", layer, layersToSend[len(layersToSend)-1], err)
+ }
+ layersToSend = append(layersToSend, additionalLayers...)
+ return layersToSend
+}
+
+// SendFrame sends a frame on the wire and updates the state of all layers.
+func (conn *Connection) SendFrame(frame Layers) {
+ outBytes, err := frame.toBytes()
+ if err != nil {
+ conn.t.Fatalf("can't build outgoing TCP packet: %s", err)
+ }
+ conn.injector.Send(outBytes)
+
+ // frame might have nil values where the caller wanted to use default values.
+ // sentFrame will have no nil values in it because it comes from parsing the
+ // bytes that were actually sent.
+ sentFrame := parse(parseEther, outBytes)
+ // Update the state of each layer based on what was sent.
+ for i, s := range conn.layerStates {
+ if err := s.sent(sentFrame[i]); err != nil {
+ conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err)
+ }
+ }
+}
+
+// Send a packet with reasonable defaults. Potentially override the final layer
+// in the connection with the provided layer and add additionLayers.
+func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) {
+ conn.SendFrame(conn.CreateFrame(layer, additionalLayers...))
+}
+
+// recvFrame gets the next successfully parsed frame (of type Layers) within the
+// timeout provided. If no parsable frame arrives before the timeout, it returns
+// nil.
+func (conn *Connection) recvFrame(timeout time.Duration) Layers {
+ if timeout <= 0 {
+ return nil
+ }
+ b := conn.sniffer.Recv(timeout)
+ if b == nil {
+ return nil
+ }
+ return parse(parseEther, b)
+}
+
+// Expect a frame with the final layerStates layer matching the provided Layer
+// within the timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) {
+ // Make a frame that will ignore all but the final layer.
+ layers := make([]Layer, len(conn.layerStates))
+ layers[len(layers)-1] = layer
+
+ gotFrame, err := conn.ExpectFrame(layers, timeout)
+ if err != nil {
+ return nil, err
+ }
+ if len(conn.layerStates)-1 < len(gotFrame) {
+ return gotFrame[len(conn.layerStates)-1], nil
+ }
+ conn.t.Fatal("the received frame should be at least as long as the expected layers")
+ return nil, fmt.Errorf("the received frame should be at least as long as the expected layers")
+}
+
+// ExpectFrame expects a frame that matches the provided Layers within the
+// timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) {
+ deadline := time.Now().Add(timeout)
+ var allLayers []string
+ for {
+ var gotLayers Layers
+ if timeout = time.Until(deadline); timeout > 0 {
+ gotLayers = conn.recvFrame(timeout)
+ }
+ if gotLayers == nil {
+ return nil, fmt.Errorf("got %d packets:\n%s", len(allLayers), strings.Join(allLayers, "\n"))
+ }
+ if conn.match(layers, gotLayers) {
+ for i, s := range conn.layerStates {
+ if err := s.received(gotLayers[i]); err != nil {
+ conn.t.Fatal(err)
+ }
+ }
+ return gotLayers, nil
+ }
+ allLayers = append(allLayers, fmt.Sprintf("%s", gotLayers))
+ }
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *Connection) Drain() {
+ conn.sniffer.Drain()
+}
+
+// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection.
+type TCPIPv4 Connection
+
+// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
+func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv4State, err := newIPv4State(IPv4{}, IPv4{})
+ if err != nil {
+ t.Fatalf("can't make ipv4State: %s", err)
+ }
+ tcpState, err := newTCPState(outgoingTCP, incomingTCP)
+ if err != nil {
+ t.Fatalf("can't make tcpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return TCPIPv4{
+ layerStates: []layerState{etherState, ipv4State, tcpState},
+ injector: injector,
+ sniffer: sniffer,
+ t: t,
+ }
+}
+
+// Handshake performs a TCP 3-way handshake. The input Connection should have a
+// final TCP Layer.
+func (conn *TCPIPv4) Handshake() {
+ // Send the SYN.
+ conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)})
+
+ // Wait for the SYN-ACK.
+ synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if synAck == nil {
+ conn.t.Fatalf("didn't get synack during handshake: %s", err)
+ }
+ conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
+
+ // Send an ACK.
+ conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)})
+}
+
+// ExpectData is a convenient method that expects a Layer and the Layer after
+// it. If it doens't arrive in time, it returns nil.
+func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = tcp
+ if payload != nil {
+ expected = append(expected, payload)
+ }
+ return (*Connection)(conn).ExpectFrame(expected, timeout)
+}
+
+// Send a packet with reasonable defaults. Potentially override the TCP layer in
+// the connection with the provided layer and add additionLayers.
+func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
+ (*Connection)(conn).Send(&tcp, additionalLayers...)
+}
+
+// Close frees associated resources held by the TCPIPv4 connection.
+func (conn *TCPIPv4) Close() {
+ (*Connection)(conn).Close()
+}
+
+// Expect a frame with the TCP layer matching the provided TCP within the
+// timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) {
+ layer, err := (*Connection)(conn).Expect(&tcp, timeout)
+ if layer == nil {
+ return nil, err
+ }
+ gotTCP, ok := layer.(*TCP)
+ if !ok {
+ conn.t.Fatalf("expected %s to be TCP", layer)
+ }
+ return gotTCP, err
+}
+
+func (conn *TCPIPv4) state() *tcpState {
+ state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState)
+ if !ok {
+ conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates)
+ }
+ return state
+}
+
+// RemoteSeqNum returns the next expected sequence number from the DUT.
+func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value {
+ return conn.state().remoteSeqNum
+}
+
+// LocalSeqNum returns the next sequence number to send from the testbench.
+func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value {
+ return conn.state().localSeqNum
+}
+
+// SynAck returns the SynAck that was part of the handshake.
+func (conn *TCPIPv4) SynAck() *TCP {
+ return conn.state().synAck
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *TCPIPv4) Drain() {
+ conn.sniffer.Drain()
+}
+
+// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection.
+type UDPIPv4 Connection
+
+// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults.
+func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv4State, err := newIPv4State(IPv4{}, IPv4{})
+ if err != nil {
+ t.Fatalf("can't make ipv4State: %s", err)
+ }
+ tcpState, err := newUDPState(outgoingUDP, incomingUDP)
+ if err != nil {
+ t.Fatalf("can't make udpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return UDPIPv4{
+ layerStates: []layerState{etherState, ipv4State, tcpState},
+ injector: injector,
+ sniffer: sniffer,
+ t: t,
+ }
+}
+
+// CreateFrame builds a frame for the connection with layer overriding defaults
+// of the innermost layer and additionalLayers added after it.
+func (conn *UDPIPv4) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
+ return (*Connection)(conn).CreateFrame(layer, additionalLayers...)
+}
+
+// SendFrame sends a frame on the wire and updates the state of all layers.
+func (conn *UDPIPv4) SendFrame(frame Layers) {
+ (*Connection)(conn).SendFrame(frame)
+}
+
+// Close frees associated resources held by the UDPIPv4 connection.
+func (conn *UDPIPv4) Close() {
+ (*Connection)(conn).Close()
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *UDPIPv4) Drain() {
+ conn.sniffer.Drain()
+}
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
new file mode 100644
index 000000000..9335909c0
--- /dev/null
+++ b/test/packetimpact/testbench/dut.go
@@ -0,0 +1,473 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "context"
+ "flag"
+ "net"
+ "strconv"
+ "syscall"
+ "testing"
+ "time"
+
+ pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto"
+
+ "golang.org/x/sys/unix"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/keepalive"
+)
+
+var (
+ posixServerIP = flag.String("posix_server_ip", "", "ip address to listen to for UDP commands")
+ posixServerPort = flag.Int("posix_server_port", 40000, "port to listen to for UDP commands")
+ rpcTimeout = flag.Duration("rpc_timeout", 100*time.Millisecond, "gRPC timeout")
+ rpcKeepalive = flag.Duration("rpc_keepalive", 10*time.Second, "gRPC keepalive")
+)
+
+// DUT communicates with the DUT to force it to make POSIX calls.
+type DUT struct {
+ t *testing.T
+ conn *grpc.ClientConn
+ posixServer PosixClient
+}
+
+// NewDUT creates a new connection with the DUT over gRPC.
+func NewDUT(t *testing.T) DUT {
+ flag.Parse()
+ posixServerAddress := *posixServerIP + ":" + strconv.Itoa(*posixServerPort)
+ conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: *rpcKeepalive}))
+ if err != nil {
+ t.Fatalf("failed to grpc.Dial(%s): %s", posixServerAddress, err)
+ }
+ posixServer := NewPosixClient(conn)
+ return DUT{
+ t: t,
+ conn: conn,
+ posixServer: posixServer,
+ }
+}
+
+// TearDown closes the underlying connection.
+func (dut *DUT) TearDown() {
+ dut.conn.Close()
+}
+
+func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr {
+ dut.t.Helper()
+ switch s := sa.(type) {
+ case *unix.SockaddrInet4:
+ return &pb.Sockaddr{
+ Sockaddr: &pb.Sockaddr_In{
+ In: &pb.SockaddrIn{
+ Family: unix.AF_INET,
+ Port: uint32(s.Port),
+ Addr: s.Addr[:],
+ },
+ },
+ }
+ case *unix.SockaddrInet6:
+ return &pb.Sockaddr{
+ Sockaddr: &pb.Sockaddr_In6{
+ In6: &pb.SockaddrIn6{
+ Family: unix.AF_INET6,
+ Port: uint32(s.Port),
+ Flowinfo: 0,
+ ScopeId: s.ZoneId,
+ Addr: s.Addr[:],
+ },
+ },
+ }
+ }
+ dut.t.Fatalf("can't parse Sockaddr: %+v", sa)
+ return nil
+}
+
+func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr {
+ dut.t.Helper()
+ switch s := sa.Sockaddr.(type) {
+ case *pb.Sockaddr_In:
+ ret := unix.SockaddrInet4{
+ Port: int(s.In.GetPort()),
+ }
+ copy(ret.Addr[:], s.In.GetAddr())
+ return &ret
+ case *pb.Sockaddr_In6:
+ ret := unix.SockaddrInet6{
+ Port: int(s.In6.GetPort()),
+ ZoneId: s.In6.GetScopeId(),
+ }
+ copy(ret.Addr[:], s.In6.GetAddr())
+ }
+ dut.t.Fatalf("can't parse Sockaddr: %+v", sa)
+ return nil
+}
+
+// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol
+// proto, and bound to the IP address addr. Returns the new file descriptor and
+// the port that was selected on the DUT.
+func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) {
+ dut.t.Helper()
+ var fd int32
+ if addr.To4() != nil {
+ fd = dut.Socket(unix.AF_INET, typ, proto)
+ sa := unix.SockaddrInet4{}
+ copy(sa.Addr[:], addr.To4())
+ dut.Bind(fd, &sa)
+ } else if addr.To16() != nil {
+ fd = dut.Socket(unix.AF_INET6, typ, proto)
+ sa := unix.SockaddrInet6{}
+ copy(sa.Addr[:], addr.To16())
+ dut.Bind(fd, &sa)
+ } else {
+ dut.t.Fatal("unknown ip addr type for remoteIP")
+ }
+ sa := dut.GetSockName(fd)
+ var port int
+ switch s := sa.(type) {
+ case *unix.SockaddrInet4:
+ port = s.Port
+ case *unix.SockaddrInet6:
+ port = s.Port
+ default:
+ dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa)
+ }
+ return fd, uint16(port)
+}
+
+// CreateListener makes a new TCP connection. If it fails, the test ends.
+func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
+ fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(*remoteIPv4))
+ dut.Listen(fd, backlog)
+ return fd, remotePort
+}
+
+// All the functions that make gRPC calls to the Posix service are below, sorted
+// alphabetically.
+
+// Accept calls accept on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// AcceptWithErrno.
+func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ fd, sa, err := dut.AcceptWithErrno(ctx, sockfd)
+ if fd < 0 {
+ dut.t.Fatalf("failed to accept: %s", err)
+ }
+ return fd, sa
+}
+
+// AcceptWithErrno calls accept on the DUT.
+func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
+ dut.t.Helper()
+ req := pb.AcceptRequest{
+ Sockfd: sockfd,
+ }
+ resp, err := dut.posixServer.Accept(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Accept: %s", err)
+ }
+ return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+}
+
+// Bind calls bind on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is
+// needed, use BindWithErrno.
+func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.BindWithErrno(ctx, fd, sa)
+ if ret != 0 {
+ dut.t.Fatalf("failed to bind socket: %s", err)
+ }
+}
+
+// BindWithErrno calls bind on the DUT.
+func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) {
+ dut.t.Helper()
+ req := pb.BindRequest{
+ Sockfd: fd,
+ Addr: dut.sockaddrToProto(sa),
+ }
+ resp, err := dut.posixServer.Bind(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Bind: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Close calls close on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// CloseWithErrno.
+func (dut *DUT) Close(fd int32) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.CloseWithErrno(ctx, fd)
+ if ret != 0 {
+ dut.t.Fatalf("failed to close: %s", err)
+ }
+}
+
+// CloseWithErrno calls close on the DUT.
+func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.CloseRequest{
+ Fd: fd,
+ }
+ resp, err := dut.posixServer.Close(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Close: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// GetSockName calls getsockname on the DUT and causes a fatal test failure if
+// it doesn't succeed. If more control over the timeout or error handling is
+// needed, use GetSockNameWithErrno.
+func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, sa, err := dut.GetSockNameWithErrno(ctx, sockfd)
+ if ret != 0 {
+ dut.t.Fatalf("failed to getsockname: %s", err)
+ }
+ return sa
+}
+
+// GetSockNameWithErrno calls getsockname on the DUT.
+func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
+ dut.t.Helper()
+ req := pb.GetSockNameRequest{
+ Sockfd: sockfd,
+ }
+ resp, err := dut.posixServer.GetSockName(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Bind: %s", err)
+ }
+ return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+}
+
+// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// ListenWithErrno.
+func (dut *DUT) Listen(sockfd, backlog int32) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.ListenWithErrno(ctx, sockfd, backlog)
+ if ret != 0 {
+ dut.t.Fatalf("failed to listen: %s", err)
+ }
+}
+
+// ListenWithErrno calls listen on the DUT.
+func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.ListenRequest{
+ Sockfd: sockfd,
+ Backlog: backlog,
+ }
+ resp, err := dut.posixServer.Listen(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Listen: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Send calls send on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// SendWithErrno.
+func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags)
+ if ret == -1 {
+ dut.t.Fatalf("failed to send: %s", err)
+ }
+ return ret
+}
+
+// SendWithErrno calls send on the DUT.
+func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.SendRequest{
+ Sockfd: sockfd,
+ Buf: buf,
+ Flags: flags,
+ }
+ resp, err := dut.posixServer.Send(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Send: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use SetSockOptWithErrno. Because endianess and the width of values
+// might differ between the testbench and DUT architectures, prefer to use a
+// more specific SetSockOptXxx function.
+func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval)
+ if ret != 0 {
+ dut.t.Fatalf("failed to SetSockOpt: %s", err)
+ }
+}
+
+// SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the
+// width of values might differ between the testbench and DUT architectures,
+// prefer to use a more specific SetSockOptXxxWithErrno function.
+func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname int32, optval []byte) (int32, error) {
+ dut.t.Helper()
+ req := pb.SetSockOptRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ Optval: optval,
+ }
+ resp, err := dut.posixServer.SetSockOpt(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call SetSockOpt: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the int optval or error handling
+// is needed, use SetSockOptIntWithErrno.
+func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval)
+ if ret != 0 {
+ dut.t.Fatalf("failed to SetSockOptInt: %s", err)
+ }
+}
+
+// SetSockOptIntWithErrno calls setsockopt with an integer optval.
+func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname, optval int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.SetSockOptIntRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ Intval: optval,
+ }
+ resp, err := dut.posixServer.SetSockOptInt(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call SetSockOptInt: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the timeout or error handling is
+// needed, use SetSockOptTimevalWithErrno.
+func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptTimevalWithErrno(ctx, sockfd, level, optname, tv)
+ if ret != 0 {
+ dut.t.Fatalf("failed to SetSockOptTimeval: %s", err)
+ }
+}
+
+// SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to
+// bytes.
+func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) {
+ dut.t.Helper()
+ timeval := pb.Timeval{
+ Seconds: int64(tv.Sec),
+ Microseconds: int64(tv.Usec),
+ }
+ req := pb.SetSockOptTimevalRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ Timeval: &timeval,
+ }
+ resp, err := dut.posixServer.SetSockOptTimeval(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call SetSockOptTimeval: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Socket calls socket on the DUT and returns the file descriptor. If socket
+// fails on the DUT, the test ends.
+func (dut *DUT) Socket(domain, typ, proto int32) int32 {
+ dut.t.Helper()
+ fd, err := dut.SocketWithErrno(domain, typ, proto)
+ if fd < 0 {
+ dut.t.Fatalf("failed to create socket: %s", err)
+ }
+ return fd
+}
+
+// SocketWithErrno calls socket on the DUT and returns the fd and errno.
+func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.SocketRequest{
+ Domain: domain,
+ Type: typ,
+ Protocol: proto,
+ }
+ ctx := context.Background()
+ resp, err := dut.posixServer.Socket(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Socket: %s", err)
+ }
+ return resp.GetFd(), syscall.Errno(resp.GetErrno_())
+}
+
+// Recv calls recv on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// RecvWithErrno.
+func (dut *DUT) Recv(sockfd, len, flags int32) []byte {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags)
+ if ret == -1 {
+ dut.t.Fatalf("failed to recv: %s", err)
+ }
+ return buf
+}
+
+// RecvWithErrno calls recv on the DUT.
+func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) {
+ dut.t.Helper()
+ req := pb.RecvRequest{
+ Sockfd: sockfd,
+ Len: len,
+ Flags: flags,
+ }
+ resp, err := dut.posixServer.Recv(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Recv: %s", err)
+ }
+ return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_())
+}
diff --git a/test/packetimpact/testbench/dut_client.go b/test/packetimpact/testbench/dut_client.go
new file mode 100644
index 000000000..b130a33a2
--- /dev/null
+++ b/test/packetimpact/testbench/dut_client.go
@@ -0,0 +1,28 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "google.golang.org/grpc"
+ pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto"
+)
+
+// PosixClient is a gRPC client for the Posix service.
+type PosixClient pb.PosixClient
+
+// NewPosixClient makes a new gRPC client for the Posix service.
+func NewPosixClient(c grpc.ClientConnInterface) PosixClient {
+ return pb.NewPosixClient(c)
+}
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
new file mode 100644
index 000000000..5ce324f0d
--- /dev/null
+++ b/test/packetimpact/testbench/layers.go
@@ -0,0 +1,710 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "encoding/hex"
+ "fmt"
+ "reflect"
+ "strings"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "github.com/imdario/mergo"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// Layer is the interface that all encapsulations must implement.
+//
+// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A
+// Layer contains all the fields of the encapsulation. Each field is a pointer
+// and may be nil.
+type Layer interface {
+ fmt.Stringer
+
+ // toBytes converts the Layer into bytes. In places where the Layer's field
+ // isn't nil, the value that is pointed to is used. When the field is nil, a
+ // reasonable default for the Layer is used. For example, "64" for IPv4 TTL
+ // and a calculated checksum for TCP or IP. Some layers require information
+ // from the previous or next layers in order to compute a default, such as
+ // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list
+ // to the layer's neighbors.
+ toBytes() ([]byte, error)
+
+ // match checks if the current Layer matches the provided Layer. If either
+ // Layer has a nil in a given field, that field is considered matching.
+ // Otherwise, the values pointed to by the fields must match. The LayerBase is
+ // ignored.
+ match(Layer) bool
+
+ // length in bytes of the current encapsulation
+ length() int
+
+ // next gets a pointer to the encapsulated Layer.
+ next() Layer
+
+ // prev gets a pointer to the Layer encapsulating this one.
+ prev() Layer
+
+ // setNext sets the pointer to the encapsulated Layer.
+ setNext(Layer)
+
+ // setPrev sets the pointer to the Layer encapsulating this one.
+ setPrev(Layer)
+
+ // merge overrides the values in the interface with the provided values.
+ merge(Layer) error
+}
+
+// LayerBase is the common elements of all layers.
+type LayerBase struct {
+ nextLayer Layer
+ prevLayer Layer
+}
+
+func (lb *LayerBase) next() Layer {
+ return lb.nextLayer
+}
+
+func (lb *LayerBase) prev() Layer {
+ return lb.prevLayer
+}
+
+func (lb *LayerBase) setNext(l Layer) {
+ lb.nextLayer = l
+}
+
+func (lb *LayerBase) setPrev(l Layer) {
+ lb.prevLayer = l
+}
+
+// equalLayer compares that two Layer structs match while ignoring field in
+// which either input has a nil and also ignoring the LayerBase of the inputs.
+func equalLayer(x, y Layer) bool {
+ if x == nil || y == nil {
+ return true
+ }
+ // opt ignores comparison pairs where either of the inputs is a nil.
+ opt := cmp.FilterValues(func(x, y interface{}) bool {
+ for _, l := range []interface{}{x, y} {
+ v := reflect.ValueOf(l)
+ if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() {
+ return true
+ }
+ }
+ return false
+ }, cmp.Ignore())
+ return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{}))
+}
+
+// mergeLayer merges other in layer. Any non-nil value in other overrides the
+// corresponding value in layer. If other is nil, no action is performed.
+func mergeLayer(layer, other Layer) error {
+ if other == nil {
+ return nil
+ }
+ return mergo.Merge(layer, other, mergo.WithOverride)
+}
+
+func stringLayer(l Layer) string {
+ v := reflect.ValueOf(l).Elem()
+ t := v.Type()
+ var ret []string
+ for i := 0; i < v.NumField(); i++ {
+ t := t.Field(i)
+ if t.Anonymous {
+ // Ignore the LayerBase in the Layer struct.
+ continue
+ }
+ v := v.Field(i)
+ if v.IsNil() {
+ continue
+ }
+ v = reflect.Indirect(v)
+ if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
+ ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes())))
+ } else {
+ ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v))
+ }
+ }
+ return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " "))
+}
+
+// Ether can construct and match an ethernet encapsulation.
+type Ether struct {
+ LayerBase
+ SrcAddr *tcpip.LinkAddress
+ DstAddr *tcpip.LinkAddress
+ Type *tcpip.NetworkProtocolNumber
+}
+
+func (l *Ether) String() string {
+ return stringLayer(l)
+}
+
+func (l *Ether) toBytes() ([]byte, error) {
+ b := make([]byte, header.EthernetMinimumSize)
+ h := header.Ethernet(b)
+ fields := &header.EthernetFields{}
+ if l.SrcAddr != nil {
+ fields.SrcAddr = *l.SrcAddr
+ }
+ if l.DstAddr != nil {
+ fields.DstAddr = *l.DstAddr
+ }
+ if l.Type != nil {
+ fields.Type = *l.Type
+ } else {
+ switch n := l.next().(type) {
+ case *IPv4:
+ fields.Type = header.IPv4ProtocolNumber
+ default:
+ // TODO(b/150301488): Support more protocols, like IPv6.
+ return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n)
+ }
+ }
+ h.Encode(fields)
+ return h, nil
+}
+
+// LinkAddress is a helper routine that allocates a new tcpip.LinkAddress value
+// to store v and returns a pointer to it.
+func LinkAddress(v tcpip.LinkAddress) *tcpip.LinkAddress {
+ return &v
+}
+
+// NetworkProtocolNumber is a helper routine that allocates a new
+// tcpip.NetworkProtocolNumber value to store v and returns a pointer to it.
+func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocolNumber {
+ return &v
+}
+
+// layerParser parses the input bytes and returns a Layer along with the next
+// layerParser to run. If there is no more parsing to do, the returned
+// layerParser is nil.
+type layerParser func([]byte) (Layer, layerParser)
+
+// parse parses bytes starting with the first layerParser and using successive
+// layerParsers until all the bytes are parsed.
+func parse(parser layerParser, b []byte) Layers {
+ var layers Layers
+ for {
+ var layer Layer
+ layer, parser = parser(b)
+ layers = append(layers, layer)
+ if parser == nil {
+ break
+ }
+ b = b[layer.length():]
+ }
+ layers.linkLayers()
+ return layers
+}
+
+// parseEther parses the bytes assuming that they start with an ethernet header
+// and continues parsing further encapsulations.
+func parseEther(b []byte) (Layer, layerParser) {
+ h := header.Ethernet(b)
+ ether := Ether{
+ SrcAddr: LinkAddress(h.SourceAddress()),
+ DstAddr: LinkAddress(h.DestinationAddress()),
+ Type: NetworkProtocolNumber(h.Type()),
+ }
+ var nextParser layerParser
+ switch h.Type() {
+ case header.IPv4ProtocolNumber:
+ nextParser = parseIPv4
+ default:
+ // Assume that the rest is a payload.
+ nextParser = parsePayload
+ }
+ return &ether, nextParser
+}
+
+func (l *Ether) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *Ether) length() int {
+ return header.EthernetMinimumSize
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *Ether) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// IPv4 can construct and match an IPv4 encapsulation.
+type IPv4 struct {
+ LayerBase
+ IHL *uint8
+ TOS *uint8
+ TotalLength *uint16
+ ID *uint16
+ Flags *uint8
+ FragmentOffset *uint16
+ TTL *uint8
+ Protocol *uint8
+ Checksum *uint16
+ SrcAddr *tcpip.Address
+ DstAddr *tcpip.Address
+}
+
+func (l *IPv4) String() string {
+ return stringLayer(l)
+}
+
+func (l *IPv4) toBytes() ([]byte, error) {
+ b := make([]byte, header.IPv4MinimumSize)
+ h := header.IPv4(b)
+ fields := &header.IPv4Fields{
+ IHL: 20,
+ TOS: 0,
+ TotalLength: 0,
+ ID: 0,
+ Flags: 0,
+ FragmentOffset: 0,
+ TTL: 64,
+ Protocol: 0,
+ Checksum: 0,
+ SrcAddr: tcpip.Address(""),
+ DstAddr: tcpip.Address(""),
+ }
+ if l.TOS != nil {
+ fields.TOS = *l.TOS
+ }
+ if l.TotalLength != nil {
+ fields.TotalLength = *l.TotalLength
+ } else {
+ fields.TotalLength = uint16(l.length())
+ current := l.next()
+ for current != nil {
+ fields.TotalLength += uint16(current.length())
+ current = current.next()
+ }
+ }
+ if l.ID != nil {
+ fields.ID = *l.ID
+ }
+ if l.Flags != nil {
+ fields.Flags = *l.Flags
+ }
+ if l.FragmentOffset != nil {
+ fields.FragmentOffset = *l.FragmentOffset
+ }
+ if l.TTL != nil {
+ fields.TTL = *l.TTL
+ }
+ if l.Protocol != nil {
+ fields.Protocol = *l.Protocol
+ } else {
+ switch n := l.next().(type) {
+ case *TCP:
+ fields.Protocol = uint8(header.TCPProtocolNumber)
+ case *UDP:
+ fields.Protocol = uint8(header.UDPProtocolNumber)
+ default:
+ // TODO(b/150301488): Support more protocols as needed.
+ return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n)
+ }
+ }
+ if l.SrcAddr != nil {
+ fields.SrcAddr = *l.SrcAddr
+ }
+ if l.DstAddr != nil {
+ fields.DstAddr = *l.DstAddr
+ }
+ if l.Checksum != nil {
+ fields.Checksum = *l.Checksum
+ }
+ h.Encode(fields)
+ if l.Checksum == nil {
+ h.SetChecksum(^h.CalculateChecksum())
+ }
+ return h, nil
+}
+
+// Uint16 is a helper routine that allocates a new
+// uint16 value to store v and returns a pointer to it.
+func Uint16(v uint16) *uint16 {
+ return &v
+}
+
+// Uint8 is a helper routine that allocates a new
+// uint8 value to store v and returns a pointer to it.
+func Uint8(v uint8) *uint8 {
+ return &v
+}
+
+// Address is a helper routine that allocates a new tcpip.Address value to store
+// v and returns a pointer to it.
+func Address(v tcpip.Address) *tcpip.Address {
+ return &v
+}
+
+// parseIPv4 parses the bytes assuming that they start with an ipv4 header and
+// continues parsing further encapsulations.
+func parseIPv4(b []byte) (Layer, layerParser) {
+ h := header.IPv4(b)
+ tos, _ := h.TOS()
+ ipv4 := IPv4{
+ IHL: Uint8(h.HeaderLength()),
+ TOS: &tos,
+ TotalLength: Uint16(h.TotalLength()),
+ ID: Uint16(h.ID()),
+ Flags: Uint8(h.Flags()),
+ FragmentOffset: Uint16(h.FragmentOffset()),
+ TTL: Uint8(h.TTL()),
+ Protocol: Uint8(h.Protocol()),
+ Checksum: Uint16(h.Checksum()),
+ SrcAddr: Address(h.SourceAddress()),
+ DstAddr: Address(h.DestinationAddress()),
+ }
+ var nextParser layerParser
+ switch h.TransportProtocol() {
+ case header.TCPProtocolNumber:
+ nextParser = parseTCP
+ case header.UDPProtocolNumber:
+ nextParser = parseUDP
+ default:
+ // Assume that the rest is a payload.
+ nextParser = parsePayload
+ }
+ return &ipv4, nextParser
+}
+
+func (l *IPv4) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *IPv4) length() int {
+ if l.IHL == nil {
+ return header.IPv4MinimumSize
+ }
+ return int(*l.IHL)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *IPv4) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// TCP can construct and match a TCP encapsulation.
+type TCP struct {
+ LayerBase
+ SrcPort *uint16
+ DstPort *uint16
+ SeqNum *uint32
+ AckNum *uint32
+ DataOffset *uint8
+ Flags *uint8
+ WindowSize *uint16
+ Checksum *uint16
+ UrgentPointer *uint16
+}
+
+func (l *TCP) String() string {
+ return stringLayer(l)
+}
+
+func (l *TCP) toBytes() ([]byte, error) {
+ b := make([]byte, header.TCPMinimumSize)
+ h := header.TCP(b)
+ if l.SrcPort != nil {
+ h.SetSourcePort(*l.SrcPort)
+ }
+ if l.DstPort != nil {
+ h.SetDestinationPort(*l.DstPort)
+ }
+ if l.SeqNum != nil {
+ h.SetSequenceNumber(*l.SeqNum)
+ }
+ if l.AckNum != nil {
+ h.SetAckNumber(*l.AckNum)
+ }
+ if l.DataOffset != nil {
+ h.SetDataOffset(*l.DataOffset)
+ } else {
+ h.SetDataOffset(uint8(l.length()))
+ }
+ if l.Flags != nil {
+ h.SetFlags(*l.Flags)
+ }
+ if l.WindowSize != nil {
+ h.SetWindowSize(*l.WindowSize)
+ } else {
+ h.SetWindowSize(32768)
+ }
+ if l.UrgentPointer != nil {
+ h.SetUrgentPoiner(*l.UrgentPointer)
+ }
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ return h, nil
+ }
+ if err := setTCPChecksum(&h, l); err != nil {
+ return nil, err
+ }
+ return h, nil
+}
+
+// totalLength returns the length of the provided layer and all following
+// layers.
+func totalLength(l Layer) int {
+ var totalLength int
+ for ; l != nil; l = l.next() {
+ totalLength += l.length()
+ }
+ return totalLength
+}
+
+// layerChecksum calculates the checksum of the Layer header, including the
+// peusdeochecksum of the layer before it and all the bytes after it..
+func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) {
+ totalLength := uint16(totalLength(l))
+ var xsum uint16
+ switch s := l.prev().(type) {
+ case *IPv4:
+ xsum = header.PseudoHeaderChecksum(protoNumber, *s.SrcAddr, *s.DstAddr, totalLength)
+ default:
+ // TODO(b/150301488): Support more protocols, like IPv6.
+ return 0, fmt.Errorf("can't get src and dst addr from previous layer: %#v", s)
+ }
+ var payloadBytes buffer.VectorisedView
+ for current := l.next(); current != nil; current = current.next() {
+ payload, err := current.toBytes()
+ if err != nil {
+ return 0, fmt.Errorf("can't get bytes for next header: %s", payload)
+ }
+ payloadBytes.AppendView(payload)
+ }
+ xsum = header.ChecksumVV(payloadBytes, xsum)
+ return xsum, nil
+}
+
+// setTCPChecksum calculates the checksum of the TCP header and sets it in h.
+func setTCPChecksum(h *header.TCP, tcp *TCP) error {
+ h.SetChecksum(0)
+ xsum, err := layerChecksum(tcp, header.TCPProtocolNumber)
+ if err != nil {
+ return err
+ }
+ h.SetChecksum(^h.CalculateChecksum(xsum))
+ return nil
+}
+
+// Uint32 is a helper routine that allocates a new
+// uint32 value to store v and returns a pointer to it.
+func Uint32(v uint32) *uint32 {
+ return &v
+}
+
+// parseTCP parses the bytes assuming that they start with a tcp header and
+// continues parsing further encapsulations.
+func parseTCP(b []byte) (Layer, layerParser) {
+ h := header.TCP(b)
+ tcp := TCP{
+ SrcPort: Uint16(h.SourcePort()),
+ DstPort: Uint16(h.DestinationPort()),
+ SeqNum: Uint32(h.SequenceNumber()),
+ AckNum: Uint32(h.AckNumber()),
+ DataOffset: Uint8(h.DataOffset()),
+ Flags: Uint8(h.Flags()),
+ WindowSize: Uint16(h.WindowSize()),
+ Checksum: Uint16(h.Checksum()),
+ UrgentPointer: Uint16(h.UrgentPointer()),
+ }
+ return &tcp, parsePayload
+}
+
+func (l *TCP) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *TCP) length() int {
+ if l.DataOffset == nil {
+ return header.TCPMinimumSize
+ }
+ return int(*l.DataOffset)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *TCP) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// UDP can construct and match a UDP encapsulation.
+type UDP struct {
+ LayerBase
+ SrcPort *uint16
+ DstPort *uint16
+ Length *uint16
+ Checksum *uint16
+}
+
+func (l *UDP) String() string {
+ return stringLayer(l)
+}
+
+func (l *UDP) toBytes() ([]byte, error) {
+ b := make([]byte, header.UDPMinimumSize)
+ h := header.UDP(b)
+ if l.SrcPort != nil {
+ h.SetSourcePort(*l.SrcPort)
+ }
+ if l.DstPort != nil {
+ h.SetDestinationPort(*l.DstPort)
+ }
+ if l.Length != nil {
+ h.SetLength(*l.Length)
+ } else {
+ h.SetLength(uint16(totalLength(l)))
+ }
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ return h, nil
+ }
+ if err := setUDPChecksum(&h, l); err != nil {
+ return nil, err
+ }
+ return h, nil
+}
+
+// setUDPChecksum calculates the checksum of the UDP header and sets it in h.
+func setUDPChecksum(h *header.UDP, udp *UDP) error {
+ h.SetChecksum(0)
+ xsum, err := layerChecksum(udp, header.UDPProtocolNumber)
+ if err != nil {
+ return err
+ }
+ h.SetChecksum(^h.CalculateChecksum(xsum))
+ return nil
+}
+
+// parseUDP parses the bytes assuming that they start with a udp header and
+// returns the parsed layer and the next parser to use.
+func parseUDP(b []byte) (Layer, layerParser) {
+ h := header.UDP(b)
+ udp := UDP{
+ SrcPort: Uint16(h.SourcePort()),
+ DstPort: Uint16(h.DestinationPort()),
+ Length: Uint16(h.Length()),
+ Checksum: Uint16(h.Checksum()),
+ }
+ return &udp, parsePayload
+}
+
+func (l *UDP) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *UDP) length() int {
+ if l.Length == nil {
+ return header.UDPMinimumSize
+ }
+ return int(*l.Length)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *UDP) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// Payload has bytes beyond OSI layer 4.
+type Payload struct {
+ LayerBase
+ Bytes []byte
+}
+
+func (l *Payload) String() string {
+ return stringLayer(l)
+}
+
+// parsePayload parses the bytes assuming that they start with a payload and
+// continue to the end. There can be no further encapsulations.
+func parsePayload(b []byte) (Layer, layerParser) {
+ payload := Payload{
+ Bytes: b,
+ }
+ return &payload, nil
+}
+
+func (l *Payload) toBytes() ([]byte, error) {
+ return l.Bytes, nil
+}
+
+func (l *Payload) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *Payload) length() int {
+ return len(l.Bytes)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *Payload) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// Layers is an array of Layer and supports similar functions to Layer.
+type Layers []Layer
+
+// linkLayers sets the linked-list ponters in ls.
+func (ls *Layers) linkLayers() {
+ for i, l := range *ls {
+ if i > 0 {
+ l.setPrev((*ls)[i-1])
+ } else {
+ l.setPrev(nil)
+ }
+ if i+1 < len(*ls) {
+ l.setNext((*ls)[i+1])
+ } else {
+ l.setNext(nil)
+ }
+ }
+}
+
+func (ls *Layers) toBytes() ([]byte, error) {
+ ls.linkLayers()
+ outBytes := []byte{}
+ for _, l := range *ls {
+ layerBytes, err := l.toBytes()
+ if err != nil {
+ return nil, err
+ }
+ outBytes = append(outBytes, layerBytes...)
+ }
+ return outBytes, nil
+}
+
+func (ls *Layers) match(other Layers) bool {
+ if len(*ls) > len(other) {
+ return false
+ }
+ for i, l := range *ls {
+ if !equalLayer(l, other[i]) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go
new file mode 100644
index 000000000..b32efda93
--- /dev/null
+++ b/test/packetimpact/testbench/layers_test.go
@@ -0,0 +1,156 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestLayerMatch(t *testing.T) {
+ var nilPayload *Payload
+ noPayload := &Payload{}
+ emptyPayload := &Payload{Bytes: []byte{}}
+ fullPayload := &Payload{Bytes: []byte{1, 2, 3}}
+ emptyTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: emptyPayload}}
+ fullTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: fullPayload}}
+ for _, tt := range []struct {
+ a, b Layer
+ want bool
+ }{
+ {nilPayload, nilPayload, true},
+ {nilPayload, noPayload, true},
+ {nilPayload, emptyPayload, true},
+ {nilPayload, fullPayload, true},
+ {noPayload, noPayload, true},
+ {noPayload, emptyPayload, true},
+ {noPayload, fullPayload, true},
+ {emptyPayload, emptyPayload, true},
+ {emptyPayload, fullPayload, false},
+ {fullPayload, fullPayload, true},
+ {emptyTCP, fullTCP, true},
+ } {
+ if got := tt.a.match(tt.b); got != tt.want {
+ t.Errorf("%s.match(%s) = %t, want %t", tt.a, tt.b, got, tt.want)
+ }
+ if got := tt.b.match(tt.a); got != tt.want {
+ t.Errorf("%s.match(%s) = %t, want %t", tt.b, tt.a, got, tt.want)
+ }
+ }
+}
+
+func TestLayerStringFormat(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ l Layer
+ want string
+ }{
+ {
+ name: "TCP",
+ l: &TCP{
+ SrcPort: Uint16(34785),
+ DstPort: Uint16(47767),
+ SeqNum: Uint32(3452155723),
+ AckNum: Uint32(2596996163),
+ DataOffset: Uint8(5),
+ Flags: Uint8(20),
+ WindowSize: Uint16(64240),
+ Checksum: Uint16(0x2e2b),
+ },
+ want: "&testbench.TCP{" +
+ "SrcPort:34785 " +
+ "DstPort:47767 " +
+ "SeqNum:3452155723 " +
+ "AckNum:2596996163 " +
+ "DataOffset:5 " +
+ "Flags:20 " +
+ "WindowSize:64240 " +
+ "Checksum:11819" +
+ "}",
+ },
+ {
+ name: "UDP",
+ l: &UDP{
+ SrcPort: Uint16(34785),
+ DstPort: Uint16(47767),
+ Length: Uint16(12),
+ },
+ want: "&testbench.UDP{" +
+ "SrcPort:34785 " +
+ "DstPort:47767 " +
+ "Length:12" +
+ "}",
+ },
+ {
+ name: "IPv4",
+ l: &IPv4{
+ IHL: Uint8(5),
+ TOS: Uint8(0),
+ TotalLength: Uint16(44),
+ ID: Uint16(0),
+ Flags: Uint8(2),
+ FragmentOffset: Uint16(0),
+ TTL: Uint8(64),
+ Protocol: Uint8(6),
+ Checksum: Uint16(0x2e2b),
+ SrcAddr: Address(tcpip.Address([]byte{197, 34, 63, 10})),
+ DstAddr: Address(tcpip.Address([]byte{197, 34, 63, 20})),
+ },
+ want: "&testbench.IPv4{" +
+ "IHL:5 " +
+ "TOS:0 " +
+ "TotalLength:44 " +
+ "ID:0 " +
+ "Flags:2 " +
+ "FragmentOffset:0 " +
+ "TTL:64 " +
+ "Protocol:6 " +
+ "Checksum:11819 " +
+ "SrcAddr:197.34.63.10 " +
+ "DstAddr:197.34.63.20" +
+ "}",
+ },
+ {
+ name: "Ether",
+ l: &Ether{
+ SrcAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x0a})),
+ DstAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x14})),
+ Type: NetworkProtocolNumber(4),
+ },
+ want: "&testbench.Ether{" +
+ "SrcAddr:02:42:c5:22:3f:0a " +
+ "DstAddr:02:42:c5:22:3f:14 " +
+ "Type:4" +
+ "}",
+ },
+ {
+ name: "Payload",
+ l: &Payload{
+ Bytes: []byte("Hooray for packetimpact."),
+ },
+ want: "&testbench.Payload{Bytes:\n" +
+ "00000000 48 6f 6f 72 61 79 20 66 6f 72 20 70 61 63 6b 65 |Hooray for packe|\n" +
+ "00000010 74 69 6d 70 61 63 74 2e |timpact.|\n" +
+ "}",
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := tt.l.String(); got != tt.want {
+ t.Errorf("%s.String() = %s, want: %s", tt.name, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
new file mode 100644
index 000000000..ff722d4a6
--- /dev/null
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -0,0 +1,183 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "encoding/binary"
+ "flag"
+ "fmt"
+ "math"
+ "net"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var device = flag.String("device", "", "local device for test packets")
+
+// Sniffer can sniff raw packets on the wire.
+type Sniffer struct {
+ t *testing.T
+ fd int
+}
+
+func htons(x uint16) uint16 {
+ buf := [2]byte{}
+ binary.BigEndian.PutUint16(buf[:], x)
+ return usermem.ByteOrder.Uint16(buf[:])
+}
+
+// NewSniffer creates a Sniffer connected to *device.
+func NewSniffer(t *testing.T) (Sniffer, error) {
+ flag.Parse()
+ snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL)))
+ if err != nil {
+ return Sniffer{}, err
+ }
+ if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, 1); err != nil {
+ t.Fatalf("can't set sockopt SO_RCVBUFFORCE to 1: %s", err)
+ }
+ if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1e7); err != nil {
+ t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err)
+ }
+ return Sniffer{
+ t: t,
+ fd: snifferFd,
+ }, nil
+}
+
+// maxReadSize should be large enough for the maximum frame size in bytes. If a
+// packet too large for the buffer arrives, the test will get a fatal error.
+const maxReadSize int = 65536
+
+// Recv tries to read one frame until the timeout is up.
+func (s *Sniffer) Recv(timeout time.Duration) []byte {
+ deadline := time.Now().Add(timeout)
+ for {
+ timeout = deadline.Sub(time.Now())
+ if timeout <= 0 {
+ return nil
+ }
+ whole, frac := math.Modf(timeout.Seconds())
+ tv := unix.Timeval{
+ Sec: int64(whole),
+ Usec: int64(frac * float64(time.Microsecond/time.Second)),
+ }
+
+ if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
+ s.t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err)
+ }
+
+ buf := make([]byte, maxReadSize)
+ nread, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC)
+ if err == unix.EINTR || err == unix.EAGAIN {
+ // There was a timeout.
+ continue
+ }
+ if err != nil {
+ s.t.Fatalf("can't read: %s", err)
+ }
+ if nread > maxReadSize {
+ s.t.Fatalf("received a truncated frame of %d bytes", nread)
+ }
+ return buf[:nread]
+ }
+}
+
+// Drain drains the Sniffer's socket receive buffer by receiving until there's
+// nothing else to receive.
+func (s *Sniffer) Drain() {
+ s.t.Helper()
+ flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0)
+ if err != nil {
+ s.t.Fatalf("failed to get sniffer socket fd flags: %s", err)
+ }
+ if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil {
+ s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err)
+ }
+ for {
+ buf := make([]byte, maxReadSize)
+ _, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC)
+ if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK {
+ break
+ }
+ }
+ if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil {
+ s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err)
+ }
+}
+
+// close the socket that Sniffer is using.
+func (s *Sniffer) close() error {
+ if err := unix.Close(s.fd); err != nil {
+ return fmt.Errorf("can't close sniffer socket: %w", err)
+ }
+ s.fd = -1
+ return nil
+}
+
+// Injector can inject raw frames.
+type Injector struct {
+ t *testing.T
+ fd int
+}
+
+// NewInjector creates a new injector on *device.
+func NewInjector(t *testing.T) (Injector, error) {
+ flag.Parse()
+ ifInfo, err := net.InterfaceByName(*device)
+ if err != nil {
+ return Injector{}, err
+ }
+
+ var haddr [8]byte
+ copy(haddr[:], ifInfo.HardwareAddr)
+ sa := unix.SockaddrLinklayer{
+ Protocol: unix.ETH_P_IP,
+ Ifindex: ifInfo.Index,
+ Halen: uint8(len(ifInfo.HardwareAddr)),
+ Addr: haddr,
+ }
+
+ injectFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL)))
+ if err != nil {
+ return Injector{}, err
+ }
+ if err := unix.Bind(injectFd, &sa); err != nil {
+ return Injector{}, err
+ }
+ return Injector{
+ t: t,
+ fd: injectFd,
+ }, nil
+}
+
+// Send a raw frame.
+func (i *Injector) Send(b []byte) {
+ if _, err := unix.Write(i.fd, b); err != nil {
+ i.t.Fatalf("can't write: %s", err)
+ }
+}
+
+// close the underlying socket.
+func (i *Injector) close() error {
+ if err := unix.Close(i.fd); err != nil {
+ return fmt.Errorf("can't close sniffer socket: %w", err)
+ }
+ i.fd = -1
+ return nil
+}
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
new file mode 100644
index 000000000..690cee140
--- /dev/null
+++ b/test/packetimpact/tests/BUILD
@@ -0,0 +1,94 @@
+load("defs.bzl", "packetimpact_go_test")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+packetimpact_go_test(
+ name = "fin_wait2_timeout",
+ srcs = ["fin_wait2_timeout_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "udp_recv_multicast",
+ srcs = ["udp_recv_multicast_test.go"],
+ # TODO(b/152813495): Fix netstack then remove the line below.
+ netstack = False,
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_window_shrink",
+ srcs = ["tcp_window_shrink_test.go"],
+ # TODO(b/153202472): Fix netstack then remove the line below.
+ netstack = False,
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_outside_the_window",
+ srcs = ["tcp_outside_the_window_test.go"],
+ # TODO(eyalsoha): Fix #1607 then remove the line below.
+ netstack = False,
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_noaccept_close_rst",
+ srcs = ["tcp_noaccept_close_rst_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_should_piggyback",
+ srcs = ["tcp_should_piggyback_test.go"],
+ # TODO(b/153680566): Fix netstack then remove the line below.
+ netstack = False,
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_close_wait_ack",
+ srcs = ["tcp_close_wait_ack_test.go"],
+ # TODO(b/153574037): Fix netstack then remove the line below.
+ netstack = False,
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+sh_binary(
+ name = "test_runner",
+ srcs = ["test_runner.sh"],
+)
diff --git a/test/packetimpact/tests/Dockerfile b/test/packetimpact/tests/Dockerfile
new file mode 100644
index 000000000..9075bc555
--- /dev/null
+++ b/test/packetimpact/tests/Dockerfile
@@ -0,0 +1,17 @@
+FROM ubuntu:bionic
+
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
+ # iptables to disable OS native packet processing.
+ iptables \
+ # nc to check that the posix_server is running.
+ netcat \
+ # tcpdump to log brief packet sniffing.
+ tcpdump \
+ # ip link show to display MAC addresses.
+ iproute2 \
+ # tshark to log verbose packet sniffing.
+ tshark \
+ # killall for cleanup.
+ psmisc
+RUN hash -r
+CMD /bin/bash
diff --git a/test/packetimpact/tests/defs.bzl b/test/packetimpact/tests/defs.bzl
new file mode 100644
index 000000000..8c0d058b2
--- /dev/null
+++ b/test/packetimpact/tests/defs.bzl
@@ -0,0 +1,118 @@
+"""Defines rules for packetimpact test targets."""
+
+load("//tools:defs.bzl", "go_test")
+
+def _packetimpact_test_impl(ctx):
+ test_runner = ctx.executable._test_runner
+ bench = ctx.actions.declare_file("%s-bench" % ctx.label.name)
+ bench_content = "\n".join([
+ "#!/bin/bash",
+ # This test will run part in a distinct user namespace. This can cause
+ # permission problems, because all runfiles may not be owned by the
+ # current user, and no other users will be mapped in that namespace.
+ # Make sure that everything is readable here.
+ "find . -type f -exec chmod a+rx {} \\;",
+ "find . -type d -exec chmod a+rx {} \\;",
+ "%s %s --posix_server_binary %s --testbench_binary %s $@\n" % (
+ test_runner.short_path,
+ " ".join(ctx.attr.flags),
+ ctx.files._posix_server_binary[0].short_path,
+ ctx.files.testbench_binary[0].short_path,
+ ),
+ ])
+ ctx.actions.write(bench, bench_content, is_executable = True)
+
+ transitive_files = depset()
+ if hasattr(ctx.attr._test_runner, "data_runfiles"):
+ transitive_files = depset(ctx.attr._test_runner.data_runfiles.files)
+ runfiles = ctx.runfiles(
+ files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server_binary,
+ transitive_files = transitive_files,
+ collect_default = True,
+ collect_data = True,
+ )
+ return [DefaultInfo(executable = bench, runfiles = runfiles)]
+
+_packetimpact_test = rule(
+ attrs = {
+ "_test_runner": attr.label(
+ executable = True,
+ cfg = "target",
+ default = ":test_runner",
+ ),
+ "_posix_server_binary": attr.label(
+ cfg = "target",
+ default = "//test/packetimpact/dut:posix_server",
+ ),
+ "testbench_binary": attr.label(
+ cfg = "target",
+ mandatory = True,
+ ),
+ "flags": attr.string_list(
+ mandatory = False,
+ default = [],
+ ),
+ },
+ test = True,
+ implementation = _packetimpact_test_impl,
+)
+
+PACKETIMPACT_TAGS = ["local", "manual"]
+
+def packetimpact_linux_test(name, testbench_binary, **kwargs):
+ """Add a packetimpact test on linux.
+
+ Args:
+ name: name of the test
+ testbench_binary: the testbench binary
+ **kwargs: all the other args, forwarded to _packetimpact_test
+ """
+ _packetimpact_test(
+ name = name + "_linux_test",
+ testbench_binary = testbench_binary,
+ flags = ["--dut_platform", "linux"],
+ tags = PACKETIMPACT_TAGS + ["packetimpact"],
+ **kwargs
+ )
+
+def packetimpact_netstack_test(name, testbench_binary, **kwargs):
+ """Add a packetimpact test on netstack.
+
+ Args:
+ name: name of the test
+ testbench_binary: the testbench binary
+ **kwargs: all the other args, forwarded to _packetimpact_test
+ """
+ _packetimpact_test(
+ name = name + "_netstack_test",
+ testbench_binary = testbench_binary,
+ # This is the default runtime unless
+ # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value.
+ flags = ["--dut_platform", "netstack", "--runtime=runsc-d"],
+ tags = PACKETIMPACT_TAGS + ["packetimpact"],
+ **kwargs
+ )
+
+def packetimpact_go_test(name, size = "small", pure = True, linux = True, netstack = True, **kwargs):
+ """Add packetimpact tests written in go.
+
+ Args:
+ name: name of the test
+ size: size of the test
+ pure: make a static go binary
+ linux: generate a linux test
+ netstack: generate a netstack test
+ **kwargs: all the other args, forwarded to go_test
+ """
+ testbench_binary = name + "_test"
+ go_test(
+ name = testbench_binary,
+ size = size,
+ pure = pure,
+ tags = PACKETIMPACT_TAGS,
+ **kwargs
+ )
+ if linux:
+ packetimpact_linux_test(name = name, testbench_binary = testbench_binary)
+ if netstack:
+ packetimpact_netstack_test(name = name, testbench_binary = testbench_binary)
diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go
new file mode 100644
index 000000000..b98594f94
--- /dev/null
+++ b/test/packetimpact/tests/fin_wait2_timeout_test.go
@@ -0,0 +1,70 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fin_wait2_timeout_test
+
+import (
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func TestFinWait2Timeout(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ linger2 bool
+ }{
+ {"WithLinger2", true},
+ {"WithoutLinger2", false},
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+ conn.Handshake()
+
+ acceptFd, _ := dut.Accept(listenFd)
+ if tt.linger2 {
+ tv := unix.Timeval{Sec: 1, Usec: 0}
+ dut.SetSockOptTimeval(acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv)
+ }
+ dut.Close(acceptFd)
+
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err)
+ }
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+
+ time.Sleep(5 * time.Second)
+ conn.Drain()
+
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ if tt.linger2 {
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ t.Fatalf("expected a RST packet within a second but got none: %s", err)
+ }
+ } else {
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, 10*time.Second); err == nil {
+ t.Fatalf("expected no RST packets within ten seconds but got one: %s", err)
+ }
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go
new file mode 100644
index 000000000..eb4cc7a65
--- /dev/null
+++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go
@@ -0,0 +1,102 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_close_wait_ack_test
+
+import (
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func TestCloseWaitAck(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ makeTestingTCP func(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP
+ seqNumOffset seqnum.Size
+ expectAck bool
+ }{
+ {"OTW", GenerateOTWSeqSegment, 0, false},
+ {"OTW", GenerateOTWSeqSegment, 1, true},
+ {"OTW", GenerateOTWSeqSegment, 2, true},
+ {"ACK", GenerateUnaccACKSegment, 0, false},
+ {"ACK", GenerateUnaccACKSegment, 1, true},
+ {"ACK", GenerateUnaccACKSegment, 2, true},
+ } {
+ t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Handshake()
+ acceptFd, _ := dut.Accept(listenFd)
+
+ // Send a FIN to DUT to intiate the active close
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagFin)})
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err)
+ }
+
+ // Send a segment with OTW Seq / unacc ACK and expect an ACK back
+ conn.Send(tt.makeTestingTCP(&conn, tt.seqNumOffset), &tb.Payload{Bytes: []byte("Sample Data")})
+ gotAck, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second)
+ if tt.expectAck && err != nil {
+ t.Fatalf("expected an ack but got none: %s", err)
+ }
+ if !tt.expectAck && gotAck != nil {
+ t.Fatalf("expected no ack but got one: %s", gotAck)
+ }
+
+ // Now let's verify DUT is indeed in CLOSE_WAIT
+ dut.Close(acceptFd)
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil {
+ t.Fatalf("expected DUT to send a FIN: %s", err)
+ }
+ // Ack the FIN from DUT
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ // Send some extra data to DUT
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, &tb.Payload{Bytes: []byte("Sample Data")})
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ t.Fatalf("expected DUT to send an RST: %s", err)
+ }
+ })
+ }
+}
+
+// This generates an segment with seqnum = RCV.NXT + RCV.WND + seqNumOffset, the
+// generated segment is only acceptable when seqNumOffset is 0, otherwise an ACK
+// is expected from the receiver.
+func GenerateOTWSeqSegment(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP {
+ windowSize := seqnum.Size(*conn.SynAck().WindowSize)
+ lastAcceptable := conn.LocalSeqNum().Add(windowSize - 1)
+ otwSeq := uint32(lastAcceptable.Add(seqNumOffset))
+ return tb.TCP{SeqNum: tb.Uint32(otwSeq), Flags: tb.Uint8(header.TCPFlagAck)}
+}
+
+// This generates an segment with acknum = SND.NXT + seqNumOffset, the generated
+// segment is only acceptable when seqNumOffset is 0, otherwise an ACK is
+// expected from the receiver.
+func GenerateUnaccACKSegment(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP {
+ lastAcceptable := conn.RemoteSeqNum()
+ unaccAck := uint32(lastAcceptable.Add(seqNumOffset))
+ return tb.TCP{AckNum: tb.Uint32(unaccAck), Flags: tb.Uint8(header.TCPFlagAck)}
+}
diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
new file mode 100644
index 000000000..7ebdd1950
--- /dev/null
+++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_noaccept_close_rst_test
+
+import (
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func TestTcpNoAcceptCloseReset(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn.Handshake()
+ defer conn.Close()
+ dut.Close(listenFd)
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil {
+ t.Fatalf("expected a RST-ACK packet but got none: %s", err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go
new file mode 100644
index 000000000..db3d3273b
--- /dev/null
+++ b/test/packetimpact/tests/tcp_outside_the_window_test.go
@@ -0,0 +1,88 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_outside_the_window_test
+
+import (
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+// TestTCPOutsideTheWindows tests the behavior of the DUT when packets arrive
+// that are inside or outside the TCP window. Packets that are outside the
+// window should force an extra ACK, as described in RFC793 page 69:
+// https://tools.ietf.org/html/rfc793#page-69
+func TestTCPOutsideTheWindow(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ tcpFlags uint8
+ payload []tb.Layer
+ seqNumOffset seqnum.Size
+ expectACK bool
+ }{
+ {"SYN", header.TCPFlagSyn, nil, 0, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 0, true},
+ {"ACK", header.TCPFlagAck, nil, 0, false},
+ {"FIN", header.TCPFlagFin, nil, 0, false},
+ {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 0, true},
+
+ {"SYN", header.TCPFlagSyn, nil, 1, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 1, true},
+ {"ACK", header.TCPFlagAck, nil, 1, true},
+ {"FIN", header.TCPFlagFin, nil, 1, false},
+ {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 1, true},
+
+ {"SYN", header.TCPFlagSyn, nil, 2, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 2, true},
+ {"ACK", header.TCPFlagAck, nil, 2, true},
+ {"FIN", header.TCPFlagFin, nil, 2, false},
+ {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 2, true},
+ } {
+ t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+ conn.Handshake()
+ acceptFD, _ := dut.Accept(listenFD)
+ defer dut.Close(acceptFD)
+
+ windowSize := seqnum.Size(*conn.SynAck().WindowSize) + tt.seqNumOffset
+ conn.Drain()
+ // Ignore whatever incrementing that this out-of-order packet might cause
+ // to the AckNum.
+ localSeqNum := tb.Uint32(uint32(*conn.LocalSeqNum()))
+ conn.Send(tb.TCP{
+ Flags: tb.Uint8(tt.tcpFlags),
+ SeqNum: tb.Uint32(uint32(conn.LocalSeqNum().Add(windowSize))),
+ }, tt.payload...)
+ timeout := 3 * time.Second
+ gotACK, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout)
+ if tt.expectACK && err != nil {
+ t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err)
+ }
+ if !tt.expectACK && gotACK != nil {
+ t.Fatalf("expected no ACK packet within %s but got one: %s", timeout, gotACK)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_should_piggyback_test.go b/test/packetimpact/tests/tcp_should_piggyback_test.go
new file mode 100644
index 000000000..b0be6ba23
--- /dev/null
+++ b/test/packetimpact/tests/tcp_should_piggyback_test.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_should_piggyback_test
+
+import (
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func TestPiggyback(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort, WindowSize: tb.Uint16(12)}, tb.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Handshake()
+ acceptFd, _ := dut.Accept(listenFd)
+ defer dut.Close(acceptFd)
+
+ dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+
+ dut.Send(acceptFd, sampleData, 0)
+ expectedTCP := tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}
+ expectedPayload := tb.Payload{Bytes: sampleData}
+ if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("Expected %v but didn't get one: %s", tb.Layers{&expectedTCP, &expectedPayload}, err)
+ }
+
+ // Cause DUT to send us more data as soon as we ACK their first data segment because we have
+ // a small window.
+ dut.Send(acceptFd, sampleData, 0)
+
+ // DUT should ACK our segment by piggybacking ACK to their outstanding data segment instead of
+ // sending a separate ACK packet.
+ conn.Send(expectedTCP, &expectedPayload)
+ if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("Expected %v but didn't get one: %s", tb.Layers{&expectedTCP, &expectedPayload}, err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go
new file mode 100644
index 000000000..c9354074e
--- /dev/null
+++ b/test/packetimpact/tests/tcp_window_shrink_test.go
@@ -0,0 +1,68 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_window_shrink_test
+
+import (
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func TestWindowShrink(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Handshake()
+ acceptFd, _ := dut.Accept(listenFd)
+ defer dut.Close(acceptFd)
+
+ dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &tb.Payload{Bytes: sampleData}
+
+ dut.Send(acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
+ }
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+
+ dut.Send(acceptFd, sampleData, 0)
+ dut.Send(acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
+ }
+ if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
+ }
+ // We close our receiving window here
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), WindowSize: tb.Uint16(0)})
+
+ dut.Send(acceptFd, []byte("Sample Data"), 0)
+ // Note: There is another kind of zero-window probing which Windows uses (by sending one
+ // new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change
+ // the following lines.
+ expectedRemoteSeqNum := *conn.RemoteSeqNum() - 1
+ if _, err := conn.ExpectData(&tb.TCP{SeqNum: tb.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil {
+ t.Fatalf("expected a packet with sequence number %v: %s", expectedRemoteSeqNum, err)
+ }
+}
diff --git a/test/packetimpact/tests/test_runner.sh b/test/packetimpact/tests/test_runner.sh
new file mode 100755
index 000000000..e99fc7d09
--- /dev/null
+++ b/test/packetimpact/tests/test_runner.sh
@@ -0,0 +1,268 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Run a packetimpact test. Two docker containers are made, one for the
+# Device-Under-Test (DUT) and one for the test bench. Each is attached with
+# two networks, one for control packets that aid the test and one for test
+# packets which are sent as part of the test and observed for correctness.
+
+set -euxo pipefail
+
+function failure() {
+ local lineno=$1
+ local msg=$2
+ local filename="$0"
+ echo "FAIL: $filename:$lineno: $msg"
+}
+trap 'failure ${LINENO} "$BASH_COMMAND"' ERR
+
+declare -r LONGOPTS="dut_platform:,posix_server_binary:,testbench_binary:,runtime:,tshark,extra_test_arg:"
+
+# Don't use declare below so that the error from getopt will end the script.
+PARSED=$(getopt --options "" --longoptions=$LONGOPTS --name "$0" -- "$@")
+
+eval set -- "$PARSED"
+
+declare -a EXTRA_TEST_ARGS
+
+while true; do
+ case "$1" in
+ --dut_platform)
+ # Either "linux" or "netstack".
+ declare -r DUT_PLATFORM="$2"
+ shift 2
+ ;;
+ --posix_server_binary)
+ declare -r POSIX_SERVER_BINARY="$2"
+ shift 2
+ ;;
+ --testbench_binary)
+ declare -r TESTBENCH_BINARY="$2"
+ shift 2
+ ;;
+ --runtime)
+ # Not readonly because there might be multiple --runtime arguments and we
+ # want to use just the last one. Only used if --dut_platform is
+ # "netstack".
+ declare RUNTIME="$2"
+ shift 2
+ ;;
+ --tshark)
+ declare -r TSHARK="1"
+ shift 1
+ ;;
+ --extra_test_arg)
+ EXTRA_TEST_ARGS+="$2"
+ shift 2
+ ;;
+ --)
+ shift
+ break
+ ;;
+ *)
+ echo "Programming error"
+ exit 3
+ esac
+done
+
+# All the other arguments are scripts.
+declare -r scripts="$@"
+
+# Check that the required flags are defined in a way that is safe for "set -u".
+if [[ "${DUT_PLATFORM-}" == "netstack" ]]; then
+ if [[ -z "${RUNTIME-}" ]]; then
+ echo "FAIL: Missing --runtime argument: ${RUNTIME-}"
+ exit 2
+ fi
+ declare -r RUNTIME_ARG="--runtime ${RUNTIME}"
+elif [[ "${DUT_PLATFORM-}" == "linux" ]]; then
+ declare -r RUNTIME_ARG=""
+else
+ echo "FAIL: Bad or missing --dut_platform argument: ${DUT_PLATFORM-}"
+ exit 2
+fi
+if [[ ! -f "${POSIX_SERVER_BINARY-}" ]]; then
+ echo "FAIL: Bad or missing --posix_server_binary: ${POSIX_SERVER-}"
+ exit 2
+fi
+if [[ ! -f "${TESTBENCH_BINARY-}" ]]; then
+ echo "FAIL: Bad or missing --testbench_binary: ${TESTBENCH_BINARY-}"
+ exit 2
+fi
+
+# Variables specific to the control network and interface start with CTRL_.
+# Variables specific to the test network and interface start with TEST_.
+# Variables specific to the DUT start with DUT_.
+# Variables specific to the test bench start with TESTBENCH_.
+# Use random numbers so that test networks don't collide.
+declare -r CTRL_NET="ctrl_net-${RANDOM}${RANDOM}"
+declare -r TEST_NET="test_net-${RANDOM}${RANDOM}"
+# On both DUT and test bench, testing packets are on the eth2 interface.
+declare -r TEST_DEVICE="eth2"
+# Number of bits in the *_NET_PREFIX variables.
+declare -r NET_MASK="24"
+function new_net_prefix() {
+ # Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24.
+ echo "$(shuf -i 192-223 -n 1).$(shuf -i 0-255 -n 1).$(shuf -i 0-255 -n 1)"
+}
+# Last bits of the DUT's IP address.
+declare -r DUT_NET_SUFFIX=".10"
+# Control port.
+declare -r CTRL_PORT="40000"
+# Last bits of the test bench's IP address.
+declare -r TESTBENCH_NET_SUFFIX=".20"
+declare -r TIMEOUT="60"
+declare -r IMAGE_TAG="gcr.io/gvisor-presubmit/packetimpact"
+# Make sure that docker is installed.
+docker --version
+
+function finish {
+ local cleanup_success=1
+
+ if [[ -z "${TSHARK-}" ]]; then
+ # Kill tcpdump so that it will flush output.
+ docker exec -t "${TESTBENCH}" \
+ killall tcpdump || \
+ cleanup_success=0
+ else
+ # Kill tshark so that it will flush output.
+ docker exec -t "${TESTBENCH}" \
+ killall tshark || \
+ cleanup_success=0
+ fi
+
+ for net in "${CTRL_NET}" "${TEST_NET}"; do
+ # Kill all processes attached to ${net}.
+ for docker_command in "kill" "rm"; do
+ (docker network inspect "${net}" \
+ --format '{{range $key, $value := .Containers}}{{$key}} {{end}}' \
+ | xargs -r docker "${docker_command}") || \
+ cleanup_success=0
+ done
+ # Remove the network.
+ docker network rm "${net}" || \
+ cleanup_success=0
+ done
+
+ if ((!$cleanup_success)); then
+ echo "FAIL: Cleanup command failed"
+ exit 4
+ fi
+}
+trap finish EXIT
+
+# Subnet for control packets between test bench and DUT.
+declare CTRL_NET_PREFIX=$(new_net_prefix)
+while ! docker network create \
+ "--subnet=${CTRL_NET_PREFIX}.0/${NET_MASK}" "${CTRL_NET}"; do
+ sleep 0.1
+ declare CTRL_NET_PREFIX=$(new_net_prefix)
+done
+
+# Subnet for the packets that are part of the test.
+declare TEST_NET_PREFIX=$(new_net_prefix)
+while ! docker network create \
+ "--subnet=${TEST_NET_PREFIX}.0/${NET_MASK}" "${TEST_NET}"; do
+ sleep 0.1
+ declare TEST_NET_PREFIX=$(new_net_prefix)
+done
+
+docker pull "${IMAGE_TAG}"
+
+# Create the DUT container and connect to network.
+DUT=$(docker create ${RUNTIME_ARG} --privileged --rm \
+ --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG})
+docker network connect "${CTRL_NET}" \
+ --ip "${CTRL_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \
+ || (docker kill ${DUT}; docker rm ${DUT}; false)
+docker network connect "${TEST_NET}" \
+ --ip "${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \
+ || (docker kill ${DUT}; docker rm ${DUT}; false)
+docker start "${DUT}"
+
+# Create the test bench container and connect to network.
+TESTBENCH=$(docker create --privileged --rm \
+ --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG})
+docker network connect "${CTRL_NET}" \
+ --ip "${CTRL_NET_PREFIX}${TESTBENCH_NET_SUFFIX}" "${TESTBENCH}" \
+ || (docker kill ${TESTBENCH}; docker rm ${TESTBENCH}; false)
+docker network connect "${TEST_NET}" \
+ --ip "${TEST_NET_PREFIX}${TESTBENCH_NET_SUFFIX}" "${TESTBENCH}" \
+ || (docker kill ${TESTBENCH}; docker rm ${TESTBENCH}; false)
+docker start "${TESTBENCH}"
+
+# Start the posix_server in the DUT.
+declare -r DOCKER_POSIX_SERVER_BINARY="/$(basename ${POSIX_SERVER_BINARY})"
+docker cp -L ${POSIX_SERVER_BINARY} "${DUT}:${DOCKER_POSIX_SERVER_BINARY}"
+
+docker exec -t "${DUT}" \
+ /bin/bash -c "${DOCKER_POSIX_SERVER_BINARY} \
+ --ip ${CTRL_NET_PREFIX}${DUT_NET_SUFFIX} \
+ --port ${CTRL_PORT}" &
+
+# Because the Linux kernel receives the SYN-ACK but didn't send the SYN it will
+# issue a RST. To prevent this IPtables can be used to filter those out.
+docker exec "${TESTBENCH}" \
+ iptables -A INPUT -i ${TEST_DEVICE} -j DROP
+
+# Wait for the DUT server to come up. Attempt to connect to it from the test
+# bench every 100 milliseconds until success.
+while ! docker exec "${TESTBENCH}" \
+ nc -zv "${CTRL_NET_PREFIX}${DUT_NET_SUFFIX}" "${CTRL_PORT}"; do
+ sleep 0.1
+done
+
+declare -r REMOTE_MAC=$(docker exec -t "${DUT}" ip link show \
+ "${TEST_DEVICE}" | tail -1 | cut -d' ' -f6)
+declare -r LOCAL_MAC=$(docker exec -t "${TESTBENCH}" ip link show \
+ "${TEST_DEVICE}" | tail -1 | cut -d' ' -f6)
+
+declare -r DOCKER_TESTBENCH_BINARY="/$(basename ${TESTBENCH_BINARY})"
+docker cp -L "${TESTBENCH_BINARY}" "${TESTBENCH}:${DOCKER_TESTBENCH_BINARY}"
+
+if [[ -z "${TSHARK-}" ]]; then
+ # Run tcpdump in the test bench unbuffered, without dns resolution, just on
+ # the interface with the test packets.
+ docker exec -t "${TESTBENCH}" \
+ tcpdump -S -vvv -U -n -i "${TEST_DEVICE}" net "${TEST_NET_PREFIX}/24" &
+else
+ # Run tshark in the test bench unbuffered, without dns resolution, just on the
+ # interface with the test packets.
+ docker exec -t "${TESTBENCH}" \
+ tshark -V -l -n -i "${TEST_DEVICE}" \
+ -o tcp.check_checksum:TRUE \
+ -o udp.check_checksum:TRUE \
+ host "${TEST_NET_PREFIX}${TESTBENCH_NET_SUFFIX}" &
+fi
+
+# tcpdump and tshark take time to startup
+sleep 3
+
+# Start a packetimpact test on the test bench. The packetimpact test sends and
+# receives packets and also sends POSIX socket commands to the posix_server to
+# be executed on the DUT.
+docker exec -t "${TESTBENCH}" \
+ /bin/bash -c "${DOCKER_TESTBENCH_BINARY} \
+ ${EXTRA_TEST_ARGS[@]-} \
+ --posix_server_ip=${CTRL_NET_PREFIX}${DUT_NET_SUFFIX} \
+ --posix_server_port=${CTRL_PORT} \
+ --remote_ipv4=${TEST_NET_PREFIX}${DUT_NET_SUFFIX} \
+ --local_ipv4=${TEST_NET_PREFIX}${TESTBENCH_NET_SUFFIX} \
+ --remote_mac=${REMOTE_MAC} \
+ --local_mac=${LOCAL_MAC} \
+ --device=${TEST_DEVICE}"
+
+echo PASS: No errors.
diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_recv_multicast_test.go
new file mode 100644
index 000000000..61fd17050
--- /dev/null
+++ b/test/packetimpact/tests/udp_recv_multicast_test.go
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp_recv_multicast_test
+
+import (
+ "net"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func TestUDPRecvMulticast(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(boundFD)
+ conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort})
+ defer conn.Close()
+ frame := conn.CreateFrame(&tb.UDP{}, &tb.Payload{Bytes: []byte("hello world")})
+ frame[1].(*tb.IPv4).DstAddr = tb.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4()))
+ conn.SendFrame(frame)
+ dut.Recv(boundFD, 100, 0)
+}
diff --git a/test/perf/BUILD b/test/perf/BUILD
index 346a28e16..471d8c2ab 100644
--- a/test/perf/BUILD
+++ b/test/perf/BUILD
@@ -30,6 +30,8 @@ syscall_test(
syscall_test(
size = "enormous",
+ shard_count = 10,
+ tags = ["nogotsan"],
test = "//test/perf/linux:getdents_benchmark",
)
@@ -40,6 +42,7 @@ syscall_test(
syscall_test(
size = "enormous",
+ tags = ["nogotsan"],
test = "//test/perf/linux:gettid_benchmark",
)
diff --git a/test/perf/linux/futex_benchmark.cc b/test/perf/linux/futex_benchmark.cc
index b349d50bf..241f39896 100644
--- a/test/perf/linux/futex_benchmark.cc
+++ b/test/perf/linux/futex_benchmark.cc
@@ -33,24 +33,24 @@ namespace testing {
namespace {
inline int FutexWait(std::atomic<int32_t>* v, int32_t val) {
- return syscall(SYS_futex, v, FUTEX_BITSET_MATCH_ANY, nullptr);
+ return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, val, nullptr);
}
-inline int FutexWaitRelativeTimeout(std::atomic<int32_t>* v, int32_t val,
- const struct timespec* reltime) {
- return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, reltime);
+inline int FutexWaitMonotonicTimeout(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* timeout) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, val, timeout);
}
-inline int FutexWaitAbsoluteTimeout(std::atomic<int32_t>* v, int32_t val,
- const struct timespec* abstime) {
- return syscall(SYS_futex, v, FUTEX_BITSET_MATCH_ANY, abstime);
+inline int FutexWaitMonotonicDeadline(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* deadline) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE, val, deadline,
+ nullptr, FUTEX_BITSET_MATCH_ANY);
}
-inline int FutexWaitBitsetAbsoluteTimeout(std::atomic<int32_t>* v, int32_t val,
- int32_t bits,
- const struct timespec* abstime) {
+inline int FutexWaitRealtimeDeadline(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* deadline) {
return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE | FUTEX_CLOCK_REALTIME,
- val, abstime, nullptr, bits);
+ val, deadline, nullptr, FUTEX_BITSET_MATCH_ANY);
}
inline int FutexWake(std::atomic<int32_t>* v, int32_t count) {
@@ -62,11 +62,11 @@ void BM_FutexWakeNop(benchmark::State& state) {
std::atomic<int32_t> v(0);
for (auto _ : state) {
- EXPECT_EQ(0, FutexWake(&v, 1));
+ TEST_PCHECK(FutexWake(&v, 1) == 0);
}
}
-BENCHMARK(BM_FutexWakeNop);
+BENCHMARK(BM_FutexWakeNop)->MinTime(5);
// This just uses FUTEX_WAIT on an address whose value has changed, i.e., the
// syscall won't wait.
@@ -74,43 +74,63 @@ void BM_FutexWaitNop(benchmark::State& state) {
std::atomic<int32_t> v(0);
for (auto _ : state) {
- EXPECT_EQ(-EAGAIN, FutexWait(&v, 1));
+ TEST_PCHECK(FutexWait(&v, 1) == -1 && errno == EAGAIN);
}
}
-BENCHMARK(BM_FutexWaitNop);
+BENCHMARK(BM_FutexWaitNop)->MinTime(5);
// This uses FUTEX_WAIT with a timeout on an address whose value never
// changes, such that it always times out. Timeout overhead can be estimated by
// timer overruns for short timeouts.
-void BM_FutexWaitTimeout(benchmark::State& state) {
+void BM_FutexWaitMonotonicTimeout(benchmark::State& state) {
const int timeout_ns = state.range(0);
std::atomic<int32_t> v(0);
auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns));
for (auto _ : state) {
- EXPECT_EQ(-ETIMEDOUT, FutexWaitRelativeTimeout(&v, 0, &ts));
+ TEST_PCHECK(FutexWaitMonotonicTimeout(&v, 0, &ts) == -1 &&
+ errno == ETIMEDOUT);
}
}
-BENCHMARK(BM_FutexWaitTimeout)
+BENCHMARK(BM_FutexWaitMonotonicTimeout)
+ ->MinTime(5)
+ ->UseRealTime()
->Arg(1)
->Arg(10)
->Arg(100)
->Arg(1000)
->Arg(10000);
-// This calls FUTEX_WAIT_BITSET with CLOCK_REALTIME.
-void BM_FutexWaitBitset(benchmark::State& state) {
+// This uses FUTEX_WAIT_BITSET with a deadline that is in the past. This allows
+// estimation of the overhead of setting up a timer for a deadline (as opposed
+// to a timeout as specified for FUTEX_WAIT).
+void BM_FutexWaitMonotonicDeadline(benchmark::State& state) {
std::atomic<int32_t> v(0);
- int timeout_ns = state.range(0);
- auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns));
+ struct timespec ts = {};
+
for (auto _ : state) {
- EXPECT_EQ(-ETIMEDOUT, FutexWaitBitsetAbsoluteTimeout(&v, 0, 1, &ts));
+ TEST_PCHECK(FutexWaitMonotonicDeadline(&v, 0, &ts) == -1 &&
+ errno == ETIMEDOUT);
}
}
-BENCHMARK(BM_FutexWaitBitset)->Range(0, 100000);
+BENCHMARK(BM_FutexWaitMonotonicDeadline)->MinTime(5);
+
+// This is equivalent to BM_FutexWaitMonotonicDeadline, but uses CLOCK_REALTIME
+// instead of CLOCK_MONOTONIC for the deadline.
+void BM_FutexWaitRealtimeDeadline(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+ struct timespec ts = {};
+
+ for (auto _ : state) {
+ TEST_PCHECK(FutexWaitRealtimeDeadline(&v, 0, &ts) == -1 &&
+ errno == ETIMEDOUT);
+ }
+}
+
+BENCHMARK(BM_FutexWaitRealtimeDeadline)->MinTime(5);
int64_t GetCurrentMonotonicTimeNanos() {
struct timespec ts;
@@ -130,11 +150,10 @@ void SpinNanos(int64_t delay_ns) {
// Each iteration of FutexRoundtripDelayed involves a thread sending a futex
// wakeup to another thread, which spins for delay_us and then sends a futex
-// wakeup back. The time per iteration is 2* (delay_us + kBeforeWakeDelayNs +
+// wakeup back. The time per iteration is 2 * (delay_us + kBeforeWakeDelayNs +
// futex/scheduling overhead).
void BM_FutexRoundtripDelayed(benchmark::State& state) {
const int delay_us = state.range(0);
-
const int64_t delay_ns = delay_us * 1000;
// Spin for an extra kBeforeWakeDelayNs before invoking FUTEX_WAKE to reduce
// the probability that the wakeup comes before the wait, preventing the wait
@@ -165,83 +184,14 @@ void BM_FutexRoundtripDelayed(benchmark::State& state) {
}
BENCHMARK(BM_FutexRoundtripDelayed)
+ ->MinTime(5)
+ ->UseRealTime()
->Arg(0)
->Arg(10)
->Arg(20)
->Arg(50)
->Arg(100);
-// FutexLock is a simple, dumb futex based lock implementation.
-// It will try to acquire the lock by atomically incrementing the
-// lock word. If it did not increment the lock from 0 to 1, someone
-// else has the lock, so it will FUTEX_WAIT until it is woken in
-// the unlock path.
-class FutexLock {
- public:
- FutexLock() : lock_word_(0) {}
-
- void lock(struct timespec* deadline) {
- int32_t val;
- while ((val = lock_word_.fetch_add(1, std::memory_order_acquire) + 1) !=
- 1) {
- // If we didn't get the lock by incrementing from 0 to 1,
- // do a FUTEX_WAIT with the desired current value set to
- // val. If val is no longer what the atomic increment returned,
- // someone might have set it to 0 so we can try to acquire
- // again.
- int ret = FutexWaitAbsoluteTimeout(&lock_word_, val, deadline);
- if (ret == 0 || ret == -EWOULDBLOCK || ret == -EINTR) {
- continue;
- } else {
- FAIL() << "unexpected FUTEX_WAIT return: " << ret;
- }
- }
- }
-
- void unlock() {
- // Store 0 into the lock word and wake one waiter. We intentionally
- // ignore the return value of the FUTEX_WAKE here, since there may be
- // no waiters to wake anyway.
- lock_word_.store(0, std::memory_order_release);
- (void)FutexWake(&lock_word_, 1);
- }
-
- private:
- std::atomic<int32_t> lock_word_;
-};
-
-FutexLock* test_lock; // Used below.
-
-void FutexContend(benchmark::State& state, int thread_index,
- struct timespec* deadline) {
- int counter = 0;
- if (thread_index == 0) {
- test_lock = new FutexLock();
- }
- for (auto _ : state) {
- test_lock->lock(deadline);
- counter++;
- test_lock->unlock();
- }
- if (thread_index == 0) {
- delete test_lock;
- }
- state.SetItemsProcessed(state.iterations());
-}
-
-void BM_FutexContend(benchmark::State& state) {
- FutexContend(state, state.thread_index, nullptr);
-}
-
-BENCHMARK(BM_FutexContend)->ThreadRange(1, 1024)->UseRealTime();
-
-void BM_FutexDeadlineContend(benchmark::State& state) {
- auto deadline = absl::ToTimespec(absl::Now() + absl::Minutes(10));
- FutexContend(state, state.thread_index, &deadline);
-}
-
-BENCHMARK(BM_FutexDeadlineContend)->ThreadRange(1, 1024)->UseRealTime();
-
} // namespace
} // namespace testing
diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc
index afc599ad2..d8e81fa8c 100644
--- a/test/perf/linux/getdents_benchmark.cc
+++ b/test/perf/linux/getdents_benchmark.cc
@@ -38,7 +38,7 @@ namespace testing {
namespace {
-constexpr int kBufferSize = 16384;
+constexpr int kBufferSize = 65536;
PosixErrorOr<TempPath> CreateDirectory(int count,
std::vector<std::string>* files) {
diff --git a/test/perf/linux/signal_benchmark.cc b/test/perf/linux/signal_benchmark.cc
index a6928df58..cec679191 100644
--- a/test/perf/linux/signal_benchmark.cc
+++ b/test/perf/linux/signal_benchmark.cc
@@ -43,11 +43,13 @@ void BM_FaultSignalFixup(benchmark::State& state) {
// Fault, fault, fault.
for (auto _ : state) {
- register volatile unsigned int* ptr asm("rax");
-
// Trigger the segfault.
- ptr = nullptr;
- *ptr = 0;
+ asm volatile(
+ "movq $0, %%rax\n"
+ "movq $0x77777777, (%%rax)\n"
+ :
+ :
+ : "rax");
}
}
diff --git a/test/root/BUILD b/test/root/BUILD
index 23ce2a70f..ddc9b4955 100644
--- a/test/root/BUILD
+++ b/test/root/BUILD
@@ -16,6 +16,7 @@ go_test(
"crictl_test.go",
"main_test.go",
"oom_score_adj_test.go",
+ "runsc_test.go",
],
data = [
"//runsc",
@@ -37,7 +38,9 @@ go_test(
"//runsc/specutils",
"//runsc/testutil",
"//test/root/testdata",
+ "@com_github_cenkalti_backoff//:go_default_library",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
"@com_github_syndtr_gocapability//capability:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go
index 4038661cb..679342def 100644
--- a/test/root/cgroup_test.go
+++ b/test/root/cgroup_test.go
@@ -53,7 +53,7 @@ func verifyPid(pid int, path string) error {
if scanner.Err() != nil {
return scanner.Err()
}
- return fmt.Errorf("got: %s, want: %d", gots, pid)
+ return fmt.Errorf("got: %v, want: %d", gots, pid)
}
// TestCgroup sets cgroup options and checks that cgroup was properly configured.
@@ -106,7 +106,7 @@ func TestMemCGroup(t *testing.T) {
time.Sleep(100 * time.Millisecond)
}
- t.Fatalf("%vMB is less than %vMB: %v", memUsage>>20, allocMemSize>>20)
+ t.Fatalf("%vMB is less than %vMB", memUsage>>20, allocMemSize>>20)
}
// TestCgroup sets cgroup options and checks that cgroup was properly configured.
diff --git a/test/root/runsc_test.go b/test/root/runsc_test.go
new file mode 100644
index 000000000..90373e2db
--- /dev/null
+++ b/test/root/runsc_test.go
@@ -0,0 +1,151 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package root
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/cenkalti/backoff"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/runsc/specutils"
+ "gvisor.dev/gvisor/runsc/testutil"
+)
+
+// TestDoKill checks that when "runsc do..." is killed, the sandbox process is
+// also terminated. This ensures that parent death signal is propagate to the
+// sandbox process correctly.
+func TestDoKill(t *testing.T) {
+ // Make the sandbox process be reparented here when it's killed, so we can
+ // wait for it.
+ if err := unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0); err != nil {
+ t.Fatalf("prctl(PR_SET_CHILD_SUBREAPER): %v", err)
+ }
+
+ cmd := exec.Command(specutils.ExePath, "do", "sleep", "10000")
+ buf := &bytes.Buffer{}
+ cmd.Stdout = buf
+ cmd.Stderr = buf
+ cmd.Start()
+
+ var pid int
+ findSandbox := func() error {
+ var err error
+ pid, err = sandboxPid(cmd.Process.Pid)
+ if err != nil {
+ return &backoff.PermanentError{Err: err}
+ }
+ if pid == 0 {
+ return fmt.Errorf("sandbox process not found")
+ }
+ return nil
+ }
+ if err := testutil.Poll(findSandbox, 10*time.Second); err != nil {
+ t.Fatalf("failed to find sandbox: %v", err)
+ }
+ t.Logf("Found sandbox, pid: %d", pid)
+
+ if err := cmd.Process.Kill(); err != nil {
+ t.Fatalf("failed to kill run process: %v", err)
+ }
+ cmd.Wait()
+ t.Logf("Parent process killed (%d). Output: %s", cmd.Process.Pid, buf.String())
+
+ ch := make(chan struct{})
+ go func() {
+ defer func() { ch <- struct{}{} }()
+ t.Logf("Waiting for sandbox process (%d) termination", pid)
+ if _, err := unix.Wait4(pid, nil, 0, nil); err != nil {
+ t.Errorf("error waiting for sandbox process (%d): %v", pid, err)
+ }
+ }()
+ select {
+ case <-ch:
+ // Done
+ case <-time.After(5 * time.Second):
+ t.Fatalf("timeout waiting for sandbox process (%d) to exit", pid)
+ }
+}
+
+// sandboxPid looks for the sandbox process inside the process tree starting
+// from "pid". It returns 0 and no error if no sandbox process is found. It
+// returns error if anything failed.
+func sandboxPid(pid int) (int, error) {
+ cmd := exec.Command("pgrep", "-P", strconv.Itoa(pid))
+ buf := &bytes.Buffer{}
+ cmd.Stdout = buf
+ if err := cmd.Start(); err != nil {
+ return 0, err
+ }
+ ps, err := cmd.Process.Wait()
+ if err != nil {
+ return 0, err
+ }
+ if ps.ExitCode() == 1 {
+ // pgrep returns 1 when no process is found.
+ return 0, nil
+ }
+
+ var children []int
+ for _, line := range strings.Split(buf.String(), "\n") {
+ if len(line) == 0 {
+ continue
+ }
+ child, err := strconv.Atoi(line)
+ if err != nil {
+ return 0, err
+ }
+
+ cmdline, err := ioutil.ReadFile(filepath.Join("/proc", line, "cmdline"))
+ if err != nil {
+ if os.IsNotExist(err) {
+ // Raced with process exit.
+ continue
+ }
+ return 0, err
+ }
+ args := strings.SplitN(string(cmdline), "\x00", 2)
+ if len(args) == 0 {
+ return 0, fmt.Errorf("malformed cmdline file: %q", cmdline)
+ }
+ // The sandbox process has the first argument set to "runsc-sandbox".
+ if args[0] == "runsc-sandbox" {
+ return child, nil
+ }
+
+ children = append(children, child)
+ }
+
+ // Sandbox process wasn't found, try another level down.
+ for _, pid := range children {
+ sand, err := sandboxPid(pid)
+ if err != nil {
+ return 0, err
+ }
+ if sand != 0 {
+ return sand, nil
+ }
+ // Not found, continue the search.
+ }
+ return 0, nil
+}
diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go
index f96e2415e..869169ad5 100644
--- a/test/runner/gtest/gtest.go
+++ b/test/runner/gtest/gtest.go
@@ -66,13 +66,12 @@ func (tc TestCase) Args() []string {
}
if tc.benchmark {
return []string{
- fmt.Sprintf("%s=^$", filterTestFlag),
fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name),
+ fmt.Sprintf("%s=", filterTestFlag),
}
}
return []string{
- fmt.Sprintf("%s=^%s$", filterTestFlag, tc.FullName()),
- fmt.Sprintf("%s=^$", filterBenchmarkFlag),
+ fmt.Sprintf("%s=%s", filterTestFlag, tc.FullName()),
}
}
@@ -147,6 +146,8 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes
return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr)
}
+ out = []byte(strings.Trim(string(out), "\n"))
+
// Parse benchmark output.
for _, line := range strings.Split(string(out), "\n") {
// Strip comments.
diff --git a/test/runner/runner.go b/test/runner/runner.go
index a78ef38e0..0d3742f71 100644
--- a/test/runner/runner.go
+++ b/test/runner/runner.go
@@ -300,6 +300,7 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
// Test spec comes with pre-defined mounts that we don't want. Reset it.
spec.Mounts = nil
+ testTmpDir := "/tmp"
if *useTmpfs {
// Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on
// features only available in gVisor's internal tmpfs may fail.
@@ -325,11 +326,19 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
t.Fatalf("could not chmod temp dir: %v", err)
}
- spec.Mounts = append(spec.Mounts, specs.Mount{
- Type: "bind",
- Destination: "/tmp",
- Source: tmpDir,
- })
+ // "/tmp" is not replaced with a tmpfs mount inside the sandbox
+ // when it's not empty. This ensures that testTmpDir uses gofer
+ // in exclusive mode.
+ testTmpDir = tmpDir
+ if *fileAccess == "shared" {
+ // All external mounts except the root mount are shared.
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Type: "bind",
+ Destination: "/tmp",
+ Source: tmpDir,
+ })
+ testTmpDir = "/tmp"
+ }
}
// Set environment variables that indicate we are
@@ -349,12 +358,8 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
// Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to
// be backed by tmpfs.
- for i, kv := range env {
- if strings.HasPrefix(kv, "TEST_TMPDIR=") {
- env[i] = "TEST_TMPDIR=/tmp"
- break
- }
- }
+ env = filterEnv(env, []string{"TEST_TMPDIR"})
+ env = append(env, fmt.Sprintf("TEST_TMPDIR=%s", testTmpDir))
spec.Process.Env = env
diff --git a/test/runtimes/blacklist_test.go b/test/runtimes/blacklist_test.go
index 52f49b984..0ff69ab18 100644
--- a/test/runtimes/blacklist_test.go
+++ b/test/runtimes/blacklist_test.go
@@ -32,6 +32,6 @@ func TestBlacklists(t *testing.T) {
t.Fatalf("error parsing blacklist: %v", err)
}
if *blacklistFile != "" && len(bl) == 0 {
- t.Errorf("got empty blacklist for file %q", blacklistFile)
+ t.Errorf("got empty blacklist for file %q", *blacklistFile)
}
}
diff --git a/test/runtimes/runner.go b/test/runtimes/runner.go
index ddb890dbc..3c98f4570 100644
--- a/test/runtimes/runner.go
+++ b/test/runtimes/runner.go
@@ -114,7 +114,7 @@ func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.Int
F: func(t *testing.T) {
// Is the test blacklisted?
if _, ok := blacklist[tc]; ok {
- t.Skip("SKIP: blacklisted test %q", tc)
+ t.Skipf("SKIP: blacklisted test %q", tc)
}
var (
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index 3518e862d..9800a0cdf 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -318,10 +318,14 @@ syscall_test(
test = "//test/syscalls/linux:proc_test",
)
-syscall_test(test = "//test/syscalls/linux:proc_pid_uid_gid_map_test")
-
syscall_test(test = "//test/syscalls/linux:proc_net_test")
+syscall_test(test = "//test/syscalls/linux:proc_pid_oomscore_test")
+
+syscall_test(test = "//test/syscalls/linux:proc_pid_smaps_test")
+
+syscall_test(test = "//test/syscalls/linux:proc_pid_uid_gid_map_test")
+
syscall_test(
size = "medium",
test = "//test/syscalls/linux:pselect_test",
@@ -680,6 +684,11 @@ syscall_test(
syscall_test(test = "//test/syscalls/linux:tuntap_test")
+syscall_test(
+ add_hostinet = True,
+ test = "//test/syscalls/linux:tuntap_hostinet_test",
+)
+
syscall_test(test = "//test/syscalls/linux:udp_bind_test")
syscall_test(
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 704bae17b..d9095c95f 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -138,7 +138,6 @@ cc_library(
hdrs = ["socket_netlink_route_util.h"],
deps = [
":socket_netlink_util",
- "@com_google_absl//absl/types:optional",
],
)
@@ -167,11 +166,6 @@ cc_library(
)
cc_library(
- name = "temp_umask",
- hdrs = ["temp_umask.h"],
-)
-
-cc_library(
name = "unix_domain_socket_test_util",
testonly = 1,
srcs = ["unix_domain_socket_test_util.cc"],
@@ -608,7 +602,10 @@ cc_binary(
cc_binary(
name = "exceptions_test",
testonly = 1,
- srcs = ["exceptions.cc"],
+ srcs = select_arch(
+ amd64 = ["exceptions.cc"],
+ arm64 = [],
+ ),
linkstatic = 1,
deps = [
gtest,
@@ -665,10 +662,7 @@ cc_binary(
cc_binary(
name = "exec_binary_test",
testonly = 1,
- srcs = select_arch(
- amd64 = ["exec_binary.cc"],
- arm64 = [],
- ),
+ srcs = ["exec_binary.cc"],
linkstatic = 1,
deps = [
"//test/util:cleanup",
@@ -1140,11 +1134,11 @@ cc_binary(
srcs = ["mkdir.cc"],
linkstatic = 1,
deps = [
- ":temp_umask",
"//test/util:capability_util",
"//test/util:fs_util",
gtest,
"//test/util:temp_path",
+ "//test/util:temp_umask",
"//test/util:test_main",
"//test/util:test_util",
],
@@ -1299,12 +1293,12 @@ cc_binary(
srcs = ["open_create.cc"],
linkstatic = 1,
deps = [
- ":temp_umask",
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
gtest,
"//test/util:temp_path",
+ "//test/util:temp_umask",
"//test/util:test_main",
"//test/util:test_util",
],
@@ -1475,7 +1469,10 @@ cc_binary(
cc_binary(
name = "arch_prctl_test",
testonly = 1,
- srcs = ["arch_prctl.cc"],
+ srcs = select_arch(
+ amd64 = ["arch_prctl.cc"],
+ arm64 = [],
+ ),
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
@@ -1631,6 +1628,19 @@ cc_binary(
)
cc_binary(
+ name = "proc_pid_oomscore_test",
+ testonly = 1,
+ srcs = ["proc_pid_oomscore.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:fs_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_binary(
name = "proc_pid_smaps_test",
testonly = 1,
srcs = ["proc_pid_smaps.cc"],
@@ -2012,6 +2022,8 @@ cc_binary(
"//test/util:file_descriptor",
"@com_google_absl//absl/strings",
gtest,
+ ":ip_socket_test_util",
+ ":unix_domain_socket_test_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
@@ -2788,13 +2800,13 @@ cc_binary(
srcs = ["socket_netlink_route.cc"],
linkstatic = 1,
deps = [
+ ":socket_netlink_route_util",
":socket_netlink_util",
":socket_test_util",
"//test/util:capability_util",
"//test/util:cleanup",
"//test/util:file_descriptor",
"@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/types:optional",
gtest,
"//test/util:test_main",
"//test/util:test_util",
@@ -3460,6 +3472,18 @@ cc_binary(
],
)
+cc_binary(
+ name = "tuntap_hostinet_test",
+ testonly = 1,
+ srcs = ["tuntap_hostinet.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
cc_library(
name = "udp_socket_test_cases",
testonly = 1,
@@ -3678,11 +3702,10 @@ cc_binary(
":socket_test_util",
gtest,
"//test/util:capability_util",
- "//test/util:memory_util",
+ "//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/synchronization",
],
)
diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc
index a33daff17..806d5729e 100644
--- a/test/syscalls/linux/aio.cc
+++ b/test/syscalls/linux/aio.cc
@@ -89,6 +89,7 @@ class AIOTest : public FileTest {
FileTest::TearDown();
if (ctx_ != 0) {
ASSERT_THAT(DestroyContext(), SyscallSucceeds());
+ ctx_ = 0;
}
}
@@ -188,14 +189,19 @@ TEST_F(AIOTest, BadWrite) {
}
TEST_F(AIOTest, ExitWithPendingIo) {
- // Setup a context that is 5 entries deep.
- ASSERT_THAT(SetupContext(5), SyscallSucceeds());
+ // Setup a context that is 100 entries deep.
+ ASSERT_THAT(SetupContext(100), SyscallSucceeds());
struct iocb cb = CreateCallback();
struct iocb* cbs[] = {&cb};
// Submit a request but don't complete it to make it pending.
- EXPECT_THAT(Submit(1, cbs), SyscallSucceeds());
+ for (int i = 0; i < 100; ++i) {
+ EXPECT_THAT(Submit(1, cbs), SyscallSucceeds());
+ }
+
+ ASSERT_THAT(DestroyContext(), SyscallSucceeds());
+ ctx_ = 0;
}
int Submitter(void* arg) {
diff --git a/test/syscalls/linux/bad.cc b/test/syscalls/linux/bad.cc
index adfb149df..a26fc6af3 100644
--- a/test/syscalls/linux/bad.cc
+++ b/test/syscalls/linux/bad.cc
@@ -28,7 +28,7 @@ namespace {
constexpr uint32_t kNotImplementedSyscall = SYS_get_kernel_syms;
#elif __aarch64__
// Use the last of arch_specific_syscalls which are not implemented on arm64.
-constexpr uint32_t kNotImplementedSyscall = SYS_arch_specific_syscall + 15;
+constexpr uint32_t kNotImplementedSyscall = __NR_arch_specific_syscall + 15;
#endif
TEST(BadSyscallTest, NotImplemented) {
diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc
index 4e473268c..4dd302eed 100644
--- a/test/syscalls/linux/dev.cc
+++ b/test/syscalls/linux/dev.cc
@@ -153,13 +153,6 @@ TEST(DevTest, TTYExists) {
EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
}
-TEST(DevTest, NetTunExists) {
- struct stat statbuf = {};
- ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallSucceeds());
- // Check that it's a character device with rw-rw-rw- permissions.
- EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
-}
-
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc
index a4f8f3cec..f57d38dc7 100644
--- a/test/syscalls/linux/epoll.cc
+++ b/test/syscalls/linux/epoll.cc
@@ -56,10 +56,6 @@ TEST(EpollTest, AllWritable) {
struct epoll_event result[kFDsPerEpoll];
ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1),
SyscallSucceedsWithValue(kFDsPerEpoll));
- // TODO(edahlgren): Why do some tests check epoll_event::data, and others
- // don't? Does Linux actually guarantee that, in any of these test cases,
- // epoll_wait will necessarily write out the epoll_events in the order that
- // they were registered?
for (int i = 0; i < kFDsPerEpoll; i++) {
ASSERT_EQ(result[i].events, EPOLLOUT);
}
diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc
index 07bd527e6..12c9b05ca 100644
--- a/test/syscalls/linux/exec.cc
+++ b/test/syscalls/linux/exec.cc
@@ -812,26 +812,28 @@ void ExecFromThread() {
bool ValidateProcCmdlineVsArgv(const int argc, const char* const* argv) {
auto contents_or = GetContents("/proc/self/cmdline");
if (!contents_or.ok()) {
- std::cerr << "Unable to get /proc/self/cmdline: " << contents_or.error();
+ std::cerr << "Unable to get /proc/self/cmdline: " << contents_or.error()
+ << std::endl;
return false;
}
auto contents = contents_or.ValueOrDie();
if (contents.back() != '\0') {
- std::cerr << "Non-null terminated /proc/self/cmdline!";
+ std::cerr << "Non-null terminated /proc/self/cmdline!" << std::endl;
return false;
}
contents.pop_back();
std::vector<std::string> procfs_cmdline = absl::StrSplit(contents, '\0');
if (static_cast<int>(procfs_cmdline.size()) != argc) {
- std::cerr << "argc = " << argc << " != " << procfs_cmdline.size();
+ std::cerr << "argc = " << argc << " != " << procfs_cmdline.size()
+ << std::endl;
return false;
}
for (int i = 0; i < argc; ++i) {
if (procfs_cmdline[i] != argv[i]) {
std::cerr << "Procfs command line argument " << i << " mismatch "
- << procfs_cmdline[i] << " != " << argv[i];
+ << procfs_cmdline[i] << " != " << argv[i] << std::endl;
return false;
}
}
diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc
index 736452b0c..1a9f203b9 100644
--- a/test/syscalls/linux/exec_binary.cc
+++ b/test/syscalls/linux/exec_binary.cc
@@ -48,10 +48,17 @@ namespace {
using ::testing::AnyOf;
using ::testing::Eq;
-#ifndef __x86_64__
+#if !defined(__x86_64__) && !defined(__aarch64__)
// The assembly stub and ELF internal details must be ported to other arches.
-#error "Test only supported on x86-64"
-#endif // __x86_64__
+#error "Test only supported on x86-64/arm64"
+#endif // __x86_64__ || __aarch64__
+
+#if defined(__x86_64__)
+#define EM_TYPE EM_X86_64
+#define IP_REG(p) ((p).rip)
+#define RAX_REG(p) ((p).rax)
+#define RDI_REG(p) ((p).rdi)
+#define RETURN_REG(p) ((p).rax)
// amd64 stub that calls PTRACE_TRACEME and sends itself SIGSTOP.
const char kPtraceCode[] = {
@@ -139,6 +146,76 @@ const char kPtraceCode[] = {
// Size of a syscall instruction.
constexpr int kSyscallSize = 2;
+#elif defined(__aarch64__)
+#define EM_TYPE EM_AARCH64
+#define IP_REG(p) ((p).pc)
+#define RAX_REG(p) ((p).regs[8])
+#define RDI_REG(p) ((p).regs[0])
+#define RETURN_REG(p) ((p).regs[0])
+
+const char kPtraceCode[] = {
+ // MOVD $117, R8 /* ptrace */
+ '\xa8',
+ '\x0e',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R0 /* PTRACE_TRACEME */
+ '\x00',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R1 /* pid */
+ '\x01',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R2 /* addr */
+ '\x02',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R3 /* data */
+ '\x03',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // SVC
+ '\x01',
+ '\x00',
+ '\x00',
+ '\xd4',
+ // MOVD $172, R8 /* getpid */
+ '\x88',
+ '\x15',
+ '\x80',
+ '\xd2',
+ // SVC
+ '\x01',
+ '\x00',
+ '\x00',
+ '\xd4',
+ // MOVD $129, R8 /* kill, R0=pid */
+ '\x28',
+ '\x10',
+ '\x80',
+ '\xd2',
+ // MOVD $19, R1 /* SIGSTOP */
+ '\x61',
+ '\x02',
+ '\x80',
+ '\xd2',
+ // SVC
+ '\x01',
+ '\x00',
+ '\x00',
+ '\xd4',
+};
+// Size of a syscall instruction.
+constexpr int kSyscallSize = 4;
+#else
+#error "Unknown architecture"
+#endif
+
// This test suite tests executable loading in the kernel (ELF and interpreter
// scripts).
@@ -281,7 +358,7 @@ ElfBinary<64> StandardElf() {
elf.header.e_ident[EI_DATA] = ELFDATA2LSB;
elf.header.e_ident[EI_VERSION] = EV_CURRENT;
elf.header.e_type = ET_EXEC;
- elf.header.e_machine = EM_X86_64;
+ elf.header.e_machine = EM_TYPE;
elf.header.e_version = EV_CURRENT;
elf.header.e_phoff = sizeof(elf.header);
elf.header.e_phentsize = sizeof(decltype(elf)::ElfPhdr);
@@ -327,9 +404,15 @@ TEST(ElfTest, Execute) {
ASSERT_NO_ERRNO(WaitStopped(child));
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
- // RIP is just beyond the final syscall instruction.
- EXPECT_EQ(regs.rip, elf.header.e_entry + sizeof(kPtraceCode));
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
+ // RIP/PC is just beyond the final syscall instruction.
+ EXPECT_EQ(IP_REG(regs), elf.header.e_entry + sizeof(kPtraceCode));
EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
{0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
@@ -718,9 +801,16 @@ TEST(ElfTest, PIE) {
// RIP tells us which page the first segment was loaded into.
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
- const uint64_t load_addr = regs.rip & ~(kPageSize - 1);
+ const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1);
EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
// text page.
@@ -787,9 +877,15 @@ TEST(ElfTest, PIENonZeroStart) {
// RIP tells us which page the first segment was loaded into.
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
- const uint64_t load_addr = regs.rip & ~(kPageSize - 1);
+ const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1);
// The ELF is loaded at an arbitrary address, not the first PT_LOAD vaddr.
//
@@ -910,9 +1006,15 @@ TEST(ElfTest, ELFInterpreter) {
// RIP tells us which page the first segment of the interpreter was loaded
// into.
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
- const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1);
+ const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1);
EXPECT_THAT(
child, ContainsMappings(std::vector<ProcMapsEntry>({
@@ -1084,9 +1186,15 @@ TEST(ElfTest, ELFInterpreterRelative) {
// RIP tells us which page the first segment of the interpreter was loaded
// into.
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
- const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1);
+ const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1);
EXPECT_THAT(
child, ContainsMappings(std::vector<ProcMapsEntry>({
@@ -1480,14 +1588,21 @@ TEST(ExecveTest, BrkAfterBinary) {
ASSERT_NO_ERRNO(WaitStopped(child));
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
// RIP is just beyond the final syscall instruction. Rewind to execute a brk
// syscall.
- regs.rip -= kSyscallSize;
- regs.rax = __NR_brk;
- regs.rdi = 0;
- ASSERT_THAT(ptrace(PTRACE_SETREGS, child, 0, &regs), SyscallSucceeds());
+ IP_REG(regs) -= kSyscallSize;
+ RAX_REG(regs) = __NR_brk;
+ RDI_REG(regs) = 0;
+ ASSERT_THAT(ptrace(PTRACE_SETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
// Resume the child, waiting for syscall entry.
ASSERT_THAT(ptrace(PTRACE_SYSCALL, child, 0, 0), SyscallSucceeds());
@@ -1504,7 +1619,12 @@ TEST(ExecveTest, BrkAfterBinary) {
ASSERT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP)
<< "status = " << status;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
// brk is after the text page.
//
@@ -1512,7 +1632,7 @@ TEST(ExecveTest, BrkAfterBinary) {
// address will be, but it is always beyond the final page in the binary.
// i.e., it does not start immediately after memsz in the middle of a page.
// Userspace may expect to use that space.
- EXPECT_GE(regs.rax, 0x41000);
+ EXPECT_GE(RETURN_REG(regs), 0x41000);
}
} // namespace
diff --git a/test/syscalls/linux/file_base.h b/test/syscalls/linux/file_base.h
index 6f80bc97c..fb418e052 100644
--- a/test/syscalls/linux/file_base.h
+++ b/test/syscalls/linux/file_base.h
@@ -52,17 +52,6 @@ class FileTest : public ::testing::Test {
test_file_fd_ = ASSERT_NO_ERRNO_AND_VALUE(
Open(test_file_name_, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR));
- // FIXME(edahlgren): enable when mknod syscall is supported.
- // test_fifo_name_ = NewTempAbsPath();
- // ASSERT_THAT(mknod(test_fifo_name_.c_str()), S_IFIFO|0644, 0,
- // SyscallSucceeds());
- // ASSERT_THAT(test_fifo_[1] = open(test_fifo_name_.c_str(),
- // O_WRONLY),
- // SyscallSucceeds());
- // ASSERT_THAT(test_fifo_[0] = open(test_fifo_name_.c_str(),
- // O_RDONLY),
- // SyscallSucceeds());
-
ASSERT_THAT(pipe(test_pipe_), SyscallSucceeds());
ASSERT_THAT(fcntl(test_pipe_[0], F_SETFL, O_NONBLOCK), SyscallSucceeds());
}
@@ -96,18 +85,12 @@ class FileTest : public ::testing::Test {
CloseFile();
UnlinkFile();
ClosePipes();
-
- // FIXME(edahlgren): enable when mknod syscall is supported.
- // close(test_fifo_[0]);
- // close(test_fifo_[1]);
- // unlink(test_fifo_name_.c_str());
}
+ protected:
std::string test_file_name_;
- std::string test_fifo_name_;
FileDescriptor test_file_fd_;
- int test_fifo_[2];
int test_pipe_[2];
};
diff --git a/test/syscalls/linux/fork.cc b/test/syscalls/linux/fork.cc
index ff8bdfeb0..853f6231a 100644
--- a/test/syscalls/linux/fork.cc
+++ b/test/syscalls/linux/fork.cc
@@ -431,7 +431,6 @@ TEST(CloneTest, NewUserNamespacePermitsAllOtherNamespaces) {
<< "status = " << status;
}
-#ifdef __x86_64__
// Clone with CLONE_SETTLS and a non-canonical TLS address is rejected.
TEST(CloneTest, NonCanonicalTLS) {
constexpr uintptr_t kNonCanonical = 1ull << 48;
@@ -440,11 +439,25 @@ TEST(CloneTest, NonCanonicalTLS) {
// on this.
char stack;
+ // The raw system call interface on x86-64 is:
+ // long clone(unsigned long flags, void *stack,
+ // int *parent_tid, int *child_tid,
+ // unsigned long tls);
+ //
+ // While on arm64, the order of the last two arguments is reversed:
+ // long clone(unsigned long flags, void *stack,
+ // int *parent_tid, unsigned long tls,
+ // int *child_tid);
+#if defined(__x86_64__)
EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr,
nullptr, kNonCanonical),
SyscallFailsWithErrno(EPERM));
-}
+#elif defined(__aarch64__)
+ EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr,
+ kNonCanonical, nullptr),
+ SyscallFailsWithErrno(EPERM));
#endif
+}
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/getrandom.cc b/test/syscalls/linux/getrandom.cc
index f97f60029..f87cdd7a1 100644
--- a/test/syscalls/linux/getrandom.cc
+++ b/test/syscalls/linux/getrandom.cc
@@ -29,6 +29,8 @@ namespace {
#define SYS_getrandom 318
#elif defined(__i386__)
#define SYS_getrandom 355
+#elif defined(__aarch64__)
+#define SYS_getrandom 278
#else
#error "Unknown architecture"
#endif
diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc
index bba022a41..98d07ae85 100644
--- a/test/syscalls/linux/ip_socket_test_util.cc
+++ b/test/syscalls/linux/ip_socket_test_util.cc
@@ -16,7 +16,6 @@
#include <net/if.h>
#include <netinet/in.h>
-#include <sys/ioctl.h>
#include <sys/socket.h>
#include <cstring>
@@ -35,12 +34,11 @@ uint16_t PortFromInetSockaddr(const struct sockaddr* addr) {
}
PosixErrorOr<int> InterfaceIndex(std::string name) {
- // TODO(igudger): Consider using netlink.
- ifreq req = {};
- memcpy(req.ifr_name, name.c_str(), name.size());
- ASSIGN_OR_RETURN_ERRNO(auto sock, Socket(AF_INET, SOCK_DGRAM, 0));
- RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(sock.get(), SIOCGIFINDEX, &req));
- return req.ifr_ifindex;
+ int index = if_nametoindex(name.c_str());
+ if (index) {
+ return index;
+ }
+ return PosixError(errno);
}
namespace {
@@ -177,17 +175,17 @@ SocketKind IPv6TCPUnboundSocket(int type) {
PosixError IfAddrHelper::Load() {
Release();
RETURN_ERROR_IF_SYSCALL_FAIL(getifaddrs(&ifaddr_));
- return PosixError(0);
+ return NoError();
}
void IfAddrHelper::Release() {
if (ifaddr_) {
freeifaddrs(ifaddr_);
+ ifaddr_ = nullptr;
}
- ifaddr_ = nullptr;
}
-std::vector<std::string> IfAddrHelper::InterfaceList(int family) {
+std::vector<std::string> IfAddrHelper::InterfaceList(int family) const {
std::vector<std::string> names;
for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) {
if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) {
@@ -198,7 +196,7 @@ std::vector<std::string> IfAddrHelper::InterfaceList(int family) {
return names;
}
-sockaddr* IfAddrHelper::GetAddr(int family, std::string name) {
+const sockaddr* IfAddrHelper::GetAddr(int family, std::string name) const {
for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) {
if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) {
continue;
@@ -210,7 +208,7 @@ sockaddr* IfAddrHelper::GetAddr(int family, std::string name) {
return nullptr;
}
-PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) {
+PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) const {
return InterfaceIndex(name);
}
diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h
index 39fd6709d..9c3859fcd 100644
--- a/test/syscalls/linux/ip_socket_test_util.h
+++ b/test/syscalls/linux/ip_socket_test_util.h
@@ -110,10 +110,10 @@ class IfAddrHelper {
PosixError Load();
void Release();
- std::vector<std::string> InterfaceList(int family);
+ std::vector<std::string> InterfaceList(int family) const;
- struct sockaddr* GetAddr(int family, std::string name);
- PosixErrorOr<int> GetIndex(std::string name);
+ const sockaddr* GetAddr(int family, std::string name) const;
+ PosixErrorOr<int> GetIndex(std::string name) const;
private:
struct ifaddrs* ifaddr_;
diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc
index 8b48f0804..dd981a278 100644
--- a/test/syscalls/linux/itimer.cc
+++ b/test/syscalls/linux/itimer.cc
@@ -246,7 +246,7 @@ int TestSIGPROFFairness(absl::Duration sleep) {
// The number of samples on the main thread should be very low as it did
// nothing.
- TEST_CHECK(result.main_thread_samples < 60);
+ TEST_CHECK(result.main_thread_samples < 80);
// Both workers should get roughly equal number of samples.
TEST_CHECK(result.worker_samples.size() == 2);
diff --git a/test/syscalls/linux/lseek.cc b/test/syscalls/linux/lseek.cc
index a8af8e545..6ce1e6cc3 100644
--- a/test/syscalls/linux/lseek.cc
+++ b/test/syscalls/linux/lseek.cc
@@ -53,7 +53,7 @@ TEST(LseekTest, NegativeOffset) {
// A 32-bit off_t is not large enough to represent an offset larger than
// maximum file size on standard file systems, so it isn't possible to cause
// overflow.
-#ifdef __x86_64__
+#if defined(__x86_64__) || defined(__aarch64__)
TEST(LseekTest, Overflow) {
// HA! Classic Linux. We really should have an EOVERFLOW
// here, since we're seeking to something that cannot be
diff --git a/test/syscalls/linux/memfd.cc b/test/syscalls/linux/memfd.cc
index e57b49a4a..f8b7f7938 100644
--- a/test/syscalls/linux/memfd.cc
+++ b/test/syscalls/linux/memfd.cc
@@ -16,6 +16,7 @@
#include <fcntl.h>
#include <linux/magic.h>
#include <linux/memfd.h>
+#include <linux/unistd.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/statfs.h>
diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc
index cf138d328..4036a9275 100644
--- a/test/syscalls/linux/mkdir.cc
+++ b/test/syscalls/linux/mkdir.cc
@@ -18,10 +18,10 @@
#include <unistd.h>
#include "gtest/gtest.h"
-#include "test/syscalls/linux/temp_umask.h"
#include "test/util/capability_util.h"
#include "test/util/fs_util.h"
#include "test/util/temp_path.h"
+#include "test/util/temp_umask.h"
#include "test/util/test_util.h"
namespace gvisor {
@@ -36,21 +36,12 @@ class MkdirTest : public ::testing::Test {
// TearDown unlinks created files.
void TearDown() override {
- // FIXME(edahlgren): We don't currently implement rmdir.
- // We do this unconditionally because there's no harm in trying.
- rmdir(dirname_.c_str());
+ EXPECT_THAT(rmdir(dirname_.c_str()), SyscallSucceeds());
}
std::string dirname_;
};
-TEST_F(MkdirTest, DISABLED_CanCreateReadbleDir) {
- ASSERT_THAT(mkdir(dirname_.c_str(), 0444), SyscallSucceeds());
- ASSERT_THAT(
- open(JoinPath(dirname_, "anything").c_str(), O_RDWR | O_CREAT, 0666),
- SyscallFailsWithErrno(EACCES));
-}
-
TEST_F(MkdirTest, CanCreateWritableDir) {
ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds());
std::string filename = JoinPath(dirname_, "anything");
@@ -84,10 +75,11 @@ TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) {
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
- auto parent = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0555));
- auto dir = JoinPath(parent.path(), "foo");
- ASSERT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(EACCES));
+ ASSERT_THAT(mkdir(dirname_.c_str(), 0555), SyscallSucceeds());
+ auto dir = JoinPath(dirname_.c_str(), "foo");
+ EXPECT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(open(JoinPath(dirname_, "file").c_str(), O_RDWR | O_CREAT, 0666),
+ SyscallFailsWithErrno(EACCES));
}
} // namespace
diff --git a/test/syscalls/linux/mlock.cc b/test/syscalls/linux/mlock.cc
index 367a90fe1..78ac96bed 100644
--- a/test/syscalls/linux/mlock.cc
+++ b/test/syscalls/linux/mlock.cc
@@ -199,8 +199,10 @@ TEST(MunlockallTest, Basic) {
}
#ifndef SYS_mlock2
-#ifdef __x86_64__
+#if defined(__x86_64__)
#define SYS_mlock2 325
+#elif defined(__aarch64__)
+#define SYS_mlock2 284
#endif
#endif
diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc
index 11fb1b457..6d3227ab6 100644
--- a/test/syscalls/linux/mmap.cc
+++ b/test/syscalls/linux/mmap.cc
@@ -361,7 +361,7 @@ TEST_F(MMapTest, MapFixed) {
}
// 64-bit addresses work too
-#ifdef __x86_64__
+#if defined(__x86_64__) || defined(__aarch64__)
TEST_F(MMapTest, MapFixed64) {
EXPECT_THAT(Map(0x300000000000, kPageSize, PROT_NONE,
MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0),
@@ -571,6 +571,12 @@ const uint8_t machine_code[] = {
0xb8, 0x2a, 0x00, 0x00, 0x00, // movl $42, %eax
0xc3, // retq
};
+#elif defined(__aarch64__)
+const uint8_t machine_code[] = {
+ 0x40, 0x05, 0x80, 0x52, // mov w0, #42
+ 0xc0, 0x03, 0x5f, 0xd6, // ret
+};
+#endif
// PROT_EXEC allows code execution
TEST_F(MMapTest, ProtExec) {
@@ -605,7 +611,6 @@ TEST_F(MMapTest, NoProtExecDeath) {
EXPECT_EXIT(func(), ::testing::KilledBySignal(SIGSEGV), "");
}
-#endif
TEST_F(MMapTest, NoExceedLimitData) {
void* prevbrk;
@@ -1644,6 +1649,7 @@ TEST(MMapNoFixtureTest, MapReadOnlyAfterCreateWriteOnly) {
}
// Conditional on MAP_32BIT.
+// This flag is supported only on x86-64, for 64-bit programs.
#ifdef __x86_64__
TEST(MMapNoFixtureTest, Map32Bit) {
diff --git a/test/syscalls/linux/network_namespace.cc b/test/syscalls/linux/network_namespace.cc
index 6ea48c263..133fdecf0 100644
--- a/test/syscalls/linux/network_namespace.cc
+++ b/test/syscalls/linux/network_namespace.cc
@@ -20,102 +20,33 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
-#include "absl/synchronization/notification.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/capability_util.h"
-#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
namespace gvisor {
namespace testing {
-
namespace {
-using TestFunc = std::function<PosixError()>;
-using RunFunc = std::function<PosixError(TestFunc)>;
-
-struct NamespaceStrategy {
- RunFunc run;
-
- static NamespaceStrategy Of(RunFunc run) {
- NamespaceStrategy s;
- s.run = run;
- return s;
- }
-};
+TEST(NetworkNamespaceTest, LoopbackExists) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
-PosixError RunWithUnshare(TestFunc fn) {
- PosixError err = PosixError(-1, "function did not return a value");
ScopedThread t([&] {
- if (unshare(CLONE_NEWNET) != 0) {
- err = PosixError(errno);
- return;
- }
- err = fn();
- });
- t.Join();
- return err;
-}
+ ASSERT_THAT(unshare(CLONE_NEWNET), SyscallSucceedsWithValue(0));
-PosixError RunWithClone(TestFunc fn) {
- struct Args {
- absl::Notification n;
- TestFunc fn;
- PosixError err;
- };
- Args args;
- args.fn = fn;
- args.err = PosixError(-1, "function did not return a value");
-
- ASSIGN_OR_RETURN_ERRNO(
- Mapping child_stack,
- MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
- pid_t child = clone(
- +[](void *arg) {
- Args *args = reinterpret_cast<Args *>(arg);
- args->err = args->fn();
- args->n.Notify();
- syscall(SYS_exit, 0); // Exit manually. No return address on stack.
- return 0;
- },
- reinterpret_cast<void *>(child_stack.addr() + kPageSize),
- CLONE_NEWNET | CLONE_THREAD | CLONE_SIGHAND | CLONE_VM, &args);
- if (child < 0) {
- return PosixError(errno, "clone() failed");
- }
- args.n.WaitForNotification();
- return args.err;
-}
-
-class NetworkNamespaceTest
- : public ::testing::TestWithParam<NamespaceStrategy> {};
-
-TEST_P(NetworkNamespaceTest, LoopbackExists) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
-
- EXPECT_NO_ERRNO(GetParam().run([]() {
// TODO(gvisor.dev/issue/1833): Update this to test that only "lo" exists.
// Check loopback device exists.
int sock = socket(AF_INET, SOCK_DGRAM, 0);
- if (sock < 0) {
- return PosixError(errno, "socket() failed");
- }
+ ASSERT_THAT(sock, SyscallSucceeds());
struct ifreq ifr;
- snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
- if (ioctl(sock, SIOCGIFINDEX, &ifr) < 0) {
- return PosixError(errno, "ioctl() failed, lo cannot be found");
- }
- return NoError();
- }));
+ strncpy(ifr.ifr_name, "lo", IFNAMSIZ);
+ EXPECT_THAT(ioctl(sock, SIOCGIFINDEX, &ifr), SyscallSucceeds())
+ << "lo cannot be found";
+ });
}
-INSTANTIATE_TEST_SUITE_P(
- AllNetworkNamespaceTest, NetworkNamespaceTest,
- ::testing::Values(NamespaceStrategy::Of(RunWithUnshare),
- NamespaceStrategy::Of(RunWithClone)));
-
} // namespace
-
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc
index 267ae19f6..640fe6bfc 100644
--- a/test/syscalls/linux/open.cc
+++ b/test/syscalls/linux/open.cc
@@ -186,6 +186,28 @@ TEST_F(OpenTest, OpenNoFollowStillFollowsLinksInPath) {
ASSERT_NO_ERRNO_AND_VALUE(Open(path_via_symlink, O_RDONLY | O_NOFOLLOW));
}
+// Test that open(2) can follow symlinks that point back to the same tree.
+// Test sets up files as follows:
+// root/child/symlink => redirects to ../..
+// root/child/target => regular file
+//
+// open("root/child/symlink/root/child/file")
+TEST_F(OpenTest, SymlinkRecurse) {
+ auto root =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(GetAbsoluteTestTmpdir()));
+ auto child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path()));
+ auto symlink = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(child.path(), "../.."));
+ auto target = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(child.path(), "abc", 0644));
+ auto path_via_symlink =
+ JoinPath(symlink.path(), Basename(root.path()), Basename(child.path()),
+ Basename(target.path()));
+ const auto contents =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents(path_via_symlink));
+ ASSERT_EQ(contents, "abc");
+}
+
TEST_F(OpenTest, Fault) {
char* totally_not_null = nullptr;
ASSERT_THAT(open(totally_not_null, O_RDONLY), SyscallFailsWithErrno(EFAULT));
diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc
index 902d0a0dc..51eacf3f2 100644
--- a/test/syscalls/linux/open_create.cc
+++ b/test/syscalls/linux/open_create.cc
@@ -19,11 +19,11 @@
#include <unistd.h>
#include "gtest/gtest.h"
-#include "test/syscalls/linux/temp_umask.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
#include "test/util/temp_path.h"
+#include "test/util/temp_umask.h"
#include "test/util/test_util.h"
namespace gvisor {
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index 92ae55eec..5ac68feb4 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <arpa/inet.h>
+#include <ifaddrs.h>
#include <linux/capability.h>
#include <linux/if_arp.h>
#include <linux/if_packet.h>
@@ -163,16 +164,11 @@ int CookedPacketTest::GetLoopbackIndex() {
return ifr.ifr_ifindex;
}
-// Receive via a packet socket.
-TEST_P(CookedPacketTest, Receive) {
- // Let's use a simple IP payload: a UDP datagram.
- FileDescriptor udp_sock =
- ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
- SendUDPMessage(udp_sock.get());
-
+// Receive and verify the message via packet socket on interface.
+void ReceiveMessage(int sock, int ifindex) {
// Wait for the socket to become readable.
struct pollfd pfd = {};
- pfd.fd = socket_;
+ pfd.fd = sock;
pfd.events = POLLIN;
EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1));
@@ -182,9 +178,10 @@ TEST_P(CookedPacketTest, Receive) {
char buf[64];
struct sockaddr_ll src = {};
socklen_t src_len = sizeof(src);
- ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0,
+ ASSERT_THAT(recvfrom(sock, buf, sizeof(buf), 0,
reinterpret_cast<struct sockaddr*>(&src), &src_len),
SyscallSucceedsWithValue(packet_size));
+
// sockaddr_ll ends with an 8 byte physical address field, but ethernet
// addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
// here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns
@@ -194,7 +191,7 @@ TEST_P(CookedPacketTest, Receive) {
// TODO(b/129292371): Verify protocol once we return it.
// Verify the source address.
EXPECT_EQ(src.sll_family, AF_PACKET);
- EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex());
+ EXPECT_EQ(src.sll_ifindex, ifindex);
EXPECT_EQ(src.sll_halen, ETH_ALEN);
// This came from the loopback device, so the address is all 0s.
for (int i = 0; i < src.sll_halen; i++) {
@@ -222,6 +219,18 @@ TEST_P(CookedPacketTest, Receive) {
EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0);
}
+// Receive via a packet socket.
+TEST_P(CookedPacketTest, Receive) {
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Receive and verify the data.
+ int loopback_index = GetLoopbackIndex();
+ ReceiveMessage(socket_, loopback_index);
+}
+
// Send via a packet socket.
TEST_P(CookedPacketTest, Send) {
// TODO(b/129292371): Remove once we support packet socket writing.
@@ -313,6 +322,115 @@ TEST_P(CookedPacketTest, Send) {
EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK));
}
+// Bind and receive via packet socket.
+TEST_P(CookedPacketTest, BindReceive) {
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = GetLoopbackIndex();
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Receive and verify the data.
+ ReceiveMessage(socket_, bind_addr.sll_ifindex);
+}
+
+// Double Bind socket.
+TEST_P(CookedPacketTest, DoubleBind) {
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = GetLoopbackIndex();
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Binding socket again should fail.
+ ASSERT_THAT(
+ bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ // Linux 4.09 returns EINVAL here, but some time before 4.19 it switched
+ // to EADDRINUSE.
+ AnyOf(SyscallFailsWithErrno(EADDRINUSE), SyscallFailsWithErrno(EINVAL)));
+}
+
+// Bind and verify we do not receive data on interface which is not bound
+TEST_P(CookedPacketTest, BindDrop) {
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ struct ifaddrs* if_addr_list = nullptr;
+ auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); });
+
+ ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds());
+
+ // Get interface other than loopback.
+ struct ifreq ifr = {};
+ for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) {
+ if (strcmp(i->ifa_name, "lo") != 0) {
+ strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name));
+ break;
+ }
+ }
+
+ // Skip if no interface is available other than loopback.
+ if (strlen(ifr.ifr_name) == 0) {
+ GTEST_SKIP();
+ }
+
+ // Get interface index.
+ EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_NE(ifr.ifr_ifindex, 0);
+
+ // Bind to packet socket requires only family, protocol and ifindex.
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = ifr.ifr_ifindex;
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Send to loopback interface.
+ struct sockaddr_in dest = {};
+ dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ dest.sin_family = AF_INET;
+ dest.sin_port = kPort;
+ EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0,
+ reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+
+ // Wait and make sure the socket never receives any data.
+ struct pollfd pfd = {};
+ pfd.fd = socket_;
+ pfd.events = POLLIN;
+ EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0));
+}
+
+// Bind with invalid address.
+TEST_P(CookedPacketTest, BindFail) {
+ // Null address.
+ ASSERT_THAT(
+ bind(socket_, nullptr, sizeof(struct sockaddr)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallFailsWithErrno(EINVAL)));
+
+ // Address of size 1.
+ uint8_t addr = 0;
+ ASSERT_THAT(
+ bind(socket_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest,
::testing::Values(ETH_P_IP, ETH_P_ALL));
diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc
index d8e19e910..67228b66b 100644
--- a/test/syscalls/linux/pipe.cc
+++ b/test/syscalls/linux/pipe.cc
@@ -265,6 +265,8 @@ TEST_P(PipeTest, OffsetCalls) {
SyscallFailsWithErrno(ESPIPE));
struct iovec iov;
+ iov.iov_base = &buf;
+ iov.iov_len = sizeof(buf);
EXPECT_THAT(preadv(wfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE));
EXPECT_THAT(pwritev(rfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE));
}
diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc
index c42472474..1e35a4a8b 100644
--- a/test/syscalls/linux/poll.cc
+++ b/test/syscalls/linux/poll.cc
@@ -266,7 +266,7 @@ TEST_F(PollTest, Nfds) {
}
rlim_t max_fds = rlim.rlim_cur;
- std::cout << "Using limit: " << max_fds;
+ std::cout << "Using limit: " << max_fds << std::endl;
// Create an eventfd. Since its value is initially zero, it is writable.
FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD());
diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc
index 2cecf2e5f..bcdbbb044 100644
--- a/test/syscalls/linux/pread64.cc
+++ b/test/syscalls/linux/pread64.cc
@@ -14,6 +14,7 @@
#include <errno.h>
#include <fcntl.h>
+#include <linux/unistd.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/types.h>
@@ -118,6 +119,21 @@ TEST_F(Pread64Test, EndOfFile) {
EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallSucceedsWithValue(0));
}
+int memfd_create(const std::string& name, unsigned int flags) {
+ return syscall(__NR_memfd_create, name.c_str(), flags);
+}
+
+TEST_F(Pread64Test, Overflow) {
+ int f = memfd_create("negative", 0);
+ const FileDescriptor fd(f);
+
+ EXPECT_THAT(ftruncate(fd.get(), 0x7fffffffffffffffull), SyscallSucceeds());
+
+ char buf[10];
+ EXPECT_THAT(pread64(fd.get(), buf, sizeof(buf), 0x7fffffffffffffffull),
+ SyscallFailsWithErrno(EINVAL));
+}
+
TEST(Pread64TestNoTempFile, CantReadSocketPair_NoRandomSave) {
int sock_fds[2];
EXPECT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_fds), SyscallSucceeds());
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index f91187e75..79a625ebc 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -994,7 +994,7 @@ constexpr uint64_t kMappingSize = 100 << 20;
// Tolerance on RSS comparisons to account for background thread mappings,
// reclaimed pages, newly faulted pages, etc.
-constexpr uint64_t kRSSTolerance = 5 << 20;
+constexpr uint64_t kRSSTolerance = 10 << 20;
// Capture RSS before and after an anonymous mapping with passed prot.
void MapPopulateRSS(int prot, uint64_t* before, uint64_t* after) {
@@ -1326,8 +1326,6 @@ TEST(ProcPidSymlink, SubprocessRunning) {
SyscallSucceedsWithValue(sizeof(buf)));
}
-// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
-// on proc files.
TEST(ProcPidSymlink, SubprocessZombied) {
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
@@ -1337,7 +1335,7 @@ TEST(ProcPidSymlink, SubprocessZombied) {
int want = EACCES;
if (!IsRunningOnGvisor()) {
auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion());
- if (version.major == 4 && version.minor > 3) {
+ if (version.major > 4 || (version.major == 4 && version.minor > 3)) {
want = ENOENT;
}
}
@@ -1350,30 +1348,25 @@ TEST(ProcPidSymlink, SubprocessZombied) {
SyscallFailsWithErrno(want));
}
- // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
- // on proc files.
+ // FIXME(gvisor.dev/issue/164): Inconsistent behavior between linux on proc
+ // files.
//
// ~4.3: Syscall fails with EACCES.
- // 4.17 & gVisor: Syscall succeeds and returns 1.
+ // 4.17: Syscall succeeds and returns 1.
//
- // EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)),
- // SyscallFailsWithErrno(EACCES));
+ if (!IsRunningOnGvisor()) {
+ return;
+ }
- // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
- // on proc files.
- //
- // ~4.3: Syscall fails with EACCES.
- // 4.17 & gVisor: Syscall succeeds and returns 1.
- //
- // EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)),
- // SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)),
+ SyscallFailsWithErrno(want));
+
+ EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)),
+ SyscallFailsWithErrno(want));
}
// Test whether /proc/PID/ symlinks can be read for an exited process.
TEST(ProcPidSymlink, SubprocessExited) {
- // FIXME(gvisor.dev/issue/164): These all succeed on gVisor.
- SKIP_IF(IsRunningOnGvisor());
-
char buf[1];
EXPECT_THAT(ReadlinkWhileExited("exe", buf, sizeof(buf)),
@@ -1431,6 +1424,12 @@ TEST(ProcPidFile, SubprocessRunning) {
EXPECT_THAT(ReadWhileRunning("uid_map", buf, sizeof(buf)),
SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("oom_score", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("oom_score_adj", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
}
// Test whether /proc/PID/ files can be read for a zombie process.
@@ -1466,6 +1465,12 @@ TEST(ProcPidFile, SubprocessZombie) {
EXPECT_THAT(ReadWhileZombied("uid_map", buf, sizeof(buf)),
SyscallSucceedsWithValue(sizeof(buf)));
+ EXPECT_THAT(ReadWhileZombied("oom_score", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileZombied("oom_score_adj", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
// on proc files.
//
@@ -1527,6 +1532,15 @@ TEST(ProcPidFile, SubprocessExited) {
EXPECT_THAT(ReadWhileExited("uid_map", buf, sizeof(buf)),
SyscallSucceedsWithValue(sizeof(buf)));
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME(gvisor.dev/issue/164): Succeeds on gVisor.
+ EXPECT_THAT(ReadWhileExited("oom_score", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+ }
+
+ EXPECT_THAT(ReadWhileExited("oom_score_adj", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
}
PosixError DirContainsImpl(absl::string_view path,
diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc
index 3a611a86f..cac394910 100644
--- a/test/syscalls/linux/proc_net.cc
+++ b/test/syscalls/linux/proc_net.cc
@@ -33,6 +33,31 @@ namespace gvisor {
namespace testing {
namespace {
+constexpr const char kProcNet[] = "/proc/net";
+
+TEST(ProcNetSymlinkTarget, FileMode) {
+ struct stat s;
+ ASSERT_THAT(stat(kProcNet, &s), SyscallSucceeds());
+ EXPECT_EQ(s.st_mode & S_IFMT, S_IFDIR);
+ EXPECT_EQ(s.st_mode & 0777, 0555);
+}
+
+TEST(ProcNetSymlink, FileMode) {
+ struct stat s;
+ ASSERT_THAT(lstat(kProcNet, &s), SyscallSucceeds());
+ EXPECT_EQ(s.st_mode & S_IFMT, S_IFLNK);
+ EXPECT_EQ(s.st_mode & 0777, 0777);
+}
+
+TEST(ProcNetSymlink, Contents) {
+ char buf[40] = {};
+ int n = readlink(kProcNet, buf, sizeof(buf));
+ ASSERT_THAT(n, SyscallSucceeds());
+
+ buf[n] = 0;
+ EXPECT_STREQ(buf, "self/net");
+}
+
TEST(ProcNetIfInet6, Format) {
auto ifinet6 = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/if_inet6"));
EXPECT_THAT(ifinet6,
@@ -67,6 +92,59 @@ TEST(ProcSysNetIpv4Sack, CanReadAndWrite) {
EXPECT_EQ(buf, to_write);
}
+// DeviceEntry is an entry in /proc/net/dev
+struct DeviceEntry {
+ std::string name;
+ uint64_t stats[16];
+};
+
+PosixErrorOr<std::vector<DeviceEntry>> GetDeviceMetricsFromProc(
+ const std::string dev) {
+ std::vector<std::string> lines = absl::StrSplit(dev, '\n');
+ std::vector<DeviceEntry> entries;
+
+ // /proc/net/dev prints 2 lines of headers followed by a line of metrics for
+ // each network interface.
+ for (unsigned i = 2; i < lines.size(); i++) {
+ // Ignore empty lines.
+ if (lines[i].empty()) {
+ continue;
+ }
+
+ std::vector<std::string> values =
+ absl::StrSplit(lines[i], ' ', absl::SkipWhitespace());
+
+ // Interface name + 16 values.
+ if (values.size() != 17) {
+ return PosixError(EINVAL, "invalid line: " + lines[i]);
+ }
+
+ DeviceEntry entry;
+ entry.name = values[0];
+ // Skip the interface name and read only the values.
+ for (unsigned j = 1; j < 17; j++) {
+ uint64_t num;
+ if (!absl::SimpleAtoi(values[j], &num)) {
+ return PosixError(EINVAL, "invalid value: " + values[j]);
+ }
+ entry.stats[j - 1] = num;
+ }
+
+ entries.push_back(entry);
+ }
+
+ return entries;
+}
+
+// TEST(ProcNetDev, Format) tests that /proc/net/dev is parsable and
+// contains at least one entry.
+TEST(ProcNetDev, Format) {
+ auto dev = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/dev"));
+ auto entries = ASSERT_NO_ERRNO_AND_VALUE(GetDeviceMetricsFromProc(dev));
+
+ EXPECT_GT(entries.size(), 0);
+}
+
PosixErrorOr<uint64_t> GetSNMPMetricFromProc(const std::string snmp,
const std::string& type,
const std::string& item) {
@@ -275,7 +353,7 @@ TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) {
EXPECT_EQ(oldNoPorts, newNoPorts - 1);
}
-TEST(ProcNetSnmp, UdpIn) {
+TEST(ProcNetSnmp, UdpIn_NoRandomSave) {
// TODO(gvisor.dev/issue/866): epsocket metrics are not savable.
const DisableSave ds;
diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc
index 66db0acaa..a63067586 100644
--- a/test/syscalls/linux/proc_net_unix.cc
+++ b/test/syscalls/linux/proc_net_unix.cc
@@ -106,7 +106,7 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() {
std::vector<UnixEntry> entries;
std::vector<std::string> lines = absl::StrSplit(content, '\n');
std::cerr << "<contents of /proc/net/unix>" << std::endl;
- for (std::string line : lines) {
+ for (const std::string& line : lines) {
// Emit the proc entry to the test output to provide context for the test
// results.
std::cerr << line << std::endl;
@@ -374,7 +374,7 @@ TEST(ProcNetUnix, DgramSocketStateDisconnectingOnBind) {
// corresponding entries, as they don't have an address yet.
if (IsRunningOnGvisor()) {
ASSERT_EQ(entries.size(), 2);
- for (auto e : entries) {
+ for (const auto& e : entries) {
ASSERT_EQ(e.state, SS_DISCONNECTING);
}
}
@@ -403,7 +403,7 @@ TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) {
// corresponding entries, as they don't have an address yet.
if (IsRunningOnGvisor()) {
ASSERT_EQ(entries.size(), 2);
- for (auto e : entries) {
+ for (const auto& e : entries) {
ASSERT_EQ(e.state, SS_DISCONNECTING);
}
}
diff --git a/test/syscalls/linux/proc_pid_oomscore.cc b/test/syscalls/linux/proc_pid_oomscore.cc
new file mode 100644
index 000000000..707821a3f
--- /dev/null
+++ b/test/syscalls/linux/proc_pid_oomscore.cc
@@ -0,0 +1,72 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+
+#include <exception>
+#include <iostream>
+#include <string>
+
+#include "test/util/fs_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+PosixErrorOr<int> ReadProcNumber(std::string path) {
+ ASSIGN_OR_RETURN_ERRNO(std::string contents, GetContents(path));
+ EXPECT_EQ(contents[contents.length() - 1], '\n');
+
+ int num;
+ if (!absl::SimpleAtoi(contents, &num)) {
+ return PosixError(EINVAL, "invalid value: " + contents);
+ }
+
+ return num;
+}
+
+TEST(ProcPidOomscoreTest, BasicRead) {
+ auto const oom_score =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score"));
+ EXPECT_LE(oom_score, 1000);
+ EXPECT_GE(oom_score, -1000);
+}
+
+TEST(ProcPidOomscoreAdjTest, BasicRead) {
+ auto const oom_score =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj"));
+
+ // oom_score_adj defaults to 0.
+ EXPECT_EQ(oom_score, 0);
+}
+
+TEST(ProcPidOomscoreAdjTest, BasicWrite) {
+ constexpr int test_value = 7;
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/oom_score_adj", O_WRONLY));
+ ASSERT_THAT(
+ RetryEINTR(write)(fd.get(), std::to_string(test_value).c_str(), 1),
+ SyscallSucceeds());
+
+ auto const oom_score =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj"));
+ EXPECT_EQ(oom_score, test_value);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/proc_pid_smaps.cc b/test/syscalls/linux/proc_pid_smaps.cc
index 7f2e8f203..9fb1b3a2c 100644
--- a/test/syscalls/linux/proc_pid_smaps.cc
+++ b/test/syscalls/linux/proc_pid_smaps.cc
@@ -173,7 +173,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps(
return;
}
unknown_fields.insert(std::string(key));
- std::cerr << "skipping unknown smaps field " << key;
+ std::cerr << "skipping unknown smaps field " << key << std::endl;
};
auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty());
@@ -191,7 +191,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps(
// amount of whitespace).
if (!entry) {
std::cerr << "smaps line not considered a maps line: "
- << maybe_maps_entry.error_message();
+ << maybe_maps_entry.error_message() << std::endl;
return PosixError(
EINVAL,
absl::StrCat("smaps field line without preceding maps line: ", l));
diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc
index bfe3e2603..926690eb8 100644
--- a/test/syscalls/linux/ptrace.cc
+++ b/test/syscalls/linux/ptrace.cc
@@ -400,9 +400,11 @@ TEST(PtraceTest, GetRegSet) {
// Read exactly the full register set.
EXPECT_EQ(iov.iov_len, sizeof(regs));
-#ifdef __x86_64__
+#if defined(__x86_64__)
// Child called kill(2), with SIGSTOP as arg 2.
EXPECT_EQ(regs.rsi, SIGSTOP);
+#elif defined(__aarch64__)
+ EXPECT_EQ(regs.regs[1], SIGSTOP);
#endif
// Suppress SIGSTOP and resume the child.
@@ -752,15 +754,23 @@ TEST(PtraceTest,
SyscallSucceeds());
EXPECT_TRUE(siginfo.si_code == SIGTRAP || siginfo.si_code == (SIGTRAP | 0x80))
<< "si_code = " << siginfo.si_code;
-#ifdef __x86_64__
+
{
struct user_regs_struct regs = {};
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+#if defined(__x86_64__)
EXPECT_TRUE(regs.orig_rax == SYS_vfork || regs.orig_rax == SYS_clone)
<< "orig_rax = " << regs.orig_rax;
EXPECT_EQ(grandchild_pid, regs.rax);
- }
+#elif defined(__aarch64__)
+ EXPECT_TRUE(regs.regs[8] == SYS_clone) << "regs[8] = " << regs.regs[8];
+ EXPECT_EQ(grandchild_pid, regs.regs[0]);
#endif // defined(__x86_64__)
+ }
// After this point, the child will be making wait4 syscalls that will be
// interrupted by saving, so saving is not permitted. Note that this is
@@ -805,14 +815,21 @@ TEST(PtraceTest,
SyscallSucceedsWithValue(child_pid));
EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80))
<< " status " << status;
-#ifdef __x86_64__
{
struct user_regs_struct regs = {};
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+#if defined(__x86_64__)
EXPECT_EQ(SYS_wait4, regs.orig_rax);
EXPECT_EQ(grandchild_pid, regs.rax);
- }
+#elif defined(__aarch64__)
+ EXPECT_EQ(SYS_wait4, regs.regs[8]);
+ EXPECT_EQ(grandchild_pid, regs.regs[0]);
#endif // defined(__x86_64__)
+ }
// Detach from the child and wait for it to exit.
ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds());
@@ -1188,7 +1205,7 @@ TEST(PtraceTest, SeizeSetOptions) {
// gVisor is not susceptible to this race because
// kernel.Task.waitCollectTraceeStopLocked() checks specifically for an
// active ptraceStop, which is not initiated if SIGKILL is pending.
- std::cout << "Observed syscall-exit after SIGKILL";
+ std::cout << "Observed syscall-exit after SIGKILL" << std::endl;
ASSERT_THAT(waitpid(child_pid, &status, 0),
SyscallSucceedsWithValue(child_pid));
}
diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc
index dafe64d20..b8a0159ba 100644
--- a/test/syscalls/linux/pty.cc
+++ b/test/syscalls/linux/pty.cc
@@ -1126,7 +1126,7 @@ TEST_F(PtyTest, SwitchTwiceMultiline) {
std::string kExpected = "GO\nBLUE\n!";
// Write each line.
- for (std::string input : kInputs) {
+ for (const std::string& input : kInputs) {
ASSERT_THAT(WriteFd(master_.get(), input.c_str(), input.size()),
SyscallSucceedsWithValue(input.size()));
}
diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc
index b48fe540d..e69794910 100644
--- a/test/syscalls/linux/pwrite64.cc
+++ b/test/syscalls/linux/pwrite64.cc
@@ -14,6 +14,7 @@
#include <errno.h>
#include <fcntl.h>
+#include <linux/unistd.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
@@ -27,14 +28,7 @@ namespace testing {
namespace {
-// This test is currently very rudimentary.
-//
-// TODO(edahlgren):
-// * bad buffer states (EFAULT).
-// * bad fds (wrong permission, wrong type of file, EBADF).
-// * check offset is not incremented.
-// * check for EOF.
-// * writing to pipes, symlinks, special files.
+// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary.
class Pwrite64 : public ::testing::Test {
void SetUp() override {
name_ = NewTempAbsPath();
@@ -72,6 +66,17 @@ TEST_F(Pwrite64, InvalidArgs) {
EXPECT_THAT(close(fd), SyscallSucceeds());
}
+TEST_F(Pwrite64, Overflow) {
+ int fd;
+ ASSERT_THAT(fd = open(name_.c_str(), O_APPEND | O_RDWR), SyscallSucceeds());
+ constexpr int64_t kBufSize = 1024;
+ std::vector<char> buf(kBufSize);
+ std::fill(buf.begin(), buf.end(), 'a');
+ EXPECT_THAT(PwriteFd(fd, buf.data(), buf.size(), 0x7fffffffffffffffull),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/seccomp.cc b/test/syscalls/linux/seccomp.cc
index cf6499f8b..ce88d90dd 100644
--- a/test/syscalls/linux/seccomp.cc
+++ b/test/syscalls/linux/seccomp.cc
@@ -53,7 +53,7 @@ namespace {
constexpr uint32_t kFilteredSyscall = SYS_vserver;
#elif __aarch64__
// Use the last of arch_specific_syscalls which are not implemented on arm64.
-constexpr uint32_t kFilteredSyscall = SYS_arch_specific_syscall + 15;
+constexpr uint32_t kFilteredSyscall = __NR_arch_specific_syscall + 15;
#endif
// Applies a seccomp-bpf filter that returns `filtered_result` for
@@ -70,20 +70,27 @@ void ApplySeccompFilter(uint32_t sysno, uint32_t filtered_result,
MaybeSave();
struct sock_filter filter[] = {
- // A = seccomp_data.arch
- BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 4),
- // if (A != AUDIT_ARCH_X86_64) goto kill
- BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_X86_64, 0, 4),
- // A = seccomp_data.nr
- BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 0),
- // if (A != sysno) goto allow
- BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, sysno, 0, 1),
- // return filtered_result
- BPF_STMT(BPF_RET | BPF_K, filtered_result),
- // allow: return SECCOMP_RET_ALLOW
- BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_ALLOW),
- // kill: return SECCOMP_RET_KILL
- BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_KILL),
+ // A = seccomp_data.arch
+ BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 4),
+#if defined(__x86_64__)
+ // if (A != AUDIT_ARCH_X86_64) goto kill
+ BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_X86_64, 0, 4),
+#elif defined(__aarch64__)
+ // if (A != AUDIT_ARCH_AARCH64) goto kill
+ BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_AARCH64, 0, 4),
+#else
+#error "Unknown architecture"
+#endif
+ // A = seccomp_data.nr
+ BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 0),
+ // if (A != sysno) goto allow
+ BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, sysno, 0, 1),
+ // return filtered_result
+ BPF_STMT(BPF_RET | BPF_K, filtered_result),
+ // allow: return SECCOMP_RET_ALLOW
+ BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_ALLOW),
+ // kill: return SECCOMP_RET_KILL
+ BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_KILL),
};
struct sock_fprog prog;
prog.len = ABSL_ARRAYSIZE(filter);
@@ -179,9 +186,12 @@ TEST(SeccompTest, RetTrapCausesSIGSYS) {
TEST_CHECK(info->si_errno == kTrapValue);
TEST_CHECK(info->si_call_addr != nullptr);
TEST_CHECK(info->si_syscall == kFilteredSyscall);
-#ifdef __x86_64__
+#if defined(__x86_64__)
TEST_CHECK(info->si_arch == AUDIT_ARCH_X86_64);
TEST_CHECK(uc->uc_mcontext.gregs[REG_RAX] == kFilteredSyscall);
+#elif defined(__aarch64__)
+ TEST_CHECK(info->si_arch == AUDIT_ARCH_AARCH64);
+ TEST_CHECK(uc->uc_mcontext.regs[8] == kFilteredSyscall);
#endif // defined(__x86_64__)
_exit(0);
});
diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc
index 580ab5193..64123e904 100644
--- a/test/syscalls/linux/sendfile.cc
+++ b/test/syscalls/linux/sendfile.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <fcntl.h>
+#include <linux/unistd.h>
#include <sys/eventfd.h>
#include <sys/sendfile.h>
#include <unistd.h>
@@ -70,6 +71,28 @@ TEST(SendFileTest, InvalidOffset) {
SyscallFailsWithErrno(EINVAL));
}
+int memfd_create(const std::string& name, unsigned int flags) {
+ return syscall(__NR_memfd_create, name.c_str(), flags);
+}
+
+TEST(SendFileTest, Overflow) {
+ // Create input file.
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file.
+ int fd;
+ EXPECT_THAT(fd = memfd_create("overflow", 0), SyscallSucceeds());
+ const FileDescriptor outf(fd);
+
+ // out_offset + kSize overflows INT64_MAX.
+ loff_t out_offset = 0x7ffffffffffffffeull;
+ constexpr int kSize = 3;
+ EXPECT_THAT(sendfile(outf.get(), inf.get(), &out_offset, kSize),
+ SyscallFailsWithErrno(EINVAL));
+}
+
TEST(SendFileTest, SendTrivially) {
// Create temp files.
constexpr char kData[] = "To be, or not to be, that is the question:";
@@ -530,6 +553,34 @@ TEST(SendFileTest, SendToSpecialFile) {
SyscallSucceedsWithValue(kSize & (~7)));
}
+TEST(SendFileTest, SendFileToPipe) {
+ // Create temp file.
+ constexpr char kData[] = "<insert-quote-here>";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Create a pipe for sending to a pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Expect to read up to the given size.
+ std::vector<char> buf(kDataSize);
+ ScopedThread t([&]() {
+ absl::SleepFor(absl::Milliseconds(100));
+ ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kDataSize));
+ });
+
+ // Send with twice the size of the file, which should hit EOF.
+ EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize * 2),
+ SyscallSucceedsWithValue(kDataSize));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/sendfile_socket.cc b/test/syscalls/linux/sendfile_socket.cc
index 8f7ee4163..c101fe9d2 100644
--- a/test/syscalls/linux/sendfile_socket.cc
+++ b/test/syscalls/linux/sendfile_socket.cc
@@ -23,6 +23,7 @@
#include "gtest/gtest.h"
#include "absl/strings/string_view.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/temp_path.h"
@@ -35,61 +36,39 @@ namespace {
class SendFileTest : public ::testing::TestWithParam<int> {
protected:
- PosixErrorOr<std::tuple<int, int>> Sockets() {
+ PosixErrorOr<std::unique_ptr<SocketPair>> Sockets(int type) {
// Bind a server socket.
int family = GetParam();
- struct sockaddr server_addr = {};
switch (family) {
case AF_INET: {
- struct sockaddr_in* server_addr_in =
- reinterpret_cast<struct sockaddr_in*>(&server_addr);
- server_addr_in->sin_family = family;
- server_addr_in->sin_addr.s_addr = INADDR_ANY;
- break;
+ if (type == SOCK_STREAM) {
+ return SocketPairKind{
+ "TCP", AF_INET, type, 0,
+ TCPAcceptBindSocketPairCreator(AF_INET, type, 0, false)}
+ .Create();
+ } else {
+ return SocketPairKind{
+ "UDP", AF_INET, type, 0,
+ UDPBidirectionalBindSocketPairCreator(AF_INET, type, 0, false)}
+ .Create();
+ }
}
case AF_UNIX: {
- struct sockaddr_un* server_addr_un =
- reinterpret_cast<struct sockaddr_un*>(&server_addr);
- server_addr_un->sun_family = family;
- server_addr_un->sun_path[0] = '\0';
- break;
+ if (type == SOCK_STREAM) {
+ return SocketPairKind{
+ "UNIX", AF_UNIX, type, 0,
+ FilesystemAcceptBindSocketPairCreator(AF_UNIX, type, 0)}
+ .Create();
+ } else {
+ return SocketPairKind{
+ "UNIX", AF_UNIX, type, 0,
+ FilesystemBidirectionalBindSocketPairCreator(AF_UNIX, type, 0)}
+ .Create();
+ }
}
default:
return PosixError(EINVAL);
}
- int server = socket(family, SOCK_STREAM, 0);
- if (bind(server, &server_addr, sizeof(server_addr)) < 0) {
- return PosixError(errno);
- }
- if (listen(server, 1) < 0) {
- close(server);
- return PosixError(errno);
- }
-
- // Fetch the address; both are anonymous.
- socklen_t length = sizeof(server_addr);
- if (getsockname(server, &server_addr, &length) < 0) {
- close(server);
- return PosixError(errno);
- }
-
- // Connect the client.
- int client = socket(family, SOCK_STREAM, 0);
- if (connect(client, &server_addr, length) < 0) {
- close(server);
- close(client);
- return PosixError(errno);
- }
-
- // Accept on the server.
- int server_client = accept(server, nullptr, 0);
- if (server_client < 0) {
- close(server);
- close(client);
- return PosixError(errno);
- }
- close(server);
- return std::make_tuple(client, server_client);
}
};
@@ -106,9 +85,7 @@ TEST_P(SendFileTest, SendMultiple) {
const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
// Create sockets.
- std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets());
- const FileDescriptor server(std::get<0>(fds));
- FileDescriptor client(std::get<1>(fds)); // non-const, reset is used.
+ auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM));
// Thread that reads data from socket and dumps to a file.
ScopedThread th([&] {
@@ -118,7 +95,7 @@ TEST_P(SendFileTest, SendMultiple) {
// Read until socket is closed.
char buf[10240];
for (int cnt = 0;; cnt++) {
- int r = RetryEINTR(read)(server.get(), buf, sizeof(buf));
+ int r = RetryEINTR(read)(socks->first_fd(), buf, sizeof(buf));
// We cannot afford to save on every read() call.
if (cnt % 1000 == 0) {
ASSERT_THAT(r, SyscallSucceeds());
@@ -149,10 +126,10 @@ TEST_P(SendFileTest, SendMultiple) {
for (size_t sent = 0; sent < data.size(); cnt++) {
const size_t remain = data.size() - sent;
std::cout << "sendfile, size=" << data.size() << ", sent=" << sent
- << ", remain=" << remain;
+ << ", remain=" << remain << std::endl;
// Send data and verify that sendfile returns the correct value.
- int res = sendfile(client.get(), inf.get(), nullptr, remain);
+ int res = sendfile(socks->second_fd(), inf.get(), nullptr, remain);
// We cannot afford to save on every sendfile() call.
if (cnt % 120 == 0) {
MaybeSave();
@@ -169,7 +146,7 @@ TEST_P(SendFileTest, SendMultiple) {
}
// Close socket to stop thread.
- client.reset();
+ close(socks->release_second_fd());
th.Join();
// Verify that the output file has the correct data.
@@ -183,9 +160,7 @@ TEST_P(SendFileTest, SendMultiple) {
TEST_P(SendFileTest, Shutdown) {
// Create a socket.
- std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets());
- const FileDescriptor client(std::get<0>(fds));
- FileDescriptor server(std::get<1>(fds)); // non-const, reset below.
+ auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM));
// If this is a TCP socket, then turn off linger.
if (GetParam() == AF_INET) {
@@ -193,7 +168,7 @@ TEST_P(SendFileTest, Shutdown) {
sl.l_onoff = 1;
sl.l_linger = 0;
ASSERT_THAT(
- setsockopt(server.get(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)),
+ setsockopt(socks->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)),
SyscallSucceeds());
}
@@ -212,12 +187,12 @@ TEST_P(SendFileTest, Shutdown) {
ScopedThread t([&]() {
size_t done = 0;
while (done < data.size()) {
- int n = RetryEINTR(read)(server.get(), data.data(), data.size());
+ int n = RetryEINTR(read)(socks->first_fd(), data.data(), data.size());
ASSERT_THAT(n, SyscallSucceeds());
done += n;
}
// Close the server side socket.
- server.reset();
+ close(socks->release_first_fd());
});
// Continuously stream from the file to the socket. Note we do not assert
@@ -225,7 +200,7 @@ TEST_P(SendFileTest, Shutdown) {
// data is written. Eventually, we should get a connection reset error.
while (1) {
off_t offset = 0; // Always read from the start.
- int n = sendfile(client.get(), inf.get(), &offset, data.size());
+ int n = sendfile(socks->second_fd(), inf.get(), &offset, data.size());
EXPECT_THAT(n, AnyOf(SyscallFailsWithErrno(ECONNRESET),
SyscallFailsWithErrno(EPIPE), SyscallSucceeds()));
if (n <= 0) {
@@ -234,6 +209,20 @@ TEST_P(SendFileTest, Shutdown) {
}
}
+TEST_P(SendFileTest, SendpageFromEmptyFileToUDP) {
+ auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_DGRAM));
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ // The value to the count argument has to be so that it is impossible to
+ // allocate a buffer of this size. In Linux, sendfile transfer at most
+ // 0x7ffff000 (MAX_RW_COUNT) bytes.
+ EXPECT_THAT(sendfile(socks->first_fd(), fd.get(), 0x0, 0x8000000000004),
+ SyscallSucceedsWithValue(0));
+}
+
INSTANTIATE_TEST_SUITE_P(AddressFamily, SendFileTest,
::testing::Values(AF_UNIX, AF_INET));
diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc
index e8f24a59e..f7d6139f1 100644
--- a/test/syscalls/linux/socket_generic.cc
+++ b/test/syscalls/linux/socket_generic.cc
@@ -447,6 +447,60 @@ TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgSucceeds) {
SyscallFailsWithErrno(EAGAIN));
}
+TEST_P(AllSocketPairTest, SendTimeoutDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval actual_tv = {.tv_sec = -1, .tv_usec = -1};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv_usec, 0);
+}
+
+TEST_P(AllSocketPairTest, SetGetSendTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval tv = {.tv_sec = 89, .tv_usec = 42000};
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ timeval actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 89);
+ EXPECT_EQ(actual_tv.tv_usec, 42000);
+}
+
+TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval_with_extra {
+ struct timeval tv;
+ int64_t extra_data;
+ } ABSL_ATTRIBUTE_PACKED;
+
+ timeval_with_extra tv_extra = {
+ .tv = {.tv_sec = 0, .tv_usec = 123000},
+ };
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &tv_extra, sizeof(tv_extra)),
+ SyscallSucceeds());
+
+ timeval_with_extra actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv.tv_usec, 123000);
+}
+
TEST_P(AllSocketPairTest, SendTimeoutAllowsWrite) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -491,18 +545,36 @@ TEST_P(AllSocketPairTest, SendTimeoutAllowsSendmsg) {
ASSERT_NO_FATAL_FAILURE(SendNullCmsg(sockets->first_fd(), buf, sizeof(buf)));
}
-TEST_P(AllSocketPairTest, SoRcvTimeoIsSet) {
+TEST_P(AllSocketPairTest, RecvTimeoutDefault) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
- struct timeval tv {
- .tv_sec = 0, .tv_usec = 35
- };
+ timeval actual_tv = {.tv_sec = -1, .tv_usec = -1};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv_usec, 0);
+}
+
+TEST_P(AllSocketPairTest, SetGetRecvTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval tv = {.tv_sec = 123, .tv_usec = 456000};
EXPECT_THAT(
setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
SyscallSucceeds());
+
+ timeval actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 123);
+ EXPECT_EQ(actual_tv.tv_usec, 456000);
}
-TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) {
+TEST_P(AllSocketPairTest, SetGetRecvTimeoutLargerArg) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
struct timeval_with_extra {
@@ -510,13 +582,21 @@ TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) {
int64_t extra_data;
} ABSL_ATTRIBUTE_PACKED;
- timeval_with_extra tv_extra;
- tv_extra.tv.tv_sec = 0;
- tv_extra.tv.tv_usec = 25;
+ timeval_with_extra tv_extra = {
+ .tv = {.tv_sec = 0, .tv_usec = 432000},
+ };
EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
&tv_extra, sizeof(tv_extra)),
SyscallSucceeds());
+
+ timeval_with_extra actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv.tv_usec, 432000);
}
TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgOneSecondSucceeds) {
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index b24618a88..d3000dbc6 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -234,7 +234,7 @@ TEST_P(DualStackSocketTest, AddressOperations) {
}
}
-// TODO(gvisor.dev/issues/1556): uncomment V4MappedAny.
+// TODO(gvisor.dev/issue/1556): uncomment V4MappedAny.
INSTANTIATE_TEST_SUITE_P(
All, DualStackSocketTest,
::testing::Combine(
@@ -319,17 +319,14 @@ TEST_P(SocketInetLoopbackTest, TCPListenUnbound) {
tcpSimpleConnectTest(listener, connector, false);
}
-TEST_P(SocketInetLoopbackTest, TCPListenClose) {
+TEST_P(SocketInetLoopbackTest, TCPListenShutdown) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
- constexpr int kAcceptCount = 32;
- constexpr int kBacklog = kAcceptCount * 2;
- constexpr int kFDs = 128;
- constexpr int kThreadCount = 4;
- constexpr int kFDsPerThread = kFDs / kThreadCount;
+ constexpr int kBacklog = 2;
+ constexpr int kFDs = kBacklog + 1;
// Create the listening socket.
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
@@ -348,39 +345,167 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
- DisableSave ds; // Too many system calls.
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- FileDescriptor clients[kFDs];
- std::unique_ptr<ScopedThread> threads[kThreadCount];
+
+ // Shutdown the write of the listener, expect to not have any effect.
+ ASSERT_THAT(shutdown(listen_fd.get(), SHUT_WR), SyscallSucceeds());
+
for (int i = 0; i < kFDs; i++) {
- clients[i] = ASSERT_NO_ERRNO_AND_VALUE(
- Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds());
}
- for (int i = 0; i < kThreadCount; i++) {
- threads[i] = absl::make_unique<ScopedThread>([&connector, &conn_addr,
- &clients, i]() {
- for (int j = 0; j < kFDsPerThread; j++) {
- int k = i * kFDsPerThread + j;
- int ret =
- connect(clients[k].get(), reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len);
- if (ret != 0) {
- EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
- }
- }
- });
+
+ // Shutdown the read of the listener, expect to fail subsequent
+ // server accepts, binds and client connects.
+ ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds());
+
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Check that shutdown did not release the port.
+ FileDescriptor new_listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ bind(new_listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Check that subsequent connection attempts receive a RST.
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ for (int i = 0; i < kFDs; i++) {
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallFailsWithErrno(ECONNREFUSED));
}
- for (int i = 0; i < kThreadCount; i++) {
- threads[i]->Join();
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenClose) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kAcceptCount = 2;
+ constexpr int kBacklog = kAcceptCount + 2;
+ constexpr int kFDs = kBacklog * 3;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ std::vector<FileDescriptor> clients;
+ for (int i = 0; i < kFDs; i++) {
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ }
+ clients.push_back(std::move(client));
}
for (int i = 0; i < kAcceptCount; i++) {
auto accepted =
ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
}
- // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked
- // before function end.
- // ds.reset();
+}
+
+void TestListenWhileConnect(const TestParam& param,
+ void (*stopListen)(FileDescriptor&)) {
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kBacklog = 2;
+ constexpr int kClients = kBacklog + 1;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ std::vector<FileDescriptor> clients;
+ for (int i = 0; i < kClients; i++) {
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ clients.push_back(std::move(client));
+ }
+ }
+
+ stopListen(listen_fd);
+
+ for (auto& client : clients) {
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = client.get(),
+ .events = POLLIN,
+ };
+ // When the listening socket is closed, then we expect the remote to reset
+ // the connection.
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR);
+ char c;
+ // Subsequent read can fail with:
+ // ECONNRESET: If the client connection was established and was reset by the
+ // remote.
+ // ECONNREFUSED: If the client connection failed to be established.
+ ASSERT_THAT(read(client.get(), &c, sizeof(c)),
+ AnyOf(SyscallFailsWithErrno(ECONNRESET),
+ SyscallFailsWithErrno(ECONNREFUSED)));
+ }
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) {
+ TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(close(f.release()), SyscallSucceeds());
+ });
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) {
+ TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds());
+ });
}
TEST_P(SocketInetLoopbackTest, TCPbacklog) {
@@ -605,15 +730,23 @@ TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) {
&conn_addrlen),
SyscallSucceeds());
- constexpr int kTCPLingerTimeout = 5;
- EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2,
- &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)),
- SyscallSucceedsWithValue(0));
+ // Disable cooperative saves after this point as TCP timers are not restored
+ // across a S/R.
+ {
+ DisableSave ds;
+ constexpr int kTCPLingerTimeout = 5;
+ EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2,
+ &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)),
+ SyscallSucceedsWithValue(0));
- // close the connecting FD to trigger FIN_WAIT2 on the connected fd.
- conn_fd.reset();
+ // close the connecting FD to trigger FIN_WAIT2 on the connected fd.
+ conn_fd.reset();
+
+ absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1));
- absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1));
+ // ds going out of scope will Re-enable S/R's since at this point the timer
+ // must have fired and cleaned up the endpoint.
+ }
// Now bind and connect a new socket and verify that we can immediately
// rebind the address bound by the conn_fd as it never entered TIME_WAIT.
@@ -1082,6 +1215,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
if (connects_received >= kConnectAttempts) {
// Another thread have shutdown our read side causing the
// accept to fail.
+ ASSERT_EQ(errno, EINVAL);
break;
}
ASSERT_NO_ERRNO(fd);
@@ -1149,7 +1283,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
}
-TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) {
+TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
@@ -1262,7 +1396,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) {
EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
}
-TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) {
+TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
@@ -2138,8 +2272,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) {
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(connect(connected_fd.get(),
- reinterpret_cast<sockaddr*>(&bound_addr), bound_addr_len),
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&bound_addr),
+ bound_addr_len),
SyscallSucceeds());
// Get the ephemeral port.
@@ -2204,7 +2339,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) {
setsockopt(fd2, SOL_SOCKET, SO_REUSEPORT, &portreuse2, sizeof(int)),
SyscallSucceeds());
- std::cout << portreuse1 << " " << portreuse2;
+ std::cout << portreuse1 << " " << portreuse2 << std::endl;
int ret = bind(fd2, reinterpret_cast<sockaddr*>(&addr), addrlen);
// Verify that two sockets can be bound to the same port only if
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
index 40e673625..d690d9564 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
@@ -45,37 +45,31 @@ void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() {
got_if_infos_ = false;
// Get interface list.
- std::vector<std::string> if_names;
ASSERT_NO_ERRNO(if_helper_.Load());
- if_names = if_helper_.InterfaceList(AF_INET);
+ std::vector<std::string> if_names = if_helper_.InterfaceList(AF_INET);
if (if_names.size() != 2) {
return;
}
// Figure out which interface is where.
- int lo = 0, eth = 1;
- if (if_names[lo] != "lo") {
- lo = 1;
- eth = 0;
- }
-
- if (if_names[lo] != "lo") {
- return;
- }
-
- lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[lo]));
- lo_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[lo]);
- if (lo_if_addr_ == nullptr) {
+ std::string lo = if_names[0];
+ std::string eth = if_names[1];
+ if (lo != "lo") std::swap(lo, eth);
+ if (lo != "lo") return;
+
+ lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(lo));
+ auto lo_if_addr = if_helper_.GetAddr(AF_INET, lo);
+ if (lo_if_addr == nullptr) {
return;
}
- lo_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(lo_if_addr_)->sin_addr;
+ lo_if_addr_ = *reinterpret_cast<const sockaddr_in*>(lo_if_addr);
- eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[eth]));
- eth_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[eth]);
- if (eth_if_addr_ == nullptr) {
+ eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(eth));
+ auto eth_if_addr = if_helper_.GetAddr(AF_INET, eth);
+ if (eth_if_addr == nullptr) {
return;
}
- eth_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(eth_if_addr_)->sin_addr;
+ eth_if_addr_ = *reinterpret_cast<const sockaddr_in*>(eth_if_addr);
got_if_infos_ = true;
}
@@ -242,7 +236,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Bind the non-receiving socket to the unicast ethernet address.
auto norecv_addr = rcv1_addr;
reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr =
- eth_if_sin_addr_;
+ eth_if_addr_.sin_addr;
ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr),
norecv_addr.addr_len),
SyscallSucceedsWithValue(0));
@@ -1028,7 +1022,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
ip_mreqn iface = {};
iface.imr_ifindex = lo_if_idx_;
- iface.imr_address = eth_if_sin_addr_;
+ iface.imr_address = eth_if_addr_.sin_addr;
ASSERT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
sizeof(iface)),
SyscallSucceeds());
@@ -1058,7 +1052,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
SKIP_IF(IsRunningOnGvisor());
// Verify the received source address.
- EXPECT_EQ(eth_if_sin_addr_.s_addr, src_addr_in->sin_addr.s_addr);
+ EXPECT_EQ(eth_if_addr_.sin_addr.s_addr, src_addr_in->sin_addr.s_addr);
}
// Check that when we are bound to one interface we can set IP_MULTICAST_IF to
@@ -1075,7 +1069,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Create sender and bind to eth interface.
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- ASSERT_THAT(bind(sender->get(), eth_if_addr_, sizeof(sockaddr_in)),
+ ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&eth_if_addr_),
+ sizeof(eth_if_addr_)),
SyscallSucceeds());
// Run through all possible combinations of index and address for
@@ -1085,9 +1080,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
struct in_addr imr_address;
} test_data[] = {
{lo_if_idx_, {}},
- {0, lo_if_sin_addr_},
- {lo_if_idx_, lo_if_sin_addr_},
- {lo_if_idx_, eth_if_sin_addr_},
+ {0, lo_if_addr_.sin_addr},
+ {lo_if_idx_, lo_if_addr_.sin_addr},
+ {lo_if_idx_, eth_if_addr_.sin_addr},
};
for (auto t : test_data) {
ip_mreqn iface = {};
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
index bec2e96ee..10b90b1e0 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
@@ -36,10 +36,8 @@ class IPv4UDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest {
// Interface infos.
int lo_if_idx_;
int eth_if_idx_;
- sockaddr* lo_if_addr_;
- sockaddr* eth_if_addr_;
- in_addr lo_if_sin_addr_;
- in_addr eth_if_sin_addr_;
+ sockaddr_in lo_if_addr_;
+ sockaddr_in eth_if_addr_;
};
} // namespace testing
diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc
index e5aed1eec..fbe61c5a0 100644
--- a/test/syscalls/linux/socket_netlink_route.cc
+++ b/test/syscalls/linux/socket_netlink_route.cc
@@ -26,7 +26,7 @@
#include "gtest/gtest.h"
#include "absl/strings/str_format.h"
-#include "absl/types/optional.h"
+#include "test/syscalls/linux/socket_netlink_route_util.h"
#include "test/syscalls/linux/socket_netlink_util.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/capability_util.h"
@@ -118,24 +118,6 @@ void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) {
// TODO(mpratt): Check ifinfomsg contents and following attrs.
}
-PosixError DumpLinks(
- const FileDescriptor& fd, uint32_t seq,
- const std::function<void(const struct nlmsghdr* hdr)>& fn) {
- struct request {
- struct nlmsghdr hdr;
- struct ifinfomsg ifm;
- };
-
- struct request req = {};
- req.hdr.nlmsg_len = sizeof(req);
- req.hdr.nlmsg_type = RTM_GETLINK;
- req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
- req.hdr.nlmsg_seq = seq;
- req.ifm.ifi_family = AF_UNSPEC;
-
- return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false);
-}
-
TEST(NetlinkRouteTest, GetLinkDump) {
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
@@ -152,7 +134,7 @@ TEST(NetlinkRouteTest, GetLinkDump) {
const struct ifinfomsg* msg =
reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
std::cout << "Found interface idx=" << msg->ifi_index
- << ", type=" << std::hex << msg->ifi_type;
+ << ", type=" << std::hex << msg->ifi_type << std::endl;
if (msg->ifi_type == ARPHRD_LOOPBACK) {
loopbackFound = true;
EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0);
@@ -161,37 +143,6 @@ TEST(NetlinkRouteTest, GetLinkDump) {
EXPECT_TRUE(loopbackFound);
}
-struct Link {
- int index;
- std::string name;
-};
-
-PosixErrorOr<absl::optional<Link>> FindLoopbackLink() {
- ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
-
- absl::optional<Link> link;
- RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) {
- if (hdr->nlmsg_type != RTM_NEWLINK ||
- hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) {
- return;
- }
- const struct ifinfomsg* msg =
- reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
- if (msg->ifi_type == ARPHRD_LOOPBACK) {
- const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME);
- if (rta == nullptr) {
- // Ignore links that do not have a name.
- return;
- }
-
- link = Link();
- link->index = msg->ifi_index;
- link->name = std::string(reinterpret_cast<const char*>(RTA_DATA(rta)));
- }
- }));
- return link;
-}
-
// CheckLinkMsg checks a netlink message against an expected link.
void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) {
ASSERT_THAT(hdr->nlmsg_type, Eq(RTM_NEWLINK));
@@ -209,9 +160,7 @@ void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) {
}
TEST(NetlinkRouteTest, GetLinkByIndex) {
- absl::optional<Link> loopback_link =
- ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink());
- ASSERT_TRUE(loopback_link.has_value());
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
@@ -227,13 +176,13 @@ TEST(NetlinkRouteTest, GetLinkByIndex) {
req.hdr.nlmsg_flags = NLM_F_REQUEST;
req.hdr.nlmsg_seq = kSeq;
req.ifm.ifi_family = AF_UNSPEC;
- req.ifm.ifi_index = loopback_link->index;
+ req.ifm.ifi_index = loopback_link.index;
bool found = false;
ASSERT_NO_ERRNO(NetlinkRequestResponse(
fd, &req, sizeof(req),
[&](const struct nlmsghdr* hdr) {
- CheckLinkMsg(hdr, *loopback_link);
+ CheckLinkMsg(hdr, loopback_link);
found = true;
},
false));
@@ -241,9 +190,7 @@ TEST(NetlinkRouteTest, GetLinkByIndex) {
}
TEST(NetlinkRouteTest, GetLinkByName) {
- absl::optional<Link> loopback_link =
- ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink());
- ASSERT_TRUE(loopback_link.has_value());
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
@@ -262,8 +209,8 @@ TEST(NetlinkRouteTest, GetLinkByName) {
req.hdr.nlmsg_seq = kSeq;
req.ifm.ifi_family = AF_UNSPEC;
req.rtattr.rta_type = IFLA_IFNAME;
- req.rtattr.rta_len = RTA_LENGTH(loopback_link->name.size() + 1);
- strncpy(req.ifname, loopback_link->name.c_str(), sizeof(req.ifname));
+ req.rtattr.rta_len = RTA_LENGTH(loopback_link.name.size() + 1);
+ strncpy(req.ifname, loopback_link.name.c_str(), sizeof(req.ifname));
req.hdr.nlmsg_len =
NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len);
@@ -271,7 +218,7 @@ TEST(NetlinkRouteTest, GetLinkByName) {
ASSERT_NO_ERRNO(NetlinkRequestResponse(
fd, &req, sizeof(req),
[&](const struct nlmsghdr* hdr) {
- CheckLinkMsg(hdr, *loopback_link);
+ CheckLinkMsg(hdr, loopback_link);
found = true;
},
false));
@@ -523,9 +470,7 @@ TEST(NetlinkRouteTest, LookupAll) {
TEST(NetlinkRouteTest, AddAddr) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
- absl::optional<Link> loopback_link =
- ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink());
- ASSERT_TRUE(loopback_link.has_value());
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
@@ -545,7 +490,7 @@ TEST(NetlinkRouteTest, AddAddr) {
req.ifa.ifa_prefixlen = 24;
req.ifa.ifa_flags = 0;
req.ifa.ifa_scope = 0;
- req.ifa.ifa_index = loopback_link->index;
+ req.ifa.ifa_index = loopback_link.index;
req.rtattr.rta_type = IFA_LOCAL;
req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr));
inet_pton(AF_INET, "10.0.0.1", &req.addr);
diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc
index 53eb3b6b2..bde1dbb4d 100644
--- a/test/syscalls/linux/socket_netlink_route_util.cc
+++ b/test/syscalls/linux/socket_netlink_route_util.cc
@@ -18,7 +18,6 @@
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
-#include "absl/types/optional.h"
#include "test/syscalls/linux/socket_netlink_util.h"
namespace gvisor {
@@ -73,14 +72,14 @@ PosixErrorOr<std::vector<Link>> DumpLinks() {
return links;
}
-PosixErrorOr<absl::optional<Link>> FindLoopbackLink() {
+PosixErrorOr<Link> LoopbackLink() {
ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
for (const auto& link : links) {
if (link.type == ARPHRD_LOOPBACK) {
- return absl::optional<Link>(link);
+ return link;
}
}
- return absl::optional<Link>();
+ return PosixError(ENOENT, "loopback link not found");
}
PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h
index 2c018e487..149c4a7f6 100644
--- a/test/syscalls/linux/socket_netlink_route_util.h
+++ b/test/syscalls/linux/socket_netlink_route_util.h
@@ -20,7 +20,6 @@
#include <vector>
-#include "absl/types/optional.h"
#include "test/syscalls/linux/socket_netlink_util.h"
namespace gvisor {
@@ -37,7 +36,8 @@ PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq,
PosixErrorOr<std::vector<Link>> DumpLinks();
-PosixErrorOr<absl::optional<Link>> FindLoopbackLink();
+// Returns the loopback link on the system. ENOENT if not found.
+PosixErrorOr<Link> LoopbackLink();
// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface.
PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
index 5d3a39868..53b678e94 100644
--- a/test/syscalls/linux/socket_test_util.cc
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -364,11 +364,6 @@ CreateTCPConnectAcceptSocketPair(int bound, int connected, int type,
}
MaybeSave(); // Successful accept.
- // FIXME(b/110484944)
- if (connect_result == -1) {
- absl::SleepFor(absl::Seconds(1));
- }
-
T extra_addr = {};
LocalhostAddr(&extra_addr, dual_stack);
return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr,
diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc
index 4cf1f76f1..8bf663e8b 100644
--- a/test/syscalls/linux/socket_unix.cc
+++ b/test/syscalls/linux/socket_unix.cc
@@ -257,6 +257,8 @@ TEST_P(UnixSocketPairTest, ShutdownWrite) {
TEST_P(UnixSocketPairTest, SocketReopenFromProcfs) {
// TODO(b/122310852): We should be returning ENXIO and NOT EIO.
+ // TODO(github.dev/issue/1624): This should be resolved in VFS2. Verify
+ // that this is the case and delete the SKIP_IF once we delete VFS1.
SKIP_IF(IsRunningOnGvisor());
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc
index faa1247f6..f103e2e56 100644
--- a/test/syscalls/linux/splice.cc
+++ b/test/syscalls/linux/splice.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <fcntl.h>
+#include <linux/unistd.h>
#include <sys/eventfd.h>
#include <sys/resource.h>
#include <sys/sendfile.h>
diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc
index c951ac3b3..2503960f3 100644
--- a/test/syscalls/linux/stat.cc
+++ b/test/syscalls/linux/stat.cc
@@ -34,6 +34,13 @@
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
+#ifndef AT_STATX_FORCE_SYNC
+#define AT_STATX_FORCE_SYNC 0x2000
+#endif
+#ifndef AT_STATX_DONT_SYNC
+#define AT_STATX_DONT_SYNC 0x4000
+#endif
+
namespace gvisor {
namespace testing {
@@ -607,7 +614,7 @@ int statx(int dirfd, const char* pathname, int flags, unsigned int mask,
}
TEST_F(StatTest, StatxAbsPath) {
- SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 &&
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
errno == ENOSYS);
struct kernel_statx stx;
@@ -617,7 +624,7 @@ TEST_F(StatTest, StatxAbsPath) {
}
TEST_F(StatTest, StatxRelPathDirFD) {
- SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 &&
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
errno == ENOSYS);
struct kernel_statx stx;
@@ -631,7 +638,7 @@ TEST_F(StatTest, StatxRelPathDirFD) {
}
TEST_F(StatTest, StatxRelPathCwd) {
- SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 &&
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
errno == ENOSYS);
ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds());
@@ -643,7 +650,7 @@ TEST_F(StatTest, StatxRelPathCwd) {
}
TEST_F(StatTest, StatxEmptyPath) {
- SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 &&
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
errno == ENOSYS);
const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY));
@@ -653,6 +660,60 @@ TEST_F(StatTest, StatxEmptyPath) {
EXPECT_TRUE(S_ISREG(stx.stx_mode));
}
+TEST_F(StatTest, StatxDoesNotRejectExtraneousMaskBits) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ struct kernel_statx stx;
+ // Set all mask bits except for STATX__RESERVED.
+ uint mask = 0xffffffff & ~0x80000000;
+ EXPECT_THAT(statx(-1, test_file_name_.c_str(), 0, mask, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(stx.stx_mode));
+}
+
+TEST_F(StatTest, StatxRejectsReservedMaskBit) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ struct kernel_statx stx;
+ // Set STATX__RESERVED in the mask.
+ EXPECT_THAT(statx(-1, test_file_name_.c_str(), 0, 0x80000000, &stx),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(StatTest, StatxSymlink) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ std::string parent_dir = "/tmp";
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(parent_dir, test_file_name_));
+ std::string p = link.path();
+
+ struct kernel_statx stx;
+ EXPECT_THAT(statx(AT_FDCWD, p.c_str(), AT_SYMLINK_NOFOLLOW, STATX_ALL, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISLNK(stx.stx_mode));
+ EXPECT_THAT(statx(AT_FDCWD, p.c_str(), 0, STATX_ALL, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(stx.stx_mode));
+}
+
+TEST_F(StatTest, StatxInvalidFlags) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ struct kernel_statx stx;
+ EXPECT_THAT(statx(AT_FDCWD, test_file_name_.c_str(), 12345, 0, &stx),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Sync flags are mutually exclusive.
+ EXPECT_THAT(statx(AT_FDCWD, test_file_name_.c_str(),
+ AT_STATX_FORCE_SYNC | AT_STATX_DONT_SYNC, 0, &stx),
+ SyscallFailsWithErrno(EINVAL));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/sticky.cc b/test/syscalls/linux/sticky.cc
index 7e73325bf..92eec0449 100644
--- a/test/syscalls/linux/sticky.cc
+++ b/test/syscalls/linux/sticky.cc
@@ -42,8 +42,9 @@ TEST(StickyTest, StickyBitPermDenied) {
auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds());
- std::string path = JoinPath(dir.path(), "NewDir");
- ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds());
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY));
+ ASSERT_THAT(mkdirat(dirfd.get(), "NewDir", 0755), SyscallSucceeds());
// Drop privileges and change IDs only in child thread, or else this parent
// thread won't be able to open some log files after the test ends.
@@ -61,7 +62,8 @@ TEST(StickyTest, StickyBitPermDenied) {
syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1),
SyscallSucceeds());
- EXPECT_THAT(rmdir(path.c_str()), SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(unlinkat(dirfd.get(), "NewDir", AT_REMOVEDIR),
+ SyscallFailsWithErrno(EPERM));
});
}
@@ -96,8 +98,9 @@ TEST(StickyTest, StickyBitCapFOWNER) {
auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds());
- std::string path = JoinPath(dir.path(), "NewDir");
- ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds());
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY));
+ ASSERT_THAT(mkdirat(dirfd.get(), "NewDir", 0755), SyscallSucceeds());
// Drop privileges and change IDs only in child thread, or else this parent
// thread won't be able to open some log files after the test ends.
@@ -114,7 +117,8 @@ TEST(StickyTest, StickyBitCapFOWNER) {
SyscallSucceeds());
EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, true));
- EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds());
+ EXPECT_THAT(unlinkat(dirfd.get(), "NewDir", AT_REMOVEDIR),
+ SyscallSucceeds());
});
}
} // namespace
diff --git a/test/syscalls/linux/sysret.cc b/test/syscalls/linux/sysret.cc
index 819fa655a..19ffbd85b 100644
--- a/test/syscalls/linux/sysret.cc
+++ b/test/syscalls/linux/sysret.cc
@@ -14,6 +14,8 @@
// Tests to verify that the behavior of linux and gvisor matches when
// 'sysret' returns to bad (aka non-canonical) %rip or %rsp.
+
+#include <linux/elf.h>
#include <sys/ptrace.h>
#include <sys/user.h>
@@ -32,6 +34,7 @@ constexpr uint64_t kNonCanonicalRsp = 0xFFFF000000000000;
class SysretTest : public ::testing::Test {
protected:
struct user_regs_struct regs_;
+ struct iovec iov;
pid_t child_;
void SetUp() override {
@@ -48,10 +51,15 @@ class SysretTest : public ::testing::Test {
// Parent.
int status;
+ memset(&iov, 0, sizeof(iov));
ASSERT_THAT(pid, SyscallSucceeds()); // Might still be < 0.
ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP);
- ASSERT_THAT(ptrace(PTRACE_GETREGS, pid, 0, &regs_), SyscallSucceeds());
+
+ iov.iov_base = &regs_;
+ iov.iov_len = sizeof(regs_);
+ ASSERT_THAT(ptrace(PTRACE_GETREGSET, pid, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
child_ = pid;
}
@@ -61,13 +69,27 @@ class SysretTest : public ::testing::Test {
}
void SetRip(uint64_t newrip) {
+#if defined(__x86_64__)
regs_.rip = newrip;
- ASSERT_THAT(ptrace(PTRACE_SETREGS, child_, 0, &regs_), SyscallSucceeds());
+#elif defined(__aarch64__)
+ regs_.pc = newrip;
+#else
+#error "Unknown architecture"
+#endif
+ ASSERT_THAT(ptrace(PTRACE_SETREGSET, child_, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
}
void SetRsp(uint64_t newrsp) {
+#if defined(__x86_64__)
regs_.rsp = newrsp;
- ASSERT_THAT(ptrace(PTRACE_SETREGS, child_, 0, &regs_), SyscallSucceeds());
+#elif defined(__aarch64__)
+ regs_.sp = newrsp;
+#else
+#error "Unknown architecture"
+#endif
+ ASSERT_THAT(ptrace(PTRACE_SETREGSET, child_, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
}
// Wait waits for the child pid and returns the exit status.
@@ -104,8 +126,15 @@ TEST_F(SysretTest, BadRsp) {
SetRsp(kNonCanonicalRsp);
Detach();
int status = Wait();
+#if defined(__x86_64__)
EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGBUS)
<< "status = " << status;
+#elif defined(__aarch64__)
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV)
+ << "status = " << status;
+#else
+#error "Unknown architecture"
+#endif
}
} // namespace
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index c4591a3b9..d9c1ac0e1 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -143,6 +143,20 @@ TEST_P(TcpSocketTest, ConnectOnEstablishedConnection) {
SyscallFailsWithErrno(EISCONN));
}
+TEST_P(TcpSocketTest, ShutdownWriteInTimeWait) {
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+ EXPECT_THAT(shutdown(s_, SHUT_RDWR), SyscallSucceeds());
+ absl::SleepFor(absl::Seconds(1)); // Wait to enter TIME_WAIT.
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(TcpSocketTest, ShutdownWriteInFinWait1) {
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+ absl::SleepFor(absl::Seconds(1)); // Wait to enter FIN-WAIT2.
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+}
+
TEST_P(TcpSocketTest, DataCoalesced) {
char buf[10];
@@ -1349,6 +1363,21 @@ TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) {
SyscallFailsWithErrno(ENOTCONN));
}
+TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ addrlen);
+ int buf_sz = 1 << 18;
+ EXPECT_THAT(
+ setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)),
+ SyscallSucceedsWithValue(0));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));
diff --git a/test/syscalls/linux/time.cc b/test/syscalls/linux/time.cc
index 1ccb95733..e75bba669 100644
--- a/test/syscalls/linux/time.cc
+++ b/test/syscalls/linux/time.cc
@@ -26,6 +26,7 @@ namespace {
constexpr long kFudgeSeconds = 5;
+#if defined(__x86_64__) || defined(__i386__)
// Mimics the time(2) wrapper from glibc prior to 2.15.
time_t vsyscall_time(time_t* t) {
constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400;
@@ -98,6 +99,7 @@ TEST(TimeTest, VsyscallGettimeofday_InvalidAddressSIGSEGV) {
reinterpret_cast<struct timezone*>(0x1)),
::testing::KilledBySignal(SIGSEGV), "");
}
+#endif
} // namespace
diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc
index f6ac9d7b8..6195b11e1 100644
--- a/test/syscalls/linux/tuntap.cc
+++ b/test/syscalls/linux/tuntap.cc
@@ -56,14 +56,14 @@ PosixErrorOr<std::set<std::string>> DumpLinkNames() {
return names;
}
-PosixErrorOr<absl::optional<Link>> GetLinkByName(const std::string& name) {
+PosixErrorOr<Link> GetLinkByName(const std::string& name) {
ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
for (const auto& link : links) {
if (link.name == name) {
- return absl::optional<Link>(link);
+ return link;
}
}
- return absl::optional<Link>();
+ return PosixError(ENOENT, "interface not found");
}
struct pihdr {
@@ -153,6 +153,13 @@ std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip,
} // namespace
+TEST(TuntapStaticTest, NetTunExists) {
+ struct stat statbuf;
+ ASSERT_THAT(stat(kDevNetTun, &statbuf), SyscallSucceeds());
+ // Check that it's a character device with rw-rw-rw- permissions.
+ EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
+}
+
class TuntapTest : public ::testing::Test {
protected:
void TearDown() override {
@@ -235,7 +242,7 @@ TEST_F(TuntapTest, InvalidReadWrite) {
TEST_F(TuntapTest, WriteToDownDevice) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
- // FIXME: gVisor always creates enabled/up'd interfaces.
+ // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces.
SKIP_IF(IsRunningOnGvisor());
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
@@ -249,6 +256,38 @@ TEST_F(TuntapTest, WriteToDownDevice) {
EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EIO));
}
+PosixErrorOr<FileDescriptor> OpenAndAttachTap(
+ const std::string& dev_name, const std::string& dev_ipv4_addr) {
+ // Interface creation.
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr_set = {};
+ ifr_set.ifr_flags = IFF_TAP;
+ strncpy(ifr_set.ifr_name, dev_name.c_str(), IFNAMSIZ);
+ if (ioctl(fd.get(), TUNSETIFF, &ifr_set) < 0) {
+ return PosixError(errno);
+ }
+
+ ASSIGN_OR_RETURN_ERRNO(auto link, GetLinkByName(dev_name));
+
+ // Interface setup.
+ struct in_addr addr;
+ inet_pton(AF_INET, dev_ipv4_addr.c_str(), &addr);
+ EXPECT_NO_ERRNO(LinkAddLocalAddr(link.index, AF_INET, /*prefixlen=*/24, &addr,
+ sizeof(addr)));
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME(b/110961832): gVisor doesn't support setting MAC address on
+ // interfaces yet.
+ RETURN_IF_ERRNO(LinkSetMacAddr(link.index, kMacA, sizeof(kMacA)));
+
+ // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces.
+ RETURN_IF_ERRNO(LinkChangeFlags(link.index, IFF_UP, IFF_UP));
+ }
+
+ return fd;
+}
+
// This test sets up a TAP device and pings kernel by sending ICMP echo request.
//
// It works as the following:
@@ -266,33 +305,8 @@ TEST_F(TuntapTest, WriteToDownDevice) {
TEST_F(TuntapTest, PingKernel) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
- // Interface creation.
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
-
- struct ifreq ifr_set = {};
- ifr_set.ifr_flags = IFF_TAP;
- strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ);
- EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set),
- SyscallSucceedsWithValue(0));
-
- absl::optional<Link> link =
- ASSERT_NO_ERRNO_AND_VALUE(GetLinkByName(kTapName));
- ASSERT_TRUE(link.has_value());
-
- // Interface setup.
- struct in_addr addr;
- inet_pton(AF_INET, "10.0.0.1", &addr);
- EXPECT_NO_ERRNO(LinkAddLocalAddr(link->index, AF_INET, /*prefixlen=*/24,
- &addr, sizeof(addr)));
-
- if (!IsRunningOnGvisor()) {
- // FIXME: gVisor doesn't support setting MAC address on interfaces yet.
- EXPECT_NO_ERRNO(LinkSetMacAddr(link->index, kMacA, sizeof(kMacA)));
-
- // FIXME: gVisor always creates enabled/up'd interfaces.
- EXPECT_NO_ERRNO(LinkChangeFlags(link->index, IFF_UP, IFF_UP));
- }
-
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1"));
ping_pkt ping_req = CreatePingPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1");
std::string arp_rep = CreateArpPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1");
@@ -342,5 +356,47 @@ TEST_F(TuntapTest, PingKernel) {
}
}
+TEST_F(TuntapTest, SendUdpTriggersArpResolution) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1"));
+
+ // Send a UDP packet to remote.
+ int sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_IP);
+ ASSERT_THAT(sock, SyscallSucceeds());
+
+ struct sockaddr_in remote = {};
+ remote.sin_family = AF_INET;
+ remote.sin_port = htons(42);
+ inet_pton(AF_INET, "10.0.0.2", &remote.sin_addr);
+ int ret = sendto(sock, "hello", 5, 0, reinterpret_cast<sockaddr*>(&remote),
+ sizeof(remote));
+ ASSERT_THAT(ret, ::testing::AnyOf(SyscallSucceeds(),
+ SyscallFailsWithErrno(EHOSTDOWN)));
+
+ struct inpkt {
+ union {
+ pihdr pi;
+ arp_pkt arp;
+ };
+ };
+ while (1) {
+ inpkt r = {};
+ int n = read(fd.get(), &r, sizeof(r));
+ EXPECT_THAT(n, SyscallSucceeds());
+
+ if (n < sizeof(pihdr)) {
+ std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol
+ << " len: " << n << std::endl;
+ continue;
+ }
+
+ if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) {
+ break;
+ }
+ }
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/tuntap_hostinet.cc b/test/syscalls/linux/tuntap_hostinet.cc
new file mode 100644
index 000000000..1513fb9d5
--- /dev/null
+++ b/test/syscalls/linux/tuntap_hostinet.cc
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(TuntapHostInetTest, NoNetTun) {
+ SKIP_IF(!IsRunningOnGvisor());
+ SKIP_IF(!IsRunningWithHostinet());
+
+ struct stat statbuf;
+ ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallFailsWithErrno(ENOENT));
+}
+
+} // namespace
+} // namespace testing
+
+} // namespace gvisor
diff --git a/test/syscalls/linux/utimes.cc b/test/syscalls/linux/utimes.cc
index 3a927a430..22e6d1a85 100644
--- a/test/syscalls/linux/utimes.cc
+++ b/test/syscalls/linux/utimes.cc
@@ -34,17 +34,10 @@ namespace testing {
namespace {
-// TODO(b/36516566): utimes(nullptr) does not pick the "now" time in the
-// application's time domain, so when asserting that times are within a window,
-// we expand the window to allow for differences between the time domains.
-constexpr absl::Duration kClockSlack = absl::Milliseconds(100);
-
// TimeBoxed runs fn, setting before and after to (coarse realtime) times
// guaranteed* to come before and after fn started and completed, respectively.
//
// fn may be called more than once if the clock is adjusted.
-//
-// * See the comment on kClockSlack. gVisor breaks this guarantee.
void TimeBoxed(absl::Time* before, absl::Time* after,
std::function<void()> const& fn) {
do {
@@ -69,12 +62,6 @@ void TimeBoxed(absl::Time* before, absl::Time* after,
// which could lead to test failures, but that is very unlikely to happen.
continue;
}
-
- if (IsRunningOnGvisor()) {
- // See comment on kClockSlack.
- *before -= kClockSlack;
- *after += kClockSlack;
- }
} while (*after < *before);
}
@@ -235,10 +222,7 @@ void TestUtimensat(int dirFd, std::string const& path) {
EXPECT_GE(mtime3, before);
EXPECT_LE(mtime3, after);
- if (!IsRunningOnGvisor()) {
- // FIXME(b/36516566): Gofers set atime and mtime to different "now" times.
- EXPECT_EQ(atime3, mtime3);
- }
+ EXPECT_EQ(atime3, mtime3);
}
TEST(UtimensatTest, OnAbsPath) {
diff --git a/test/syscalls/linux/vsyscall.cc b/test/syscalls/linux/vsyscall.cc
index 2c2303358..ae4377108 100644
--- a/test/syscalls/linux/vsyscall.cc
+++ b/test/syscalls/linux/vsyscall.cc
@@ -24,6 +24,7 @@ namespace testing {
namespace {
+#if defined(__x86_64__) || defined(__i386__)
time_t vsyscall_time(time_t* t) {
constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400;
return reinterpret_cast<time_t (*)(time_t*)>(kVsyscallTimeEntry)(t);
@@ -37,6 +38,7 @@ TEST(VsyscallTest, VsyscallAlwaysAvailableOnGvisor) {
time_t t;
EXPECT_THAT(vsyscall_time(&t), SyscallSucceeds());
}
+#endif
} // namespace
diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc
index 9b219cfd6..39b5b2f56 100644
--- a/test/syscalls/linux/write.cc
+++ b/test/syscalls/linux/write.cc
@@ -31,14 +31,8 @@ namespace gvisor {
namespace testing {
namespace {
-// This test is currently very rudimentary.
-//
-// TODO(edahlgren):
-// * bad buffer states (EFAULT).
-// * bad fds (wrong permission, wrong type of file, EBADF).
-// * check offset is incremented.
-// * check for EOF.
-// * writing to pipes, symlinks, special files.
+
+// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary.
class WriteTest : public ::testing::Test {
public:
ssize_t WriteBytes(int fd, int bytes) {
diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc
index 8b00ef44c..3231732ec 100644
--- a/test/syscalls/linux/xattr.cc
+++ b/test/syscalls/linux/xattr.cc
@@ -41,12 +41,12 @@ class XattrTest : public FileTest {};
TEST_F(XattrTest, XattrNonexistentFile) {
const char* path = "/does/not/exist";
- EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, /*flags=*/0),
- SyscallFailsWithErrno(ENOENT));
- EXPECT_THAT(getxattr(path, nullptr, nullptr, 0),
+ const char* name = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0),
SyscallFailsWithErrno(ENOENT));
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENOENT));
EXPECT_THAT(listxattr(path, nullptr, 0), SyscallFailsWithErrno(ENOENT));
- EXPECT_THAT(removexattr(path, nullptr), SyscallFailsWithErrno(ENOENT));
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(ENOENT));
}
TEST_F(XattrTest, XattrNullName) {
diff --git a/test/util/BUILD b/test/util/BUILD
index 8b5a0f25c..2a17c33ee 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -350,3 +350,9 @@ cc_library(
":save_util",
],
)
+
+cc_library(
+ name = "temp_umask",
+ testonly = 1,
+ hdrs = ["temp_umask.h"],
+)
diff --git a/test/util/capability_util.cc b/test/util/capability_util.cc
index 9fee52fbb..a1b994c45 100644
--- a/test/util/capability_util.cc
+++ b/test/util/capability_util.cc
@@ -63,13 +63,13 @@ PosixErrorOr<bool> CanCreateUserNamespace() {
// is in a chroot environment (i.e., the caller's root directory does
// not match the root directory of the mount namespace in which it
// resides)."
- std::cerr << "clone(CLONE_NEWUSER) failed with EPERM";
+ std::cerr << "clone(CLONE_NEWUSER) failed with EPERM" << std::endl;
return false;
} else if (errno == EUSERS) {
// "(since Linux 3.11) CLONE_NEWUSER was specified in flags, and the call
// would cause the limit on the number of nested user namespaces to be
// exceeded. See user_namespaces(7)."
- std::cerr << "clone(CLONE_NEWUSER) failed with EUSERS";
+ std::cerr << "clone(CLONE_NEWUSER) failed with EUSERS" << std::endl;
return false;
} else {
// Unexpected error code; indicate an actual error.
diff --git a/test/syscalls/linux/temp_umask.h b/test/util/temp_umask.h
index 81a25440c..e7de84a54 100644
--- a/test/syscalls/linux/temp_umask.h
+++ b/test/util/temp_umask.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_
-#define GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_
+#ifndef GVISOR_TEST_UTIL_TEMP_UMASK_H_
+#define GVISOR_TEST_UTIL_TEMP_UMASK_H_
#include <sys/stat.h>
#include <sys/types.h>
@@ -36,4 +36,4 @@ class TempUmask {
} // namespace testing
} // namespace gvisor
-#endif // GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_
+#endif // GVISOR_TEST_UTIL_TEMP_UMASK_H_
diff --git a/tools/BUILD b/tools/BUILD
index e73a9c885..ba3506c04 100644
--- a/tools/BUILD
+++ b/tools/BUILD
@@ -1,3 +1,3 @@
package(licenses = ["notice"])
-exports_files(["nogo.js"])
+exports_files(["nogo.json"])
diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl
index 905b16d41..0a74370a6 100644
--- a/tools/bazeldefs/defs.bzl
+++ b/tools/bazeldefs/defs.bzl
@@ -2,15 +2,15 @@
load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier")
load("@io_bazel_rules_go//go:def.bzl", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_test = "go_test", _go_tool_library = "go_tool_library")
-load("@io_bazel_rules_go//proto:def.bzl", _go_proto_library = "go_proto_library")
+load("@io_bazel_rules_go//proto:def.bzl", _go_grpc_library = "go_grpc_library", _go_proto_library = "go_proto_library")
load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test")
load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar")
load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image")
load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image")
load("@pydeps//:requirements.bzl", _py_requirement = "requirement")
+load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", _cc_grpc_library = "cc_grpc_library")
container_image = _container_image
-cc_binary = _cc_binary
cc_library = _cc_library
cc_flags_supplier = _cc_flags_supplier
cc_proto_library = _cc_proto_library
@@ -19,16 +19,77 @@ cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain"
go_image = _go_image
go_embed_data = _go_embed_data
gtest = "@com_google_googletest//:gtest"
+grpcpp = "@com_github_grpc_grpc//:grpc++"
gbenchmark = "@com_google_benchmark//:benchmark"
loopback = "//tools/bazeldefs:loopback"
-proto_library = native.proto_library
pkg_deb = _pkg_deb
pkg_tar = _pkg_tar
py_library = native.py_library
py_binary = native.py_binary
py_test = native.py_test
+def proto_library(name, has_services = None, **kwargs):
+ native.proto_library(
+ name = name,
+ **kwargs
+ )
+
+def cc_grpc_library(name, **kwargs):
+ _cc_grpc_library(name = name, grpc_only = True, **kwargs)
+
+def _go_proto_or_grpc_library(go_library_func, name, **kwargs):
+ deps = [
+ dep.replace("_proto", "_go_proto")
+ for dep in (kwargs.pop("deps", []) or [])
+ ]
+ go_library_func(
+ name = name + "_go_proto",
+ importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name + "_go_proto",
+ proto = ":" + name + "_proto",
+ deps = deps,
+ **kwargs
+ )
+
+def go_proto_library(name, **kwargs):
+ _go_proto_or_grpc_library(_go_proto_library, name, **kwargs)
+
+def go_grpc_and_proto_libraries(name, **kwargs):
+ _go_proto_or_grpc_library(_go_grpc_library, name, **kwargs)
+
+def cc_binary(name, static = False, **kwargs):
+ """Run cc_binary.
+
+ Args:
+ name: name of the target.
+ static: make a static binary if True
+ **kwargs: the rest of the args.
+ """
+ if static:
+ # How to statically link a c++ program that uses threads, like for gRPC:
+ # https://gcc.gnu.org/legacy-ml/gcc-help/2010-05/msg00029.html
+ if "linkopts" not in kwargs:
+ kwargs["linkopts"] = []
+ kwargs["linkopts"] += [
+ "-static",
+ "-lstdc++",
+ "-Wl,--whole-archive",
+ "-lpthread",
+ "-Wl,--no-whole-archive",
+ ]
+ _cc_binary(
+ name = name,
+ **kwargs
+ )
+
def go_binary(name, static = False, pure = False, **kwargs):
+ """Build a go binary.
+
+ Args:
+ name: name of the target.
+ static: build a static binary.
+ pure: build without cgo.
+ **kwargs: rest of the arguments are passed to _go_binary.
+ """
if static:
kwargs["static"] = "on"
if pure:
@@ -52,18 +113,17 @@ def go_tool_library(name, **kwargs):
**kwargs
)
-def go_proto_library(name, proto, **kwargs):
- deps = kwargs.pop("deps", [])
- _go_proto_library(
- name = name,
- importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name,
- proto = proto,
- deps = [dep.replace("_proto", "_go_proto") for dep in deps],
- **kwargs
- )
+def go_test(name, pure = False, library = None, **kwargs):
+ """Build a go test.
-def go_test(name, **kwargs):
- library = kwargs.pop("library", None)
+ Args:
+ name: name of the output binary.
+ pure: should it be built without cgo.
+ library: the library to embed.
+ **kwargs: rest of the arguments to pass to _go_test.
+ """
+ if pure:
+ kwargs["pure"] = "on"
if library:
kwargs["embed"] = [library]
_go_test(
diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl
index 92b0b5fc0..132040c20 100644
--- a/tools/bazeldefs/platforms.bzl
+++ b/tools/bazeldefs/platforms.bzl
@@ -2,15 +2,10 @@
# Platform to associated tags.
platforms = {
- "ptrace": [
- # TODO(b/120560048): Make the tests run without this tag.
- "no-sandbox",
- ],
+ "ptrace": [],
"kvm": [
"manual",
"local",
- # TODO(b/120560048): Make the tests run without this tag.
- "no-sandbox",
],
}
diff --git a/tools/bigquery/BUILD b/tools/bigquery/BUILD
new file mode 100644
index 000000000..5748fb390
--- /dev/null
+++ b/tools/bigquery/BUILD
@@ -0,0 +1,10 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "bigquery",
+ testonly = 1,
+ srcs = ["bigquery.go"],
+ deps = ["@com_google_cloud_go_bigquery//:go_default_library"],
+)
diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go
new file mode 100644
index 000000000..56f0dc5c9
--- /dev/null
+++ b/tools/bigquery/bigquery.go
@@ -0,0 +1,121 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package bigquery defines a BigQuery schema for benchmarks.
+//
+// This package contains a schema for BigQuery and methods for publishing
+// benchmark data into tables.
+package bigquery
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "time"
+
+ bq "cloud.google.com/go/bigquery"
+)
+
+// Benchmark is the top level structure of recorded benchmark data. BigQuery
+// will infer the schema from this.
+type Benchmark struct {
+ Name string `bq:"name"`
+ Timestamp time.Time `bq:"timestamp"`
+ Official bool `bq:"official"`
+ Metric []*Metric `bq:"metric"`
+ Metadata *Metadata `bq:"metadata"`
+}
+
+// Metric holds the actual metric data and unit information for this benchmark.
+type Metric struct {
+ Name string `bq:"name"`
+ Unit string `bq:"unit"`
+ Sample float64 `bq:"sample"`
+}
+
+// Metadata about this benchmark.
+type Metadata struct {
+ CL string `bq:"changelist"`
+ IterationID string `bq:"iteration_id"`
+ PendingCL string `bq:"pending_cl"`
+ Workflow string `bq:"workflow"`
+ Platform string `bq:"platform"`
+ Gofer string `bq:"gofer"`
+}
+
+// InitBigQuery initializes a BigQuery dataset/table in the project. If the dataset/table already exists, it is not duplicated.
+func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string) error {
+ client, err := bq.NewClient(ctx, projectID)
+ if err != nil {
+ return fmt.Errorf("failed to initialize client on project %s: %v", projectID, err)
+ }
+ defer client.Close()
+
+ dataset := client.Dataset(datasetID)
+ if err := dataset.Create(ctx, nil); err != nil && !checkDuplicateError(err) {
+ return fmt.Errorf("failed to create dataset: %s: %v", datasetID, err)
+ }
+
+ table := dataset.Table(tableID)
+ schema, err := bq.InferSchema(Benchmark{})
+ if err != nil {
+ return fmt.Errorf("failed to infer schema: %v", err)
+ }
+
+ if err := table.Create(ctx, &bq.TableMetadata{Schema: schema}); err != nil && !checkDuplicateError(err) {
+ return fmt.Errorf("failed to create table: %s: %v", tableID, err)
+ }
+ return nil
+}
+
+// AddMetric adds a metric to an existing Benchmark.
+func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) {
+ m := &Metric{
+ Name: metricName,
+ Unit: unit,
+ Sample: sample,
+ }
+ bm.Metric = append(bm.Metric, m)
+}
+
+// NewBenchmark initializes a new benchmark.
+func NewBenchmark(name string, official bool) *Benchmark {
+ return &Benchmark{
+ Name: name,
+ Timestamp: time.Now().UTC(),
+ Official: official,
+ Metric: make([]*Metric, 0),
+ }
+}
+
+// SendBenchmarks sends the slice of benchmarks to the BigQuery dataset/table.
+func SendBenchmarks(ctx context.Context, benchmarks []*Benchmark, projectID, datasetID, tableID string) error {
+ client, err := bq.NewClient(ctx, projectID)
+ if err != nil {
+ return fmt.Errorf("Failed to initialize client on project: %s: %v", projectID, err)
+ }
+ defer client.Close()
+
+ uploader := client.Dataset(datasetID).Table(tableID).Uploader()
+ if err = uploader.Put(ctx, benchmarks); err != nil {
+ return fmt.Errorf("failed to upload benchmarks to proejct %s, table %s.%s: %v", projectID, datasetID, tableID, err)
+ }
+
+ return nil
+}
+
+// BigQuery will error "409" for duplicate tables and datasets.
+func checkDuplicateError(err error) bool {
+ return strings.Contains(err.Error(), "googleapi: Error 409: Already Exists")
+}
diff --git a/tools/defs.bzl b/tools/defs.bzl
index 15a310403..91d689a82 100644
--- a/tools/defs.bzl
+++ b/tools/defs.bzl
@@ -7,36 +7,39 @@ change for Google-internal and bazel-compatible rules.
load("//tools/go_stateify:defs.bzl", "go_stateify")
load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps")
-load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system")
+load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system")
load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms")
load("//tools/bazeldefs:tags.bzl", "go_suffixes")
# Delegate directly.
cc_binary = _cc_binary
+cc_flags_supplier = _cc_flags_supplier
+cc_grpc_library = _cc_grpc_library
cc_library = _cc_library
cc_test = _cc_test
cc_toolchain = _cc_toolchain
-cc_flags_supplier = _cc_flags_supplier
container_image = _container_image
+default_installer = _default_installer
+default_net_util = _default_net_util
+gbenchmark = _gbenchmark
go_embed_data = _go_embed_data
go_image = _go_image
go_test = _go_test
go_tool_library = _go_tool_library
gtest = _gtest
-gbenchmark = _gbenchmark
+grpcpp = _grpcpp
+loopback = _loopback
pkg_deb = _pkg_deb
pkg_tar = _pkg_tar
-py_library = _py_library
py_binary = _py_binary
-py_test = _py_test
+py_library = _py_library
py_requirement = _py_requirement
+py_test = _py_test
select_arch = _select_arch
select_system = _select_system
-loopback = _loopback
-default_installer = _default_installer
-default_net_util = _default_net_util
-platforms = _platforms
+
default_platform = _default_platform
+platforms = _platforms
def go_binary(name, **kwargs):
"""Wraps the standard go_binary.
@@ -190,33 +193,52 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
**kwargs
)
-def proto_library(name, srcs, **kwargs):
+def proto_library(name, srcs, deps = None, has_services = 0, **kwargs):
"""Wraps the standard proto_library.
- Given a proto_library named "foo", this produces three different targets:
+ Given a proto_library named "foo", this produces up to five different
+ targets:
- foo_proto: proto_library rule.
- foo_go_proto: go_proto_library rule.
- foo_cc_proto: cc_proto_library rule.
+ - foo_go_grpc_proto: go_grpc_library rule.
+ - foo_cc_grpc_proto: cc_grpc_library rule.
Args:
+ name: the name to which _proto, _go_proto, etc, will be appended.
srcs: the proto sources.
+ deps: for the proto library and the go_proto_library.
+ has_services: 1 to build gRPC code, otherwise 0.
**kwargs: standard proto_library arguments.
"""
- deps = kwargs.pop("deps", [])
_proto_library(
name = name + "_proto",
srcs = srcs,
deps = deps,
+ has_services = has_services,
**kwargs
)
- _go_proto_library(
- name = name + "_go_proto",
- proto = ":" + name + "_proto",
- deps = deps,
- **kwargs
- )
+ if has_services:
+ _go_grpc_and_proto_libraries(
+ name = name,
+ deps = deps,
+ **kwargs
+ )
+ else:
+ _go_proto_library(
+ name = name,
+ deps = deps,
+ **kwargs
+ )
_cc_proto_library(
name = name + "_cc_proto",
deps = [":" + name + "_proto"],
**kwargs
)
+ if has_services:
+ _cc_grpc_library(
+ name = name + "_cc_grpc_proto",
+ srcs = [":" + name + "_proto"],
+ deps = [":" + name + "_cc_proto"],
+ **kwargs
+ )
diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl
index c5be52ecd..8c9995fd4 100644
--- a/tools/go_generics/defs.bzl
+++ b/tools/go_generics/defs.bzl
@@ -105,7 +105,6 @@ def _go_template_instance_impl(ctx):
executable = ctx.executable._tool,
)
- # TODO: How can we get the dependencies out?
return struct(
files = depset([output]),
)
diff --git a/tools/go_marshal/analysis/analysis_unsafe.go b/tools/go_marshal/analysis/analysis_unsafe.go
index 9a9a4f298..cd55cf5cb 100644
--- a/tools/go_marshal/analysis/analysis_unsafe.go
+++ b/tools/go_marshal/analysis/analysis_unsafe.go
@@ -161,6 +161,10 @@ func AlignmentCheck(t *testing.T, typ reflect.Type) (ok bool, delta uint64) {
if typ.NumField() > 0 && nextXOff != int(typ.Size()) {
implicitPad := int(typ.Size()) - nextXOff
f := typ.Field(typ.NumField() - 1) // Final field
+ if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" {
+ // Final field explicitly marked unaligned.
+ break
+ }
t.Fatalf("Suspect offset for field %s.%s at the end of %s, detected an implicit %d byte padding from offset %d to %d at the end of the struct; either add %d bytes of explict padding at end of the struct or tag the final field %s as `marshal:\"unaligned\"`.",
typ.Name(), f.Name, typ.Name(), implicitPad, nextXOff, typ.Size(), implicitPad, f.Name)
}
diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl
index d79786a68..323e33882 100644
--- a/tools/go_marshal/defs.bzl
+++ b/tools/go_marshal/defs.bzl
@@ -53,9 +53,10 @@ go_marshal = rule(
# marshal_deps are the dependencies requied by generated code.
marshal_deps = [
- "//tools/go_marshal/marshal",
+ "//pkg/gohacks",
"//pkg/safecopy",
"//pkg/usermem",
+ "//tools/go_marshal/marshal",
]
# marshal_test_deps are required by test targets.
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD
index b5d5a4487..44cb33ae4 100644
--- a/tools/go_marshal/gomarshal/BUILD
+++ b/tools/go_marshal/gomarshal/BUILD
@@ -7,6 +7,9 @@ go_library(
srcs = [
"generator.go",
"generator_interfaces.go",
+ "generator_interfaces_array_newtype.go",
+ "generator_interfaces_primitive_newtype.go",
+ "generator_interfaces_struct.go",
"generator_tests.go",
"util.go",
],
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index d365a1f3c..177013dbb 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -28,12 +28,6 @@ import (
"gvisor.dev/gvisor/tools/tags"
)
-const (
- marshalImport = "gvisor.dev/gvisor/tools/go_marshal/marshal"
- safecopyImport = "gvisor.dev/gvisor/pkg/safecopy"
- usermemImport = "gvisor.dev/gvisor/pkg/usermem"
-)
-
// List of identifiers we use in generated code that may conflict with a
// similarly-named source identifier. Abort gracefully when we see these to
// avoid potentially confusing compilation failures in generated code.
@@ -44,8 +38,8 @@ const (
// All recievers are single letters, so we don't allow import aliases to be a
// single letter.
var badIdents = []string{
- "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "idx", "inner", "len",
- "ptr", "src", "srcs", "task", "val",
+ "addr", "blk", "buf", "dst", "dsts", "count", "err", "hdr", "idx", "inner",
+ "length", "limit", "ptr", "size", "src", "srcs", "task", "val",
// All single-letter identifiers.
}
@@ -110,9 +104,10 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G
g.imports.add("reflect")
g.imports.add("runtime")
g.imports.add("unsafe")
- g.imports.add(marshalImport)
- g.imports.add(safecopyImport)
- g.imports.add(usermemImport)
+ g.imports.add("gvisor.dev/gvisor/pkg/gohacks")
+ g.imports.add("gvisor.dev/gvisor/pkg/safecopy")
+ g.imports.add("gvisor.dev/gvisor/pkg/usermem")
+ g.imports.add("gvisor.dev/gvisor/tools/go_marshal/marshal")
return &g, nil
}
@@ -194,10 +189,73 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
return files, fsets, nil
}
+// sliceAPI carries information about the '+marshal slice' directive.
+type sliceAPI struct {
+ // Comment node in the AST containing the +marshal tag.
+ comment *ast.Comment
+ // Identifier fragment to use when naming generated functions for the slice
+ // API.
+ ident string
+ // Whether the generated functions should reference the newtype name, or the
+ // inner type name. Only meaningful on newtype declarations on primitives.
+ inner bool
+}
+
+// marshallableType carries information about a type marked with the '+marshal'
+// directive.
+type marshallableType struct {
+ spec *ast.TypeSpec
+ slice *sliceAPI
+}
+
+func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) marshallableType {
+ mt := marshallableType{
+ spec: spec,
+ slice: nil,
+ }
+
+ var unhandledTags []string
+
+ for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) {
+ if strings.HasPrefix(tag, "slice:") {
+ tokens := strings.Split(tag, ":")
+ if len(tokens) < 2 || len(tokens) > 3 {
+ abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag))
+ }
+ if len(tokens[1]) == 0 {
+ abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'")
+ }
+
+ sa := &sliceAPI{
+ comment: tagLine,
+ ident: tokens[1],
+ }
+ mt.slice = sa
+
+ if len(tokens) == 3 {
+ if tokens[2] != "inner" {
+ abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'")
+ }
+ sa.inner = true
+ }
+
+ continue
+ }
+
+ unhandledTags = append(unhandledTags, tag)
+ }
+
+ if len(unhandledTags) > 0 {
+ abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " ")))
+ }
+
+ return mt
+}
+
// collectMarshallableTypes walks the parsed AST and collects a list of type
// declarations for which we need to generate the Marshallable interface.
-func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec {
- var types []*ast.TypeSpec
+func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []marshallableType {
+ var types []marshallableType
for _, decl := range a.Decls {
gdecl, ok := decl.(*ast.GenDecl)
// Type declaration?
@@ -212,9 +270,11 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*a
}
// Does the comment contain a "+marshal" line?
marked := false
+ var tagLine *ast.Comment
for _, c := range gdecl.Doc.List {
- if c.Text == "// +marshal" {
+ if strings.HasPrefix(c.Text, "// +marshal") {
marked = true
+ tagLine = c
break
}
}
@@ -229,16 +289,17 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*a
switch t.Type.(type) {
case *ast.StructType:
debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name)
- types = append(types, t)
- continue
case *ast.Ident: // Newtype on primitive.
debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name)
- types = append(types, t)
- continue
+ case *ast.ArrayType: // Newtype on array.
+ debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name)
+ default:
+ // A user specifically requested marshalling on this type, but we
+ // don't support it.
+ abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
}
- // A user specifically requested marshalling on this type, but we
- // don't support it.
- abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
+ types = append(types, newMarshallableType(f, tagLine, t))
+
}
}
return types
@@ -265,7 +326,7 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
// Make sure we have an import that doesn't use any local names that
// would conflict with identifiers in the generated code.
- if len(i.name) == 1 {
+ if len(i.name) == 1 && i.name != "_" {
abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name))
}
if _, ok := badIdentsMap[i.name]; ok {
@@ -277,28 +338,40 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
}
-func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
- i := newInterfaceGenerator(t, fset)
- switch ty := t.Type.(type) {
+func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interfaceGenerator {
+ i := newInterfaceGenerator(t.spec, fset)
+ switch ty := t.spec.Type.(type) {
case *ast.StructType:
- i.validateStruct()
- i.emitMarshallableForStruct()
- return i
+ i.validateStruct(t.spec, ty)
+ i.emitMarshallableForStruct(ty)
+ if t.slice != nil {
+ i.emitMarshallableSliceForStruct(ty, t.slice)
+ }
case *ast.Ident:
i.validatePrimitiveNewtype(ty)
- i.emitMarshallableForPrimitiveNewtype()
- return i
+ i.emitMarshallableForPrimitiveNewtype(ty)
+ if t.slice != nil {
+ i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice)
+ }
+ case *ast.ArrayType:
+ i.validateArrayNewtype(t.spec.Name, ty)
+ // After validate, we can safely call arrayLen.
+ i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident))
+ if t.slice != nil {
+ abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?"))
+ }
default:
// This should've been filtered out by collectMarshallabeTypes.
panic(fmt.Sprintf("Unexpected type %+v", ty))
}
+ return i
}
// generateOneTestSuite generates a test suite for the automatically generated
// implementations type t.
-func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator {
- i := newTestGenerator(t)
- i.emitTests()
+func (g *Generator) generateOneTestSuite(t marshallableType) *testGenerator {
+ i := newTestGenerator(t.spec)
+ i.emitTests(t.slice)
return i
}
@@ -348,7 +421,7 @@ func (g *Generator) Run() error {
// the list of imports we need to copy to the generated code.
for name, _ := range impl.is {
if !g.imports.markUsed(name) {
- panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'", impl.typeName(), name))
+ panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name))
}
}
ts = append(ts, g.generateOneTestSuite(t))
@@ -406,7 +479,7 @@ func (g *Generator) writeTests(ts []*testGenerator) error {
// empty example instead.
if len(ts) == 0 {
b.reset()
- b.emit("func ExampleEmptyTestSuite() {\n")
+ b.emit("func Example() {\n")
b.inIndent(func() {
b.emit("// This example is intentionally empty to ensure this file contains at least\n")
b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n")
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index ea1af998e..e3c3dac63 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -74,25 +74,12 @@ func (g *interfaceGenerator) recordUsedMarshallable(m string) {
func (g *interfaceGenerator) recordUsedImport(i string) {
g.is[i] = struct{}{}
-
}
func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
g.as[fieldName] = struct{}{}
}
-func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) {
- // This is guaranteed to succeed because g.t is always a struct.
- st := g.t.Type.(*ast.StructType)
- for _, field := range st.Fields.List {
- fn(field)
- }
-}
-
-func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
- return fmt.Sprintf("%s.%s", g.r, n.Name)
-}
-
// abortAt aborts the go_marshal tool with the given error message, with a
// reference position to the input source. Same as abortAt, but uses g to
// resolve p to position.
@@ -100,71 +87,6 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
abortAt(g.f.Position(p), msg)
}
-func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
- switch t.Name {
- case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
- // These are the only primitive types we're allow. Below, we provide
- // suggestions for some disallowed types and reject them, then attempt
- // to marshal any remaining types by invoking the marshal.Marshallable
- // interface on them. If these types don't actually implement
- // marshal.Marshallable, compilation of the generated code will fail
- // with an appropriate error message.
- return
- case "int":
- g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
- case "uint":
- g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
- case "string":
- g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
- default:
- debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
- }
-}
-
-// validateStruct ensures the type we're working with can be marshalled. These
-// checks are done ahead of time and in one place so we can make assumptions
-// later.
-func (g *interfaceGenerator) validateStruct() {
- g.forEachField(func(f *ast.Field) {
- if len(f.Names) == 0 {
- g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
- }
- })
-
- g.forEachField(func(f *ast.Field) {
- fieldDispatcher{
- primitive: func(_, t *ast.Ident) {
- g.validatePrimitiveNewtype(t)
- },
- selector: func(_, _, _ *ast.Ident) {
- // No validation to perform on selector fields. However this
- // callback must still be provided.
- },
- array: func(n, _ *ast.Ident, len int) {
- a := f.Type.(*ast.ArrayType)
- if a.Len == nil {
- g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
- }
-
- if _, ok := a.Len.(*ast.BasicLit); !ok {
- g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions"))
- }
-
- if _, ok := a.Elt.(*ast.Ident); !ok {
- g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
- }
-
- if len <= 0 {
- g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
- }
- },
- unhandled: func(_ *ast.Ident) {
- g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
- },
- }.dispatch(f)
- })
-}
-
// scalarSize returns the size of type identified by t. If t isn't a primitive
// type, the size isn't known at code generation time, and must be resolved via
// the marshal.Marshallable interface.
@@ -191,8 +113,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
}
-// marshalStructFieldScalar writes a single scalar field from a struct to a byte slice.
-func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar string) {
+// marshalScalar writes a single scalar to a byte slice.
+func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) {
switch typ {
case "int8", "uint8", "byte":
g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
@@ -215,9 +137,8 @@ func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar stri
}
}
-// unmarshalStructFieldScalar reads a single scalar field from a struct, from a
-// byte slice.
-func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar string) {
+// unmarshalScalar reads a single scalar from a byte slice.
+func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) {
switch typ {
case "byte":
g.emit("%s = %s[0]\n", accessor, bufVar)
@@ -244,579 +165,112 @@ func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar st
}
}
-// marshalPrimitiveScalar writes a single primitive variable to a byte slice.
-func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
- switch typ {
- case "int8", "uint8", "byte":
- g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
- case "int16", "uint16":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
- case "int32", "uint32":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
- case "int64", "uint64":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
- default:
- g.emit("inner := (*%s)(%s)\n", typ, accessor)
- g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
- }
+// emitCastToByteSlice unsafely casts an arbitrary type's underlying memory to a
+// byte slice, bypassing escape analysis. The caller is responsible for ensuring
+// srcPtr lives until they're done with dstVar, the runtime does not consider
+// dstVar dependent on srcPtr due to the escape analysis bypass.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifier "hdr", and cannot be used
+// in a context where it is already bound.
+func (g *interfaceGenerator) emitCastToByteSlice(srcPtr, dstVar, lenExpr string) {
+ g.recordUsedImport("gohacks")
+ g.emit("// Construct a slice backed by dst's underlying memory.\n")
+ g.emit("var %s []byte\n", dstVar)
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
+ g.emit("hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(%s)))\n", srcPtr)
+ g.emit("hdr.Len = %s\n", lenExpr)
+ g.emit("hdr.Cap = %s\n\n", lenExpr)
}
-// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
-func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
- switch typ {
- case "byte":
- g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
- case "int8", "uint8":
- g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
- case "int16", "uint16":
- g.recordUsedImport("usermem")
- g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
- case "int32", "uint32":
- g.recordUsedImport("usermem")
- g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
-
- case "int64", "uint64":
- g.recordUsedImport("usermem")
- g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
- default:
- g.emit("inner := (*%s)(%s)\n", typ, accessor)
- g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
- }
+// emitCastToByteSlice unsafely casts a slice with elements of an abitrary type
+// to a byte slice. As part of the cast, the byte slice is made to look
+// independent of the src slice by bypassing escape analysis. This means the
+// byte slice can be used without causing the source to escape. The caller is
+// responsible for ensuring srcPtr lives until they're done with dstVar, as the
+// runtime no longer considers dstVar dependent on srcPtr and is free to GC it.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifiers "ptr", "val" and "hdr",
+// and cannot be used in a context where these identifiers are already bound.
+func (g *interfaceGenerator) emitCastSliceToByteSlice(srcPtr, dstVar, lenExpr string) {
+ g.emitNoEscapeSliceDataPointer(srcPtr, "val")
+
+ g.emit("// Construct a slice backed by dst's underlying memory.\n")
+ g.emit("var %s []byte\n", dstVar)
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
+ g.emit("hdr.Data = uintptr(val)\n")
+ g.emit("hdr.Len = %s\n", lenExpr)
+ g.emit("hdr.Cap = %s\n\n", lenExpr)
}
-// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
-// packed. Returns "", false if g.t has no fields that may be potentially
-// packed, otherwise returns <clause>, true, where <clause> is an expression
-// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
-func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
- if len(g.as) == 0 {
- return "", false
- }
-
- cs := make([]string, 0, len(g.as))
- for accessor, _ := range g.as {
- cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
- }
- return strings.Join(cs, " && "), true
+// emitNoEscapeSliceDataPointer unsafely casts a slice's data pointer to an
+// unsafe.Pointer, bypassing escape analysis. The caller is responsible for
+// ensuring srcPtr lives until they're done with dstVar, as the runtime no
+// longer considers dstVar dependent on srcPtr and is free to GC it.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifier "ptr" cannot be used in a
+// context where this identifier is already bound.
+func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) {
+ g.recordUsedImport("gohacks")
+ g.emit("ptr := unsafe.Pointer(%s)\n", srcPtr)
+ g.emit("%s := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data))\n\n", dstVar)
}
-func (g *interfaceGenerator) emitMarshallableForStruct() {
- // Is g.t a packed struct without consideing field types?
- thisPacked := true
- g.forEachField(func(f *ast.Field) {
- if f.Tag != nil {
- if f.Tag.Value == "`marshal:\"unaligned\"`" {
- if thisPacked {
- debugfAt(g.f.Position(g.t.Pos()),
- fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
- thisPacked = false
- }
- }
- }
- })
-
- g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
- g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
- g.inIndent(func() {
- primitiveSize := 0
- var dynamicSizeTerms []string
-
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
- }
- },
- selector: func(n, tX, tSel *ast.Ident) {
- tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
- g.recordUsedImport(tX.Name)
- g.recordUsedMarshallable(tName)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
- },
- array: func(n, t *ast.Ident, len int) {
- if len < 1 {
- // Zero-length arrays should've been rejected by validate().
- panic("unreachable")
- }
- if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size * len
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
- }
- },
- }.dispatch)
- g.emit("return %d", primitiveSize)
- if len(dynamicSizeTerms) > 0 {
- g.incIndent()
- }
- {
- for _, d := range dynamicSizeTerms {
- g.emitNoIndent(" +\n")
- g.emit(d)
- }
- }
- if len(dynamicSizeTerms) > 0 {
- g.decIndent()
- }
- })
- g.emit("\n}\n\n")
-
- g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
- g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can referece here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
- }
- return
- }
- g.marshalStructFieldScalar(g.fieldAccessor(n), t.Name, "dst")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- g.marshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
- },
- array: func(n, t *ast.Ident, size int) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len*size)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
- }
- return
- }
-
- g.emit("for idx := 0; idx < %d; idx++ {\n", size)
- g.inIndent(func() {
- g.marshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
- g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name)
- g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
- }
- return
- }
- g.unmarshalStructFieldScalar(g.fieldAccessor(n), t.Name, "src")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- g.unmarshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
- },
- array: func(n, t *ast.Ident, size int) {
- if n.Name == "_" {
- g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len*size)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can referece here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
- }
- return
- }
-
- g.emit("for idx := 0; idx < %d; idx++ {\n", size)
- g.inIndent(func() {
- g.unmarshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
-
- g.emit("// Packed implements marshal.Marshallable.Packed.\n")
- g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
- g.inIndent(func() {
- expr, fieldsMaybePacked := g.areFieldsPackedExpression()
- switch {
- case !thisPacked:
- g.emit("return false\n")
- case fieldsMaybePacked:
- g.emit("return %s\n", expr)
- default:
- g.emit("return true\n")
-
- }
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
- g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- if thisPacked {
- g.recordUsedImport("safecopy")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if %s {\n", cond)
- g.inIndent(func() {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
- })
- g.emit("} else {\n")
- g.inIndent(func() {
- g.emit("%s.MarshalBytes(dst)\n", g.r)
- })
- g.emit("}\n")
- } else {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
- }
- } else {
- g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
- g.emit("%s.MarshalBytes(dst)\n", g.r)
- }
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
- g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- if thisPacked {
- g.recordUsedImport("safecopy")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if %s {\n", cond)
- g.inIndent(func() {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
- })
- g.emit("} else {\n")
- g.inIndent(func() {
- g.emit("%s.UnmarshalBytes(src)\n", g.r)
- })
- g.emit("}\n")
- } else {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
- }
- } else {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
- g.emit("%s.UnmarshalBytes(src)\n", g.r)
- }
- })
- g.emit("}\n\n")
-
- g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
- g.recordUsedImport("marshal")
- g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
- g.inIndent(func() {
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
- g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
- g.emit("%s.MarshalBytes(buf)\n", g.r)
- g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
- g.emit("return err\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !%s {\n", cond)
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- // Fast serialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyOutBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-
- g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
- g.recordUsedImport("marshal")
- g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
- g.inIndent(func() {
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
- g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
- g.emit("_, err := task.CopyInBytes(addr, buf)\n")
- g.emit("if err != nil {\n")
- g.inIndent(func() {
- g.emit("return err\n")
- })
- g.emit("}\n")
-
- g.emit("%s.UnmarshalBytes(buf)\n", g.r)
- g.emit("return nil\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !%s {\n", cond)
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- // Fast deserialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyInBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyInBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-
- g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
- g.recordUsedImport("io")
- g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
- g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
- g.emit("%s.MarshalBytes(buf)\n", g.r)
- g.emit("n, err := w.Write(buf)\n")
- g.emit("return int64(n), err\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !%s {\n", cond)
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- // Fast serialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("len, err := w.Write(buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the Write.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return int64(len), err\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
+func (g *interfaceGenerator) emitKeepAlive(ptrVar string) {
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar)
+ g.emit("// must live until the use above.\n")
+ g.emit("runtime.KeepAlive(%s)\n", ptrVar)
}
-// emitMarshallableForPrimitiveNewtype outputs code to implement the
-// marshal.Marshallable interface for a newtype on a primitive. Primitive
-// newtypes are always packed, so we can omit the various fallbacks required for
-// non-packed structs.
-func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype() {
- g.recordUsedImport("io")
- g.recordUsedImport("marshal")
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("safecopy")
- g.recordUsedImport("unsafe")
- g.recordUsedImport("usermem")
-
- nt := g.t.Type.(*ast.Ident)
-
- g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
- g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
- g.inIndent(func() {
- if size, dynamic := g.scalarSize(nt); !dynamic {
- g.emit("return %d\n", size)
- } else {
- g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
- }
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
- g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
- g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
- })
- g.emit("}\n\n")
-
- g.emit("// Packed implements marshal.Marshallable.Packed.\n")
- g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Scalar newtypes are always packed.\n")
- g.emit("return true\n")
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
- g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
- g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
- g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
- g.inIndent(func() {
- // Fast serialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyOutBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
- })
- g.emit("}\n\n")
-
- g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
- g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyInBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyInBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
- })
- g.emit("}\n\n")
-
- g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
- g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) {
+ switch x := e.X.(type) {
+ case *ast.BinaryExpr:
+ // Recursively expand sub-expression.
+ g.expandBinaryExpr(b, x)
+ case *ast.Ident:
+ fmt.Fprintf(b, "%s", x.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(b, "%s", x.Value)
+ default:
+ g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
- g.emit("len, err := w.Write(buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the Write.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return int64(len), err\n")
+ fmt.Fprintf(b, "%s", e.Op)
- })
- g.emit("}\n\n")
+ switch y := e.Y.(type) {
+ case *ast.BinaryExpr:
+ // Recursively expand sub-expression.
+ g.expandBinaryExpr(b, y)
+ case *ast.Ident:
+ fmt.Fprintf(b, "%s", y.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(b, "%s", y.Value)
+ default:
+ g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+}
+// arrayLenExpr returns a string containing a valid golang expression
+// representing the length of array a. The returned expression should be treated
+// as a single value, and will be already parenthesized as required.
+func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string {
+ var b strings.Builder
+
+ switch l := a.Len.(type) {
+ case *ast.Ident:
+ fmt.Fprintf(&b, "%s", l.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(&b, "%s", l.Value)
+ case *ast.BinaryExpr:
+ g.expandBinaryExpr(&b, l)
+ return fmt.Sprintf("(%s)", b.String())
+ default:
+ g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+ return b.String()
}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
new file mode 100644
index 000000000..8d6f102d5
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
@@ -0,0 +1,141 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// newtypes on arrays.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+)
+
+func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) {
+ if a.Len == nil {
+ g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
+ }
+
+ if _, ok := a.Elt.(*ast.Ident); !ok {
+ g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
+ }
+}
+
+func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ lenExpr := g.arrayLenExpr(a)
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(elt); !dynamic {
+ g.emit("return %d * %s\n", size, lenExpr)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes() * %s\n", n.Name, lenExpr)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst")
+ })
+ g.emit("}\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src")
+ })
+ g.emit("}\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Array newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit])\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
+
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
new file mode 100644
index 000000000..ef9bb903d
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
@@ -0,0 +1,282 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// newtypes on primitives.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+)
+
+// marshalPrimitiveScalar writes a single primitive variable to a byte
+// slice.
+func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
+ default:
+ g.emit("// Explicilty cast to the underlying type before dispatching to\n")
+ g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor)
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
+func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
+ switch typ {
+ case "byte":
+ g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
+ case "int8", "uint8":
+ g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
+ default:
+ g.emit("// Explicilty cast to the underlying type before dispatching to\n")
+ g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor)
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
+ switch t.Name {
+ case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
+ // These are the only primitive types we're allow. Below, we provide
+ // suggestions for some disallowed types and reject them, then attempt
+ // to marshal any remaining types by invoking the marshal.Marshallable
+ // interface on them. If these types don't actually implement
+ // marshal.Marshallable, compilation of the generated code will fail
+ // with an appropriate error message.
+ return
+ case "int":
+ g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
+ case "uint":
+ g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
+ case "string":
+ g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
+ default:
+ debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
+ }
+}
+
+// emitMarshallableForPrimitiveNewtype outputs code to implement the
+// marshal.Marshallable interface for a newtype on a primitive. Primitive
+// newtypes are always packed, so we can omit the various fallbacks required for
+// non-packed structs.
+func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(nt); !dynamic {
+ g.emit("return %d\n", size)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Scalar newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit])\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
+
+ })
+ g.emit("}\n\n")
+}
+
+func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Ident, slice *sliceAPI) {
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+
+ eltType := g.typeName()
+ if slice.inner {
+ eltType = nt.Name
+ }
+
+ g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType)
+ g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType)
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitCastSliceToByteSlice("&dst", "buf", "size * count")
+
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType)
+ g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType)
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitCastSliceToByteSlice("&src", "buf", "size * count")
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitNoEscapeSliceDataPointer("&src", "val")
+
+ g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitNoEscapeSliceDataPointer("&dst", "val")
+
+ g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
new file mode 100644
index 000000000..4236e978e
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
@@ -0,0 +1,614 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// structs.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "strings"
+)
+
+func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
+ return fmt.Sprintf("%s.%s", g.r, n.Name)
+}
+
+// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
+// packed. Returns "", false if g.t has no fields that may be potentially
+// packed, otherwise returns <clause>, true, where <clause> is an expression
+// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
+func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
+ if len(g.as) == 0 {
+ return "", false
+ }
+
+ cs := make([]string, 0, len(g.as))
+ for accessor, _ := range g.as {
+ cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
+ }
+ return strings.Join(cs, " && "), true
+}
+
+// validateStruct ensures the type we're working with can be marshalled. These
+// checks are done ahead of time and in one place so we can make assumptions
+// later.
+func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) {
+ forEachStructField(st, func(f *ast.Field) {
+ if len(f.Names) == 0 {
+ g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
+ }
+ })
+
+ forEachStructField(st, func(f *ast.Field) {
+ fieldDispatcher{
+ primitive: func(_, t *ast.Ident) {
+ g.validatePrimitiveNewtype(t)
+ },
+ selector: func(_, _, _ *ast.Ident) {
+ // No validation to perform on selector fields. However this
+ // callback must still be provided.
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, _ *ast.Ident) {
+ g.validateArrayNewtype(n, a)
+ },
+ unhandled: func(_ *ast.Ident) {
+ g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
+ },
+ }.dispatch(f)
+ })
+}
+
+func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool {
+ packed := true
+ forEachStructField(st, func(f *ast.Field) {
+ if f.Tag != nil {
+ if f.Tag.Value == "`marshal:\"unaligned\"`" {
+ if packed {
+ debugfAt(g.f.Position(g.t.Pos()),
+ fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
+ packed = false
+ }
+ }
+ }
+ })
+ return packed
+}
+
+func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
+ thisPacked := g.isStructPacked(st)
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ primitiveSize := 0
+ var dynamicSizeTerms []string
+
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
+ }
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
+ g.recordUsedImport(tX.Name)
+ g.recordUsedMarshallable(tName)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr))
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr))
+ }
+ },
+ }.dispatch)
+ g.emit("return %d", primitiveSize)
+ if len(dynamicSizeTerms) > 0 {
+ g.incIndent()
+ }
+ {
+ for _, d := range dynamicSizeTerms {
+ g.emitNoIndent(" +\n")
+ g.emit(d)
+ }
+ }
+ if len(dynamicSizeTerms) > 0 {
+ g.decIndent()
+ }
+ })
+ g.emit("\n}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
+ }
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name)
+ g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("dst = dst[%d*(%s):]\n", size, lenExpr)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len)
+ } else {
+ // We don't have an instance of the dynamic type we can
+ // reference here (since the version in this struct is
+ // anonymous). Use a typed nil pointer to call
+ // SizeBytes() instead.
+ g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name))
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
+ }
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name)
+ g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name))
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if n.Name == "_" {
+ g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("src = src[%d*(%s):]\n", size, lenExpr)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ expr, fieldsMaybePacked := g.areFieldsPackedExpression()
+ switch {
+ case !thisPacked:
+ g.emit("return false\n")
+ case fieldsMaybePacked:
+ g.emit("return %s\n", expr)
+ default:
+ g.emit("return true\n")
+
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
+ g.emit("%s.MarshalBytes(buf)\n", g.r)
+ g.emit("return task.CopyOutBytes(addr, buf[:limit])\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit])\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n")
+ g.emit("// partially unmarshalled struct.\n")
+ g.emit("%s.UnmarshalBytes(buf)\n", g.r)
+ g.emit("return length, err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast deserialization.
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.recordUsedImport("io")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
+ g.emit("%s.MarshalBytes(buf)\n", g.r)
+ g.emit("length, err := w.Write(buf)\n")
+ g.emit("return int64(length), err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+}
+
+func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) {
+ thisPacked := g.isStructPacked(st)
+
+ if slice.inner {
+ abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident))
+ }
+
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+
+ g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName())
+ g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(size * count)\n")
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n\n")
+
+ g.emit("// Unmarshal as much as possible, even on error. First handle full objects.\n")
+ g.emit("limit := length/size\n")
+ g.emit("for idx := 0; idx < limit; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Handle any final partial object.\n")
+ g.emit("if length < size*count && length%size != 0 {\n")
+ g.inIndent(func() {
+ g.emit("idx := limit\n")
+ g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("return length, err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !dst[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast deserialization.
+ g.emitCastSliceToByteSlice("&dst", "buf", "size * count")
+
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName())
+ g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(size * count)\n")
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return task.CopyOutBytes(addr, buf)\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !src[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emitCastSliceToByteSlice("&src", "buf", "size * count")
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("src[idx].MarshalBytes(dst[size*idx:(size)*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return size * count, nil\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !src[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ g.emitNoEscapeSliceDataPointer("&src", "val")
+
+ g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("dst[idx].UnmarshalBytes(src[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return size * count, nil\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !dst[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ g.emitNoEscapeSliceDataPointer("&dst", "val")
+
+ g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index 8ba47eb67..631295373 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -30,6 +30,11 @@ var standardImports = []string{
"gvisor.dev/gvisor/tools/go_marshal/analysis",
}
+var sliceAPIImports = []string{
+ "encoding/binary",
+ "gvisor.dev/gvisor/pkg/usermem",
+}
+
type testGenerator struct {
sourceBuffer
@@ -58,6 +63,11 @@ func newTestGenerator(t *ast.TypeSpec) *testGenerator {
for _, i := range standardImports {
g.imports.add(i).markUsed()
}
+ // These imports are used if a type requests the slice API. Don't
+ // mark them as used by default.
+ for _, i := range sliceAPIImports {
+ g.imports.add(i)
+ }
return g
}
@@ -132,6 +142,42 @@ func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() {
})
}
+func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceAPI) {
+ for _, name := range []string{"binary", "usermem"} {
+ if !g.imports.markUsed(name) {
+ panic(fmt.Sprintf("Generated test for '%s' referenced a non-existent import with local name '%s'", g.typeName(), name))
+ }
+ }
+
+ g.inTestFunction("TestSafeMarshalUnmarshalSlicePreservesData", func() {
+ g.emit("var x, y, yUnsafe [8]%s\n", g.typeName())
+ g.emit("analysis.RandomizeValue(&x)\n\n")
+ g.emit("size := (*%s)(nil).SizeBytes() * len(x)\n", g.typeName())
+ g.emit("buf := bytes.NewBuffer(make([]byte, size))\n")
+ g.emit("buf.Reset()\n")
+ g.emit("if err := binary.Write(buf, usermem.ByteOrder, x[:]); err != nil {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"binary.Write failed: %v\", err))\n")
+ })
+ g.emit("}\n")
+ g.emit("bufUnsafe := make([]byte, size)\n")
+ g.emit("MarshalUnsafe%s(x[:], bufUnsafe)\n\n", slice.ident)
+
+ g.emit("UnmarshalUnsafe%s(y[:], buf.Bytes())\n", slice.ident)
+ g.emit("if !reflect.DeepEqual(x, y) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across binary.Write/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
+ })
+ g.emit("}\n")
+ g.emit("UnmarshalUnsafe%s(yUnsafe[:], bufUnsafe)\n", slice.ident)
+ g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafeSlice/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
+ })
+ g.emit("}\n\n")
+ })
+}
+
func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() {
g.inTestFunction("TestWriteToUnmarshalPreservesData", func() {
g.emit("var x, y, yUnsafe %s\n", g.typeName())
@@ -164,18 +210,22 @@ func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() {
g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n")
g.inIndent(func() {
- g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)")
+ g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n")
})
g.emit("}\n")
})
}
-func (g *testGenerator) emitTests() {
+func (g *testGenerator) emitTests(slice *sliceAPI) {
g.emitTestNonZeroSize()
g.emitTestSuspectAlignment()
g.emitTestMarshalUnmarshalPreservesData()
g.emitTestWriteToUnmarshalPreservesData()
g.emitTestSizeBytesOnTypedNilPtr()
+
+ if slice != nil {
+ g.emitTestMarshalUnmarshalSlicePreservesData(slice)
+ }
}
func (g *testGenerator) write(out io.Writer) error {
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
index e2bca4e7c..d94314302 100644
--- a/tools/go_marshal/gomarshal/util.go
+++ b/tools/go_marshal/gomarshal/util.go
@@ -25,7 +25,6 @@ import (
"path"
"reflect"
"sort"
- "strconv"
"strings"
)
@@ -64,12 +63,18 @@ func kindString(e ast.Expr) string {
}
}
+func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) {
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
// fieldDispatcher is a collection of callbacks for handling different types of
// fields in a struct declaration.
type fieldDispatcher struct {
primitive func(n, t *ast.Ident)
selector func(n, tX, tSel *ast.Ident)
- array func(n, t *ast.Ident, size int)
+ array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident)
unhandled func(n *ast.Ident)
}
@@ -96,22 +101,12 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) {
case *ast.SelectorExpr:
fd.selector(name, v.X.(*ast.Ident), v.Sel)
case *ast.ArrayType:
- len := 0
- if v.Len != nil {
- // Non-literal array length is handled by generatorInterfaces.validate().
- if lenLit, ok := v.Len.(*ast.BasicLit); ok {
- var err error
- len, err = strconv.Atoi(lenLit.Value)
- if err != nil {
- panic(err)
- }
- }
- }
switch t := v.Elt.(type) {
case *ast.Ident:
- fd.array(name, t, len)
+ fd.array(name, v, t)
default:
- fd.array(name, nil, len)
+ // Should be handled with a better error message during validate.
+ panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t))
}
default:
fd.unhandled(name)
@@ -270,6 +265,11 @@ type importStmt struct {
aliased bool
// Indicates whether this import was referenced by generated code.
used bool
+ // AST node and file set representing the import statement, if any. These
+ // are only non-nil if the import statement originates from an input source
+ // file.
+ spec *ast.ImportSpec
+ fset *token.FileSet
}
func newImport(p string) *importStmt {
@@ -295,14 +295,27 @@ func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
name: name,
path: p,
aliased: spec.Name != nil,
+ spec: spec,
+ fset: f,
}
}
+// String implements fmt.Stringer.String. This generates a string for the import
+// statement appropriate for writing directly to generated code.
func (i *importStmt) String() string {
if i.aliased {
- return fmt.Sprintf("%s \"%s\"", i.name, i.path)
+ return fmt.Sprintf("%s %q", i.name, i.path)
}
- return fmt.Sprintf("\"%s\"", i.path)
+ return fmt.Sprintf("%q", i.path)
+}
+
+// debugString returns a debug string representing an import statement. This
+// representation is not valid golang code and is used for debugging output.
+func (i *importStmt) debugString() string {
+ if i.spec != nil && i.fset != nil {
+ return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i)
+ }
+ return fmt.Sprintf("(go-marshal import): %s", i)
}
func (i *importStmt) markUsed() {
@@ -314,40 +327,78 @@ func (i *importStmt) equivalent(other *importStmt) bool {
}
// importTable represents a collection of importStmts.
+//
+// An importTable may contain multiple import statements referencing the same
+// local name. All import statements aliasing to the same local name are
+// technically ambiguous, as if such an import name is used in the generated
+// code, it's not clear which import statement it refers to. We ignore any
+// potential collisions until actually writing the import table to the generated
+// source file. See importTable.write.
+//
+// Given the following import statements across all the files comprising a
+// package marshalled:
+//
+// "sync"
+// "pkg/sync"
+// "pkg/sentry/kernel"
+// ktime "pkg/sentry/kernel/time"
+//
+// An importTable representing them would look like this:
+//
+// importTable {
+// is: map[string][]*importStmt {
+// "sync": []*importStmt{
+// importStmt{name:"sync", path:"sync", aliased:false}
+// importStmt{name:"sync", path:"pkg/sync", aliased:false}
+// },
+// "kernel": []*importStmt{importStmt{
+// name: "kernel",
+// path: "pkg/sentry/kernel",
+// aliased: false
+// }},
+// "ktime": []*importStmt{importStmt{
+// name: "ktime",
+// path: "pkg/sentry/kernel/time",
+// aliased: true,
+// }},
+// }
+// }
+//
+// Note that the local name "sync" is assigned to two different import
+// statements. This is possible if the import statements are from different
+// source files in the same package.
+//
+// Since go-marshal generates a single output file per package regardless of the
+// number of input files, if "sync" is referenced by any generated code, it's
+// unclear which import statement "sync" refers to. While it's theoretically
+// possible to resolve this by assigning a unique local alias to each instance
+// of the sync package, go-marshal currently aborts when it encounters such an
+// ambiguity.
+//
+// TODO(b/151478251): importTable considers the final component of an import
+// path to be the package name, but this is only a convention. The actual
+// package name is determined by the package statement in the source files for
+// the package.
type importTable struct {
// Map of imports and whether they should be copied to the output.
- is map[string]*importStmt
+ is map[string][]*importStmt
}
func newImportTable() *importTable {
return &importTable{
- is: make(map[string]*importStmt),
+ is: make(map[string][]*importStmt),
}
}
-// Merges import statements from other into i. Collisions in import statements
-// result in a panic.
+// Merges import statements from other into i.
func (i *importTable) merge(other *importTable) {
- for name, im := range other.is {
- if dup, ok := i.is[name]; ok && !dup.equivalent(im) {
- panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im))
- }
-
- i.is[name] = im
+ for name, ims := range other.is {
+ i.is[name] = append(i.is[name], ims...)
}
}
func (i *importTable) addStmt(s *importStmt) *importStmt {
- if old, ok := i.is[s.name]; ok && !old.equivalent(s) {
- // A collision should always be between an import inserted by the
- // go-marshal tool and an import from the original source file (assuming
- // the original source file was valid). We could theoretically handle
- // the collision by assigning a local name to our import. However, this
- // would need to be plumbed throughout the generator. Given that
- // collisions should be rare, simply panic on collision.
- panic(fmt.Sprintf("Import collision: old: %s as %v; new: %v as %v", old.path, old.name, s.path, s.name))
- }
- i.is[s.name] = s
+ i.is[s.name] = append(i.is[s.name], s)
return s
}
@@ -363,16 +414,20 @@ func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *impor
// Marks the import named n as used. If no such import is in the table, returns
// false.
func (i *importTable) markUsed(n string) bool {
- if n, ok := i.is[n]; ok {
- n.markUsed()
+ if ns, ok := i.is[n]; ok {
+ for _, n := range ns {
+ n.markUsed()
+ }
return true
}
return false
}
func (i *importTable) clear() {
- for _, i := range i.is {
- i.used = false
+ for _, is := range i.is {
+ for _, i := range is {
+ i.used = false
+ }
}
}
@@ -383,9 +438,42 @@ func (i *importTable) write(out io.Writer) error {
}
imports := make([]string, 0, len(i.is))
- for _, i := range i.is {
- if i.used {
- imports = append(imports, i.String())
+ for name, is := range i.is {
+ var lastUsed *importStmt
+ var ambiguous bool
+
+ for _, i := range is {
+ if i.used {
+ if lastUsed != nil {
+ if !i.equivalent(lastUsed) {
+ ambiguous = true
+ }
+ }
+ lastUsed = i
+ }
+ }
+
+ if ambiguous {
+ // We have two or more import statements across the different source
+ // files that share a local name, and at least one of these imports
+ // are used by the generated code. This ambiguity can't be resolved
+ // by go-marshal and requires the user intervention. Dump a list of
+ // the colliding import statements and let the user modify the input
+ // files as appropriate.
+ var b strings.Builder
+ fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name)
+ fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name)
+ // Note: len(is) is guaranteed to be 1 or greater or ambiguous can't
+ // be true. Therefore the slicing below is safe.
+ for _, i := range is[:len(is)-1] {
+ fmt.Fprintf(&b, " %v\n", i.debugString())
+ }
+ fmt.Fprintf(&b, " %v", is[len(is)-1].debugString())
+ panic(b.String())
+ }
+
+ if lastUsed != nil {
+ imports = append(imports, lastUsed.String())
}
}
sort.Strings(imports)
diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go
index f129788e0..cb2166252 100644
--- a/tools/go_marshal/marshal/marshal.go
+++ b/tools/go_marshal/marshal/marshal.go
@@ -42,7 +42,11 @@ type Task interface {
CopyInBytes(addr usermem.Addr, b []byte) (int, error)
}
-// Marshallable represents a type that can be marshalled to and from memory.
+// Marshallable represents operations on a type that can be marshalled to and
+// from memory.
+//
+// go-marshal automatically generates implementations for this interface for
+// types marked as '+marshal'.
type Marshallable interface {
io.WriterTo
@@ -54,12 +58,18 @@ type Marshallable interface {
// likely make use of the type of these fields).
SizeBytes() int
- // MarshalBytes serializes a copy of a type to dst. dst must be at least
- // SizeBytes() long.
+ // MarshalBytes serializes a copy of a type to dst. dst may be smaller than
+ // SizeBytes(), which results in a part of the struct being marshalled. Note
+ // that this may have unexpected results for non-packed types, as implicit
+ // padding needs to be taken into account when reasoning about how much of
+ // the type is serialized.
MarshalBytes(dst []byte)
- // UnmarshalBytes deserializes a type from src. src must be at least
- // SizeBytes() long.
+ // UnmarshalBytes deserializes a type from src. src may be smaller than
+ // SizeBytes(), which results in a partially deserialized struct. Note that
+ // this may have unexpected results for non-packed types, as implicit
+ // padding needs to be taken into account when reasoning about how much of
+ // the type is deserialized.
UnmarshalBytes(src []byte)
// Packed returns true if the marshalled size of the type is the same as the
@@ -67,13 +77,20 @@ type Marshallable interface {
// starting at unaligned addresses (should always be true by default for ABI
// structs, verified by automatically generated tests when using
// go_marshal), and has no fields marked `marshal:"unaligned"`.
+ //
+ // Packed must return the same result for all possible values of the type
+ // implementing it. Violating this constraint implies the type doesn't have
+ // a static memory layout, and will lead to memory corruption.
+ // Go-marshal-generated code reuses the result of Packed for multiple values
+ // of the same type.
Packed() bool
// MarshalUnsafe serializes a type by bulk copying its in-memory
// representation to the dst buffer. This is only safe to do when the type
// has no implicit padding, see Marshallable.Packed. When Packed would
// return false, MarshalUnsafe should fall back to the safer but slower
- // MarshalBytes.
+ // MarshalBytes. dst may be smaller than SizeBytes(), see comment for
+ // MarshalBytes for implications.
MarshalUnsafe(dst []byte)
// UnmarshalUnsafe deserializes a type by directly copying to the underlying
@@ -82,7 +99,8 @@ type Marshallable interface {
// This allows much faster unmarshalling of types which have no implicit
// padding, see Marshallable.Packed. When Packed would return false,
// UnmarshalUnsafe should fall back to the safer but slower unmarshal
- // mechanism implemented in UnmarshalBytes.
+ // mechanism implemented in UnmarshalBytes. src may be smaller than
+ // SizeBytes(), see comment for UnmarshalBytes for implications.
UnmarshalUnsafe(src []byte)
// CopyIn deserializes a Marshallable type from a task's memory. This may
@@ -91,12 +109,79 @@ type Marshallable interface {
// marshalled does not escape. The implementation should avoid creating
// extra copies in memory by directly deserializing to the object's
// underlying memory.
- CopyIn(task Task, addr usermem.Addr) error
+ //
+ // If the copy-in from the task memory is only partially successful, CopyIn
+ // should still attempt to deserialize as much data as possible. See comment
+ // for UnmarshalBytes.
+ CopyIn(task Task, addr usermem.Addr) (int, error)
// CopyOut serializes a Marshallable type to a task's memory. This may only
// be called from a task goroutine. This is more efficient than calling
// MarshalUnsafe on Marshallable.Packed types, as the type being serialized
// does not escape. The implementation should avoid creating extra copies in
// memory by directly serializing from the object's underlying memory.
- CopyOut(task Task, addr usermem.Addr) error
+ //
+ // The copy-out to the task memory may be partially successful, in which
+ // case CopyOut returns how much data was serialized. See comment for
+ // MarshalBytes for implications.
+ CopyOut(task Task, addr usermem.Addr) (int, error)
+
+ // CopyOutN is like CopyOut, but explicitly requests a partial
+ // copy-out. Note that this may yield unexpected results for non-packed
+ // types and the caller may only want to allow this for packed types. See
+ // comment on MarshalBytes.
+ //
+ // The limit must be less than or equal to SizeBytes().
+ CopyOutN(task Task, addr usermem.Addr, limit int) (int, error)
}
+
+// go-marshal generates additional functions for a type based on additional
+// clauses to the +marshal directive. They are documented below.
+//
+// Slice API
+// =========
+//
+// Adding a "slice" clause to the +marshal directive for structs or newtypes on
+// primitives like this:
+//
+// // +marshal slice:FooSlice
+// type Foo struct { ... }
+//
+// Generates four additional functions for marshalling slices of Foos like this:
+//
+// // MarshalUnsafeFooSlice is like Foo.MarshalUnsafe, buf for a []Foo. It's
+// // more efficient that repeatedly calling calling Foo.MarshalUnsafe over a
+// // []Foo in a loop.
+// func MarshalUnsafeFooSlice(src []Foo, dst []byte) (int, error) { ... }
+//
+// // UnmarshalUnsafeFooSlice is like Foo.UnmarshalUnsafe, buf for a []Foo. It's
+// // more efficient that repeatedly calling calling Foo.UnmarshalUnsafe over a
+// // []Foo in a loop.
+// func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... }
+//
+// // CopyFooSliceIn copies in a slice of Foo objects from the task's memory.
+// func CopyFooSliceIn(task marshal.Task, addr usermem.Addr, dst []Foo) (int, error) { ... }
+//
+// // CopyFooSliceIn copies out a slice of Foo objects to the task's memory.
+// func CopyFooSliceOut(task marshal.Task, addr usermem.Addr, src []Foo) (int, error) { ... }
+//
+// The name of the functions are of the format "Copy%sIn" and "Copy%sOut", where
+// %s is the first argument to the slice clause. This directive is not supported
+// for newtypes on arrays.
+//
+// The slice clause also takes an optional second argument, which must be the
+// value "inner":
+//
+// // +marshal slice:Int32Slice:inner
+// type Int32 int32
+//
+// This is only valid on newtypes on primitives, and causes the generated
+// functions to accept slices of the inner type instead:
+//
+// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []int32) (int, error) { ... }
+//
+// Without "inner", they would instead be:
+//
+// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []Int32) (int, error) { ... }
+//
+// This may help avoid a cast depending on how the generated functions are used.
diff --git a/tools/go_marshal/primitive/BUILD b/tools/go_marshal/primitive/BUILD
new file mode 100644
index 000000000..cc08ba63a
--- /dev/null
+++ b/tools/go_marshal/primitive/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "primitive",
+ srcs = [
+ "primitive.go",
+ ],
+ marshal = True,
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+ ],
+)
diff --git a/tools/go_marshal/primitive/primitive.go b/tools/go_marshal/primitive/primitive.go
new file mode 100644
index 000000000..ebcf130ae
--- /dev/null
+++ b/tools/go_marshal/primitive/primitive.go
@@ -0,0 +1,175 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package primitive defines marshal.Marshallable implementations for primitive
+// types.
+package primitive
+
+import (
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+)
+
+// Int16 is a marshal.Marshallable implementation for int16.
+//
+// +marshal slice:Int16Slice:inner
+type Int16 int16
+
+// Uint16 is a marshal.Marshallable implementation for uint16.
+//
+// +marshal slice:Uint16Slice:inner
+type Uint16 uint16
+
+// Int32 is a marshal.Marshallable implementation for int32.
+//
+// +marshal slice:Int32Slice:inner
+type Int32 int32
+
+// Uint32 is a marshal.Marshallable implementation for uint32.
+//
+// +marshal slice:Uint32Slice:inner
+type Uint32 uint32
+
+// Int64 is a marshal.Marshallable implementation for int64.
+//
+// +marshal slice:Int64Slice:inner
+type Int64 int64
+
+// Uint64 is a marshal.Marshallable implementation for uint64.
+//
+// +marshal slice:Uint64Slice:inner
+type Uint64 uint64
+
+// Below, we define some convenience functions for marshalling primitive types
+// using the newtypes above, without requiring superfluous casts.
+
+// 16-bit integers
+
+// CopyInt16In is a convenient wrapper for copying in an int16 from the task's
+// memory.
+func CopyInt16In(task marshal.Task, addr usermem.Addr, dst *int16) (int, error) {
+ var buf Int16
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int16(buf)
+ return n, nil
+}
+
+// CopyInt16Out is a convenient wrapper for copying out an int16 to the task's
+// memory.
+func CopyInt16Out(task marshal.Task, addr usermem.Addr, src int16) (int, error) {
+ srcP := Int16(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// CopyUint16In is a convenient wrapper for copying in a uint16 from the task's
+// memory.
+func CopyUint16In(task marshal.Task, addr usermem.Addr, dst *uint16) (int, error) {
+ var buf Uint16
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint16(buf)
+ return n, nil
+}
+
+// CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's
+// memory.
+func CopyUint16Out(task marshal.Task, addr usermem.Addr, src uint16) (int, error) {
+ srcP := Uint16(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// 32-bit integers
+
+// CopyInt32In is a convenient wrapper for copying in an int32 from the task's
+// memory.
+func CopyInt32In(task marshal.Task, addr usermem.Addr, dst *int32) (int, error) {
+ var buf Int32
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int32(buf)
+ return n, nil
+}
+
+// CopyInt32Out is a convenient wrapper for copying out an int32 to the task's
+// memory.
+func CopyInt32Out(task marshal.Task, addr usermem.Addr, src int32) (int, error) {
+ srcP := Int32(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// CopyUint32In is a convenient wrapper for copying in a uint32 from the task's
+// memory.
+func CopyUint32In(task marshal.Task, addr usermem.Addr, dst *uint32) (int, error) {
+ var buf Uint32
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint32(buf)
+ return n, nil
+}
+
+// CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's
+// memory.
+func CopyUint32Out(task marshal.Task, addr usermem.Addr, src uint32) (int, error) {
+ srcP := Uint32(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// 64-bit integers
+
+// CopyInt64In is a convenient wrapper for copying in an int64 from the task's
+// memory.
+func CopyInt64In(task marshal.Task, addr usermem.Addr, dst *int64) (int, error) {
+ var buf Int64
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int64(buf)
+ return n, nil
+}
+
+// CopyInt64Out is a convenient wrapper for copying out an int64 to the task's
+// memory.
+func CopyInt64Out(task marshal.Task, addr usermem.Addr, src int64) (int, error) {
+ srcP := Int64(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// CopyUint64In is a convenient wrapper for copying in a uint64 from the task's
+// memory.
+func CopyUint64In(task marshal.Task, addr usermem.Addr, dst *uint64) (int, error) {
+ var buf Uint64
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint64(buf)
+ return n, nil
+}
+
+// CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's
+// memory.
+func CopyUint64Out(task marshal.Task, addr usermem.Addr, src uint64) (int, error) {
+ srcP := Uint64(src)
+ return srcP.CopyOut(task, addr)
+}
diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD
index f27c5ce52..3b839799d 100644
--- a/tools/go_marshal/test/BUILD
+++ b/tools/go_marshal/test/BUILD
@@ -39,3 +39,17 @@ go_binary(
"//tools/go_marshal/marshal",
],
)
+
+go_test(
+ name = "marshal_test",
+ size = "small",
+ srcs = ["marshal_test.go"],
+ deps = [
+ ":test",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//tools/go_marshal/analysis",
+ "//tools/go_marshal/marshal",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
+)
diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go
index c79defe9e..224d308c7 100644
--- a/tools/go_marshal/test/benchmark_test.go
+++ b/tools/go_marshal/test/benchmark_test.go
@@ -176,3 +176,45 @@ func BenchmarkGoMarshalUnsafe(b *testing.B) {
panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
}
}
+
+func BenchmarkBinarySlice(b *testing.B) {
+ var s1, s2 [64]test.Stat
+ analysis.RandomizeValue(&s1)
+
+ size := binary.Size(s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, 0, size)
+ buf = binary.Marshal(buf, usermem.ByteOrder, &s1)
+ binary.Unmarshal(buf, usermem.ByteOrder, &s2)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+func BenchmarkGoMarshalUnsafeSlice(b *testing.B) {
+ var s1, s2 [64]test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, (*test.Stat)(nil).SizeBytes()*len(s1))
+ test.MarshalUnsafeStatSlice(s1[:], buf)
+ test.UnmarshalUnsafeStatSlice(s2[:], buf)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
diff --git a/tools/go_marshal/test/external/external.go b/tools/go_marshal/test/external/external.go
index 4be3722f3..26fe8e0c8 100644
--- a/tools/go_marshal/test/external/external.go
+++ b/tools/go_marshal/test/external/external.go
@@ -21,3 +21,11 @@ package external
type External struct {
j int64
}
+
+// NotPacked is an unaligned Marshallable type for use in testing.
+//
+// +marshal
+type NotPacked struct {
+ a int32
+ b byte `marshal:"unaligned"`
+}
diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go
new file mode 100644
index 000000000..16829ee45
--- /dev/null
+++ b/tools/go_marshal/test/marshal_test.go
@@ -0,0 +1,515 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package marshal_test contains manual tests for the marshal interface. These
+// are intended to test behaviour not covered by the automatically generated
+// tests.
+package marshal_test
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "reflect"
+ "runtime"
+ "testing"
+ "unsafe"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/analysis"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/test"
+)
+
+var simulatedErr error = syserror.EFAULT
+
+// mockTask implements marshal.Task.
+type mockTask struct {
+ taskMem usermem.BytesIO
+}
+
+// populate fills the task memory with the contents of val.
+func (t *mockTask) populate(val interface{}) {
+ var buf bytes.Buffer
+ // Use binary.Write so we aren't testing go-marshal against its own
+ // potentially buggy implementation.
+ if err := binary.Write(&buf, usermem.ByteOrder, val); err != nil {
+ panic(err)
+ }
+ t.taskMem.Bytes = buf.Bytes()
+}
+
+func (t *mockTask) setLimit(n int) {
+ if len(t.taskMem.Bytes) < n {
+ grown := make([]byte, n)
+ copy(grown, t.taskMem.Bytes)
+ t.taskMem.Bytes = grown
+ return
+ }
+ t.taskMem.Bytes = t.taskMem.Bytes[:n]
+}
+
+// CopyScratchBuffer implements marshal.Task.CopyScratchBuffer.
+func (t *mockTask) CopyScratchBuffer(size int) []byte {
+ return make([]byte, size)
+}
+
+// CopyOutBytes implements marshal.Task.CopyOutBytes. The implementation
+// completely ignores the target address and stores a copy of b in its
+// internally buffer, overriding any previous contents.
+func (t *mockTask) CopyOutBytes(_ usermem.Addr, b []byte) (int, error) {
+ return t.taskMem.CopyOut(nil, 0, b, usermem.IOOpts{})
+}
+
+// CopyInBytes implements marshal.Task.CopyInBytes. The implementation
+// completely ignores the source address and always fills b from the begining of
+// its internal buffer.
+func (t *mockTask) CopyInBytes(_ usermem.Addr, b []byte) (int, error) {
+ return t.taskMem.CopyIn(nil, 0, b, usermem.IOOpts{})
+}
+
+// unsafeMemory returns the underlying memory for m. The returned slice is only
+// valid for the lifetime for m. The garbage collector isn't aware that the
+// returned slice is related to m, the caller must ensure m lives long enough.
+func unsafeMemory(m marshal.Marshallable) []byte {
+ if !m.Packed() {
+ // We can't return a slice pointing to the underlying memory
+ // since the layout isn't packed. Allocate a temporary buffer
+ // and marshal instead.
+ var buf bytes.Buffer
+ if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil {
+ panic(err)
+ }
+ return buf.Bytes()
+ }
+
+ // reflect.ValueOf(m)
+ // .Elem() // Unwrap interface to inner concrete object
+ // .Addr() // Pointer value to object
+ // .Pointer() // Actual address from the pointer value
+ ptr := reflect.ValueOf(m).Elem().Addr().Pointer()
+
+ size := m.SizeBytes()
+
+ var mem []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&mem))
+ hdr.Data = ptr
+ hdr.Len = size
+ hdr.Cap = size
+
+ return mem
+}
+
+// unsafeMemorySlice returns the underlying memory for m. The returned slice is
+// only valid for the lifetime for m. The garbage collector isn't aware that the
+// returned slice is related to m, the caller must ensure m lives long enough.
+//
+// Precondition: m must be a slice.
+func unsafeMemorySlice(m interface{}, elt marshal.Marshallable) []byte {
+ kind := reflect.TypeOf(m).Kind()
+ if kind != reflect.Slice {
+ panic("unsafeMemorySlice called on non-slice")
+ }
+
+ if !elt.Packed() {
+ // We can't return a slice pointing to the underlying memory
+ // since the layout isn't packed. Allocate a temporary buffer
+ // and marshal instead.
+ var buf bytes.Buffer
+ if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil {
+ panic(err)
+ }
+ return buf.Bytes()
+ }
+
+ v := reflect.ValueOf(m)
+ length := v.Len() * elt.SizeBytes()
+
+ var mem []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&mem))
+ hdr.Data = v.Pointer() // This is a pointer to the first elem for slices.
+ hdr.Len = length
+ hdr.Cap = length
+
+ return mem
+}
+
+func isZeroes(buf []byte) bool {
+ for _, b := range buf {
+ if b != 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// compareMemory compares the first n bytes of two chuncks of memory represented
+// by expected and actual.
+func compareMemory(t *testing.T, expected, actual []byte, n int) {
+ t.Logf("Expected (%d): %v (%d) + (%d) %v\n", len(expected), expected[:n], n, len(expected)-n, expected[n:])
+ t.Logf("Actual (%d): %v (%d) + (%d) %v\n", len(actual), actual[:n], n, len(actual)-n, actual[n:])
+
+ if diff := cmp.Diff(expected[:n], actual[:n]); diff != "" {
+ t.Errorf("Memory buffers don't match:\n--- expected only\n+++ actual only\n%v", diff)
+ }
+}
+
+// limitedCopyIn populates task memory with src, then unmarshals task memory to
+// dst. The task signals an error at limit bytes during copy-in, which should
+// result in a truncated unmarshalling.
+func limitedCopyIn(t *testing.T, src, dst marshal.Marshallable, limit int) {
+ var task mockTask
+ task.populate(src)
+ task.setLimit(limit)
+
+ n, err := dst.CopyIn(&task, usermem.Addr(0))
+ if n != limit {
+ t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if err != simulatedErr {
+ t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := unsafeMemory(src)
+ defer runtime.KeepAlive(src)
+ actualMem := unsafeMemory(dst)
+ defer runtime.KeepAlive(dst)
+
+ compareMemory(t, expectedMem, actualMem, n)
+
+ // The last n bytes should be zero for actual, since actual was
+ // zero-initialized, and CopyIn shouldn't have touched those bytes. However
+ // we can only guarantee we didn't touch anything in the last n bytes if the
+ // layout is packed.
+ if dst.Packed() && !isZeroes(actualMem[n:]) {
+ t.Errorf("Expected the last %d bytes of copied in object to be zeroes, got %v\n", dst.SizeBytes()-n, actualMem)
+ }
+}
+
+// limitedCopyOut marshals src to task memory. The task signals an error at
+// limit bytes during copy-out, which should result in a truncated marshalling.
+func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) {
+ var task mockTask
+ task.setLimit(limit)
+
+ n, err := src.CopyOut(&task, usermem.Addr(0))
+ if n != limit {
+ t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if err != simulatedErr {
+ t.Errorf("CopyOut returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := unsafeMemory(src)
+ defer runtime.KeepAlive(src)
+ actualMem := task.taskMem.Bytes
+
+ compareMemory(t, expectedMem, actualMem, n)
+}
+
+// copyOutN marshals src to task memory, requesting the marshalling to be
+// limited to limit bytes.
+func copyOutN(t *testing.T, src marshal.Marshallable, limit int) {
+ var task mockTask
+ task.setLimit(limit)
+
+ n, err := src.CopyOutN(&task, usermem.Addr(0), limit)
+ if err != nil {
+ t.Errorf("CopyOut returned unexpected error: %v", err)
+ }
+ if n != limit {
+ t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+
+ expectedMem := unsafeMemory(src)
+ defer runtime.KeepAlive(src)
+ actualMem := task.taskMem.Bytes
+
+ t.Logf("Expected: %v + %v\n", expectedMem[:n], expectedMem[n:])
+ t.Logf("Actual : %v + %v\n", actualMem[:n], actualMem[n:])
+
+ compareMemory(t, expectedMem, actualMem, n)
+}
+
+// TestLimitedMarshalling verifies marshalling/unmarshalling succeeds when the
+// underyling copy in/out operations partially succeed.
+func TestLimitedMarshalling(t *testing.T) {
+ types := []reflect.Type{
+ // Packed types.
+ reflect.TypeOf((*test.Type2)(nil)),
+ reflect.TypeOf((*test.Type3)(nil)),
+ reflect.TypeOf((*test.Timespec)(nil)),
+ reflect.TypeOf((*test.Stat)(nil)),
+ reflect.TypeOf((*test.InetAddr)(nil)),
+ reflect.TypeOf((*test.SignalSet)(nil)),
+ reflect.TypeOf((*test.SignalSetAlias)(nil)),
+ // Non-packed types.
+ reflect.TypeOf((*test.Type1)(nil)),
+ reflect.TypeOf((*test.Type4)(nil)),
+ reflect.TypeOf((*test.Type5)(nil)),
+ reflect.TypeOf((*test.Type6)(nil)),
+ reflect.TypeOf((*test.Type7)(nil)),
+ reflect.TypeOf((*test.Type8)(nil)),
+ }
+
+ for _, tyPtr := range types {
+ // Remove one level of pointer-indirection from the type. We get this
+ // back when we pass the type to reflect.New.
+ ty := tyPtr.Elem()
+
+ // Partial copy-in.
+ t.Run(fmt.Sprintf("PartialCopyIn_%v", ty), func(t *testing.T) {
+ expected := reflect.New(ty).Interface().(marshal.Marshallable)
+ actual := reflect.New(ty).Interface().(marshal.Marshallable)
+ analysis.RandomizeValue(expected)
+
+ limitedCopyIn(t, expected, actual, expected.SizeBytes()/2)
+ })
+
+ // Partial copy-out.
+ t.Run(fmt.Sprintf("PartialCopyOut_%v", ty), func(t *testing.T) {
+ expected := reflect.New(ty).Interface().(marshal.Marshallable)
+ analysis.RandomizeValue(expected)
+
+ limitedCopyOut(t, expected, expected.SizeBytes()/2)
+ })
+
+ // Explicitly request partial copy-out.
+ t.Run(fmt.Sprintf("PartialCopyOutN_%v", ty), func(t *testing.T) {
+ expected := reflect.New(ty).Interface().(marshal.Marshallable)
+ analysis.RandomizeValue(expected)
+
+ copyOutN(t, expected, expected.SizeBytes()/2)
+ })
+ }
+}
+
+// TestLimitedMarshalling verifies marshalling/unmarshalling of slices of
+// marshallable types succeed when the underyling copy in/out operations
+// partially succeed.
+func TestLimitedSliceMarshalling(t *testing.T) {
+ types := []struct {
+ arrayPtrType reflect.Type
+ copySliceIn func(task marshal.Task, addr usermem.Addr, dstSlice interface{}) (int, error)
+ copySliceOut func(task marshal.Task, addr usermem.Addr, srcSlice interface{}) (int, error)
+ unsafeMemory func(arrPtr interface{}) []byte
+ }{
+ // Packed types.
+ {
+ reflect.TypeOf((*[20]test.Stat)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[20]test.Stat)[:]
+ return test.CopyStatSliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[20]test.Stat)[:]
+ return test.CopyStatSliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[20]test.Stat)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[1]test.Stat)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[1]test.Stat)[:]
+ return test.CopyStatSliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[1]test.Stat)[:]
+ return test.CopyStatSliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[1]test.Stat)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[5]test.SignalSetAlias)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[5]test.SignalSetAlias)[:]
+ return test.CopySignalSetAliasSliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[5]test.SignalSetAlias)[:]
+ return test.CopySignalSetAliasSliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[5]test.SignalSetAlias)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ // Non-packed types.
+ {
+ reflect.TypeOf((*[20]test.Type1)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[20]test.Type1)[:]
+ return test.CopyType1SliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[20]test.Type1)[:]
+ return test.CopyType1SliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[20]test.Type1)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[1]test.Type1)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[1]test.Type1)[:]
+ return test.CopyType1SliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[1]test.Type1)[:]
+ return test.CopyType1SliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[1]test.Type1)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[7]test.Type8)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[7]test.Type8)[:]
+ return test.CopyType8SliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[7]test.Type8)[:]
+ return test.CopyType8SliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[7]test.Type8)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ }
+
+ for _, tt := range types {
+ // The body of this loop is generic over the type tt.arrayPtrType, with
+ // the help of reflection. To aid in readability, comments below show
+ // the equivalent go code assuming
+ // tt.arrayPtrType = typeof(*[20]test.Stat).
+
+ // Equivalent:
+ // var x *[20]test.Stat
+ // arrayTy := reflect.TypeOf(*x)
+ arrayTy := tt.arrayPtrType.Elem()
+
+ // Partial copy-in of slices.
+ t.Run(fmt.Sprintf("PartialCopySliceIn_%v", arrayTy), func(t *testing.T) {
+ // Equivalent:
+ // var x [20]test.Stat
+ // length := len(x)
+ length := arrayTy.Len()
+ if length < 1 {
+ panic("Test type can't be zero-length array")
+ }
+ // Equivalent:
+ // elem := new(test.Stat).(marshal.Marshallable)
+ elem := reflect.New(arrayTy.Elem()).Interface().(marshal.Marshallable)
+
+ // Equivalent:
+ // var expected, actual interface{}
+ // expected = new([20]test.Stat)
+ // actual = new([20]test.Stat)
+ expected := reflect.New(arrayTy).Interface()
+ actual := reflect.New(arrayTy).Interface()
+
+ analysis.RandomizeValue(expected)
+
+ limit := (length * elem.SizeBytes()) / 2
+ // Also make sure the limit is partially inside one of the elements.
+ limit += elem.SizeBytes() / 2
+ analysis.RandomizeValue(expected)
+
+ var task mockTask
+ task.populate(expected)
+ task.setLimit(limit)
+
+ n, err := tt.copySliceIn(&task, usermem.Addr(0), actual)
+ if n != limit {
+ t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if n < length*elem.SizeBytes() && err != simulatedErr {
+ t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := tt.unsafeMemory(expected)
+ defer runtime.KeepAlive(expected)
+ actualMem := tt.unsafeMemory(actual)
+ defer runtime.KeepAlive(actual)
+
+ compareMemory(t, expectedMem, actualMem, n)
+
+ // The last n bytes should be zero for actual, since actual was
+ // zero-initialized, and CopyIn shouldn't have touched those bytes. However
+ // we can only guarantee we didn't touch anything in the last n bytes if the
+ // layout is packed.
+ if elem.Packed() && !isZeroes(actualMem[n:]) {
+ t.Errorf("Expected the last %d bytes of copied in object to be zeroes, got %v\n", (elem.SizeBytes()*length)-n, actualMem)
+ }
+ })
+
+ // Partial copy-out of slices.
+ t.Run(fmt.Sprintf("PartialCopySliceOut_%v", arrayTy), func(t *testing.T) {
+ // Equivalent:
+ // var x [20]test.Stat
+ // length := len(x)
+ length := arrayTy.Len()
+ if length < 1 {
+ panic("Test type can't be zero-length array")
+ }
+ // Equivalent:
+ // elem := new(test.Stat).(marshal.Marshallable)
+ elem := reflect.New(arrayTy.Elem()).Interface().(marshal.Marshallable)
+
+ // Equivalent:
+ // var expected, actual interface{}
+ // expected = new([20]test.Stat)
+ // actual = new([20]test.Stat)
+ expected := reflect.New(arrayTy).Interface()
+
+ analysis.RandomizeValue(expected)
+
+ limit := (length * elem.SizeBytes()) / 2
+ // Also make sure the limit is partially inside one of the elements.
+ limit += elem.SizeBytes() / 2
+ analysis.RandomizeValue(expected)
+
+ var task mockTask
+ task.populate(expected)
+ task.setLimit(limit)
+
+ n, err := tt.copySliceOut(&task, usermem.Addr(0), expected)
+ if n != limit {
+ t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if n < length*elem.SizeBytes() && err != simulatedErr {
+ t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := tt.unsafeMemory(expected)
+ defer runtime.KeepAlive(expected)
+ actualMem := task.taskMem.Bytes
+
+ compareMemory(t, expectedMem, actualMem, n)
+ })
+ }
+}
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
index 93229dedb..f75ca1b7f 100644
--- a/tools/go_marshal/test/test.go
+++ b/tools/go_marshal/test/test.go
@@ -23,7 +23,7 @@ import (
// Type1 is a test data type.
//
-// +marshal
+// +marshal slice:Type1Slice
type Type1 struct {
a Type2
x, y int64 // Multiple field names.
@@ -75,6 +75,34 @@ type Type5 struct {
m int64
}
+// Type6 is a test data type ends mid-word.
+//
+// +marshal
+type Type6 struct {
+ a int64
+ b int64
+ // If c isn't marked unaligned, analysis fails (as it should, since
+ // the unsafe API corrupts Type7).
+ c byte `marshal:"unaligned"`
+}
+
+// Type7 is a test data type that contains a child struct that ends
+// mid-word.
+// +marshal
+type Type7 struct {
+ x Type6
+ y int64
+}
+
+// Type8 is a test data type which contains an external non-packed field.
+//
+// +marshal slice:Type8Slice
+type Type8 struct {
+ a int64
+ np ex.NotPacked
+ b int64
+}
+
// Timespec represents struct timespec in <time.h>.
//
// +marshal
@@ -85,7 +113,7 @@ type Timespec struct {
// Stat represents struct stat.
//
-// +marshal
+// +marshal slice:StatSlice
type Stat struct {
Dev uint64
Ino uint64
@@ -104,12 +132,45 @@ type Stat struct {
_ [3]int64
}
-// SignalSet is an example marshallable newtype on a primitive.
+// InetAddr is an example marshallable newtype on an array.
//
// +marshal
+type InetAddr [4]byte
+
+// SignalSet is an example marshallable newtype on a primitive.
+//
+// +marshal slice:SignalSetSlice:inner
type SignalSet uint64
// SignalSetAlias is an example newtype on another marshallable type.
//
-// +marshal
+// +marshal slice:SignalSetAliasSlice
type SignalSetAlias SignalSet
+
+const sizeA = 64
+const sizeB = 8
+
+// TestArray is a test data structure on an array with a constant length.
+//
+// +marshal
+type TestArray [sizeA]int32
+
+// TestArray2 is a newtype on an array with a simple arithmetic expression of
+// constants for the array length.
+//
+// +marshal
+type TestArray2 [sizeA * sizeB]int32
+
+// TestArray2 is a newtype on an array with a simple arithmetic expression of
+// mixed constants and literals for the array length.
+//
+// +marshal
+type TestArray3 [sizeA*sizeB + 12]int32
+
+// Type9 is a test data type containing an array with a non-literal length.
+//
+// +marshal
+type Type9 struct {
+ x int64
+ y [sizeA]int32
+}
diff --git a/tools/go_mod.sh b/tools/go_mod.sh
new file mode 100755
index 000000000..84b779d6d
--- /dev/null
+++ b/tools/go_mod.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -eo pipefail
+
+# Build the :gopath target.
+bazel build //:gopath
+declare -r gopathdir="bazel-bin/gopath/src/gvisor.dev/gvisor/"
+
+# Copy go.mod and execute the command.
+cp -a go.mod go.sum "${gopathdir}"
+(cd "${gopathdir}" && go mod "$@")
+cp -a "${gopathdir}/go.mod" "${gopathdir}/go.sum" .
+
+# Cleanup the WORKSPACE file.
+bazel run //:gazelle -- update-repos -from_file=go.mod
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
index 3437aa476..309ee9c21 100644
--- a/tools/go_stateify/main.go
+++ b/tools/go_stateify/main.go
@@ -206,7 +206,7 @@ func main() {
initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name))
}
emitZeroCheck := func(name string) {
- fmt.Fprintf(outputFile, " if !%sIsZeroValue(x.%s) { m.Failf(\"%s is %%v, expected zero\", x.%s) }\n", statePrefix, name, name, name)
+ fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { m.Failf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, name, name)
}
emitLoadValue := func(name, typName string) {
fmt.Fprintf(outputFile, " m.LoadValue(\"%s\", new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", name, typName, camelCased(name), typName)
diff --git a/tools/image_build.sh b/tools/image_build.sh
deleted file mode 100755
index 9b20a740d..000000000
--- a/tools/image_build.sh
+++ /dev/null
@@ -1,98 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# This script is responsible for building a new GCP image that: 1) has nested
-# virtualization enabled, and 2) has been completely set up with the
-# image_setup.sh script. This script should be idempotent, as we memoize the
-# setup script with a hash and check for that name.
-#
-# The GCP project name should be defined via a gcloud config.
-
-set -xeo pipefail
-
-# Parameters.
-declare -r ZONE=${ZONE:-us-central1-f}
-declare -r USERNAME=${USERNAME:-test}
-declare -r IMAGE_PROJECT=${IMAGE_PROJECT:-ubuntu-os-cloud}
-declare -r IMAGE_FAMILY=${IMAGE_FAMILY:-ubuntu-1604-lts}
-
-# Random names.
-declare -r DISK_NAME=$(mktemp -u disk-XXXXXX | tr A-Z a-z)
-declare -r SNAPSHOT_NAME=$(mktemp -u snapshot-XXXXXX | tr A-Z a-z)
-declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z)
-
-# Hashes inputs.
-declare -r SETUP_BLOB=$(echo ${ZONE} ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && sha256sum "$@")
-declare -r SETUP_HASH=$(echo ${SETUP_BLOB} | sha256sum - | cut -d' ' -f1 | cut -c 1-16)
-declare -r IMAGE_NAME=${IMAGE_NAME:-image-}${SETUP_HASH}
-
-# Does the image already exist? Skip the build.
-declare -r existing=$(gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)")
-if ! [[ -z "${existing}" ]]; then
- echo "${existing}"
- exit 0
-fi
-
-# Set the zone for all actions.
-gcloud config set compute/zone "${ZONE}"
-
-# Start a unique instance. Note that this instance will have a unique persistent
-# disk as it's boot disk with the same name as the instance.
-gcloud compute instances create \
- --quiet \
- --image-project "${IMAGE_PROJECT}" \
- --image-family "${IMAGE_FAMILY}" \
- --boot-disk-size "200GB" \
- "${INSTANCE_NAME}"
-function cleanup {
- gcloud compute instances delete --quiet "${INSTANCE_NAME}"
-}
-trap cleanup EXIT
-
-# Wait for the instance to become available.
-declare attempts=0
-while [[ "${attempts}" -lt 30 ]]; do
- attempts=$((${attempts}+1))
- if gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- true; then
- break
- fi
-done
-if [[ "${attempts}" -ge 30 ]]; then
- echo "too many attempts: failed"
- exit 1
-fi
-
-# Run the install scripts provided.
-for arg; do
- gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- sudo bash - <"${arg}"
-done
-
-# Stop the instance; required before creating an image.
-gcloud compute instances stop --quiet "${INSTANCE_NAME}"
-
-# Create a snapshot of the instance disk.
-gcloud compute disks snapshot \
- --quiet \
- --zone="${ZONE}" \
- --snapshot-names="${SNAPSHOT_NAME}" \
- "${INSTANCE_NAME}"
-
-# Create the disk image.
-gcloud compute images create \
- --quiet \
- --source-snapshot="${SNAPSHOT_NAME}" \
- --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \
- "${IMAGE_NAME}"
diff --git a/tools/images/BUILD b/tools/images/BUILD
index fe11f08a3..8d319e3e4 100644
--- a/tools/images/BUILD
+++ b/tools/images/BUILD
@@ -6,14 +6,9 @@ package(
licenses = ["notice"],
)
-genrule(
+sh_binary(
name = "zone",
- outs = ["zone.txt"],
- cmd = "gcloud config get-value compute/zone > $@",
- tags = [
- "local",
- "manual",
- ],
+ srcs = ["zone.sh"],
)
sh_binary(
diff --git a/tools/images/README.md b/tools/images/README.md
new file mode 100644
index 000000000..26c0f84f2
--- /dev/null
+++ b/tools/images/README.md
@@ -0,0 +1,42 @@
+# Images
+
+All commands in this directory require the `gcloud` project to be set.
+
+For example: `gcloud config set project gvisor-kokoro-testing`.
+
+Images can be generated by using the `vm_image` rule. This rule will generate a
+binary target that builds an image in an idempotent way, and can be referenced
+from other rules.
+
+For example:
+
+```
+vm_image(
+ name = "ubuntu",
+ project = "ubuntu-1604-lts",
+ family = "ubuntu-os-cloud",
+ scripts = [
+ "script.sh",
+ "other.sh",
+ ],
+)
+```
+
+These images can be built manually by executing the target. The output on
+`stdout` will be the image id (in the current project).
+
+Images are always named per the hash of all the hermetic input scripts. This
+allows images to be memoized quickly and easily.
+
+The `vm_test` rule can be used to execute a command remotely. This is still
+under development however, and will likely change over time.
+
+For example:
+
+```
+vm_test(
+ name = "mycommand",
+ image = ":ubuntu",
+ targets = [":test"],
+)
+```
diff --git a/tools/images/build.sh b/tools/images/build.sh
index be462d556..f39f723b8 100755
--- a/tools/images/build.sh
+++ b/tools/images/build.sh
@@ -19,7 +19,7 @@
# image_setup.sh script. This script should be idempotent, as we memoize the
# setup script with a hash and check for that name.
-set -xeou pipefail
+set -eou pipefail
# Parameters.
declare -r USERNAME=${USERNAME:-test}
@@ -34,10 +34,10 @@ declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z)
# Hash inputs in order to memoize the produced image.
declare -r SETUP_HASH=$( (echo ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && cat "$@") | sha256sum - | cut -d' ' -f1 | cut -c 1-16)
-declare -r IMAGE_NAME=${IMAGE_FAMILY:-image-}${SETUP_HASH}
+declare -r IMAGE_NAME=${IMAGE_FAMILY:-image}-${SETUP_HASH}
# Does the image already exist? Skip the build.
-declare -r existing=$(gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)")
+declare -r existing=$(set -x; gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)")
if ! [[ -z "${existing}" ]]; then
echo "${existing}"
exit 0
@@ -48,54 +48,66 @@ export PATH=${PATH:-/bin:/usr/bin:/usr/local/bin}
# Start a unique instance. Note that this instance will have a unique persistent
# disk as it's boot disk with the same name as the instance.
-gcloud compute instances create \
+(set -x; gcloud compute instances create \
--quiet \
--image-project "${IMAGE_PROJECT}" \
--image-family "${IMAGE_FAMILY}" \
--boot-disk-size "200GB" \
--zone "${ZONE}" \
- "${INSTANCE_NAME}" >/dev/null
+ "${INSTANCE_NAME}" >/dev/null)
function cleanup {
- gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}"
+ (set -x; gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}")
}
trap cleanup EXIT
# Wait for the instance to become available (up to 5 minutes).
+echo -n "Waiting for ${INSTANCE_NAME}"
declare timeout=300
declare success=0
+declare internal=""
declare -r start=$(date +%s)
declare -r end=$((${start}+${timeout}))
while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do
+ echo -n "."
if gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- env - true 2>/dev/null; then
success=$((${success}+1))
+ elif gcloud compute ssh --internal-ip --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- env - true 2>/dev/null; then
+ success=$((${success}+1))
+ internal="--internal-ip"
fi
done
+
if [[ "${success}" -eq "0" ]]; then
echo "connect timed out after ${timeout} seconds."
exit 1
+else
+ echo "done."
fi
# Run the install scripts provided.
for arg; do
- gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- sudo bash - <"${arg}" >/dev/null
+ (set -x; gcloud compute ssh ${internal} \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}" -- \
+ sudo bash - <"${arg}" >/dev/null)
done
# Stop the instance; required before creating an image.
-gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null
+(set -x; gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null)
# Create a snapshot of the instance disk.
-gcloud compute disks snapshot \
+(set -x; gcloud compute disks snapshot \
--quiet \
--zone "${ZONE}" \
--snapshot-names="${SNAPSHOT_NAME}" \
- "${INSTANCE_NAME}" >/dev/null
+ "${INSTANCE_NAME}" >/dev/null)
# Create the disk image.
-gcloud compute images create \
+(set -x; gcloud compute images create \
--quiet \
--source-snapshot="${SNAPSHOT_NAME}" \
--licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \
- "${IMAGE_NAME}" >/dev/null
+ "${IMAGE_NAME}" >/dev/null)
# Finish up.
echo "${IMAGE_NAME}"
diff --git a/tools/images/defs.bzl b/tools/images/defs.bzl
index de365d153..2847e1847 100644
--- a/tools/images/defs.bzl
+++ b/tools/images/defs.bzl
@@ -1,76 +1,49 @@
-"""Image configuration.
-
-Images can be generated by using the vm_image rule. For example,
-
- vm_image(
- name = "ubuntu",
- project = "...",
- family = "...",
- scripts = [
- "script.sh",
- "other.sh",
- ],
- )
-
-This will always create an vm_image in the current default gcloud project. The
-rule has a text file as its output containing the image name. This will enforce
-serialization for all dependent rules.
-
-Images are always named per the hash of all the hermetic input scripts. This
-allows images to be memoized quickly and easily.
-
-The vm_test rule can be used to execute a command remotely. For example,
-
- vm_test(
- name = "mycommand",
- image = ":myimage",
- targets = [":test"],
- )
-"""
+"""Image configuration. See README.md."""
load("//tools:defs.bzl", "default_installer")
-def _vm_image_impl(ctx):
+# vm_image_builder is a rule that will construct a shell script that actually
+# generates a given VM image. Note that this does not _run_ the shell script
+# (although it can be run manually). It will be run manually during generation
+# of the vm_image target itself. This level of indirection is used so that the
+# build system itself only runs the builder once when multiple targets depend
+# on it, avoiding a set of races and conflicts.
+def _vm_image_builder_impl(ctx):
+ # Generate a binary that actually builds the image.
+ builder = ctx.actions.declare_file(ctx.label.name)
script_paths = []
for script in ctx.files.scripts:
script_paths.append(script.short_path)
+ builder_content = "\n".join([
+ "#!/bin/bash",
+ "export ZONE=$(%s)" % ctx.files.zone[0].short_path,
+ "export USERNAME=%s" % ctx.attr.username,
+ "export IMAGE_PROJECT=%s" % ctx.attr.project,
+ "export IMAGE_FAMILY=%s" % ctx.attr.family,
+ "%s %s" % (ctx.files._builder[0].short_path, " ".join(script_paths)),
+ "",
+ ])
+ ctx.actions.write(builder, builder_content, is_executable = True)
- resolved_inputs, argv, runfiles_manifests = ctx.resolve_command(
- command = "USERNAME=%s ZONE=$(cat %s) IMAGE_PROJECT=%s IMAGE_FAMILY=%s %s %s > %s" %
- (
- ctx.attr.username,
- ctx.files.zone[0].path,
- ctx.attr.project,
- ctx.attr.family,
- ctx.executable.builder.path,
- " ".join(script_paths),
- ctx.outputs.out.path,
- ),
- tools = [ctx.attr.builder] + ctx.attr.scripts,
- )
-
- ctx.actions.run_shell(
- tools = resolved_inputs,
- outputs = [ctx.outputs.out],
- progress_message = "Building image...",
- execution_requirements = {"local": "true"},
- command = argv,
- input_manifests = runfiles_manifests,
- )
+ # Note that the scripts should only be files, and should not include any
+ # indirect transitive dependencies. The build script wouldn't work.
return [DefaultInfo(
- files = depset([ctx.outputs.out]),
- runfiles = ctx.runfiles(files = [ctx.outputs.out]),
+ executable = builder,
+ runfiles = ctx.runfiles(
+ files = ctx.files.scripts + ctx.files._builder + ctx.files.zone,
+ ),
)]
-_vm_image = rule(
+vm_image_builder = rule(
attrs = {
- "builder": attr.label(
+ "_builder": attr.label(
executable = True,
default = "//tools/images:builder",
cfg = "host",
),
"username": attr.string(default = "$(whoami)"),
"zone": attr.label(
+ executable = True,
default = "//tools/images:zone",
cfg = "host",
),
@@ -78,20 +51,55 @@ _vm_image = rule(
"project": attr.string(mandatory = True),
"scripts": attr.label_list(allow_files = True),
},
- outputs = {
- "out": "%{name}.txt",
+ executable = True,
+ implementation = _vm_image_builder_impl,
+)
+
+# See vm_image_builder above.
+def _vm_image_impl(ctx):
+ # Run the builder to generate our output.
+ echo = ctx.actions.declare_file(ctx.label.name)
+ resolved_inputs, argv, runfiles_manifests = ctx.resolve_command(
+ command = "echo -ne \"#!/bin/bash\\necho $(%s)\\n\" > %s && chmod 0755 %s" % (
+ ctx.files.builder[0].path,
+ echo.path,
+ echo.path,
+ ),
+ tools = [ctx.attr.builder],
+ )
+ ctx.actions.run_shell(
+ tools = resolved_inputs,
+ outputs = [echo],
+ progress_message = "Building image...",
+ execution_requirements = {"local": "true"},
+ command = argv,
+ input_manifests = runfiles_manifests,
+ )
+
+ # Return just the echo command. All of the builder runfiles have been
+ # resolved and consumed in the generation of the trivial echo script.
+ return [DefaultInfo(executable = echo)]
+
+_vm_image = rule(
+ attrs = {
+ "builder": attr.label(
+ executable = True,
+ cfg = "host",
+ ),
},
+ executable = True,
implementation = _vm_image_impl,
)
-def vm_image(**kwargs):
- _vm_image(
- tags = [
- "local",
- "manual",
- ],
+def vm_image(name, **kwargs):
+ vm_image_builder(
+ name = name + "_builder",
**kwargs
)
+ _vm_image(
+ name = name,
+ builder = ":" + name + "_builder",
+ )
def _vm_test_impl(ctx):
runner = ctx.actions.declare_file("%s-executer" % ctx.label.name)
diff --git a/tools/images/ubuntu1604/10_core.sh b/tools/images/ubuntu1604/10_core.sh
index 46dda6bb1..cd518d6ac 100755
--- a/tools/images/ubuntu1604/10_core.sh
+++ b/tools/images/ubuntu1604/10_core.sh
@@ -17,7 +17,20 @@
set -xeo pipefail
# Install all essential build tools.
-apt-get update && apt-get -y install make git-core build-essential linux-headers-$(uname -r) pkg-config
+while true; do
+ if (apt-get update && apt-get install -y \
+ make \
+ git-core \
+ build-essential \
+ linux-headers-$(uname -r) \
+ pkg-config); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
# Install a recent go toolchain.
if ! [[ -d /usr/local/go ]]; then
diff --git a/tools/images/ubuntu1604/20_bazel.sh b/tools/images/ubuntu1604/20_bazel.sh
index b33e1656c..bb7afa676 100755
--- a/tools/images/ubuntu1604/20_bazel.sh
+++ b/tools/images/ubuntu1604/20_bazel.sh
@@ -19,7 +19,17 @@ set -xeo pipefail
declare -r BAZEL_VERSION=2.0.0
# Install bazel dependencies.
-apt-get update && apt-get install -y openjdk-8-jdk-headless unzip
+while true; do
+ if (apt-get update && apt-get install -y \
+ openjdk-8-jdk-headless \
+ unzip); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
# Use the release installer.
curl -L -o bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
diff --git a/tools/images/ubuntu1604/25_docker.sh b/tools/images/ubuntu1604/25_docker.sh
index 1d3defcd3..11eea2d72 100755
--- a/tools/images/ubuntu1604/25_docker.sh
+++ b/tools/images/ubuntu1604/25_docker.sh
@@ -15,12 +15,20 @@
# limitations under the License.
# Add dependencies.
-apt-get update && apt-get -y install \
- apt-transport-https \
- ca-certificates \
- curl \
- gnupg-agent \
- software-properties-common
+while true; do
+ if (apt-get update && apt-get install -y \
+ apt-transport-https \
+ ca-certificates \
+ curl \
+ gnupg-agent \
+ software-properties-common); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
# Install the key.
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add -
@@ -32,4 +40,15 @@ add-apt-repository \
stable"
# Install docker.
-apt-get update && apt-get install -y docker-ce docker-ce-cli containerd.io
+while true; do
+ if (apt-get update && apt-get install -y \
+ docker-ce \
+ docker-ce-cli \
+ containerd.io); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
diff --git a/tools/images/ubuntu1604/30_containerd.sh b/tools/images/ubuntu1604/30_containerd.sh
index a7472bd1c..fb3699c12 100755
--- a/tools/images/ubuntu1604/30_containerd.sh
+++ b/tools/images/ubuntu1604/30_containerd.sh
@@ -34,7 +34,17 @@ install_helper() {
}
# Install dependencies for the crictl tests.
-apt-get install -y btrfs-tools libseccomp-dev
+while true; do
+ if (apt-get update && apt-get install -y \
+ btrfs-tools \
+ libseccomp-dev); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
# Install containerd & cri-tools.
GOPATH=$(mktemp -d --tmpdir gopathXXXXX)
diff --git a/tools/images/ubuntu1604/40_kokoro.sh b/tools/images/ubuntu1604/40_kokoro.sh
index 5f2dfc858..06a1e6c48 100755
--- a/tools/images/ubuntu1604/40_kokoro.sh
+++ b/tools/images/ubuntu1604/40_kokoro.sh
@@ -23,7 +23,22 @@ declare -r ssh_public_keys=(
)
# Install dependencies.
-apt-get update && apt-get install -y rsync coreutils python-psutil qemu-kvm python-pip python3-pip zip
+while true; do
+ if (apt-get update && apt-get install -y \
+ rsync \
+ coreutils \
+ python-psutil \
+ qemu-kvm \
+ python-pip \
+ python3-pip \
+ zip); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
# junitparser is used to merge junit xml files.
pip install junitparser
diff --git a/tools/images/zone.sh b/tools/images/zone.sh
new file mode 100755
index 000000000..79569fb19
--- /dev/null
+++ b/tools/images/zone.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+exec gcloud config get-value compute/zone
diff --git a/tools/installers/master.sh b/tools/installers/master.sh
index 52f9734a6..2c6001c6c 100755
--- a/tools/installers/master.sh
+++ b/tools/installers/master.sh
@@ -19,17 +19,16 @@ set -e
curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add -
add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main"
+
while true; do
- if apt-get update; then
- apt-get install -y runsc
+ if (apt-get update && apt-get install -y runsc); then
break
fi
result=$?
- # Check if apt update failed to aquire the file lock.
if [[ $result -ne 100 ]]; then
exit $result
fi
done
+
runsc install
service docker restart
-
diff --git a/tools/nogo.js b/tools/nogo.js
deleted file mode 100644
index fc0a4d1f0..000000000
--- a/tools/nogo.js
+++ /dev/null
@@ -1,7 +0,0 @@
-{
- "checkunsafe": {
- "exclude_files": {
- "/external/": "not subject to constraint"
- }
- }
-}
diff --git a/tools/nogo.json b/tools/nogo.json
new file mode 100644
index 000000000..ae969409e
--- /dev/null
+++ b/tools/nogo.json
@@ -0,0 +1,39 @@
+{
+ "assign": {
+ "exclude_files": {
+ "/external/bazel_gazelle/walk/walk.go": "allowed: false positive"
+ }
+ },
+ "checkunsafe": {
+ "exclude_files": {
+ "/external/": "allowed: not subject to unsafe naming rules"
+ }
+ },
+ "nilness": {
+ "exclude_files": {
+ "/com_github_vishvananda_netlink/route_linux.go": "allowed: false positive",
+ "/external/bazel_gazelle/cmd/gazelle/.*": "allowed: false positive",
+ "/org_golang_x_tools/go/packages/golist.go": "allowed: runtime internals",
+ "/pkg/sentry/platform/kvm/kvm_test.go": "allowed: intentional",
+ "/tools/bigquery/bigquery.go": "allowed: false positive",
+ "/external/io_opencensus_go/tag/map_codec.go": "allowed: false positive"
+ }
+ },
+ "structtag": {
+ "exclude_files": {
+ "/external/": "allowed: may use arbitrary tags"
+ }
+ },
+ "unsafeptr": {
+ "exclude_files": {
+ ".*_test.go": "allowed: exclude tests",
+ "/pkg/flipcall/flipcall_unsafe.go": "allowed: special case",
+ "/pkg/gohacks/gohacks_unsafe.go": "allowed: special case",
+ "/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go": "allowed: special case",
+ "/pkg/sentry/platform/kvm/(bluepill|machine)_unsafe.go": "allowed: special case",
+ "/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go": "allowed: special case",
+ "/pkg/sentry/platform/safecopy/safecopy_unsafe.go": "allowed: special case",
+ "/pkg/sentry/vfs/mount_unsafe.go": "allowed: special case"
+ }
+ }
+}
diff --git a/tools/tag_release.sh b/tools/tag_release.sh
index f33b902d6..4dbfe420a 100755
--- a/tools/tag_release.sh
+++ b/tools/tag_release.sh
@@ -21,13 +21,19 @@
set -xeu
# Check arguments.
-if [ "$#" -ne 2 ]; then
- echo "usage: $0 <commit|revid> <release.rc>"
+if [ "$#" -ne 3 ]; then
+ echo "usage: $0 <commit|revid> <release.rc> <message-file>"
exit 1
fi
declare -r target_commit="$1"
declare -r release="$2"
+declare -r message_file="$3"
+
+if ! [[ -r "${message_file}" ]]; then
+ echo "error: message file '${message_file}' is not readable."
+ exit 1
+fi
closest_commit() {
while read line; do
@@ -64,6 +70,6 @@ fi
# Tag the given commit (annotated, to record the committer).
declare -r tag="release-${release}"
-(git tag -m "Release ${release}" -a "${tag}" "${commit}" && \
+(git tag -F "${message_file}" -a "${tag}" "${commit}" && \
git push origin tag "${tag}") || \
(git tag -d "${tag}" && false)
diff --git a/vdso/vdso.cc b/vdso/vdso.cc
index 8bb80a7a4..c2585d592 100644
--- a/vdso/vdso.cc
+++ b/vdso/vdso.cc
@@ -126,6 +126,10 @@ extern "C" int __kernel_clock_getres(clockid_t clock, struct timespec* res) {
case CLOCK_REALTIME:
case CLOCK_MONOTONIC:
case CLOCK_BOOTTIME: {
+ if (res == nullptr) {
+ return 0;
+ }
+
res->tv_sec = 0;
res->tv_nsec = 1;
break;