summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--WORKSPACE30
-rw-r--r--benchmarks/runner/__init__.py6
-rw-r--r--benchmarks/runner/commands.py8
-rw-r--r--kokoro/runtime_tests/go1.12.cfg2
-rw-r--r--kokoro/runtime_tests/java11.cfg2
-rw-r--r--kokoro/runtime_tests/nodejs12.4.0.cfg2
-rw-r--r--kokoro/runtime_tests/php7.3.6.cfg2
-rw-r--r--kokoro/runtime_tests/python3.7.3.cfg2
-rw-r--r--pkg/atomicbitops/atomicbitops_amd64.s16
-rw-r--r--pkg/atomicbitops/atomicbitops_arm64.s16
-rw-r--r--pkg/atomicbitops/atomicbitops_noasm.go8
-rw-r--r--pkg/context/context.go4
-rw-r--r--pkg/cpuid/cpuid_parse_x86_test.go2
-rw-r--r--pkg/cpuid/cpuid_x86.go2
-rw-r--r--pkg/cpuid/cpuid_x86_test.go2
-rw-r--r--pkg/eventchannel/event_test.go4
-rw-r--r--pkg/ilist/list.go13
-rw-r--r--pkg/log/glog.go6
-rw-r--r--pkg/log/json.go2
-rw-r--r--pkg/log/json_k8s.go4
-rw-r--r--pkg/log/log.go2
-rw-r--r--pkg/log/log_test.go6
-rw-r--r--pkg/p9/client.go45
-rw-r--r--pkg/p9/client_test.go7
-rw-r--r--pkg/p9/file.go8
-rw-r--r--pkg/p9/handlers.go41
-rw-r--r--pkg/p9/messages.go14
-rw-r--r--pkg/p9/messages_test.go2
-rw-r--r--pkg/p9/transport_flipcall.go2
-rw-r--r--pkg/safecopy/memcpy_amd64.s111
-rw-r--r--pkg/safecopy/safecopy.go4
-rw-r--r--pkg/safecopy/safecopy_unsafe.go6
-rw-r--r--pkg/segment/test/segment_test.go2
-rw-r--r--pkg/sentry/arch/arch.go3
-rw-r--r--pkg/sentry/arch/arch_state_x86.go2
-rw-r--r--pkg/sentry/arch/arch_x86.go2
-rw-r--r--pkg/sentry/arch/arch_x86_impl.go2
-rw-r--r--pkg/sentry/arch/signal_stack.go2
-rw-r--r--pkg/sentry/arch/stack.go3
-rw-r--r--pkg/sentry/arch/syscalls_amd64.go7
-rw-r--r--pkg/sentry/arch/syscalls_arm64.go13
-rw-r--r--pkg/sentry/contexttest/contexttest.go4
-rw-r--r--pkg/sentry/fs/dirent.go6
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_test.go4
-rw-r--r--pkg/sentry/fs/gofer/file_state.go1
-rw-r--r--pkg/sentry/fs/gofer/handles.go1
-rw-r--r--pkg/sentry/fs/gofer/inode.go5
-rw-r--r--pkg/sentry/fs/gofer/inode_state.go1
-rw-r--r--pkg/sentry/fs/gofer/session_state.go1
-rw-r--r--pkg/sentry/fs/gofer/util.go16
-rw-r--r--pkg/sentry/fs/host/inode.go3
-rw-r--r--pkg/sentry/fs/host/socket_test.go6
-rw-r--r--pkg/sentry/fs/inode.go3
-rw-r--r--pkg/sentry/fs/proc/mounts.go3
-rw-r--r--pkg/sentry/fs/proc/sys_net.go4
-rw-r--r--pkg/sentry/fs/proc/task.go28
-rw-r--r--pkg/sentry/fs/tmpfs/fs.go3
-rw-r--r--pkg/sentry/fsbridge/vfs.go28
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go4
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD12
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go7
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go30
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go118
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer_test.go64
-rw-r--r--pkg/sentry/fsimpl/gofer/p9file.go14
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go7
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go4
-rw-r--r--pkg/sentry/fsimpl/gofer/symlink.go2
-rw-r--r--pkg/sentry/fsimpl/gofer/time.go39
-rw-r--r--pkg/sentry/fsimpl/host/host.go76
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go51
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go5
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go16
-rw-r--r--pkg/sentry/fsimpl/kernfs/symlink.go5
-rw-r--r--pkg/sentry/fsimpl/pipefs/BUILD20
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go148
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD1
-rw-r--r--pkg/sentry/fsimpl/proc/task.go16
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go125
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go73
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go92
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go10
-rw-r--r--pkg/sentry/fsimpl/sockfs/BUILD17
-rw-r--r--pkg/sentry/fsimpl/sockfs/sockfs.go102
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/benchmark_test.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go79
-rw-r--r--pkg/sentry/fsimpl/tmpfs/named_pipe.go23
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/socket_file.go (renamed from pkg/tcpip/stack/packet_buffer_state.go)27
-rw-r--r--pkg/sentry/fsimpl/tmpfs/stat_test.go12
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go170
-rw-r--r--pkg/sentry/kernel/BUILD3
-rw-r--r--pkg/sentry/kernel/fd_table.go55
-rw-r--r--pkg/sentry/kernel/kernel.go124
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go13
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go162
-rw-r--r--pkg/sentry/kernel/ptrace.go1
-rw-r--r--pkg/sentry/kernel/rseq.go2
-rw-r--r--pkg/sentry/kernel/shm/shm.go2
-rw-r--r--pkg/sentry/kernel/syscalls.go43
-rw-r--r--pkg/sentry/kernel/syscalls_state.go36
-rw-r--r--pkg/sentry/kernel/task.go9
-rw-r--r--pkg/sentry/kernel/task_context.go3
-rw-r--r--pkg/sentry/kernel/task_identity.go2
-rw-r--r--pkg/sentry/kernel/task_run.go2
-rw-r--r--pkg/sentry/kernel/task_signals.go4
-rw-r--r--pkg/sentry/kernel/task_syscall.go26
-rw-r--r--pkg/sentry/kernel/time/time.go10
-rw-r--r--pkg/sentry/mm/address_space.go6
-rw-r--r--pkg/sentry/mm/aio_context.go101
-rw-r--r--pkg/sentry/mm/aio_context_state.go2
-rw-r--r--pkg/sentry/mm/lifecycle.go1
-rw-r--r--pkg/sentry/mm/metadata.go14
-rw-r--r--pkg/sentry/mm/mm.go3
-rw-r--r--pkg/sentry/platform/kvm/BUILD1
-rw-r--r--pkg/sentry/platform/kvm/bluepill.go12
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go97
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64.go4
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go1
-rw-r--r--pkg/sentry/platform/kvm/kvm_test.go64
-rw-r--r--pkg/sentry/platform/kvm/machine.go9
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go25
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go89
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go41
-rw-r--r--pkg/sentry/platform/ring0/BUILD2
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.go7
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64_unsafe.go108
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD3
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_x86.go2
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids.go5
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go32
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s45
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids_x86.go20
-rw-r--r--pkg/sentry/platform/ring0/x86.go2
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go169
-rw-r--r--pkg/sentry/socket/netstack/provider.go8
-rw-r--r--pkg/sentry/socket/socket.go89
-rw-r--r--pkg/sentry/socket/unix/BUILD4
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go48
-rw-r--r--pkg/sentry/socket/unix/unix.go89
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go348
-rw-r--r--pkg/sentry/strace/strace.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go36
-rw-r--r--pkg/sentry/syscalls/linux/sys_pipe.go14
-rw-r--r--pkg/sentry/syscalls/linux/sys_prctl.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_rlimit.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go6
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD10
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go17
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/getdents.go20
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go50
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/pipe.go63
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/poll.go14
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/read_write.go16
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go1139
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat.go23
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sys_timerfd.go123
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/xattr.go13
-rw-r--r--pkg/sentry/vfs/BUILD2
-rw-r--r--pkg/sentry/vfs/anonfs.go4
-rw-r--r--pkg/sentry/vfs/epoll.go2
-rw-r--r--pkg/sentry/vfs/file_description.go62
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go4
-rw-r--r--pkg/sentry/vfs/filesystem.go26
-rw-r--r--pkg/sentry/vfs/memxattr/BUILD15
-rw-r--r--pkg/sentry/vfs/memxattr/xattr.go102
-rw-r--r--pkg/sentry/vfs/mount.go15
-rw-r--r--pkg/sentry/vfs/mount_test.go2
-rw-r--r--pkg/sentry/vfs/options.go19
-rw-r--r--pkg/sentry/vfs/timerfd.go142
-rw-r--r--pkg/sentry/vfs/vfs.go15
-rw-r--r--pkg/sentry/watchdog/watchdog.go13
-rw-r--r--pkg/state/state.go5
-rw-r--r--pkg/sync/BUILD10
-rw-r--r--pkg/sync/mutex_test.go (renamed from pkg/sync/tmutex_test.go)0
-rw-r--r--pkg/sync/mutex_unsafe.go (renamed from pkg/sync/tmutex_unsafe.go)0
-rw-r--r--pkg/sync/rwmutex_test.go (renamed from pkg/sync/downgradable_rwmutex_test.go)0
-rw-r--r--pkg/sync/rwmutex_unsafe.go (renamed from pkg/sync/downgradable_rwmutex_unsafe.go)0
-rw-r--r--pkg/sync/sync.go (renamed from pkg/sync/syncutil.go)0
-rw-r--r--pkg/tcpip/buffer/view.go53
-rw-r--r--pkg/tcpip/buffer/view_test.go137
-rw-r--r--pkg/tcpip/checker/checker.go100
-rw-r--r--pkg/tcpip/header/BUILD1
-rw-r--r--pkg/tcpip/header/eth_test.go2
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go41
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers_test.go41
-rw-r--r--pkg/tcpip/header/ndp_options.go407
-rw-r--r--pkg/tcpip/header/ndp_test.go694
-rw-r--r--pkg/tcpip/header/ndpoptionidentifier_string.go50
-rw-r--r--pkg/tcpip/header/udp.go5
-rw-r--r--pkg/tcpip/link/channel/channel.go52
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go164
-rw-r--r--pkg/tcpip/link/loopback/loopback.go2
-rw-r--r--pkg/tcpip/link/muxed/injectable.go2
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go2
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go166
-rw-r--r--pkg/tcpip/link/waitable/waitable.go4
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go6
-rw-r--r--pkg/tcpip/network/arp/arp.go2
-rw-r--r--pkg/tcpip/network/arp/arp_test.go3
-rw-r--r--pkg/tcpip/network/ip_test.go2
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go37
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go239
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go3
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go67
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go68
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go423
-rw-r--r--pkg/tcpip/seqnum/seqnum.go5
-rw-r--r--pkg/tcpip/stack/BUILD14
-rw-r--r--pkg/tcpip/stack/forwarder_test.go8
-rw-r--r--pkg/tcpip/stack/iptables.go17
-rw-r--r--pkg/tcpip/stack/ndp.go231
-rw-r--r--pkg/tcpip/stack/ndp_test.go732
-rw-r--r--pkg/tcpip/stack/nic.go73
-rw-r--r--pkg/tcpip/stack/packet_buffer.go14
-rw-r--r--pkg/tcpip/stack/registration.go4
-rw-r--r--pkg/tcpip/stack/route.go19
-rw-r--r--pkg/tcpip/stack/stack_test.go36
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go35
-rw-r--r--pkg/tcpip/tcpip.go152
-rw-r--r--pkg/tcpip/tcpip_test.go2
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go43
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go18
-rw-r--r--pkg/tcpip/transport/tcp/BUILD15
-rw-r--r--pkg/tcpip/transport/tcp/accept.go145
-rw-r--r--pkg/tcpip/transport/tcp/connect.go99
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go9
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go486
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go102
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go23
-rw-r--r--pkg/tcpip/transport/tcp/rcv_test.go74
-rw-r--r--pkg/tcpip/transport/tcp/segment.go6
-rw-r--r--pkg/tcpip/transport/tcp/snd.go15
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go83
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go41
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go251
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go30
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go24
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD1
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go13
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go227
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go77
-rw-r--r--pkg/usermem/usermem.go3
-rw-r--r--pkg/usermem/usermem_x86.go2
-rw-r--r--runsc/boot/BUILD12
-rw-r--r--runsc/boot/compat.go2
-rw-r--r--runsc/boot/config.go5
-rw-r--r--runsc/boot/fds.go33
-rw-r--r--runsc/boot/filter/config.go2
-rw-r--r--runsc/boot/fs.go24
-rw-r--r--runsc/boot/loader.go76
-rw-r--r--runsc/boot/loader_amd64.go5
-rw-r--r--runsc/boot/loader_arm64.go5
-rw-r--r--runsc/boot/loader_test.go49
-rw-r--r--runsc/boot/user.go64
-rw-r--r--runsc/boot/vfs.go310
-rw-r--r--runsc/cmd/capability_test.go2
-rw-r--r--runsc/cmd/gofer.go5
-rw-r--r--runsc/container/console_test.go6
-rw-r--r--runsc/container/container.go4
-rw-r--r--runsc/container/container_test.go72
-rw-r--r--runsc/container/multi_container_test.go44
-rw-r--r--runsc/container/shared_volume_test.go4
-rw-r--r--runsc/container/test_app/test_app.go2
-rw-r--r--runsc/fsgofer/fsgofer.go4
-rw-r--r--runsc/main.go13
-rw-r--r--runsc/sandbox/sandbox.go97
-rw-r--r--runsc/specutils/specutils.go11
-rw-r--r--runsc/testutil/testutil.go7
-rwxr-xr-xscripts/benchmark.sh11
-rwxr-xr-xscripts/common.sh22
-rwxr-xr-xscripts/runtime_tests.sh26
-rw-r--r--test/packetdrill/BUILD12
-rw-r--r--test/packetdrill/linux/tcp_user_timeout.pkt39
-rw-r--r--test/packetdrill/netstack/tcp_user_timeout.pkt38
-rw-r--r--test/packetimpact/dut/posix_server.cc67
-rw-r--r--test/packetimpact/proto/posix_server.proto107
-rw-r--r--test/packetimpact/testbench/BUILD12
-rw-r--r--test/packetimpact/testbench/connections.go663
-rw-r--r--test/packetimpact/testbench/dut.go362
-rw-r--r--test/packetimpact/testbench/layers.go339
-rw-r--r--test/packetimpact/testbench/layers_test.go156
-rw-r--r--test/packetimpact/testbench/rawsockets.go44
-rw-r--r--test/packetimpact/tests/BUILD81
-rw-r--r--test/packetimpact/tests/Dockerfile14
-rw-r--r--test/packetimpact/tests/defs.bzl18
-rw-r--r--test/packetimpact/tests/fin_wait2_timeout_test.go16
-rw-r--r--test/packetimpact/tests/tcp_close_wait_ack_test.go102
-rw-r--r--test/packetimpact/tests/tcp_noaccept_close_rst_test.go37
-rw-r--r--test/packetimpact/tests/tcp_outside_the_window_test.go88
-rw-r--r--test/packetimpact/tests/tcp_should_piggyback_test.go59
-rw-r--r--test/packetimpact/tests/tcp_user_timeout_test.go100
-rw-r--r--test/packetimpact/tests/tcp_window_shrink_test.go68
-rwxr-xr-xtest/packetimpact/tests/test_runner.sh24
-rw-r--r--test/packetimpact/tests/udp_recv_multicast_test.go37
-rw-r--r--test/perf/BUILD1
-rw-r--r--test/perf/linux/getdents_benchmark.cc2
-rw-r--r--test/root/cgroup_test.go4
-rw-r--r--test/root/oom_score_adj_test.go4
-rw-r--r--test/runtimes/blacklist_test.go2
-rw-r--r--test/runtimes/runner.go2
-rw-r--r--test/syscalls/linux/BUILD10
-rw-r--r--test/syscalls/linux/aio.cc12
-rw-r--r--test/syscalls/linux/epoll.cc4
-rw-r--r--test/syscalls/linux/exec.cc10
-rw-r--r--test/syscalls/linux/exec_binary.cc164
-rw-r--r--test/syscalls/linux/file_base.h19
-rw-r--r--test/syscalls/linux/fork.cc17
-rw-r--r--test/syscalls/linux/getrandom.cc2
-rw-r--r--test/syscalls/linux/ip_socket_test_util.cc22
-rw-r--r--test/syscalls/linux/ip_socket_test_util.h6
-rw-r--r--test/syscalls/linux/itimer.cc2
-rw-r--r--test/syscalls/linux/lseek.cc2
-rw-r--r--test/syscalls/linux/memfd.cc1
-rw-r--r--test/syscalls/linux/mkdir.cc20
-rw-r--r--test/syscalls/linux/mlock.cc4
-rw-r--r--test/syscalls/linux/mmap.cc10
-rw-r--r--test/syscalls/linux/open.cc22
-rw-r--r--test/syscalls/linux/pipe.cc2
-rw-r--r--test/syscalls/linux/poll.cc2
-rw-r--r--test/syscalls/linux/pread64.cc16
-rw-r--r--test/syscalls/linux/proc.cc33
-rw-r--r--test/syscalls/linux/proc_net.cc2
-rw-r--r--test/syscalls/linux/proc_net_unix.cc6
-rw-r--r--test/syscalls/linux/proc_pid_smaps.cc4
-rw-r--r--test/syscalls/linux/ptrace.cc33
-rw-r--r--test/syscalls/linux/pty.cc2
-rw-r--r--test/syscalls/linux/pwrite64.cc21
-rw-r--r--test/syscalls/linux/rseq/BUILD34
-rw-r--r--test/syscalls/linux/sendfile.cc51
-rw-r--r--test/syscalls/linux/sendfile_socket.cc107
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc219
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc49
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h6
-rw-r--r--test/syscalls/linux/socket_netlink_route.cc77
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.cc7
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.h4
-rw-r--r--test/syscalls/linux/socket_test_util.cc5
-rw-r--r--test/syscalls/linux/socket_unix.cc2
-rw-r--r--test/syscalls/linux/splice.cc1
-rw-r--r--test/syscalls/linux/tuntap.cc26
-rw-r--r--test/syscalls/linux/uidgid.cc12
-rw-r--r--test/syscalls/linux/utimes.cc18
-rw-r--r--test/syscalls/linux/write.cc10
-rw-r--r--test/syscalls/linux/xattr.cc8
-rw-r--r--test/util/capability_util.cc4
-rw-r--r--tools/bazeldefs/platforms.bzl7
-rw-r--r--tools/go_generics/defs.bzl1
-rw-r--r--tools/go_marshal/analysis/analysis_unsafe.go4
-rw-r--r--tools/go_marshal/defs.bzl3
-rw-r--r--tools/go_marshal/gomarshal/generator.go134
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go113
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go104
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go173
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_struct.go372
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go52
-rw-r--r--tools/go_marshal/gomarshal/util.go175
-rw-r--r--tools/go_marshal/marshal/marshal.go103
-rw-r--r--tools/go_marshal/primitive/BUILD18
-rw-r--r--tools/go_marshal/primitive/primitive.go175
-rw-r--r--tools/go_marshal/test/BUILD14
-rw-r--r--tools/go_marshal/test/benchmark_test.go42
-rw-r--r--tools/go_marshal/test/external/external.go8
-rw-r--r--tools/go_marshal/test/marshal_test.go515
-rw-r--r--tools/go_marshal/test/test.go64
-rw-r--r--tools/go_stateify/main.go2
-rwxr-xr-xtools/image_build.sh98
-rw-r--r--tools/images/BUILD9
-rw-r--r--tools/images/README.md42
-rwxr-xr-xtools/images/build.sh35
-rw-r--r--tools/images/defs.bzl136
-rwxr-xr-xtools/images/zone.sh17
-rw-r--r--tools/nogo.json59
386 files changed, 14147 insertions, 4632 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 4d2b4a72f..c40e03ad2 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -4,10 +4,10 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
# Load go bazel rules and gazelle.
http_archive(
name = "io_bazel_rules_go",
- sha256 = "94f90feaa65c9cdc840cd21f67d967870b5943d684966a47569da8073e42063d",
+ sha256 = "db2b2d35293f405430f553bc7a865a8749a8ef60c30287e90d2b278c32771afe",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.22.0/rules_go-v0.22.0.tar.gz",
- "https://github.com/bazelbuild/rules_go/releases/download/v0.22.0/rules_go-v0.22.0.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.22.3/rules_go-v0.22.3.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.22.3/rules_go-v0.22.3.tar.gz",
],
)
@@ -25,7 +25,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_depe
go_rules_dependencies()
go_register_toolchains(
- go_version = "1.14",
+ go_version = "1.14.2",
nogo = "@//:nogo",
)
@@ -99,11 +99,11 @@ pip_install()
# See releases at https://releases.bazel.build/bazel-toolchains.html
http_archive(
name = "bazel_toolchains",
- sha256 = "b5a8039df7119d618402472f3adff8a1bd0ae9d5e253f53fcc4c47122e91a3d2",
- strip_prefix = "bazel-toolchains-2.1.1",
+ sha256 = "239a1a673861eabf988e9804f45da3b94da28d1aff05c373b013193c315d9d9e",
+ strip_prefix = "bazel-toolchains-3.0.1",
urls = [
- "https://github.com/bazelbuild/bazel-toolchains/releases/download/2.1.1/bazel-toolchains-2.1.1.tar.gz",
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/2.1.1.tar.gz",
+ "https://github.com/bazelbuild/bazel-toolchains/releases/download/3.0.1/bazel-toolchains-3.0.1.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/releases/download/3.0.1/bazel-toolchains-3.0.1.tar.gz",
],
)
@@ -400,6 +400,20 @@ go_repository(
version = "v0.20.0",
)
+go_repository(
+ name = "org_uber_go_atomic",
+ importpath = "go.uber.org/atomic",
+ version = "v1.6.0",
+ sum = "h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=",
+)
+
+go_repository(
+ name = "org_uber_go_multierr",
+ importpath = "go.uber.org/multierr",
+ version = "v1.5.0",
+ sum = "h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A=",
+)
+
# BigQuery Dependencies for Benchmarks
go_repository(
name = "com_google_cloud_go",
diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py
index ca785a148..fc59cf505 100644
--- a/benchmarks/runner/__init__.py
+++ b/benchmarks/runner/__init__.py
@@ -19,6 +19,7 @@ import logging
import pkgutil
import pydoc
import re
+import subprocess
import sys
import types
from typing import List
@@ -125,9 +126,8 @@ def run_gcp(ctx, image_file: str, zone_file: str, internal: bool,
"""Runs all benchmarks on GCP instances."""
# Resolve all files.
- image = open(image_file).read().rstrip()
- zone = open(zone_file).read().rstrip()
-
+ image = subprocess.check_output([image_file]).rstrip()
+ zone = subprocess.check_output([zone_file]).rstrip()
key_file = harness.make_key()
producer = gcloud_producer.GCloudProducer(
diff --git a/benchmarks/runner/commands.py b/benchmarks/runner/commands.py
index 194804527..e8289f6c5 100644
--- a/benchmarks/runner/commands.py
+++ b/benchmarks/runner/commands.py
@@ -101,15 +101,15 @@ class GCPCommand(RunCommand):
image_file = click.core.Option(
("--image_file",),
- help="The file containing the image for VMs.",
+ help="The binary that emits the GCP image.",
default=os.path.join(
- os.path.dirname(__file__), "../../tools/images/ubuntu1604.txt"),
+ os.path.dirname(__file__), "../../tools/images/ubuntu1604"),
)
zone_file = click.core.Option(
("--zone_file",),
- help="The file containing the GCP zone.",
+ help="The binary that emits the GCP zone.",
default=os.path.join(
- os.path.dirname(__file__), "../../tools/images/zone.txt"),
+ os.path.dirname(__file__), "../../tools/images/zone"),
)
internal = click.core.Option(
("--internal/--no-internal",),
diff --git a/kokoro/runtime_tests/go1.12.cfg b/kokoro/runtime_tests/go1.12.cfg
index fd4911e88..04bfe2868 100644
--- a/kokoro/runtime_tests/go1.12.cfg
+++ b/kokoro/runtime_tests/go1.12.cfg
@@ -1,4 +1,4 @@
-build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh"
+build_file: "github/github/scripts/runtime_tests.sh"
env_vars {
key: "RUNTIME_TEST_NAME"
diff --git a/kokoro/runtime_tests/java11.cfg b/kokoro/runtime_tests/java11.cfg
index 7f8611a08..c82855cd2 100644
--- a/kokoro/runtime_tests/java11.cfg
+++ b/kokoro/runtime_tests/java11.cfg
@@ -1,4 +1,4 @@
-build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh"
+build_file: "github/github/scripts/runtime_tests.sh"
env_vars {
key: "RUNTIME_TEST_NAME"
diff --git a/kokoro/runtime_tests/nodejs12.4.0.cfg b/kokoro/runtime_tests/nodejs12.4.0.cfg
index c67ad5567..5512db5df 100644
--- a/kokoro/runtime_tests/nodejs12.4.0.cfg
+++ b/kokoro/runtime_tests/nodejs12.4.0.cfg
@@ -1,4 +1,4 @@
-build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh"
+build_file: "github/github/scripts/runtime_tests.sh"
env_vars {
key: "RUNTIME_TEST_NAME"
diff --git a/kokoro/runtime_tests/php7.3.6.cfg b/kokoro/runtime_tests/php7.3.6.cfg
index f266c5e26..bc9ac92aa 100644
--- a/kokoro/runtime_tests/php7.3.6.cfg
+++ b/kokoro/runtime_tests/php7.3.6.cfg
@@ -1,4 +1,4 @@
-build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh"
+build_file: "github/github/scripts/runtime_tests.sh"
env_vars {
key: "RUNTIME_TEST_NAME"
diff --git a/kokoro/runtime_tests/python3.7.3.cfg b/kokoro/runtime_tests/python3.7.3.cfg
index 574add152..12eb13860 100644
--- a/kokoro/runtime_tests/python3.7.3.cfg
+++ b/kokoro/runtime_tests/python3.7.3.cfg
@@ -1,4 +1,4 @@
-build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh"
+build_file: "github/github/scripts/runtime_tests.sh"
env_vars {
key: "RUNTIME_TEST_NAME"
diff --git a/pkg/atomicbitops/atomicbitops_amd64.s b/pkg/atomicbitops/atomicbitops_amd64.s
index f0edd4de7..54c887ee5 100644
--- a/pkg/atomicbitops/atomicbitops_amd64.s
+++ b/pkg/atomicbitops/atomicbitops_amd64.s
@@ -16,28 +16,28 @@
#include "textflag.h"
-TEXT ·AndUint32(SB),NOSPLIT,$0-12
+TEXT ·AndUint32(SB),$0-12
MOVQ addr+0(FP), BP
MOVL val+8(FP), AX
LOCK
ANDL AX, 0(BP)
RET
-TEXT ·OrUint32(SB),NOSPLIT,$0-12
+TEXT ·OrUint32(SB),$0-12
MOVQ addr+0(FP), BP
MOVL val+8(FP), AX
LOCK
ORL AX, 0(BP)
RET
-TEXT ·XorUint32(SB),NOSPLIT,$0-12
+TEXT ·XorUint32(SB),$0-12
MOVQ addr+0(FP), BP
MOVL val+8(FP), AX
LOCK
XORL AX, 0(BP)
RET
-TEXT ·CompareAndSwapUint32(SB),NOSPLIT,$0-20
+TEXT ·CompareAndSwapUint32(SB),$0-20
MOVQ addr+0(FP), DI
MOVL old+8(FP), AX
MOVL new+12(FP), DX
@@ -46,28 +46,28 @@ TEXT ·CompareAndSwapUint32(SB),NOSPLIT,$0-20
MOVL AX, ret+16(FP)
RET
-TEXT ·AndUint64(SB),NOSPLIT,$0-16
+TEXT ·AndUint64(SB),$0-16
MOVQ addr+0(FP), BP
MOVQ val+8(FP), AX
LOCK
ANDQ AX, 0(BP)
RET
-TEXT ·OrUint64(SB),NOSPLIT,$0-16
+TEXT ·OrUint64(SB),$0-16
MOVQ addr+0(FP), BP
MOVQ val+8(FP), AX
LOCK
ORQ AX, 0(BP)
RET
-TEXT ·XorUint64(SB),NOSPLIT,$0-16
+TEXT ·XorUint64(SB),$0-16
MOVQ addr+0(FP), BP
MOVQ val+8(FP), AX
LOCK
XORQ AX, 0(BP)
RET
-TEXT ·CompareAndSwapUint64(SB),NOSPLIT,$0-32
+TEXT ·CompareAndSwapUint64(SB),$0-32
MOVQ addr+0(FP), DI
MOVQ old+8(FP), AX
MOVQ new+16(FP), DX
diff --git a/pkg/atomicbitops/atomicbitops_arm64.s b/pkg/atomicbitops/atomicbitops_arm64.s
index 644a6bca5..5c780851b 100644
--- a/pkg/atomicbitops/atomicbitops_arm64.s
+++ b/pkg/atomicbitops/atomicbitops_arm64.s
@@ -16,7 +16,7 @@
#include "textflag.h"
-TEXT ·AndUint32(SB),NOSPLIT,$0-12
+TEXT ·AndUint32(SB),$0-12
MOVD ptr+0(FP), R0
MOVW val+8(FP), R1
again:
@@ -26,7 +26,7 @@ again:
CBNZ R3, again
RET
-TEXT ·OrUint32(SB),NOSPLIT,$0-12
+TEXT ·OrUint32(SB),$0-12
MOVD ptr+0(FP), R0
MOVW val+8(FP), R1
again:
@@ -36,7 +36,7 @@ again:
CBNZ R3, again
RET
-TEXT ·XorUint32(SB),NOSPLIT,$0-12
+TEXT ·XorUint32(SB),$0-12
MOVD ptr+0(FP), R0
MOVW val+8(FP), R1
again:
@@ -46,7 +46,7 @@ again:
CBNZ R3, again
RET
-TEXT ·CompareAndSwapUint32(SB),NOSPLIT,$0-20
+TEXT ·CompareAndSwapUint32(SB),$0-20
MOVD addr+0(FP), R0
MOVW old+8(FP), R1
MOVW new+12(FP), R2
@@ -60,7 +60,7 @@ done:
MOVW R3, prev+16(FP)
RET
-TEXT ·AndUint64(SB),NOSPLIT,$0-16
+TEXT ·AndUint64(SB),$0-16
MOVD ptr+0(FP), R0
MOVD val+8(FP), R1
again:
@@ -70,7 +70,7 @@ again:
CBNZ R3, again
RET
-TEXT ·OrUint64(SB),NOSPLIT,$0-16
+TEXT ·OrUint64(SB),$0-16
MOVD ptr+0(FP), R0
MOVD val+8(FP), R1
again:
@@ -80,7 +80,7 @@ again:
CBNZ R3, again
RET
-TEXT ·XorUint64(SB),NOSPLIT,$0-16
+TEXT ·XorUint64(SB),$0-16
MOVD ptr+0(FP), R0
MOVD val+8(FP), R1
again:
@@ -90,7 +90,7 @@ again:
CBNZ R3, again
RET
-TEXT ·CompareAndSwapUint64(SB),NOSPLIT,$0-32
+TEXT ·CompareAndSwapUint64(SB),$0-32
MOVD addr+0(FP), R0
MOVD old+8(FP), R1
MOVD new+16(FP), R2
diff --git a/pkg/atomicbitops/atomicbitops_noasm.go b/pkg/atomicbitops/atomicbitops_noasm.go
index 4e9c27b98..3b2898256 100644
--- a/pkg/atomicbitops/atomicbitops_noasm.go
+++ b/pkg/atomicbitops/atomicbitops_noasm.go
@@ -20,7 +20,6 @@ import (
"sync/atomic"
)
-//go:nosplit
func AndUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -31,7 +30,6 @@ func AndUint32(addr *uint32, val uint32) {
}
}
-//go:nosplit
func OrUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -42,7 +40,6 @@ func OrUint32(addr *uint32, val uint32) {
}
}
-//go:nosplit
func XorUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -53,7 +50,6 @@ func XorUint32(addr *uint32, val uint32) {
}
}
-//go:nosplit
func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
for {
prev = atomic.LoadUint32(addr)
@@ -66,7 +62,6 @@ func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
}
}
-//go:nosplit
func AndUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -77,7 +72,6 @@ func AndUint64(addr *uint64, val uint64) {
}
}
-//go:nosplit
func OrUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -88,7 +82,6 @@ func OrUint64(addr *uint64, val uint64) {
}
}
-//go:nosplit
func XorUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -99,7 +92,6 @@ func XorUint64(addr *uint64, val uint64) {
}
}
-//go:nosplit
func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
for {
prev = atomic.LoadUint64(addr)
diff --git a/pkg/context/context.go b/pkg/context/context.go
index 23e009ef3..5319b6d8d 100644
--- a/pkg/context/context.go
+++ b/pkg/context/context.go
@@ -127,10 +127,6 @@ func (logContext) Value(key interface{}) interface{} {
var bgContext = &logContext{Logger: log.Log()}
// Background returns an empty context using the default logger.
-//
-// Users should be wary of using a Background context. Please tag any use with
-// FIXME(b/38173783) and a note to remove this use.
-//
// Generally, one should use the Task as their context when available, or avoid
// having to use a context in places where a Task is unavailable.
//
diff --git a/pkg/cpuid/cpuid_parse_x86_test.go b/pkg/cpuid/cpuid_parse_x86_test.go
index d48418e69..c9bd40e1b 100644
--- a/pkg/cpuid/cpuid_parse_x86_test.go
+++ b/pkg/cpuid/cpuid_parse_x86_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build 386 amd64
package cpuid
diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go
index 9abf6914d..562f8f405 100644
--- a/pkg/cpuid/cpuid_x86.go
+++ b/pkg/cpuid/cpuid_x86.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build 386 amd64
package cpuid
diff --git a/pkg/cpuid/cpuid_x86_test.go b/pkg/cpuid/cpuid_x86_test.go
index 0fe20c213..bacf345c8 100644
--- a/pkg/cpuid/cpuid_x86_test.go
+++ b/pkg/cpuid/cpuid_x86_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build 386 amd64
package cpuid
diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go
index 7f41b4a27..43750360b 100644
--- a/pkg/eventchannel/event_test.go
+++ b/pkg/eventchannel/event_test.go
@@ -78,7 +78,7 @@ func TestMultiEmitter(t *testing.T) {
for _, name := range names {
m := testMessage{name: name}
if _, err := me.Emit(m); err != nil {
- t.Fatal("me.Emit(%v) failed: %v", m, err)
+ t.Fatalf("me.Emit(%v) failed: %v", m, err)
}
}
@@ -96,7 +96,7 @@ func TestMultiEmitter(t *testing.T) {
// Close multiEmitter.
if err := me.Close(); err != nil {
- t.Fatal("me.Close() failed: %v", err)
+ t.Fatalf("me.Close() failed: %v", err)
}
// All testEmitters should be closed.
diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go
index 8f93e4d6d..0d07da3b1 100644
--- a/pkg/ilist/list.go
+++ b/pkg/ilist/list.go
@@ -86,12 +86,21 @@ func (l *List) Back() Element {
return l.tail
}
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+func (l *List) Len() (count int) {
+ for e := l.Front(); e != nil; e = e.Next() {
+ count++
+ }
+ return count
+}
+
// PushFront inserts the element e at the front of list l.
func (l *List) PushFront(e Element) {
linker := ElementMapper{}.linkerFor(e)
linker.SetNext(l.head)
linker.SetPrev(nil)
-
if l.head != nil {
ElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
@@ -106,7 +115,6 @@ func (l *List) PushBack(e Element) {
linker := ElementMapper{}.linkerFor(e)
linker.SetNext(nil)
linker.SetPrev(l.tail)
-
if l.tail != nil {
ElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
@@ -127,7 +135,6 @@ func (l *List) PushBackList(m *List) {
l.tail = m.tail
}
-
m.head = nil
m.tail = nil
}
diff --git a/pkg/log/glog.go b/pkg/log/glog.go
index b4f7bb5a4..f57c4427b 100644
--- a/pkg/log/glog.go
+++ b/pkg/log/glog.go
@@ -25,7 +25,7 @@ import (
// GoogleEmitter is a wrapper that emits logs in a format compatible with
// package github.com/golang/glog.
type GoogleEmitter struct {
- Writer
+ *Writer
}
// pid is used for the threadid component of the header.
@@ -46,7 +46,7 @@ var pid = os.Getpid()
// line The line number
// msg The user-supplied message
//
-func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) {
+func (g GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) {
// Log level.
prefix := byte('?')
switch level {
@@ -81,5 +81,5 @@ func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format
message := fmt.Sprintf(format, args...)
// Emit the formatted result.
- fmt.Fprintf(&g.Writer, "%c%02d%02d %02d:%02d:%02d.%06d % 7d %s:%d] %s\n", prefix, int(month), day, hour, minute, second, microsecond, pid, file, line, message)
+ fmt.Fprintf(g.Writer, "%c%02d%02d %02d:%02d:%02d.%06d % 7d %s:%d] %s\n", prefix, int(month), day, hour, minute, second, microsecond, pid, file, line, message)
}
diff --git a/pkg/log/json.go b/pkg/log/json.go
index 0943db1cc..bdf9d691e 100644
--- a/pkg/log/json.go
+++ b/pkg/log/json.go
@@ -58,7 +58,7 @@ func (lv *Level) UnmarshalJSON(b []byte) error {
// JSONEmitter logs messages in json format.
type JSONEmitter struct {
- Writer
+ *Writer
}
// Emit implements Emitter.Emit.
diff --git a/pkg/log/json_k8s.go b/pkg/log/json_k8s.go
index 6c6fc8b6f..5883e95e1 100644
--- a/pkg/log/json_k8s.go
+++ b/pkg/log/json_k8s.go
@@ -29,11 +29,11 @@ type k8sJSONLog struct {
// K8sJSONEmitter logs messages in json format that is compatible with
// Kubernetes fluent configuration.
type K8sJSONEmitter struct {
- Writer
+ *Writer
}
// Emit implements Emitter.Emit.
-func (e *K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
+func (e K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
j := k8sJSONLog{
Log: fmt.Sprintf(format, v...),
Level: level,
diff --git a/pkg/log/log.go b/pkg/log/log.go
index a794da1aa..37e0605ad 100644
--- a/pkg/log/log.go
+++ b/pkg/log/log.go
@@ -374,5 +374,5 @@ func CopyStandardLogTo(l Level) error {
func init() {
// Store the initial value for the log.
- log.Store(&BasicLogger{Level: Info, Emitter: &GoogleEmitter{Writer{Next: os.Stderr}}})
+ log.Store(&BasicLogger{Level: Info, Emitter: GoogleEmitter{&Writer{Next: os.Stderr}}})
}
diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go
index 402cc29ae..9ff18559b 100644
--- a/pkg/log/log_test.go
+++ b/pkg/log/log_test.go
@@ -52,7 +52,7 @@ func TestDropMessages(t *testing.T) {
t.Fatalf("Write should have failed")
}
- fmt.Printf("writer: %+v\n", w)
+ fmt.Printf("writer: %#v\n", &w)
tw.fail = false
if _, err := w.Write([]byte("line 2\n")); err != nil {
@@ -76,7 +76,7 @@ func TestDropMessages(t *testing.T) {
func TestCaller(t *testing.T) {
tw := &testWriter{}
- e := &GoogleEmitter{Writer: Writer{Next: tw}}
+ e := GoogleEmitter{Writer: &Writer{Next: tw}}
bl := &BasicLogger{
Emitter: e,
Level: Debug,
@@ -94,7 +94,7 @@ func BenchmarkGoogleLogging(b *testing.B) {
tw := &testWriter{
limit: 1, // Only record one message.
}
- e := &GoogleEmitter{Writer: Writer{Next: tw}}
+ e := GoogleEmitter{Writer: &Writer{Next: tw}}
bl := &BasicLogger{
Emitter: e,
Level: Debug,
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index a6f493b82..71e944c30 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -174,7 +174,7 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
// our sendRecv function to use that functionality. Otherwise,
// we stick to sendRecvLegacy.
rversion := Rversion{}
- err := c.sendRecvLegacy(&Tversion{
+ _, err := c.sendRecvLegacy(&Tversion{
Version: versionString(requested),
MSize: messageSize,
}, &rversion)
@@ -219,11 +219,11 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
c.sendRecv = c.sendRecvChannel
} else {
// Channel setup failed; fallback.
- c.sendRecv = c.sendRecvLegacy
+ c.sendRecv = c.sendRecvLegacySyscallErr
}
} else {
// No channels available: use the legacy mechanism.
- c.sendRecv = c.sendRecvLegacy
+ c.sendRecv = c.sendRecvLegacySyscallErr
}
// Ensure that the socket and channels are closed when the socket is shut
@@ -305,7 +305,7 @@ func (c *Client) openChannel(id int) error {
)
// Open the data channel.
- if err := c.sendRecvLegacy(&Tchannel{
+ if _, err := c.sendRecvLegacy(&Tchannel{
ID: uint32(id),
Control: 0,
}, &rchannel0); err != nil {
@@ -319,7 +319,7 @@ func (c *Client) openChannel(id int) error {
defer rchannel0.FilePayload().Close()
// Open the channel for file descriptors.
- if err := c.sendRecvLegacy(&Tchannel{
+ if _, err := c.sendRecvLegacy(&Tchannel{
ID: uint32(id),
Control: 1,
}, &rchannel1); err != nil {
@@ -431,13 +431,28 @@ func (c *Client) waitAndRecv(done chan error) error {
}
}
+// sendRecvLegacySyscallErr is a wrapper for sendRecvLegacy that converts all
+// non-syscall errors to EIO.
+func (c *Client) sendRecvLegacySyscallErr(t message, r message) error {
+ received, err := c.sendRecvLegacy(t, r)
+ if !received {
+ log.Warningf("p9.Client.sendRecvChannel: %v", err)
+ return syscall.EIO
+ }
+ return err
+}
+
// sendRecvLegacy performs a roundtrip message exchange.
//
+// sendRecvLegacy returns true if a message was received. This allows us to
+// differentiate between failed receives and successful receives where the
+// response was an error message.
+//
// This is called by internal functions.
-func (c *Client) sendRecvLegacy(t message, r message) error {
+func (c *Client) sendRecvLegacy(t message, r message) (bool, error) {
tag, ok := c.tagPool.Get()
if !ok {
- return ErrOutOfTags
+ return false, ErrOutOfTags
}
defer c.tagPool.Put(tag)
@@ -457,12 +472,12 @@ func (c *Client) sendRecvLegacy(t message, r message) error {
err := send(c.socket, Tag(tag), t)
c.sendMu.Unlock()
if err != nil {
- return err
+ return false, err
}
// Co-ordinate with other receivers.
if err := c.waitAndRecv(resp.done); err != nil {
- return err
+ return false, err
}
// Is it an error message?
@@ -470,14 +485,14 @@ func (c *Client) sendRecvLegacy(t message, r message) error {
// For convenience, we transform these directly
// into errors. Handlers need not handle this case.
if rlerr, ok := resp.r.(*Rlerror); ok {
- return syscall.Errno(rlerr.Error)
+ return true, syscall.Errno(rlerr.Error)
}
// At this point, we know it matches.
//
// Per recv call above, we will only allow a type
// match (and give our r) or an instance of Rlerror.
- return nil
+ return true, nil
}
// sendRecvChannel uses channels to send a message.
@@ -486,7 +501,7 @@ func (c *Client) sendRecvChannel(t message, r message) error {
c.channelsMu.Lock()
if len(c.availableChannels) == 0 {
c.channelsMu.Unlock()
- return c.sendRecvLegacy(t, r)
+ return c.sendRecvLegacySyscallErr(t, r)
}
idx := len(c.availableChannels) - 1
ch := c.availableChannels[idx]
@@ -526,7 +541,11 @@ func (c *Client) sendRecvChannel(t message, r message) error {
}
// Parse the server's response.
- _, retErr := ch.recv(r, rsz)
+ resp, retErr := ch.recv(r, rsz)
+ if resp == nil {
+ log.Warningf("p9.Client.sendRecvChannel: p9.channel.recv: %v", retErr)
+ retErr = syscall.EIO
+ }
// Release the channel.
c.channelsMu.Lock()
diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go
index 29a0afadf..c757583e0 100644
--- a/pkg/p9/client_test.go
+++ b/pkg/p9/client_test.go
@@ -96,7 +96,12 @@ func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) e
}
func BenchmarkSendRecvLegacy(b *testing.B) {
- benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy })
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error {
+ return func(t message, r message) error {
+ _, err := c.sendRecvLegacy(t, r)
+ return err
+ }
+ })
}
func BenchmarkSendRecvChannel(b *testing.B) {
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
index d4ffbc8e3..cab35896f 100644
--- a/pkg/p9/file.go
+++ b/pkg/p9/file.go
@@ -97,12 +97,12 @@ type File interface {
// free to ignore the hint entirely (i.e. the value returned may be larger
// than size). All size checking is done independently at the syscall layer.
//
- // TODO(b/127675828): Determine concurrency guarantees once implemented.
+ // On the server, GetXattr has a read concurrency guarantee.
GetXattr(name string, size uint64) (string, error)
// SetXattr sets extended attributes on this node.
//
- // TODO(b/127675828): Determine concurrency guarantees once implemented.
+ // On the server, SetXattr has a write concurrency guarantee.
SetXattr(name, value string, flags uint32) error
// ListXattr lists the names of the extended attributes on this node.
@@ -113,12 +113,12 @@ type File interface {
// free to ignore the hint entirely (i.e. the value returned may be larger
// than size). All size checking is done independently at the syscall layer.
//
- // TODO(b/148303075): Determine concurrency guarantees once implemented.
+ // On the server, ListXattr has a read concurrency guarantee.
ListXattr(size uint64) (map[string]struct{}, error)
// RemoveXattr removes extended attributes on this node.
//
- // TODO(b/148303075): Determine concurrency guarantees once implemented.
+ // On the server, RemoveXattr has a write concurrency guarantee.
RemoveXattr(name string) error
// Allocate allows the caller to directly manipulate the allocated disk space
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index 2ac45eb80..1db5797dd 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -48,6 +48,8 @@ func ExtractErrno(err error) syscall.Errno {
return ExtractErrno(e.Err)
case *os.SyscallError:
return ExtractErrno(e.Err)
+ case *os.LinkError:
+ return ExtractErrno(e.Err)
}
// Default case.
@@ -920,8 +922,15 @@ func (t *Tgetxattr) handle(cs *connState) message {
}
defer ref.DecRef()
- val, err := ref.file.GetXattr(t.Name, t.Size)
- if err != nil {
+ var val string
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow getxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ val, err = ref.file.GetXattr(t.Name, t.Size)
+ return err
+ }); err != nil {
return newErr(err)
}
return &Rgetxattr{Value: val}
@@ -935,7 +944,13 @@ func (t *Tsetxattr) handle(cs *connState) message {
}
defer ref.DecRef()
- if err := ref.file.SetXattr(t.Name, t.Value, t.Flags); err != nil {
+ if err := ref.safelyWrite(func() error {
+ // Don't allow setxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ return ref.file.SetXattr(t.Name, t.Value, t.Flags)
+ }); err != nil {
return newErr(err)
}
return &Rsetxattr{}
@@ -949,10 +964,18 @@ func (t *Tlistxattr) handle(cs *connState) message {
}
defer ref.DecRef()
- xattrs, err := ref.file.ListXattr(t.Size)
- if err != nil {
+ var xattrs map[string]struct{}
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow listxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ xattrs, err = ref.file.ListXattr(t.Size)
+ return err
+ }); err != nil {
return newErr(err)
}
+
xattrList := make([]string, 0, len(xattrs))
for x := range xattrs {
xattrList = append(xattrList, x)
@@ -968,7 +991,13 @@ func (t *Tremovexattr) handle(cs *connState) message {
}
defer ref.DecRef()
- if err := ref.file.RemoveXattr(t.Name); err != nil {
+ if err := ref.safelyWrite(func() error {
+ // Don't allow removexattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ return ref.file.RemoveXattr(t.Name)
+ }); err != nil {
return newErr(err)
}
return &Rremovexattr{}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index 3863ad1f5..57b89ad7d 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -1926,19 +1926,17 @@ func (r *Rreaddir) decode(b *buffer) {
// encode implements encoder.encode.
func (r *Rreaddir) encode(b *buffer) {
entriesBuf := buffer{}
+ payloadSize := 0
for _, d := range r.Entries {
d.encode(&entriesBuf)
- if len(entriesBuf.data) >= int(r.Count) {
+ if len(entriesBuf.data) > int(r.Count) {
break
}
+ payloadSize = len(entriesBuf.data)
}
- if len(entriesBuf.data) < int(r.Count) {
- r.Count = uint32(len(entriesBuf.data))
- r.payload = entriesBuf.data
- } else {
- r.payload = entriesBuf.data[:r.Count]
- }
- b.Write32(uint32(r.Count))
+ r.Count = uint32(payloadSize)
+ r.payload = entriesBuf.data[:payloadSize]
+ b.Write32(r.Count)
}
// Type implements message.Type.
diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go
index c20324404..7facc9f5e 100644
--- a/pkg/p9/messages_test.go
+++ b/pkg/p9/messages_test.go
@@ -216,7 +216,7 @@ func TestEncodeDecode(t *testing.T) {
},
&Rreaddir{
// Count must be sufficient to encode a dirent.
- Count: 0x18,
+ Count: 0x1a,
Entries: []Dirent{{QID: QID{Type: 2}}},
},
&Tfsync{
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
index a0d274f3b..38038abdf 100644
--- a/pkg/p9/transport_flipcall.go
+++ b/pkg/p9/transport_flipcall.go
@@ -236,7 +236,7 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) {
// Convert errors appropriately; see above.
if rlerr, ok := r.(*Rlerror); ok {
- return nil, syscall.Errno(rlerr.Error)
+ return r, syscall.Errno(rlerr.Error)
}
return r, nil
diff --git a/pkg/safecopy/memcpy_amd64.s b/pkg/safecopy/memcpy_amd64.s
index 129691d68..00b46c18f 100644
--- a/pkg/safecopy/memcpy_amd64.s
+++ b/pkg/safecopy/memcpy_amd64.s
@@ -55,15 +55,9 @@ TEXT ·memcpy(SB), NOSPLIT, $0-36
MOVQ from+8(FP), SI
MOVQ n+16(FP), BX
- // REP instructions have a high startup cost, so we handle small sizes
- // with some straightline code. The REP MOVSQ instruction is really fast
- // for large sizes. The cutover is approximately 2K.
tail:
- // move_129through256 or smaller work whether or not the source and the
- // destination memory regions overlap because they load all data into
- // registers before writing it back. move_256through2048 on the other
- // hand can be used only when the memory regions don't overlap or the copy
- // direction is forward.
+ // BSR+branch table make almost all memmove/memclr benchmarks worse. Not
+ // worth doing.
TESTQ BX, BX
JEQ move_0
CMPQ BX, $2
@@ -83,31 +77,45 @@ tail:
JBE move_65through128
CMPQ BX, $256
JBE move_129through256
- // TODO: use branch table and BSR to make this just a single dispatch
-/*
- * forward copy loop
- */
- CMPQ BX, $2048
- JLS move_256through2048
-
- // Check alignment
- MOVL SI, AX
- ORL DI, AX
- TESTL $7, AX
- JEQ fwdBy8
-
- // Do 1 byte at a time
- MOVQ BX, CX
- REP; MOVSB
- RET
-
-fwdBy8:
- // Do 8 bytes at a time
- MOVQ BX, CX
- SHRQ $3, CX
- ANDQ $7, BX
- REP; MOVSQ
+move_257plus:
+ SUBQ $256, BX
+ MOVOU (SI), X0
+ MOVOU X0, (DI)
+ MOVOU 16(SI), X1
+ MOVOU X1, 16(DI)
+ MOVOU 32(SI), X2
+ MOVOU X2, 32(DI)
+ MOVOU 48(SI), X3
+ MOVOU X3, 48(DI)
+ MOVOU 64(SI), X4
+ MOVOU X4, 64(DI)
+ MOVOU 80(SI), X5
+ MOVOU X5, 80(DI)
+ MOVOU 96(SI), X6
+ MOVOU X6, 96(DI)
+ MOVOU 112(SI), X7
+ MOVOU X7, 112(DI)
+ MOVOU 128(SI), X8
+ MOVOU X8, 128(DI)
+ MOVOU 144(SI), X9
+ MOVOU X9, 144(DI)
+ MOVOU 160(SI), X10
+ MOVOU X10, 160(DI)
+ MOVOU 176(SI), X11
+ MOVOU X11, 176(DI)
+ MOVOU 192(SI), X12
+ MOVOU X12, 192(DI)
+ MOVOU 208(SI), X13
+ MOVOU X13, 208(DI)
+ MOVOU 224(SI), X14
+ MOVOU X14, 224(DI)
+ MOVOU 240(SI), X15
+ MOVOU X15, 240(DI)
+ CMPQ BX, $256
+ LEAQ 256(SI), SI
+ LEAQ 256(DI), DI
+ JGE move_257plus
JMP tail
move_1or2:
@@ -209,42 +217,3 @@ move_129through256:
MOVOU -16(SI)(BX*1), X15
MOVOU X15, -16(DI)(BX*1)
RET
-move_256through2048:
- SUBQ $256, BX
- MOVOU (SI), X0
- MOVOU X0, (DI)
- MOVOU 16(SI), X1
- MOVOU X1, 16(DI)
- MOVOU 32(SI), X2
- MOVOU X2, 32(DI)
- MOVOU 48(SI), X3
- MOVOU X3, 48(DI)
- MOVOU 64(SI), X4
- MOVOU X4, 64(DI)
- MOVOU 80(SI), X5
- MOVOU X5, 80(DI)
- MOVOU 96(SI), X6
- MOVOU X6, 96(DI)
- MOVOU 112(SI), X7
- MOVOU X7, 112(DI)
- MOVOU 128(SI), X8
- MOVOU X8, 128(DI)
- MOVOU 144(SI), X9
- MOVOU X9, 144(DI)
- MOVOU 160(SI), X10
- MOVOU X10, 160(DI)
- MOVOU 176(SI), X11
- MOVOU X11, 176(DI)
- MOVOU 192(SI), X12
- MOVOU X12, 192(DI)
- MOVOU 208(SI), X13
- MOVOU X13, 208(DI)
- MOVOU 224(SI), X14
- MOVOU X14, 224(DI)
- MOVOU 240(SI), X15
- MOVOU X15, 240(DI)
- CMPQ BX, $256
- LEAQ 256(SI), SI
- LEAQ 256(DI), DI
- JGE move_256through2048
- JMP tail
diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go
index 521f1a82d..2fb7e5809 100644
--- a/pkg/safecopy/safecopy.go
+++ b/pkg/safecopy/safecopy.go
@@ -127,10 +127,10 @@ func initializeAddresses() {
func init() {
initializeAddresses()
- if err := ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(signalHandler).Pointer(), &savedSigSegVHandler, 0); err != nil {
+ if err := ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(signalHandler).Pointer(), &savedSigSegVHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err))
}
- if err := ReplaceSignalHandler(syscall.SIGBUS, reflect.ValueOf(signalHandler).Pointer(), &savedSigBusHandler, 0); err != nil {
+ if err := ReplaceSignalHandler(syscall.SIGBUS, reflect.ValueOf(signalHandler).Pointer(), &savedSigBusHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err))
}
syserror.AddErrorUnwrapper(func(e error) (syscall.Errno, bool) {
diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go
index b15b920fe..41dd567f3 100644
--- a/pkg/safecopy/safecopy_unsafe.go
+++ b/pkg/safecopy/safecopy_unsafe.go
@@ -324,13 +324,11 @@ func errorFromFaultSignal(addr uintptr, sig int32) error {
//
// It stores the value of the previously set handler in previous.
//
-// The extraMask parameter is OR'ed into the existing signal handler mask.
-//
// This function will be called on initialization in order to install safecopy
// handlers for appropriate signals. These handlers will call the previous
// handler however, and if this is function is being used externally then the
// same courtesy is expected.
-func ReplaceSignalHandler(sig syscall.Signal, handler uintptr, previous *uintptr, extraMask uint64) error {
+func ReplaceSignalHandler(sig syscall.Signal, handler uintptr, previous *uintptr) error {
var sa struct {
handler uintptr
flags uint64
@@ -350,10 +348,10 @@ func ReplaceSignalHandler(sig syscall.Signal, handler uintptr, previous *uintptr
if sa.handler == 0 {
return fmt.Errorf("previous handler for signal %x isn't set", sig)
}
+
*previous = sa.handler
// Install our own handler.
- sa.mask |= extraMask
sa.handler = handler
if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 {
return e
diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go
index f19a005f3..97b16c158 100644
--- a/pkg/segment/test/segment_test.go
+++ b/pkg/segment/test/segment_test.go
@@ -63,7 +63,7 @@ func checkSet(s *Set, expectedSegments int) error {
return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments)
}
if got, want := seg.Value(), seg.Start()+valueOffset; got != want {
- return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start, got, want)
+ return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start(), got, want)
}
prev = next
havePrev = true
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index 1d11cc472..a903d031c 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -88,6 +88,9 @@ type Context interface {
// SyscallNo returns the syscall number.
SyscallNo() uintptr
+ // SyscallSaveOrig save orignal register value.
+ SyscallSaveOrig()
+
// SyscallArgs returns the syscall arguments in an array.
SyscallArgs() SyscallArguments
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
index e35c9214a..aa31169e0 100644
--- a/pkg/sentry/arch/arch_state_x86.go
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 i386
+// +build amd64 386
package arch
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index 88b40a9d1..7fc4c0473 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 i386
+// +build amd64 386
package arch
diff --git a/pkg/sentry/arch/arch_x86_impl.go b/pkg/sentry/arch/arch_x86_impl.go
index 04ac283c6..3edf40764 100644
--- a/pkg/sentry/arch/arch_x86_impl.go
+++ b/pkg/sentry/arch/arch_x86_impl.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 i386
+// +build amd64 386
package arch
diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go
index 1a6056171..e58f055c7 100644
--- a/pkg/sentry/arch/signal_stack.go
+++ b/pkg/sentry/arch/signal_stack.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64 arm64
+// +build 386 amd64 arm64
package arch
diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go
index 09bceabc9..1108fa0bd 100644
--- a/pkg/sentry/arch/stack.go
+++ b/pkg/sentry/arch/stack.go
@@ -97,7 +97,6 @@ func (s *Stack) Push(vals ...interface{}) (usermem.Addr, error) {
if c < 0 {
return 0, fmt.Errorf("bad binary.Size for %T", v)
}
- // TODO(b/38173783): Use a real context.Context.
n, err := usermem.CopyObjectOut(context.Background(), s.IO, s.Bottom-usermem.Addr(c), norm, usermem.IOOpts{})
if err != nil || c != n {
return 0, err
@@ -121,11 +120,9 @@ func (s *Stack) Pop(vals ...interface{}) (usermem.Addr, error) {
var err error
if isVaddr {
value := s.Arch.Native(uintptr(0))
- // TODO(b/38173783): Use a real context.Context.
n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, value, usermem.IOOpts{})
*vaddr = usermem.Addr(s.Arch.Value(value))
} else {
- // TODO(b/38173783): Use a real context.Context.
n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, v, usermem.IOOpts{})
}
if err != nil {
diff --git a/pkg/sentry/arch/syscalls_amd64.go b/pkg/sentry/arch/syscalls_amd64.go
index 8b4f23007..3859f41ee 100644
--- a/pkg/sentry/arch/syscalls_amd64.go
+++ b/pkg/sentry/arch/syscalls_amd64.go
@@ -18,6 +18,13 @@ package arch
const restartSyscallNr = uintptr(219)
+// SyscallSaveOrig save the value of the register which is clobbered in
+// syscall handler(doSyscall()).
+//
+// Noop on x86.
+func (c *context64) SyscallSaveOrig() {
+}
+
// SyscallNo returns the syscall number according to the 64-bit convention.
func (c *context64) SyscallNo() uintptr {
return uintptr(c.Regs.Orig_rax)
diff --git a/pkg/sentry/arch/syscalls_arm64.go b/pkg/sentry/arch/syscalls_arm64.go
index dc13b6124..92d062513 100644
--- a/pkg/sentry/arch/syscalls_arm64.go
+++ b/pkg/sentry/arch/syscalls_arm64.go
@@ -18,6 +18,17 @@ package arch
const restartSyscallNr = uintptr(128)
+// SyscallSaveOrig save the value of the register R0 which is clobbered in
+// syscall handler(doSyscall()).
+//
+// In linux, at the entry of the syscall handler(el0_svc_common()), value of R0
+// is saved to the pt_regs.orig_x0 in kernel code. But currently, the orig_x0
+// was not accessible to the user space application, so we have to do the same
+// operation in the sentry code to save the R0 value into the App context.
+func (c *context64) SyscallSaveOrig() {
+ c.OrigR0 = c.Regs.Regs[0]
+}
+
// SyscallNo returns the syscall number according to the 64-bit convention.
func (c *context64) SyscallNo() uintptr {
return uintptr(c.Regs.Regs[8])
@@ -40,7 +51,7 @@ func (c *context64) SyscallNo() uintptr {
// R30: the link register.
func (c *context64) SyscallArgs() SyscallArguments {
return SyscallArguments{
- SyscallArgument{Value: uintptr(c.Regs.Regs[0])},
+ SyscallArgument{Value: uintptr(c.OrigR0)},
SyscallArgument{Value: uintptr(c.Regs.Regs[1])},
SyscallArgument{Value: uintptr(c.Regs.Regs[2])},
SyscallArgument{Value: uintptr(c.Regs.Regs[3])},
diff --git a/pkg/sentry/contexttest/contexttest.go b/pkg/sentry/contexttest/contexttest.go
index 031fc64ec..8e5658c7a 100644
--- a/pkg/sentry/contexttest/contexttest.go
+++ b/pkg/sentry/contexttest/contexttest.go
@@ -97,7 +97,7 @@ type hostClock struct {
}
// Now implements ktime.Clock.Now.
-func (hostClock) Now() ktime.Time {
+func (*hostClock) Now() ktime.Time {
return ktime.FromNanoseconds(time.Now().UnixNano())
}
@@ -127,7 +127,7 @@ func (t *TestContext) Value(key interface{}) interface{} {
case uniqueid.CtxInotifyCookie:
return atomic.AddUint32(&lastInotifyCookie, 1)
case ktime.CtxRealtimeClock:
- return hostClock{}
+ return &hostClock{}
default:
if val, ok := t.otherValues[key]; ok {
return val
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index 0266a5287..65be12175 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -312,9 +312,9 @@ func (d *Dirent) SyncAll(ctx context.Context) {
// There is nothing to sync for a read-only filesystem.
if !d.Inode.MountSource.Flags.ReadOnly {
- // FIXME(b/34856369): This should be a mount traversal, not a
- // Dirent traversal, because some Inodes that need to be synced
- // may no longer be reachable by name (after sys_unlink).
+ // NOTE(b/34856369): This should be a mount traversal, not a Dirent
+ // traversal, because some Inodes that need to be synced may no longer
+ // be reachable by name (after sys_unlink).
//
// Write out metadata, dirty page cached pages, and sync disk/remote
// caches.
diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go
index 5aff0cc95..a0082ecca 100644
--- a/pkg/sentry/fs/fdpipe/pipe_test.go
+++ b/pkg/sentry/fs/fdpipe/pipe_test.go
@@ -119,7 +119,7 @@ func TestNewPipe(t *testing.T) {
continue
}
if flags := p.flags; test.flags != flags {
- t.Errorf("%s: got file flags %s, want %s", test.desc, flags, test.flags)
+ t.Errorf("%s: got file flags %v, want %v", test.desc, flags, test.flags)
continue
}
if len(test.readAheadBuffer) != len(p.readAheadBuffer) {
@@ -136,7 +136,7 @@ func TestNewPipe(t *testing.T) {
continue
}
if !fdnotifier.HasFD(int32(f.FD())) {
- t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD)
+ t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD())
}
}
}
diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go
index ff96b28ba..edd6576aa 100644
--- a/pkg/sentry/fs/gofer/file_state.go
+++ b/pkg/sentry/fs/gofer/file_state.go
@@ -34,7 +34,6 @@ func (f *fileOperations) afterLoad() {
flags := f.flags
flags.Truncate = false
- // TODO(b/38173783): Context is not plumbed to save/restore.
f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), flags, f.inodeOperations.cachingInodeOps)
if err != nil {
return fmt.Errorf("failed to re-open handle: %v", err)
diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go
index 9f7c3e89f..fc14249be 100644
--- a/pkg/sentry/fs/gofer/handles.go
+++ b/pkg/sentry/fs/gofer/handles.go
@@ -57,7 +57,6 @@ func (h *handles) DecRef() {
}
}
}
- // FIXME(b/38173783): Context is not plumbed here.
if err := h.File.close(context.Background()); err != nil {
log.Warningf("error closing p9 file: %v", err)
}
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 1c934981b..a016c896e 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -273,7 +273,7 @@ func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handle
// operations on the old will see the new data. Then, make the new handle take
// ownereship of the old FD and mark the old readHandle to not close the FD
// when done.
- if err := syscall.Dup3(h.Host.FD(), i.readHandles.Host.FD(), 0); err != nil {
+ if err := syscall.Dup3(h.Host.FD(), i.readHandles.Host.FD(), syscall.O_CLOEXEC); err != nil {
return err
}
@@ -710,13 +710,10 @@ func init() {
}
// AddLink implements InodeOperations.AddLink, but is currently a noop.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (*inodeOperations) AddLink() {}
// DropLink implements InodeOperations.DropLink, but is currently a noop.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (*inodeOperations) DropLink() {}
// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {}
diff --git a/pkg/sentry/fs/gofer/inode_state.go b/pkg/sentry/fs/gofer/inode_state.go
index 238f7804c..a3402e343 100644
--- a/pkg/sentry/fs/gofer/inode_state.go
+++ b/pkg/sentry/fs/gofer/inode_state.go
@@ -123,7 +123,6 @@ func (i *inodeFileState) afterLoad() {
// beforeSave.
return fmt.Errorf("failed to find path for inode number %d. Device %s contains %s", i.sattr.InodeID, i.s.connID, fs.InodeMappings(i.s.inodeMappings))
}
- // TODO(b/38173783): Context is not plumbed to save/restore.
ctx := &dummyClockContext{context.Background()}
_, i.file, err = i.s.attach.walk(ctx, splitAbsolutePath(name))
diff --git a/pkg/sentry/fs/gofer/session_state.go b/pkg/sentry/fs/gofer/session_state.go
index 111da59f9..2d398b753 100644
--- a/pkg/sentry/fs/gofer/session_state.go
+++ b/pkg/sentry/fs/gofer/session_state.go
@@ -104,7 +104,6 @@ func (s *session) afterLoad() {
// If private unix sockets are enabled, create and fill the session's endpoint
// maps.
if opts.privateunixsocket {
- // TODO(b/38173783): Context is not plumbed to save/restore.
ctx := &dummyClockContext{context.Background()}
if err = s.restoreEndpointMaps(ctx); err != nil {
diff --git a/pkg/sentry/fs/gofer/util.go b/pkg/sentry/fs/gofer/util.go
index 2d8d3a2ea..47a6c69bf 100644
--- a/pkg/sentry/fs/gofer/util.go
+++ b/pkg/sentry/fs/gofer/util.go
@@ -20,17 +20,29 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
)
func utimes(ctx context.Context, file contextFile, ts fs.TimeSpec) error {
if ts.ATimeOmit && ts.MTimeOmit {
return nil
}
+
+ // Replace requests to use the "system time" with the current time to
+ // ensure that timestamps remain consistent with the remote
+ // filesystem.
+ now := ktime.NowFromContext(ctx)
+ if ts.ATimeSetSystemTime {
+ ts.ATime = now
+ }
+ if ts.MTimeSetSystemTime {
+ ts.MTime = now
+ }
mask := p9.SetAttrMask{
ATime: !ts.ATimeOmit,
- ATimeNotSystemTime: !ts.ATimeSetSystemTime,
+ ATimeNotSystemTime: true,
MTime: !ts.MTimeOmit,
- MTimeNotSystemTime: !ts.MTimeSetSystemTime,
+ MTimeNotSystemTime: true,
}
as, ans := ts.ATime.Unix()
ms, mns := ts.MTime.Unix()
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 1da3c0a17..62f1246aa 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -397,15 +397,12 @@ func (i *inodeOperations) StatFS(context.Context) (fs.Info, error) {
}
// AddLink implements fs.InodeOperations.AddLink.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (i *inodeOperations) AddLink() {}
// DropLink implements fs.InodeOperations.DropLink.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (i *inodeOperations) DropLink() {}
// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {}
// readdirAll returns all of the directory entries in i.
diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go
index eb4afe520..affdbcacb 100644
--- a/pkg/sentry/fs/host/socket_test.go
+++ b/pkg/sentry/fs/host/socket_test.go
@@ -199,14 +199,14 @@ func TestListen(t *testing.T) {
}
func TestPasscred(t *testing.T) {
- e := ConnectedEndpoint{}
+ e := &ConnectedEndpoint{}
if got, want := e.Passcred(), false; got != want {
t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want)
}
}
func TestGetLocalAddress(t *testing.T) {
- e := ConnectedEndpoint{path: "foo"}
+ e := &ConnectedEndpoint{path: "foo"}
want := tcpip.FullAddress{Addr: tcpip.Address("foo")}
if got, err := e.GetLocalAddress(); err != nil || got != want {
t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil)
@@ -214,7 +214,7 @@ func TestGetLocalAddress(t *testing.T) {
}
func TestQueuedSize(t *testing.T) {
- e := ConnectedEndpoint{}
+ e := &ConnectedEndpoint{}
tests := []struct {
name string
f func() int64
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
index 55fb71c16..a34fbc946 100644
--- a/pkg/sentry/fs/inode.go
+++ b/pkg/sentry/fs/inode.go
@@ -102,7 +102,6 @@ func (i *Inode) DecRef() {
// destroy releases the Inode and releases the msrc reference taken.
func (i *Inode) destroy() {
- // FIXME(b/38173783): Context is not plumbed here.
ctx := context.Background()
if err := i.WriteOut(ctx); err != nil {
// FIXME(b/65209558): Mark as warning again once noatime is
@@ -397,8 +396,6 @@ func (i *Inode) Getlink(ctx context.Context) (*Dirent, error) {
// AddLink calls i.InodeOperations.AddLink.
func (i *Inode) AddLink() {
if i.overlay != nil {
- // FIXME(b/63117438): Remove this from InodeOperations altogether.
- //
// This interface is only used by ramfs to update metadata of
// children. These filesystems should _never_ have overlay
// Inodes cached as children. So explicitly disallow this
diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go
index 94deb553b..1fc9c703c 100644
--- a/pkg/sentry/fs/proc/mounts.go
+++ b/pkg/sentry/fs/proc/mounts.go
@@ -170,7 +170,8 @@ func superBlockOpts(mountPath string, msrc *fs.MountSource) string {
// NOTE(b/147673608): If the mount is a cgroup, we also need to include
// the cgroup name in the options. For now we just read that from the
// path.
- // TODO(gvisor.dev/issues/190): Once gVisor has full cgroup support, we
+ //
+ // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we
// should get this value from the cgroup itself, and not rely on the
// path.
if msrc.FilesystemType == "cgroup" {
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index d4c4b533d..702fdd392 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -80,7 +80,7 @@ func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir
}
// Truncate implements fs.InodeOperations.Truncate.
-func (tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error {
+func (*tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error {
return nil
}
@@ -196,7 +196,7 @@ func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *f
}
// Truncate implements fs.InodeOperations.Truncate.
-func (tcpSack) Truncate(context.Context, *fs.Inode, int64) error {
+func (*tcpSack) Truncate(context.Context, *fs.Inode, int64) error {
return nil
}
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index d6c5dd2c1..4d42eac83 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -57,6 +57,16 @@ func getTaskMM(t *kernel.Task) (*mm.MemoryManager, error) {
return m, nil
}
+func checkTaskState(t *kernel.Task) error {
+ switch t.ExitState() {
+ case kernel.TaskExitZombie:
+ return syserror.EACCES
+ case kernel.TaskExitDead:
+ return syserror.ESRCH
+ }
+ return nil
+}
+
// taskDir represents a task-level directory.
//
// +stateify savable
@@ -254,11 +264,12 @@ func newExe(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
}
func (e *exe) executable() (file fsbridge.File, err error) {
+ if err := checkTaskState(e.t); err != nil {
+ return nil, err
+ }
e.t.WithMuLocked(func(t *kernel.Task) {
mm := t.MemoryManager()
if mm == nil {
- // TODO(b/34851096): Check shouldn't allow Readlink once the
- // Task is zombied.
err = syserror.EACCES
return
}
@@ -268,7 +279,7 @@ func (e *exe) executable() (file fsbridge.File, err error) {
// (with locks held).
file = mm.Executable()
if file == nil {
- err = syserror.ENOENT
+ err = syserror.ESRCH
}
})
return
@@ -313,11 +324,22 @@ func newNamespaceSymlink(t *kernel.Task, msrc *fs.MountSource, name string) *fs.
return newProcInode(t, n, msrc, fs.Symlink, t)
}
+// Readlink reads the symlink value.
+func (n *namespaceSymlink) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ if err := checkTaskState(n.t); err != nil {
+ return "", err
+ }
+ return n.Symlink.Readlink(ctx, inode)
+}
+
// Getlink implements fs.InodeOperations.Getlink.
func (n *namespaceSymlink) Getlink(ctx context.Context, inode *fs.Inode) (*fs.Dirent, error) {
if !kernel.ContextCanTrace(ctx, n.t, false) {
return nil, syserror.EACCES
}
+ if err := checkTaskState(n.t); err != nil {
+ return nil, err
+ }
// Create a new regular file to fake the namespace file.
iops := fsutil.NewNoReadWriteFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0777), linux.PROC_SUPER_MAGIC)
diff --git a/pkg/sentry/fs/tmpfs/fs.go b/pkg/sentry/fs/tmpfs/fs.go
index d5be56c3f..bc117ca6a 100644
--- a/pkg/sentry/fs/tmpfs/fs.go
+++ b/pkg/sentry/fs/tmpfs/fs.go
@@ -44,9 +44,6 @@ const (
// lookup.
cacheRevalidate = "revalidate"
- // TODO(edahlgren/mpratt): support a tmpfs size limit.
- // size = "size"
-
// Permissions that exceed modeMask will be rejected.
modeMask = 01777
diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go
index 79b808359..89168220a 100644
--- a/pkg/sentry/fsbridge/vfs.go
+++ b/pkg/sentry/fsbridge/vfs.go
@@ -26,22 +26,22 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// fsFile implements File interface over vfs.FileDescription.
+// VFSFile implements File interface over vfs.FileDescription.
//
// +stateify savable
-type vfsFile struct {
+type VFSFile struct {
file *vfs.FileDescription
}
-var _ File = (*vfsFile)(nil)
+var _ File = (*VFSFile)(nil)
// NewVFSFile creates a new File over fs.File.
func NewVFSFile(file *vfs.FileDescription) File {
- return &vfsFile{file: file}
+ return &VFSFile{file: file}
}
// PathnameWithDeleted implements File.
-func (f *vfsFile) PathnameWithDeleted(ctx context.Context) string {
+func (f *VFSFile) PathnameWithDeleted(ctx context.Context) string {
root := vfs.RootFromContext(ctx)
defer root.DecRef()
@@ -51,7 +51,7 @@ func (f *vfsFile) PathnameWithDeleted(ctx context.Context) string {
}
// ReadFull implements File.
-func (f *vfsFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
+func (f *VFSFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
var total int64
for dst.NumBytes() > 0 {
n, err := f.file.PRead(ctx, dst, offset+total, vfs.ReadOptions{})
@@ -67,12 +67,12 @@ func (f *vfsFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset i
}
// ConfigureMMap implements File.
-func (f *vfsFile) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+func (f *VFSFile) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
return f.file.ConfigureMMap(ctx, opts)
}
// Type implements File.
-func (f *vfsFile) Type(ctx context.Context) (linux.FileMode, error) {
+func (f *VFSFile) Type(ctx context.Context) (linux.FileMode, error) {
stat, err := f.file.Stat(ctx, vfs.StatOptions{})
if err != nil {
return 0, err
@@ -81,15 +81,21 @@ func (f *vfsFile) Type(ctx context.Context) (linux.FileMode, error) {
}
// IncRef implements File.
-func (f *vfsFile) IncRef() {
+func (f *VFSFile) IncRef() {
f.file.IncRef()
}
// DecRef implements File.
-func (f *vfsFile) DecRef() {
+func (f *VFSFile) DecRef() {
f.file.DecRef()
}
+// FileDescription returns the FileDescription represented by f. It does not
+// take an additional reference on the returned FileDescription.
+func (f *VFSFile) FileDescription() *vfs.FileDescription {
+ return f.file
+}
+
// fsLookup implements Lookup interface using fs.File.
//
// +stateify savable
@@ -132,5 +138,5 @@ func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.Open
if err != nil {
return nil, err
}
- return &vfsFile{file: fd}, nil
+ return &VFSFile{file: fd}, nil
}
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index 48eaccdbc..afea58f65 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -476,7 +476,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
}
// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
-func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
_, _, err := fs.walk(rp, false)
if err != nil {
return nil, err
@@ -485,7 +485,7 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([
}
// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
-func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
_, _, err := fs.walk(rp, false)
if err != nil {
return "", err
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index d15a36709..99d1e3f8f 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
licenses(["notice"])
@@ -54,3 +54,13 @@ go_library(
"//pkg/usermem",
],
)
+
+go_test(
+ name = "gofer_test",
+ srcs = ["gofer_test.go"],
+ library = ":gofer",
+ deps = [
+ "//pkg/p9",
+ "//pkg/sentry/contexttest",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 5dbfc6250..49d9f859b 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -56,14 +56,19 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
fd.mu.Lock()
defer fd.mu.Unlock()
+ d := fd.dentry()
if fd.dirents == nil {
- ds, err := fd.dentry().getDirents(ctx)
+ ds, err := d.getDirents(ctx)
if err != nil {
return err
}
fd.dirents = ds
}
+ if d.fs.opts.interop != InteropModeShared {
+ d.touchAtime(fd.vfsfd.Mount())
+ }
+
for fd.off < int64(len(fd.dirents)) {
if err := cb.Handle(fd.dirents[fd.off]); err != nil {
return err
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 269624362..cd744bf5e 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -356,7 +356,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if err := create(parent, name); err != nil {
return err
}
- parent.touchCMtime(ctx)
+ if fs.opts.interop != InteropModeShared {
+ parent.touchCMtime()
+ }
delete(parent.negativeChildren, name)
parent.dirents = nil
return nil
@@ -435,14 +437,19 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
flags := uint32(0)
if dir {
if child != nil && !child.isDir() {
+ vfsObj.AbortDeleteDentry(childVFSD)
return syserror.ENOTDIR
}
flags = linux.AT_REMOVEDIR
} else {
if child != nil && child.isDir() {
+ vfsObj.AbortDeleteDentry(childVFSD)
return syserror.EISDIR
}
if rp.MustBeDir() {
+ if childVFSD != nil {
+ vfsObj.AbortDeleteDentry(childVFSD)
+ }
return syserror.ENOTDIR
}
}
@@ -454,7 +461,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
return err
}
if fs.opts.interop != InteropModeShared {
- parent.touchCMtime(ctx)
+ parent.touchCMtime()
if dir {
parent.decLinks()
}
@@ -802,7 +809,6 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
d.IncRef() // reference held by child on its parent d
d.vfsd.InsertChild(&child.vfsd, name)
if d.fs.opts.interop != InteropModeShared {
- d.touchCMtime(ctx)
delete(d.negativeChildren, name)
d.dirents = nil
}
@@ -834,6 +840,9 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
}
childVFSFD = &fd.vfsfd
}
+ if d.fs.opts.interop != InteropModeShared {
+ d.touchCMtime()
+ }
return childVFSFD, nil
}
@@ -975,6 +984,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
oldParent.decLinks()
newParent.incLinks()
}
+ oldParent.touchCMtime()
+ newParent.touchCMtime()
+ renamed.touchCtime()
}
vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, &newParent.vfsd, newName, replacedVFSD)
return nil
@@ -1068,7 +1080,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
}
// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
-func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(&ds)
@@ -1076,11 +1088,11 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([
if err != nil {
return nil, err
}
- return d.listxattr(ctx)
+ return d.listxattr(ctx, rp.Credentials(), size)
}
// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
-func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(&ds)
@@ -1088,7 +1100,7 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, nam
if err != nil {
return "", err
}
- return d.getxattr(ctx, name)
+ return d.getxattr(ctx, rp.Credentials(), &opts)
}
// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
@@ -1100,7 +1112,7 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt
if err != nil {
return err
}
- return d.setxattr(ctx, &opts)
+ return d.setxattr(ctx, rp.Credentials(), &opts)
}
// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
@@ -1112,7 +1124,7 @@ func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath,
if err != nil {
return err
}
- return d.removexattr(ctx, name)
+ return d.removexattr(ctx, rp.Credentials(), name)
}
// PrependPath implements vfs.FilesystemImpl.PrependPath.
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 8e41b6b1c..2485cdb53 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -34,6 +34,7 @@ package gofer
import (
"fmt"
"strconv"
+ "strings"
"sync"
"sync/atomic"
"syscall"
@@ -44,6 +45,7 @@ import (
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -72,6 +74,9 @@ type filesystem struct {
// client is the client used by this filesystem. client is immutable.
client *p9.Client
+ // clock is a realtime clock used to set timestamps in file operations.
+ clock ktime.Clock
+
// uid and gid are the effective KUID and KGID of the filesystem's creator,
// and are used as the owner and group for files that don't specify one.
// uid and gid are immutable.
@@ -376,6 +381,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
uid: creds.EffectiveKUID,
gid: creds.EffectiveKGID,
client: client,
+ clock: ktime.RealtimeClockFromContext(ctx),
dentries: make(map[*dentry]struct{}),
specialFileFDs: make(map[*specialFileFD]struct{}),
}
@@ -439,7 +445,8 @@ type dentry struct {
// refs is the reference count. Each dentry holds a reference on its
// parent, even if disowned. refs is accessed using atomic memory
- // operations.
+ // operations. When refs reaches 0, the dentry may be added to the cache or
+ // destroyed. If refs==-1 the dentry has already been destroyed.
refs int64
// fs is the owning filesystem. fs is immutable.
@@ -779,10 +786,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
// data, so there's no cache to truncate either.)
return nil
}
- now, haveNow := nowFromContext(ctx)
- if !haveNow {
- ctx.Warningf("gofer.dentry.setStat: current time not available")
- }
+ now := d.fs.clock.Now().Nanoseconds()
if stat.Mask&linux.STATX_MODE != 0 {
atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode))
}
@@ -794,25 +798,19 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
}
if setLocalAtime {
if stat.Atime.Nsec == linux.UTIME_NOW {
- if haveNow {
- atomic.StoreInt64(&d.atime, now)
- }
+ atomic.StoreInt64(&d.atime, now)
} else {
atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime))
}
}
if setLocalMtime {
if stat.Mtime.Nsec == linux.UTIME_NOW {
- if haveNow {
- atomic.StoreInt64(&d.mtime, now)
- }
+ atomic.StoreInt64(&d.mtime, now)
} else {
atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime))
}
}
- if haveNow {
- atomic.StoreInt64(&d.ctime, now)
- }
+ atomic.StoreInt64(&d.ctime, now)
if stat.Mask&linux.STATX_SIZE != 0 {
d.dataMu.Lock()
oldSize := d.size
@@ -864,7 +862,7 @@ func (d *dentry) IncRef() {
func (d *dentry) TryIncRef() bool {
for {
refs := atomic.LoadInt64(&d.refs)
- if refs == 0 {
+ if refs <= 0 {
return false
}
if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
@@ -887,13 +885,20 @@ func (d *dentry) DecRef() {
// checkCachingLocked should be called after d's reference count becomes 0 or it
// becomes disowned.
//
+// It may be called on a destroyed dentry. For example,
+// renameMu[R]UnlockAndCheckCaching may call checkCachingLocked multiple times
+// for the same dentry when the dentry is visited more than once in the same
+// operation. One of the calls may destroy the dentry, so subsequent calls will
+// do nothing.
+//
// Preconditions: d.fs.renameMu must be locked for writing.
func (d *dentry) checkCachingLocked() {
// Dentries with a non-zero reference count must be retained. (The only way
// to obtain a reference on a dentry with zero references is via path
// resolution, which requires renameMu, so if d.refs is zero then it will
// remain zero while we hold renameMu for writing.)
- if atomic.LoadInt64(&d.refs) != 0 {
+ refs := atomic.LoadInt64(&d.refs)
+ if refs > 0 {
if d.cached {
d.fs.cachedDentries.Remove(d)
d.fs.cachedDentriesLen--
@@ -901,6 +906,10 @@ func (d *dentry) checkCachingLocked() {
}
return
}
+ if refs == -1 {
+ // Dentry has already been destroyed.
+ return
+ }
// Non-child dentries with zero references are no longer reachable by path
// resolution and should be dropped immediately.
if d.vfsd.Parent() == nil || d.vfsd.IsDisowned() {
@@ -953,9 +962,22 @@ func (d *dentry) checkCachingLocked() {
}
}
+// destroyLocked destroys the dentry. It may flushes dirty pages from cache,
+// close p9 file and remove reference on parent dentry.
+//
// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. d is
// not a child dentry.
func (d *dentry) destroyLocked() {
+ switch atomic.LoadInt64(&d.refs) {
+ case 0:
+ // Mark the dentry destroyed.
+ atomic.StoreInt64(&d.refs, -1)
+ case -1:
+ panic("dentry.destroyLocked() called on already destroyed dentry")
+ default:
+ panic("dentry.destroyLocked() called with references on the dentry")
+ }
+
ctx := context.Background()
d.handleMu.Lock()
if !d.handle.file.isNil() {
@@ -975,7 +997,10 @@ func (d *dentry) destroyLocked() {
d.handle.close(ctx)
}
d.handleMu.Unlock()
- d.file.close(ctx)
+ if !d.file.isNil() {
+ d.file.close(ctx)
+ d.file = p9file{}
+ }
// Remove d from the set of all dentries.
d.fs.syncMu.Lock()
delete(d.fs.dentries, d)
@@ -1000,21 +1025,50 @@ func (d *dentry) setDeleted() {
atomic.StoreUint32(&d.deleted, 1)
}
-func (d *dentry) listxattr(ctx context.Context) ([]string, error) {
- return nil, syserror.ENOTSUP
+// We only support xattrs prefixed with "user." (see b/148380782). Currently,
+// there is no need to expose any other xattrs through a gofer.
+func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) {
+ xattrMap, err := d.file.listXattr(ctx, size)
+ if err != nil {
+ return nil, err
+ }
+ xattrs := make([]string, 0, len(xattrMap))
+ for x := range xattrMap {
+ if strings.HasPrefix(x, linux.XATTR_USER_PREFIX) {
+ xattrs = append(xattrs, x)
+ }
+ }
+ return xattrs, nil
}
-func (d *dentry) getxattr(ctx context.Context, name string) (string, error) {
- // TODO(jamieliu): add vfs.GetxattrOptions.Size
- return d.file.getXattr(ctx, name, linux.XATTR_SIZE_MAX)
+func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) {
+ if err := d.checkPermissions(creds, vfs.MayRead); err != nil {
+ return "", err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return "", syserror.EOPNOTSUPP
+ }
+ return d.file.getXattr(ctx, opts.Name, opts.Size)
}
-func (d *dentry) setxattr(ctx context.Context, opts *vfs.SetxattrOptions) error {
+func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetxattrOptions) error {
+ if err := d.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags)
}
-func (d *dentry) removexattr(ctx context.Context, name string) error {
- return syserror.ENOTSUP
+func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name string) error {
+ if err := d.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ return d.file.removeXattr(ctx, name)
}
// Preconditions: d.isRegularFile() || d.isDirectory().
@@ -1065,7 +1119,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// description, but this doesn't matter since they refer to the
// same file (unless d.fs.opts.overlayfsStaleRead is true,
// which we handle separately).
- if err := syscall.Dup3(int(h.fd), int(d.handle.fd), 0); err != nil {
+ if err := syscall.Dup3(int(h.fd), int(d.handle.fd), syscall.O_CLOEXEC); err != nil {
d.handleMu.Unlock()
ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, d.handle.fd, err)
h.close(ctx)
@@ -1165,21 +1219,21 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
}
// Listxattr implements vfs.FileDescriptionImpl.Listxattr.
-func (fd *fileDescription) Listxattr(ctx context.Context) ([]string, error) {
- return fd.dentry().listxattr(ctx)
+func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
+ return fd.dentry().listxattr(ctx, auth.CredentialsFromContext(ctx), size)
}
// Getxattr implements vfs.FileDescriptionImpl.Getxattr.
-func (fd *fileDescription) Getxattr(ctx context.Context, name string) (string, error) {
- return fd.dentry().getxattr(ctx, name)
+func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) {
+ return fd.dentry().getxattr(ctx, auth.CredentialsFromContext(ctx), &opts)
}
// Setxattr implements vfs.FileDescriptionImpl.Setxattr.
func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error {
- return fd.dentry().setxattr(ctx, &opts)
+ return fd.dentry().setxattr(ctx, auth.CredentialsFromContext(ctx), &opts)
}
// Removexattr implements vfs.FileDescriptionImpl.Removexattr.
func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
- return fd.dentry().removexattr(ctx, name)
+ return fd.dentry().removexattr(ctx, auth.CredentialsFromContext(ctx), name)
}
diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go
new file mode 100644
index 000000000..82bc239db
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/gofer_test.go
@@ -0,0 +1,64 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "sync/atomic"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+)
+
+func TestDestroyIdempotent(t *testing.T) {
+ fs := filesystem{
+ dentries: make(map[*dentry]struct{}),
+ opts: filesystemOptions{
+ // Test relies on no dentry being held in the cache.
+ maxCachedDentries: 0,
+ },
+ }
+
+ ctx := contexttest.Context(t)
+ attr := &p9.Attr{
+ Mode: p9.ModeRegular,
+ }
+ mask := p9.AttrMask{
+ Mode: true,
+ Size: true,
+ }
+ parent, err := fs.newDentry(ctx, p9file{}, p9.QID{}, mask, attr)
+ if err != nil {
+ t.Fatalf("fs.newDentry(): %v", err)
+ }
+
+ child, err := fs.newDentry(ctx, p9file{}, p9.QID{}, mask, attr)
+ if err != nil {
+ t.Fatalf("fs.newDentry(): %v", err)
+ }
+ parent.IncRef() // reference held by child on its parent.
+ parent.vfsd.InsertChild(&child.vfsd, "child")
+
+ child.checkCachingLocked()
+ if got := atomic.LoadInt64(&child.refs); got != -1 {
+ t.Fatalf("child.refs=%d, want: -1", got)
+ }
+ // Parent will also be destroyed when child reference is removed.
+ if got := atomic.LoadInt64(&parent.refs); got != -1 {
+ t.Fatalf("parent.refs=%d, want: -1", got)
+ }
+ child.checkCachingLocked()
+ child.checkCachingLocked()
+}
diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go
index 755ac2985..87f0b877f 100644
--- a/pkg/sentry/fsimpl/gofer/p9file.go
+++ b/pkg/sentry/fsimpl/gofer/p9file.go
@@ -85,6 +85,13 @@ func (f p9file) setAttr(ctx context.Context, valid p9.SetAttrMask, attr p9.SetAt
return err
}
+func (f p9file) listXattr(ctx context.Context, size uint64) (map[string]struct{}, error) {
+ ctx.UninterruptibleSleepStart(false)
+ xattrs, err := f.file.ListXattr(size)
+ ctx.UninterruptibleSleepFinish(false)
+ return xattrs, err
+}
+
func (f p9file) getXattr(ctx context.Context, name string, size uint64) (string, error) {
ctx.UninterruptibleSleepStart(false)
val, err := f.file.GetXattr(name, size)
@@ -99,6 +106,13 @@ func (f p9file) setXattr(ctx context.Context, name, value string, flags uint32)
return err
}
+func (f p9file) removeXattr(ctx context.Context, name string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.RemoveXattr(name)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
func (f p9file) allocate(ctx context.Context, mode p9.AllocateMode, offset, length uint64) error {
ctx.UninterruptibleSleepStart(false)
err := f.file.Allocate(mode, offset, length)
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 3593eb1d5..857f7c74e 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -104,7 +104,7 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
putDentryReadWriter(rw)
if d.fs.opts.interop != InteropModeShared {
// Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed().
- d.touchAtime(ctx, fd.vfsfd.Mount())
+ d.touchAtime(fd.vfsfd.Mount())
}
return n, err
}
@@ -139,10 +139,7 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// Compare Linux's mm/filemap.c:__generic_file_write_iter() =>
// file_update_time(). This is d.touchCMtime(), but without locking
// d.metadataMu (recursively).
- if now, ok := nowFromContext(ctx); ok {
- atomic.StoreInt64(&d.mtime, now)
- atomic.StoreInt64(&d.ctime, now)
- }
+ d.touchCMtimeLocked()
}
if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
// Write dirty cached pages that will be touched by the write back to
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index 274f7346f..507e0e276 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -76,7 +76,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
// hold here since specialFileFD doesn't client-cache data. Just buffer the
// read instead.
if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
- d.touchAtime(ctx, fd.vfsfd.Mount())
+ d.touchAtime(fd.vfsfd.Mount())
}
buf := make([]byte, dst.NumBytes())
n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
@@ -117,7 +117,7 @@ func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// Do a buffered write. See rationale in PRead.
if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
- d.touchCMtime(ctx)
+ d.touchCMtime()
}
buf := make([]byte, src.NumBytes())
// Don't do partial writes if we get a partial read from src.
diff --git a/pkg/sentry/fsimpl/gofer/symlink.go b/pkg/sentry/fsimpl/gofer/symlink.go
index adf43be60..2ec819f86 100644
--- a/pkg/sentry/fsimpl/gofer/symlink.go
+++ b/pkg/sentry/fsimpl/gofer/symlink.go
@@ -27,7 +27,7 @@ func (d *dentry) isSymlink() bool {
// Precondition: d.isSymlink().
func (d *dentry) readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
if d.fs.opts.interop != InteropModeShared {
- d.touchAtime(ctx, mnt)
+ d.touchAtime(mnt)
d.dataMu.Lock()
if d.haveTarget {
target := d.target
diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go
index 7598ec6a8..2608e7e1d 100644
--- a/pkg/sentry/fsimpl/gofer/time.go
+++ b/pkg/sentry/fsimpl/gofer/time.go
@@ -18,8 +18,6 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/context"
- ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
@@ -38,23 +36,12 @@ func statxTimestampFromDentry(ns int64) linux.StatxTimestamp {
}
}
-func nowFromContext(ctx context.Context) (int64, bool) {
- if clock := ktime.RealtimeClockFromContext(ctx); clock != nil {
- return clock.Now().Nanoseconds(), true
- }
- return 0, false
-}
-
// Preconditions: fs.interop != InteropModeShared.
-func (d *dentry) touchAtime(ctx context.Context, mnt *vfs.Mount) {
+func (d *dentry) touchAtime(mnt *vfs.Mount) {
if err := mnt.CheckBeginWrite(); err != nil {
return
}
- now, ok := nowFromContext(ctx)
- if !ok {
- mnt.EndWrite()
- return
- }
+ now := d.fs.clock.Now().Nanoseconds()
d.metadataMu.Lock()
atomic.StoreInt64(&d.atime, now)
d.metadataMu.Unlock()
@@ -63,13 +50,25 @@ func (d *dentry) touchAtime(ctx context.Context, mnt *vfs.Mount) {
// Preconditions: fs.interop != InteropModeShared. The caller has successfully
// called vfs.Mount.CheckBeginWrite().
-func (d *dentry) touchCMtime(ctx context.Context) {
- now, ok := nowFromContext(ctx)
- if !ok {
- return
- }
+func (d *dentry) touchCtime() {
+ now := d.fs.clock.Now().Nanoseconds()
+ d.metadataMu.Lock()
+ atomic.StoreInt64(&d.ctime, now)
+ d.metadataMu.Unlock()
+}
+
+// Preconditions: fs.interop != InteropModeShared. The caller has successfully
+// called vfs.Mount.CheckBeginWrite().
+func (d *dentry) touchCMtime() {
+ now := d.fs.clock.Now().Nanoseconds()
d.metadataMu.Lock()
atomic.StoreInt64(&d.mtime, now)
atomic.StoreInt64(&d.ctime, now)
d.metadataMu.Unlock()
}
+
+func (d *dentry) touchCMtimeLocked() {
+ now := d.fs.clock.Now().Nanoseconds()
+ atomic.StoreInt64(&d.mtime, now)
+ atomic.StoreInt64(&d.ctime, now)
+}
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 7d9dcd4c9..fe14476f1 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -74,31 +74,33 @@ func ImportFD(mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs.FileDescription, err
}
// Retrieve metadata.
- var s syscall.Stat_t
- if err := syscall.Fstat(hostFD, &s); err != nil {
+ var s unix.Stat_t
+ if err := unix.Fstat(hostFD, &s); err != nil {
return nil, err
}
fileMode := linux.FileMode(s.Mode)
fileType := fileMode.FileType()
- // Pipes, character devices, and sockets.
- isStream := fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK
+
+ // Determine if hostFD is seekable. If not, this syscall will return ESPIPE
+ // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character
+ // devices.
+ _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR)
+ seekable := err != syserror.ESPIPE
i := &inode{
hostFD: hostFD,
- isStream: isStream,
+ seekable: seekable,
isTTY: isTTY,
canMap: canMap(uint32(fileType)),
ino: fs.NextIno(),
- mode: fileMode,
- // For simplicity, set offset to 0. Technically, we should
- // only set to 0 on files that are not seekable (sockets, pipes, etc.),
- // and use the offset from the host fd otherwise.
+ // For simplicity, set offset to 0. Technically, we should use the existing
+ // offset on the host if the file is seekable.
offset: 0,
}
- // These files can't be memory mapped, assert this.
- if i.isStream && i.canMap {
+ // Non-seekable files can't be memory mapped, assert this.
+ if !i.seekable && i.canMap {
panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped")
}
@@ -124,12 +126,12 @@ type inode struct {
// This field is initialized at creation time and is immutable.
hostFD int
- // isStream is true if the host fd points to a file representing a stream,
+ // seekable is false if the host fd points to a file representing a stream,
// e.g. a socket or a pipe. Such files are not seekable and can return
// EWOULDBLOCK for I/O operations.
//
// This field is initialized at creation time and is immutable.
- isStream bool
+ seekable bool
// isTTY is true if this file represents a TTY.
//
@@ -146,20 +148,6 @@ type inode struct {
// This field is initialized at creation time and is immutable.
ino uint64
- // modeMu protects mode.
- modeMu sync.Mutex
-
- // mode is a cached version of the file mode on the host. Note that it may
- // become out of date if the mode is changed on the host, e.g. with chmod.
- //
- // Generally, it is better to retrieve the mode from the host through an
- // fstat syscall. We only use this value in inode.Mode(), which cannot
- // return an error, if the syscall to host fails.
- //
- // FIXME(b/152294168): Plumb error into Inode.Mode() return value so we
- // can get rid of this.
- mode linux.FileMode
-
// offsetMu protects offset.
offsetMu sync.Mutex
@@ -192,10 +180,11 @@ func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, a
// Mode implements kernfs.Inode.
func (i *inode) Mode() linux.FileMode {
mode, _, _, err := i.getPermissions()
+ // Retrieving the mode from the host fd using fstat(2) should not fail.
+ // If the syscall does not succeed, something is fundamentally wrong.
if err != nil {
- return i.mode
+ panic(fmt.Sprintf("failed to retrieve mode from host fd %d: %v", i.hostFD, err))
}
-
return linux.FileMode(mode)
}
@@ -205,11 +194,6 @@ func (i *inode) getPermissions() (linux.FileMode, auth.KUID, auth.KGID, error) {
if err := syscall.Fstat(i.hostFD, &s); err != nil {
return 0, 0, 0, err
}
-
- // Update cached mode.
- i.modeMu.Lock()
- i.mode = linux.FileMode(s.Mode)
- i.modeMu.Unlock()
return linux.FileMode(s.Mode), auth.KUID(s.Uid), auth.KGID(s.Gid), nil
}
@@ -289,12 +273,6 @@ func (i *inode) Stat(_ *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, erro
ls.Ino = i.ino
}
- // Update cached mode.
- if (mask&linux.STATX_TYPE != 0) && (mask&linux.STATX_MODE != 0) {
- i.modeMu.Lock()
- i.mode = linux.FileMode(s.Mode)
- i.modeMu.Unlock()
- }
return ls, nil
}
@@ -361,9 +339,6 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
if err := syscall.Fchmod(i.hostFD, uint32(s.Mode)); err != nil {
return err
}
- i.modeMu.Lock()
- i.mode = linux.FileMode(s.Mode)
- i.modeMu.Unlock()
}
if m&linux.STATX_SIZE != 0 {
if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil {
@@ -481,8 +456,7 @@ func (f *fileDescription) Release() {
// PRead implements FileDescriptionImpl.
func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
i := f.inode
- // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null.
- if i.isStream {
+ if !i.seekable {
return 0, syserror.ESPIPE
}
@@ -492,8 +466,7 @@ func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, off
// Read implements FileDescriptionImpl.
func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
i := f.inode
- // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null.
- if i.isStream {
+ if !i.seekable {
n, err := readFromHostFD(ctx, i.hostFD, dst, -1, opts.Flags)
if isBlockError(err) {
// If we got any data at all, return it as a "completed" partial read
@@ -538,8 +511,7 @@ func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, off
// PWrite implements FileDescriptionImpl.
func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
i := f.inode
- // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null.
- if i.isStream {
+ if !i.seekable {
return 0, syserror.ESPIPE
}
@@ -549,8 +521,7 @@ func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, of
// Write implements FileDescriptionImpl.
func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
i := f.inode
- // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null.
- if i.isStream {
+ if !i.seekable {
n, err := writeToHostFD(ctx, i.hostFD, src, -1, opts.Flags)
if isBlockError(err) {
err = syserror.ErrWouldBlock
@@ -593,8 +564,7 @@ func writeToHostFD(ctx context.Context, hostFD int, src usermem.IOSequence, offs
// allow directory fds to be imported at all.
func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (int64, error) {
i := f.inode
- // TODO(b/34716638): Some char devices do support seeking, e.g. /dev/null.
- if i.isStream {
+ if !i.seekable {
return 0, syserror.ESPIPE
}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index a429fa23d..baf81b4db 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -63,6 +63,9 @@ afterSymlink:
rp.Advance()
return nextVFSD, nil
}
+ if len(name) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
d.dirMu.Lock()
nextVFSD, err := rp.ResolveChild(vfsd, name)
if err != nil {
@@ -76,16 +79,22 @@ afterSymlink:
}
// Resolve any symlink at current path component.
if rp.ShouldFollowSymlink() && next.isSymlink() {
- // TODO: VFS2 needs something extra for /proc/[pid]/fd/ "magic symlinks".
- target, err := next.inode.Readlink(ctx)
+ targetVD, targetPathname, err := next.inode.Getlink(ctx)
if err != nil {
return nil, err
}
- if err := rp.HandleSymlink(target); err != nil {
- return nil, err
+ if targetVD.Ok() {
+ err := rp.HandleJump(targetVD)
+ targetVD.DecRef()
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ if err := rp.HandleSymlink(targetPathname); err != nil {
+ return nil, err
+ }
}
goto afterSymlink
-
}
rp.Advance()
return &next.vfsd, nil
@@ -191,6 +200,9 @@ func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *v
if pc == "." || pc == ".." {
return "", syserror.EEXIST
}
+ if len(pc) > linux.NAME_MAX {
+ return "", syserror.ENAMETOOLONG
+ }
childVFSD, err := rp.ResolveChild(parentVFSD, pc)
if err != nil {
return "", err
@@ -433,6 +445,9 @@ afterTrailingSymlink:
if pc == "." || pc == ".." {
return nil, syserror.EISDIR
}
+ if len(pc) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
// Determine whether or not we need to create a file.
childVFSD, err := rp.ResolveChild(parentVFSD, pc)
if err != nil {
@@ -461,19 +476,25 @@ afterTrailingSymlink:
}
childDentry := childVFSD.Impl().(*Dentry)
childInode := childDentry.inode
- if rp.ShouldFollowSymlink() {
- if childDentry.isSymlink() {
- target, err := childInode.Readlink(ctx)
+ if rp.ShouldFollowSymlink() && childDentry.isSymlink() {
+ targetVD, targetPathname, err := childInode.Getlink(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if targetVD.Ok() {
+ err := rp.HandleJump(targetVD)
+ targetVD.DecRef()
if err != nil {
return nil, err
}
- if err := rp.HandleSymlink(target); err != nil {
+ } else {
+ if err := rp.HandleSymlink(targetPathname); err != nil {
return nil, err
}
- // rp.Final() may no longer be true since we now need to resolve the
- // symlink target.
- goto afterTrailingSymlink
}
+ // rp.Final() may no longer be true since we now need to resolve the
+ // symlink target.
+ goto afterTrailingSymlink
}
if err := childInode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
@@ -661,7 +682,7 @@ func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
if err != nil {
return linux.Statfs{}, err
}
- // TODO: actually implement statfs
+ // TODO(gvisor.dev/issue/1193): actually implement statfs.
return linux.Statfs{}, syserror.ENOSYS
}
@@ -742,7 +763,7 @@ func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
}
// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
-func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
fs.mu.RLock()
_, _, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
@@ -755,7 +776,7 @@ func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([
}
// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
-func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
fs.mu.RLock()
_, _, err := fs.walkExistingLocked(ctx, rp)
fs.mu.RUnlock()
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 5c84b10c9..65f09af5d 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -181,6 +181,11 @@ func (InodeNotSymlink) Readlink(context.Context) (string, error) {
return "", syserror.EINVAL
}
+// Getlink implements Inode.Getlink.
+func (InodeNotSymlink) Getlink(context.Context) (vfs.VirtualDentry, string, error) {
+ return vfs.VirtualDentry{}, "", syserror.EINVAL
+}
+
// InodeAttrs partially implements the Inode interface, specifically the
// inodeMetadata sub interface. InodeAttrs provides functionality related to
// inode attributes.
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 2cefef020..ad76b9f64 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -414,7 +414,21 @@ type inodeDynamicLookup interface {
}
type inodeSymlink interface {
- // Readlink resolves the target of a symbolic link. If an inode is not a
+ // Readlink returns the target of a symbolic link. If an inode is not a
// symlink, the implementation should return EINVAL.
Readlink(ctx context.Context) (string, error)
+
+ // Getlink returns the target of a symbolic link, as used by path
+ // resolution:
+ //
+ // - If the inode is a "magic link" (a link whose target is most accurately
+ // represented as a VirtualDentry), Getlink returns (ok VirtualDentry, "",
+ // nil). A reference is taken on the returned VirtualDentry.
+ //
+ // - If the inode is an ordinary symlink, Getlink returns (zero-value
+ // VirtualDentry, symlink target, nil).
+ //
+ // - If the inode is not a symlink, Getlink returns (zero-value
+ // VirtualDentry, "", EINVAL).
+ Getlink(ctx context.Context) (vfs.VirtualDentry, string, error)
}
diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go
index 5918d3309..018aa503c 100644
--- a/pkg/sentry/fsimpl/kernfs/symlink.go
+++ b/pkg/sentry/fsimpl/kernfs/symlink.go
@@ -55,6 +55,11 @@ func (s *StaticSymlink) Readlink(_ context.Context) (string, error) {
return s.target, nil
}
+// Getlink implements Inode.Getlink.
+func (s *StaticSymlink) Getlink(_ context.Context) (vfs.VirtualDentry, string, error) {
+ return vfs.VirtualDentry{}, s.target, nil
+}
+
// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
func (*StaticSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
diff --git a/pkg/sentry/fsimpl/pipefs/BUILD b/pkg/sentry/fsimpl/pipefs/BUILD
new file mode 100644
index 000000000..0d411606f
--- /dev/null
+++ b/pkg/sentry/fsimpl/pipefs/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "pipefs",
+ srcs = ["pipefs.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
new file mode 100644
index 000000000..faf3179bc
--- /dev/null
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -0,0 +1,148 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pipefs provides the filesystem implementation backing
+// Kernel.PipeMount.
+package pipefs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type filesystemType struct{}
+
+// Name implements vfs.FilesystemType.Name.
+func (filesystemType) Name() string {
+ return "pipefs"
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (filesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ panic("pipefs.filesystemType.GetFilesystem should never be called")
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+
+ // TODO(gvisor.dev/issue/1193):
+ //
+ // - kernfs does not provide a way to implement statfs, from which we
+ // should indicate PIPEFS_MAGIC.
+ //
+ // - kernfs does not provide a way to override names for
+ // vfs.FilesystemImpl.PrependPath(); pipefs inodes should use synthetic
+ // name fmt.Sprintf("pipe:[%d]", inode.ino).
+}
+
+// NewFilesystem sets up and returns a new vfs.Filesystem implemented by
+// pipefs.
+func NewFilesystem(vfsObj *vfs.VirtualFilesystem) *vfs.Filesystem {
+ fs := &filesystem{}
+ fs.Init(vfsObj, filesystemType{})
+ return fs.VFSFilesystem()
+}
+
+// inode implements kernfs.Inode.
+type inode struct {
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+ kernfs.InodeNoopRefCount
+
+ pipe *pipe.VFSPipe
+
+ ino uint64
+ uid auth.KUID
+ gid auth.KGID
+ // We use the creation timestamp for all of atime, mtime, and ctime.
+ ctime ktime.Time
+}
+
+func newInode(ctx context.Context, fs *kernfs.Filesystem) *inode {
+ creds := auth.CredentialsFromContext(ctx)
+ return &inode{
+ pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize),
+ ino: fs.NextIno(),
+ uid: creds.EffectiveKUID,
+ gid: creds.EffectiveKGID,
+ ctime: ktime.NowFromContext(ctx),
+ }
+}
+
+const pipeMode = 0600 | linux.S_IFIFO
+
+// CheckPermissions implements kernfs.Inode.CheckPermissions.
+func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, pipeMode, i.uid, i.gid)
+}
+
+// Mode implements kernfs.Inode.Mode.
+func (i *inode) Mode() linux.FileMode {
+ return pipeMode
+}
+
+// Stat implements kernfs.Inode.Stat.
+func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ ts := linux.NsecToStatxTimestamp(i.ctime.Nanoseconds())
+ return linux.Statx{
+ Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS,
+ Blksize: usermem.PageSize,
+ Nlink: 1,
+ UID: uint32(i.uid),
+ GID: uint32(i.gid),
+ Mode: pipeMode,
+ Ino: i.ino,
+ Size: 0,
+ Blocks: 0,
+ Atime: ts,
+ Ctime: ts,
+ Mtime: ts,
+ // TODO(gvisor.dev/issue/1197): Device number.
+ }, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat.
+func (i *inode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask == 0 {
+ return nil
+ }
+ return syserror.EPERM
+}
+
+// Open implements kernfs.Inode.Open.
+func (i *inode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // FIXME(b/38173783): kernfs does not plumb Context here.
+ return i.pipe.Open(context.Background(), rp.Mount(), vfsd, opts.Flags)
+}
+
+// NewConnectedPipeFDs returns a pair of FileDescriptions representing the read
+// and write ends of a newly-created pipe, as for pipe(2) and pipe2(2).
+//
+// Preconditions: mnt.Filesystem() must have been returned by NewFilesystem().
+func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription) {
+ fs := mnt.Filesystem().Impl().(*kernfs.Filesystem)
+ inode := newInode(ctx, fs)
+ var d kernfs.Dentry
+ d.Init(inode)
+ defer d.DecRef()
+ return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags)
+}
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 8156984eb..17c1342b5 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -22,7 +22,6 @@ go_library(
"//pkg/log",
"//pkg/refs",
"//pkg/safemem",
- "//pkg/sentry/fs",
"//pkg/sentry/fsbridge",
"//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/inet",
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index aee2a4392..888afc0fd 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -214,22 +214,6 @@ func newIO(t *kernel.Task, isThreadGroup bool) *ioData {
return &ioData{ioUsage: t}
}
-func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry {
- // Namespace symlinks should contain the namespace name and the inode number
- // for the namespace instance, so for example user:[123456]. We currently fake
- // the inode number by sticking the symlink inode in its place.
- target := fmt.Sprintf("%s:[%d]", ns, ino)
-
- inode := &kernfs.StaticSymlink{}
- // Note: credentials are overridden by taskOwnedInode.
- inode.Init(task.Credentials(), ino, target)
-
- taskInode := &taskOwnedInode{Inode: inode, owner: task}
- d := &kernfs.Dentry{}
- d.Init(taskInode)
- return d
-}
-
// newCgroupData creates inode that shows cgroup information.
// From man 7 cgroups: "For each cgroup hierarchy of which the process is a
// member, there is one entry containing three colon-separated fields:
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index 76bfc5307..046265eca 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -30,34 +30,35 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-type fdDir struct {
- inoGen InoGenerator
- task *kernel.Task
-
- // When produceSymlinks is set, dirents produces for the FDs are reported
- // as symlink. Otherwise, they are reported as regular files.
- produceSymlink bool
-}
-
-func (i *fdDir) lookup(name string) (*vfs.FileDescription, kernel.FDFlags, error) {
- fd, err := strconv.ParseUint(name, 10, 64)
- if err != nil {
- return nil, kernel.FDFlags{}, syserror.ENOENT
- }
-
+func getTaskFD(t *kernel.Task, fd int32) (*vfs.FileDescription, kernel.FDFlags) {
var (
file *vfs.FileDescription
flags kernel.FDFlags
)
- i.task.WithMuLocked(func(t *kernel.Task) {
- if fdTable := t.FDTable(); fdTable != nil {
- file, flags = fdTable.GetVFS2(int32(fd))
+ t.WithMuLocked(func(t *kernel.Task) {
+ if fdt := t.FDTable(); fdt != nil {
+ file, flags = fdt.GetVFS2(fd)
}
})
+ return file, flags
+}
+
+func taskFDExists(t *kernel.Task, fd int32) bool {
+ file, _ := getTaskFD(t, fd)
if file == nil {
- return nil, kernel.FDFlags{}, syserror.ENOENT
+ return false
}
- return file, flags, nil
+ file.DecRef()
+ return true
+}
+
+type fdDir struct {
+ inoGen InoGenerator
+ task *kernel.Task
+
+ // When produceSymlinks is set, dirents produces for the FDs are reported
+ // as symlink. Otherwise, they are reported as regular files.
+ produceSymlink bool
}
// IterDirents implements kernfs.inodeDynamicLookup.
@@ -128,11 +129,15 @@ func newFDDirInode(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry {
// Lookup implements kernfs.inodeDynamicLookup.
func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
- file, _, err := i.lookup(name)
+ fdInt, err := strconv.ParseInt(name, 10, 32)
if err != nil {
- return nil, err
+ return nil, syserror.ENOENT
+ }
+ fd := int32(fdInt)
+ if !taskFDExists(i.task, fd) {
+ return nil, syserror.ENOENT
}
- taskDentry := newFDSymlink(i.task.Credentials(), file, i.inoGen.NextIno())
+ taskDentry := newFDSymlink(i.task, fd, i.inoGen.NextIno())
return taskDentry.VFSDentry(), nil
}
@@ -169,19 +174,22 @@ func (i *fdDirInode) CheckPermissions(ctx context.Context, creds *auth.Credentia
//
// +stateify savable
type fdSymlink struct {
- refs.AtomicRefCount
kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
kernfs.InodeSymlink
- file *vfs.FileDescription
+ task *kernel.Task
+ fd int32
}
var _ kernfs.Inode = (*fdSymlink)(nil)
-func newFDSymlink(creds *auth.Credentials, file *vfs.FileDescription, ino uint64) *kernfs.Dentry {
- file.IncRef()
- inode := &fdSymlink{file: file}
- inode.Init(creds, ino, linux.ModeSymlink|0777)
+func newFDSymlink(task *kernel.Task, fd int32, ino uint64) *kernfs.Dentry {
+ inode := &fdSymlink{
+ task: task,
+ fd: fd,
+ }
+ inode.Init(task.Credentials(), ino, linux.ModeSymlink|0777)
d := &kernfs.Dentry{}
d.Init(inode)
@@ -189,21 +197,25 @@ func newFDSymlink(creds *auth.Credentials, file *vfs.FileDescription, ino uint64
}
func (s *fdSymlink) Readlink(ctx context.Context) (string, error) {
+ file, _ := getTaskFD(s.task, s.fd)
+ if file == nil {
+ return "", syserror.ENOENT
+ }
+ defer file.DecRef()
root := vfs.RootFromContext(ctx)
defer root.DecRef()
-
- vfsObj := s.file.VirtualDentry().Mount().Filesystem().VirtualFilesystem()
- return vfsObj.PathnameWithDeleted(ctx, root, s.file.VirtualDentry())
-}
-
-func (s *fdSymlink) DecRef() {
- s.AtomicRefCount.DecRefWithDestructor(func() {
- s.Destroy()
- })
+ return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry())
}
-func (s *fdSymlink) Destroy() {
- s.file.DecRef()
+func (s *fdSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) {
+ file, _ := getTaskFD(s.task, s.fd)
+ if file == nil {
+ return vfs.VirtualDentry{}, "", syserror.ENOENT
+ }
+ defer file.DecRef()
+ vd := file.VirtualDentry()
+ vd.IncRef()
+ return vd, "", nil
}
// fdInfoDirInode represents the inode for /proc/[pid]/fdinfo directory.
@@ -238,12 +250,18 @@ func newFDInfoDirInode(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry {
// Lookup implements kernfs.inodeDynamicLookup.
func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
- file, flags, err := i.lookup(name)
+ fdInt, err := strconv.ParseInt(name, 10, 32)
if err != nil {
- return nil, err
+ return nil, syserror.ENOENT
+ }
+ fd := int32(fdInt)
+ if !taskFDExists(i.task, fd) {
+ return nil, syserror.ENOENT
+ }
+ data := &fdInfoData{
+ task: i.task,
+ fd: fd,
}
-
- data := &fdInfoData{file: file, flags: flags}
dentry := newTaskOwnedFile(i.task, i.inoGen.NextIno(), 0444, data)
return dentry.VFSDentry(), nil
}
@@ -262,26 +280,23 @@ type fdInfoData struct {
kernfs.DynamicBytesFile
refs.AtomicRefCount
- file *vfs.FileDescription
- flags kernel.FDFlags
+ task *kernel.Task
+ fd int32
}
var _ dynamicInode = (*fdInfoData)(nil)
-func (d *fdInfoData) DecRef() {
- d.AtomicRefCount.DecRefWithDestructor(d.destroy)
-}
-
-func (d *fdInfoData) destroy() {
- d.file.DecRef()
-}
-
// Generate implements vfs.DynamicBytesSource.Generate.
func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ file, descriptorFlags := getTaskFD(d.task, d.fd)
+ if file == nil {
+ return syserror.ENOENT
+ }
+ defer file.DecRef()
// TODO(b/121266871): Include pos, locks, and other data. For now we only
// have flags.
// See https://www.kernel.org/doc/Documentation/filesystems/proc.txt
- flags := uint(d.file.StatusFlags()) | d.flags.ToLinuxFileFlags()
+ flags := uint(file.StatusFlags()) | descriptorFlags.ToLinuxFileFlags()
fmt.Fprintf(buf, "flags:\t0%o\n", flags)
return nil
}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index df0d1bcc5..2c6f8bdfc 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -64,6 +64,16 @@ func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) {
return m, nil
}
+func checkTaskState(t *kernel.Task) error {
+ switch t.ExitState() {
+ case kernel.TaskExitZombie:
+ return syserror.EACCES
+ case kernel.TaskExitDead:
+ return syserror.ESRCH
+ }
+ return nil
+}
+
type bufferWriter struct {
buf *bytes.Buffer
}
@@ -610,12 +620,31 @@ func (s *exeSymlink) Readlink(ctx context.Context) (string, error) {
return exec.PathnameWithDeleted(ctx), nil
}
+// Getlink implements kernfs.Inode.Getlink.
+func (s *exeSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) {
+ if !kernel.ContextCanTrace(ctx, s.task, false) {
+ return vfs.VirtualDentry{}, "", syserror.EACCES
+ }
+
+ exec, err := s.executable()
+ if err != nil {
+ return vfs.VirtualDentry{}, "", err
+ }
+ defer exec.DecRef()
+
+ vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry()
+ vd.IncRef()
+ return vd, "", nil
+}
+
func (s *exeSymlink) executable() (file fsbridge.File, err error) {
+ if err := checkTaskState(s.task); err != nil {
+ return nil, err
+ }
+
s.task.WithMuLocked(func(t *kernel.Task) {
mm := t.MemoryManager()
if mm == nil {
- // TODO(b/34851096): Check shouldn't allow Readlink once the
- // Task is zombied.
err = syserror.EACCES
return
}
@@ -625,7 +654,7 @@ func (s *exeSymlink) executable() (file fsbridge.File, err error) {
// (with locks held).
file = mm.Executable()
if file == nil {
- err = syserror.ENOENT
+ err = syserror.ESRCH
}
})
return
@@ -692,3 +721,41 @@ func (i *mountsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
i.task.Kernel().VFS().GenerateProcMounts(ctx, rootDir, buf)
return nil
}
+
+type namespaceSymlink struct {
+ kernfs.StaticSymlink
+
+ task *kernel.Task
+}
+
+func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry {
+ // Namespace symlinks should contain the namespace name and the inode number
+ // for the namespace instance, so for example user:[123456]. We currently fake
+ // the inode number by sticking the symlink inode in its place.
+ target := fmt.Sprintf("%s:[%d]", ns, ino)
+
+ inode := &namespaceSymlink{task: task}
+ // Note: credentials are overridden by taskOwnedInode.
+ inode.Init(task.Credentials(), ino, target)
+
+ taskInode := &taskOwnedInode{Inode: inode, owner: task}
+ d := &kernfs.Dentry{}
+ d.Init(taskInode)
+ return d
+}
+
+// Readlink implements Inode.
+func (s *namespaceSymlink) Readlink(ctx context.Context) (string, error) {
+ if err := checkTaskState(s.task); err != nil {
+ return "", err
+ }
+ return s.StaticSymlink.Readlink(ctx)
+}
+
+// Getlink implements Inode.Getlink.
+func (s *namespaceSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) {
+ if err := checkTaskState(s.task); err != nil {
+ return vfs.VirtualDentry{}, "", err
+ }
+ return s.StaticSymlink.Getlink(ctx)
+}
diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go
index 373a7b17d..6595fcee6 100644
--- a/pkg/sentry/fsimpl/proc/task_net.go
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -32,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/usermem"
@@ -206,22 +206,21 @@ var _ dynamicInode = (*netUnixData)(nil)
func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
buf.WriteString("Num RefCount Protocol Flags Type St Inode Path\n")
for _, se := range n.kernel.ListSockets() {
- s := se.Sock.Get()
- if s == nil {
- log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock)
+ s := se.SockVFS2
+ if !s.TryIncRef() {
+ log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
continue
}
- sfile := s.(*fs.File)
- if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX {
+ if family, _, _ := s.Impl().(socket.SocketVFS2).Type(); family != linux.AF_UNIX {
s.DecRef()
// Not a unix socket.
continue
}
- sops := sfile.FileOperations.(*unix.SocketOperations)
+ sops := s.Impl().(*unix.SocketVFS2)
addr, err := sops.Endpoint().GetLocalAddress()
if err != nil {
- log.Warningf("Failed to retrieve socket name from %+v: %v", sfile, err)
+ log.Warningf("Failed to retrieve socket name from %+v: %v", s, err)
addr.Addr = "<unknown>"
}
@@ -234,6 +233,15 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
}
+ // Get inode number.
+ var ino uint64
+ stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_INO})
+ if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
+ log.Warningf("Failed to retrieve ino for socket file: %v", statErr)
+ } else {
+ ino = stat.Ino
+ }
+
// In the socket entry below, the value for the 'Num' field requires
// some consideration. Linux prints the address to the struct
// unix_sock representing a socket in the kernel, but may redact the
@@ -252,14 +260,14 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// the definition of this struct changes over time.
//
// For now, we always redact this pointer.
- fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %5d",
+ fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %8d",
(*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct.
- sfile.ReadRefs()-1, // RefCount, don't count our own ref.
+ s.Refs()-1, // RefCount, don't count our own ref.
0, // Protocol, always 0 for UDS.
sockFlags, // Flags.
sops.Endpoint().Type(), // Type.
sops.State(), // State.
- sfile.InodeID(), // Inode.
+ ino, // Inode.
)
// Path
@@ -341,15 +349,14 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel,
t := kernel.TaskFromContext(ctx)
for _, se := range k.ListSockets() {
- s := se.Sock.Get()
- if s == nil {
- log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
+ s := se.SockVFS2
+ if !s.TryIncRef() {
+ log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
continue
}
- sfile := s.(*fs.File)
- sops, ok := sfile.FileOperations.(socket.Socket)
+ sops, ok := s.Impl().(socket.SocketVFS2)
if !ok {
- panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
+ panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s))
}
if fa, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) {
s.DecRef()
@@ -398,14 +405,15 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel,
// Unimplemented.
fmt.Fprintf(buf, "%08X ", 0)
+ stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_UID | linux.STATX_INO})
+
// Field: uid.
- uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx)
- if err != nil {
- log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
+ if statErr != nil || stat.Mask&linux.STATX_UID == 0 {
+ log.Warningf("Failed to retrieve uid for socket file: %v", statErr)
fmt.Fprintf(buf, "%5d ", 0)
} else {
creds := auth.CredentialsFromContext(ctx)
- fmt.Fprintf(buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow()))
+ fmt.Fprintf(buf, "%5d ", uint32(auth.KUID(stat.UID).In(creds.UserNamespace).OrOverflow()))
}
// Field: timeout; number of unanswered 0-window probes.
@@ -413,11 +421,16 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel,
fmt.Fprintf(buf, "%8d ", 0)
// Field: inode.
- fmt.Fprintf(buf, "%8d ", sfile.InodeID())
+ if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
+ log.Warningf("Failed to retrieve inode for socket file: %v", statErr)
+ fmt.Fprintf(buf, "%8d ", 0)
+ } else {
+ fmt.Fprintf(buf, "%8d ", stat.Ino)
+ }
// Field: refcount. Don't count the ref we obtain while deferencing
// the weakref to this socket.
- fmt.Fprintf(buf, "%d ", sfile.ReadRefs()-1)
+ fmt.Fprintf(buf, "%d ", s.Refs()-1)
// Field: Socket struct address. Redacted due to the same reason as
// the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
@@ -499,15 +512,14 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error {
t := kernel.TaskFromContext(ctx)
for _, se := range d.kernel.ListSockets() {
- s := se.Sock.Get()
- if s == nil {
- log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
+ s := se.SockVFS2
+ if !s.TryIncRef() {
+ log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
continue
}
- sfile := s.(*fs.File)
- sops, ok := sfile.FileOperations.(socket.Socket)
+ sops, ok := s.Impl().(socket.SocketVFS2)
if !ok {
- panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
+ panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s))
}
if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM {
s.DecRef()
@@ -551,25 +563,31 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// Field: retrnsmt. Always 0 for UDP.
fmt.Fprintf(buf, "%08X ", 0)
+ stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_UID | linux.STATX_INO})
+
// Field: uid.
- uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx)
- if err != nil {
- log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
+ if statErr != nil || stat.Mask&linux.STATX_UID == 0 {
+ log.Warningf("Failed to retrieve uid for socket file: %v", statErr)
fmt.Fprintf(buf, "%5d ", 0)
} else {
creds := auth.CredentialsFromContext(ctx)
- fmt.Fprintf(buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow()))
+ fmt.Fprintf(buf, "%5d ", uint32(auth.KUID(stat.UID).In(creds.UserNamespace).OrOverflow()))
}
// Field: timeout. Always 0 for UDP.
fmt.Fprintf(buf, "%8d ", 0)
// Field: inode.
- fmt.Fprintf(buf, "%8d ", sfile.InodeID())
+ if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
+ log.Warningf("Failed to retrieve inode for socket file: %v", statErr)
+ fmt.Fprintf(buf, "%8d ", 0)
+ } else {
+ fmt.Fprintf(buf, "%8d ", stat.Ino)
+ }
// Field: ref; reference count on the socket inode. Don't count the ref
// we obtain while deferencing the weakref to this socket.
- fmt.Fprintf(buf, "%d ", sfile.ReadRefs()-1)
+ fmt.Fprintf(buf, "%d ", s.Refs()-1)
// Field: Socket struct address. Redacted due to the same reason as
// the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
@@ -670,9 +688,9 @@ func (d *netSnmpData) Generate(ctx context.Context, buf *bytes.Buffer) error {
if line.prefix == "Tcp" {
tcp := stat.(*inet.StatSNMPTCP)
// "Tcp" needs special processing because MaxConn is signed. RFC 2012.
- fmt.Sprintf("%s: %s %d %s\n", line.prefix, sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:]))
+ fmt.Fprintf(buf, "%s: %s %d %s\n", line.prefix, sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:]))
} else {
- fmt.Sprintf("%s: %s\n", line.prefix, sprintSlice(toSlice(stat)))
+ fmt.Fprintf(buf, "%s: %s\n", line.prefix, sprintSlice(toSlice(stat)))
}
}
return nil
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index 882c1981e..4621e2de0 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -63,6 +63,11 @@ func (s *selfSymlink) Readlink(ctx context.Context) (string, error) {
return strconv.FormatUint(uint64(tgid), 10), nil
}
+func (s *selfSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) {
+ target, err := s.Readlink(ctx)
+ return vfs.VirtualDentry{}, target, err
+}
+
// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
func (*selfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
@@ -101,6 +106,11 @@ func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) {
return fmt.Sprintf("%d/task/%d", tgid, tid), nil
}
+func (s *threadSelfSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) {
+ target, err := s.Readlink(ctx)
+ return vfs.VirtualDentry{}, target, err
+}
+
// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
func (*threadSelfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
return syserror.EPERM
diff --git a/pkg/sentry/fsimpl/sockfs/BUILD b/pkg/sentry/fsimpl/sockfs/BUILD
new file mode 100644
index 000000000..52084ddb5
--- /dev/null
+++ b/pkg/sentry/fsimpl/sockfs/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "sockfs",
+ srcs = ["sockfs.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
new file mode 100644
index 000000000..3f7ad1d65
--- /dev/null
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -0,0 +1,102 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sockfs provides a filesystem implementation for anonymous sockets.
+package sockfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// NewFilesystem creates a new sockfs filesystem.
+//
+// Note that there should only ever be one instance of sockfs.Filesystem,
+// backing a global socket mount.
+func NewFilesystem(vfsObj *vfs.VirtualFilesystem) *vfs.Filesystem {
+ fs, _, err := filesystemType{}.GetFilesystem(nil, vfsObj, nil, "", vfs.GetFilesystemOptions{})
+ if err != nil {
+ panic("failed to create sockfs filesystem")
+ }
+ return fs
+}
+
+// filesystemType implements vfs.FilesystemType.
+type filesystemType struct{}
+
+// GetFilesystem implements FilesystemType.GetFilesystem.
+func (fsType filesystemType) GetFilesystem(_ context.Context, vfsObj *vfs.VirtualFilesystem, _ *auth.Credentials, _ string, _ vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ fs := &filesystem{}
+ fs.Init(vfsObj, fsType)
+ return fs.VFSFilesystem(), nil, nil
+}
+
+// Name implements FilesystemType.Name.
+//
+// Note that registering sockfs is unnecessary, except for the fact that it
+// will not show up under /proc/filesystems as a result. This is a very minor
+// discrepancy from Linux.
+func (filesystemType) Name() string {
+ return "sockfs"
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+}
+
+// inode implements kernfs.Inode.
+//
+// TODO(gvisor.dev/issue/1476): Add device numbers to this inode (which are
+// not included in InodeAttrs) to store the numbers of the appropriate
+// socket device. Override InodeAttrs.Stat() accordingly.
+type inode struct {
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+}
+
+// Open implements kernfs.Inode.Open.
+func (i *inode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ return nil, syserror.ENXIO
+}
+
+// InitSocket initializes a socket FileDescription, with a corresponding
+// Dentry in mnt.
+//
+// fd should be the FileDescription associated with socketImpl, i.e. its first
+// field. mnt should be the global socket mount, Kernel.socketMount.
+func InitSocket(socketImpl vfs.FileDescriptionImpl, fd *vfs.FileDescription, mnt *vfs.Mount, creds *auth.Credentials) error {
+ fsimpl := mnt.Filesystem().Impl()
+ fs := fsimpl.(*kernfs.Filesystem)
+
+ // File mode matches net/socket.c:sock_alloc.
+ filemode := linux.FileMode(linux.S_IFSOCK | 0600)
+ i := &inode{}
+ i.Init(creds, fs.NextIno(), filemode)
+
+ d := &kernfs.Dentry{}
+ d.Init(i)
+
+ opts := &vfs.FileDescriptionOptions{UseDentryMetadata: true}
+ if err := fd.Init(socketImpl, linux.O_RDWR, mnt, d.VFSDentry(), opts); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index 6ea35affb..4e6cd3491 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -24,6 +24,7 @@ go_library(
"filesystem.go",
"named_pipe.go",
"regular_file.go",
+ "socket_file.go",
"symlink.go",
"tmpfs.go",
],
@@ -50,6 +51,7 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
"//pkg/sentry/vfs/lock",
+ "//pkg/sentry/vfs/memxattr",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
index 383133e44..651912169 100644
--- a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
@@ -168,7 +168,7 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) {
}
}
-func BenchmarkVFS2MemfsStat(b *testing.B) {
+func BenchmarkVFS2TmpfsStat(b *testing.B) {
for _, depth := range depths {
b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) {
ctx := contexttest.Context(b)
@@ -362,7 +362,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) {
}
}
-func BenchmarkVFS2MemfsMountStat(b *testing.B) {
+func BenchmarkVFS2TmpfsMountStat(b *testing.B) {
for _, depth := range depths {
b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) {
ctx := contexttest.Context(b)
diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index 37c75ab64..45712c9b9 100644
--- a/pkg/sentry/fsimpl/tmpfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -68,6 +68,8 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
fs.mu.Lock()
defer fs.mu.Unlock()
+ fd.inode().touchAtime(fd.vfsfd.Mount())
+
if fd.off == 0 {
if err := cb.Handle(vfs.Dirent{
Name: ".",
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index e678ecc37..452c4e2e0 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -46,6 +46,9 @@ func stepLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) {
return nil, err
}
afterSymlink:
+ if len(rp.Component()) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
nextVFSD, err := rp.ResolveComponent(&d.vfsd)
if err != nil {
return nil, err
@@ -57,7 +60,7 @@ afterSymlink:
}
next := nextVFSD.Impl().(*dentry)
if symlink, ok := next.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
- // TODO(gvisor.dev/issues/1197): Symlink traversals updates
+ // TODO(gvisor.dev/issue/1197): Symlink traversals updates
// access time.
if err := rp.HandleSymlink(symlink.target); err != nil {
return nil, err
@@ -133,6 +136,9 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa
if name == "." || name == ".." {
return syserror.EEXIST
}
+ if len(name) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
+ }
// Call parent.vfsd.Child() instead of stepLocked() or rp.ResolveChild(),
// because if the child exists we want to return EEXIST immediately instead
// of attempting symlink/mount traversal.
@@ -142,7 +148,7 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa
if !dir && rp.MustBeDir() {
return syserror.ENOENT
}
- // In memfs, the only way to cause a dentry to be disowned is by removing
+ // In tmpfs, the only way to cause a dentry to be disowned is by removing
// it from the filesystem, so this check is equivalent to checking if
// parent has been removed.
if parent.vfsd.IsDisowned() {
@@ -153,7 +159,11 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa
return err
}
defer mnt.EndWrite()
- return create(parent, name)
+ if err := create(parent, name); err != nil {
+ return err
+ }
+ parent.inode.touchCMtime()
+ return nil
}
// AccessAt implements vfs.Filesystem.Impl.AccessAt.
@@ -251,8 +261,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
case linux.S_IFCHR:
childInode = fs.newDeviceFile(rp.Credentials(), opts.Mode, vfs.CharDevice, opts.DevMajor, opts.DevMinor)
case linux.S_IFSOCK:
- // Not yet supported.
- return syserror.EPERM
+ childInode = fs.newSocketFile(rp.Credentials(), opts.Mode, opts.Endpoint)
default:
return syserror.EINVAL
}
@@ -328,7 +337,12 @@ afterTrailingSymlink:
child := fs.newDentry(fs.newRegularFile(rp.Credentials(), opts.Mode))
parent.vfsd.InsertChild(&child.vfsd, name)
parent.inode.impl.(*directory).childList.PushBack(child)
- return child.open(ctx, rp, &opts, true)
+ fd, err := child.open(ctx, rp, &opts, true)
+ if err != nil {
+ return nil, err
+ }
+ parent.inode.touchCMtime()
+ return fd, nil
}
if err != nil {
return nil, err
@@ -378,9 +392,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
// Can't open symlinks without O_PATH (which is unimplemented).
return nil, syserror.ELOOP
case *namedPipe:
- return newNamedPipeFD(ctx, impl, rp, &d.vfsd, opts.Flags)
+ return impl.pipe.Open(ctx, rp.Mount(), &d.vfsd, opts.Flags)
case *deviceFile:
return rp.VirtualFilesystem().OpenDeviceSpecialFile(ctx, rp.Mount(), &d.vfsd, impl.kind, impl.major, impl.minor, opts)
+ case *socketFile:
+ return nil, syserror.ENXIO
default:
panic(fmt.Sprintf("unknown inode type: %T", d.inode.impl))
}
@@ -398,6 +414,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
if !ok {
return "", syserror.EINVAL
}
+ symlink.inode.touchAtime(rp.Mount())
return symlink.target, nil
}
@@ -515,7 +532,10 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
oldParent.inode.decLinksLocked()
newParent.inode.incLinksLocked()
}
- // TODO(gvisor.dev/issues/1197): Update timestamps and parent directory
+ oldParent.inode.touchCMtime()
+ newParent.inode.touchCMtime()
+ renamed.inode.touchCtime()
+ // TODO(gvisor.dev/issue/1197): Update timestamps and parent directory
// sizes.
vfsObj.CommitRenameReplaceDentry(renamedVFSD, &newParent.vfsd, newName, replacedVFSD)
return nil
@@ -565,6 +585,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
parent.inode.decLinksLocked() // from child's ".."
child.inode.decLinksLocked()
vfsObj.CommitDeleteDentry(childVFSD)
+ parent.inode.touchCMtime()
return nil
}
@@ -600,7 +621,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
if err != nil {
return linux.Statfs{}, err
}
- // TODO(gvisor.dev/issues/1197): Actually implement statfs.
+ // TODO(gvisor.dev/issue/1197): Actually implement statfs.
return linux.Statfs{}, syserror.ENOSYS
}
@@ -654,62 +675,68 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
parent.inode.impl.(*directory).childList.Remove(child)
child.inode.decLinksLocked()
vfsObj.CommitDeleteDentry(childVFSD)
+ parent.inode.touchCMtime()
return nil
}
// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
-//
-// TODO(gvisor.dev/issue/1476): Implement BoundEndpointAt.
func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) {
- return nil, syserror.ECONNREFUSED
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(rp)
+ if err != nil {
+ return nil, err
+ }
+ switch impl := d.inode.impl.(type) {
+ case *socketFile:
+ return impl.ep, nil
+ default:
+ return nil, syserror.ECONNREFUSED
+ }
}
// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
-func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
- _, err := resolveLocked(rp)
+ d, err := resolveLocked(rp)
if err != nil {
return nil, err
}
- // TODO(b/127675828): support extended attributes
- return nil, syserror.ENOTSUP
+ return d.inode.listxattr(size)
}
// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
-func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
- _, err := resolveLocked(rp)
+ d, err := resolveLocked(rp)
if err != nil {
return "", err
}
- // TODO(b/127675828): support extended attributes
- return "", syserror.ENOTSUP
+ return d.inode.getxattr(rp.Credentials(), &opts)
}
// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
fs.mu.RLock()
defer fs.mu.RUnlock()
- _, err := resolveLocked(rp)
+ d, err := resolveLocked(rp)
if err != nil {
return err
}
- // TODO(b/127675828): support extended attributes
- return syserror.ENOTSUP
+ return d.inode.setxattr(rp.Credentials(), &opts)
}
// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
fs.mu.RLock()
defer fs.mu.RUnlock()
- _, err := resolveLocked(rp)
+ d, err := resolveLocked(rp)
if err != nil {
return err
}
- // TODO(b/127675828): support extended attributes
- return syserror.ENOTSUP
+ return d.inode.removexattr(rp.Credentials(), name)
}
// PrependPath implements vfs.FilesystemImpl.PrependPath.
diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
index 2c5c739df..8d77b3fa8 100644
--- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go
+++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
@@ -16,10 +16,8 @@ package tmpfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -33,27 +31,8 @@ type namedPipe struct {
// * fs.mu must be locked.
// * rp.Mount().CheckBeginWrite() has been called successfully.
func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode {
- file := &namedPipe{pipe: pipe.NewVFSPipe(pipe.DefaultPipeSize, usermem.PageSize)}
+ file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)}
file.inode.init(file, fs, creds, linux.S_IFIFO|mode)
file.inode.nlink = 1 // Only the parent has a link.
return &file.inode
}
-
-// namedPipeFD implements vfs.FileDescriptionImpl. Methods are implemented
-// entirely via struct embedding.
-type namedPipeFD struct {
- fileDescription
-
- *pipe.VFSPipeFD
-}
-
-func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
- var err error
- var fd namedPipeFD
- fd.VFSPipeFD, err = np.pipe.NewVFSPipeFD(ctx, vfsd, &fd.vfsfd, flags)
- if err != nil {
- return nil, err
- }
- fd.vfsfd.Init(&fd, flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{})
- return &fd.vfsfd, nil
-}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index 26cd65605..57e5e28ec 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -286,7 +286,8 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
rw := getRegularFileReadWriter(f, offset)
n, err := dst.CopyOutFrom(ctx, rw)
putRegularFileReadWriter(rw)
- return int64(n), err
+ fd.inode().touchAtime(fd.vfsfd.Mount())
+ return n, err
}
// Read implements vfs.FileDescriptionImpl.Read.
@@ -323,6 +324,7 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
f.inode.mu.Lock()
rw := getRegularFileReadWriter(f, offset)
n, err := src.CopyInTo(ctx, rw)
+ fd.inode().touchCMtimeLocked()
f.inode.mu.Unlock()
putRegularFileReadWriter(rw)
return n, err
diff --git a/pkg/tcpip/stack/packet_buffer_state.go b/pkg/sentry/fsimpl/tmpfs/socket_file.go
index 0c6b7924c..25c2321af 100644
--- a/pkg/tcpip/stack/packet_buffer_state.go
+++ b/pkg/sentry/fsimpl/tmpfs/socket_file.go
@@ -12,16 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package stack
+package tmpfs
-import "gvisor.dev/gvisor/pkg/tcpip/buffer"
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+)
-// beforeSave is invoked by stateify.
-func (pk *PacketBuffer) beforeSave() {
- // Non-Data fields may be slices of the Data field. This causes
- // problems for SR, so during save we make each header independent.
- pk.Header = pk.Header.DeepCopy()
- pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...)
- pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...)
- pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...)
+// socketFile is a socket (=S_IFSOCK) tmpfs file.
+type socketFile struct {
+ inode inode
+ ep transport.BoundEndpoint
+}
+
+func (fs *filesystem) newSocketFile(creds *auth.Credentials, mode linux.FileMode, ep transport.BoundEndpoint) *inode {
+ file := &socketFile{ep: ep}
+ file.inode.init(file, fs, creds, mode)
+ file.inode.nlink = 1 // from parent directory
+ return &file.inode
}
diff --git a/pkg/sentry/fsimpl/tmpfs/stat_test.go b/pkg/sentry/fsimpl/tmpfs/stat_test.go
index ebe035dee..d4f59ee5b 100644
--- a/pkg/sentry/fsimpl/tmpfs/stat_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/stat_test.go
@@ -29,7 +29,7 @@ func TestStatAfterCreate(t *testing.T) {
mode := linux.FileMode(0644)
// Run with different file types.
- // TODO(gvisor.dev/issues/1197): Also test symlinks and sockets.
+ // TODO(gvisor.dev/issue/1197): Also test symlinks and sockets.
for _, typ := range []string{"file", "dir", "pipe"} {
t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) {
var (
@@ -140,7 +140,7 @@ func TestSetStatAtime(t *testing.T) {
Mask: 0,
Atime: linux.NsecToStatxTimestamp(100),
}}); err != nil {
- t.Errorf("SetStat atime without mask failed: %v")
+ t.Errorf("SetStat atime without mask failed: %v", err)
}
// Atime should be unchanged.
if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
@@ -155,7 +155,7 @@ func TestSetStatAtime(t *testing.T) {
Atime: linux.NsecToStatxTimestamp(100),
}
if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil {
- t.Errorf("SetStat atime with mask failed: %v")
+ t.Errorf("SetStat atime with mask failed: %v", err)
}
if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
t.Errorf("Stat got error: %v", err)
@@ -169,7 +169,7 @@ func TestSetStat(t *testing.T) {
mode := linux.FileMode(0644)
// Run with different file types.
- // TODO(gvisor.dev/issues/1197): Also test symlinks and sockets.
+ // TODO(gvisor.dev/issue/1197): Also test symlinks and sockets.
for _, typ := range []string{"file", "dir", "pipe"} {
t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) {
var (
@@ -205,7 +205,7 @@ func TestSetStat(t *testing.T) {
Mask: 0,
Atime: linux.NsecToStatxTimestamp(100),
}}); err != nil {
- t.Errorf("SetStat atime without mask failed: %v")
+ t.Errorf("SetStat atime without mask failed: %v", err)
}
// Atime should be unchanged.
if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
@@ -220,7 +220,7 @@ func TestSetStat(t *testing.T) {
Atime: linux.NsecToStatxTimestamp(100),
}
if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil {
- t.Errorf("SetStat atime with mask failed: %v")
+ t.Errorf("SetStat atime with mask failed: %v", err)
}
if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
t.Errorf("Stat got error: %v", err)
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index afd9f8533..82c709b43 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -27,6 +27,7 @@ package tmpfs
import (
"fmt"
"math"
+ "strings"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -37,6 +38,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sentry/vfs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/memxattr"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -77,6 +79,11 @@ type FilesystemOpts struct {
// RootSymlinkTarget is the target of the root symlink. Only valid if
// RootFileType == S_IFLNK.
RootSymlinkTarget string
+
+ // FilesystemType allows setting a different FilesystemType for this
+ // tmpfs filesystem. This allows tmpfs to "impersonate" other
+ // filesystems, like ramdiskfs and cgroupfs.
+ FilesystemType vfs.FilesystemType
}
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
@@ -91,15 +98,22 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
clock: clock,
}
- fs.vfsfs.Init(vfsObj, &fstype, &fs)
-
- typ := uint16(linux.S_IFDIR)
+ rootFileType := uint16(linux.S_IFDIR)
+ newFSType := vfs.FilesystemType(&fstype)
tmpfsOpts, ok := opts.InternalData.(FilesystemOpts)
- if ok && tmpfsOpts.RootFileType != 0 {
- typ = tmpfsOpts.RootFileType
+ if ok {
+ if tmpfsOpts.RootFileType != 0 {
+ rootFileType = tmpfsOpts.RootFileType
+ }
+ if tmpfsOpts.FilesystemType != nil {
+ newFSType = tmpfsOpts.FilesystemType
+ }
}
+
+ fs.vfsfs.Init(vfsObj, newFSType, &fs)
+
var root *inode
- switch typ {
+ switch rootFileType {
case linux.S_IFREG:
root = fs.newRegularFile(creds, 0777)
case linux.S_IFLNK:
@@ -107,7 +121,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
case linux.S_IFDIR:
root = fs.newDirectory(creds, 01777)
default:
- return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", typ)
+ return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType)
}
return &fs.vfsfs, &fs.newDentry(root).vfsd, nil
}
@@ -174,6 +188,11 @@ type inode struct {
// filesystem.RmdirAt() drops the reference.
refs int64
+ // xattrs implements extended attributes.
+ //
+ // TODO(b/148380782): Support xattrs other than user.*
+ xattrs memxattr.SimpleExtendedAttributes
+
// Inode metadata. Writing multiple fields atomically requires holding
// mu, othewise atomic operations can be used.
mu sync.Mutex
@@ -228,7 +247,7 @@ func (i *inode) incLinksLocked() {
panic("tmpfs.inode.incLinksLocked() called with no existing links")
}
if i.nlink == maxLinks {
- panic("memfs.inode.incLinksLocked() called with maximum link count")
+ panic("tmpfs.inode.incLinksLocked() called with maximum link count")
}
atomic.AddUint32(&i.nlink, 1)
}
@@ -303,7 +322,7 @@ func (i *inode) statTo(stat *linux.Statx) {
stat.Atime = linux.NsecToStatxTimestamp(i.atime)
stat.Ctime = linux.NsecToStatxTimestamp(i.ctime)
stat.Mtime = linux.NsecToStatxTimestamp(i.mtime)
- // TODO(gvisor.dev/issues/1197): Device number.
+ // TODO(gvisor.dev/issue/1197): Device number.
switch impl := i.impl.(type) {
case *regularFile:
stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
@@ -319,7 +338,7 @@ func (i *inode) statTo(stat *linux.Statx) {
case *deviceFile:
stat.RdevMajor = impl.major
stat.RdevMinor = impl.minor
- case *directory, *namedPipe:
+ case *socketFile, *directory, *namedPipe:
// Nothing to do.
default:
panic(fmt.Sprintf("unknown inode type: %T", i.impl))
@@ -338,6 +357,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu
return err
}
i.mu.Lock()
+ defer i.mu.Unlock()
var (
needsMtimeBump bool
needsCtimeBump bool
@@ -373,29 +393,41 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu
return syserror.EINVAL
}
}
+ now := i.clock.Now().Nanoseconds()
if mask&linux.STATX_ATIME != 0 {
- atomic.StoreInt64(&i.atime, stat.Atime.ToNsecCapped())
+ if stat.Atime.Nsec == linux.UTIME_NOW {
+ atomic.StoreInt64(&i.atime, now)
+ } else {
+ atomic.StoreInt64(&i.atime, stat.Atime.ToNsecCapped())
+ }
needsCtimeBump = true
}
if mask&linux.STATX_MTIME != 0 {
- atomic.StoreInt64(&i.mtime, stat.Mtime.ToNsecCapped())
+ if stat.Mtime.Nsec == linux.UTIME_NOW {
+ atomic.StoreInt64(&i.mtime, now)
+ } else {
+ atomic.StoreInt64(&i.mtime, stat.Mtime.ToNsecCapped())
+ }
needsCtimeBump = true
// Ignore the mtime bump, since we just set it ourselves.
needsMtimeBump = false
}
if mask&linux.STATX_CTIME != 0 {
- atomic.StoreInt64(&i.ctime, stat.Ctime.ToNsecCapped())
+ if stat.Ctime.Nsec == linux.UTIME_NOW {
+ atomic.StoreInt64(&i.ctime, now)
+ } else {
+ atomic.StoreInt64(&i.ctime, stat.Ctime.ToNsecCapped())
+ }
// Ignore the ctime bump, since we just set it ourselves.
needsCtimeBump = false
}
- now := i.clock.Now().Nanoseconds()
if needsMtimeBump {
atomic.StoreInt64(&i.mtime, now)
}
if needsCtimeBump {
atomic.StoreInt64(&i.ctime, now)
}
- i.mu.Unlock()
+
return nil
}
@@ -454,6 +486,8 @@ func (i *inode) direntType() uint8 {
return linux.DT_DIR
case *symlink:
return linux.DT_LNK
+ case *socketFile:
+ return linux.DT_SOCK
case *deviceFile:
switch impl.kind {
case vfs.BlockDevice:
@@ -472,6 +506,92 @@ func (i *inode) isDir() bool {
return linux.FileMode(i.mode).FileType() == linux.S_IFDIR
}
+func (i *inode) touchAtime(mnt *vfs.Mount) {
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return
+ }
+ now := i.clock.Now().Nanoseconds()
+ i.mu.Lock()
+ atomic.StoreInt64(&i.atime, now)
+ i.mu.Unlock()
+ mnt.EndWrite()
+}
+
+// Preconditions: The caller has called vfs.Mount.CheckBeginWrite().
+func (i *inode) touchCtime() {
+ now := i.clock.Now().Nanoseconds()
+ i.mu.Lock()
+ atomic.StoreInt64(&i.ctime, now)
+ i.mu.Unlock()
+}
+
+// Preconditions: The caller has called vfs.Mount.CheckBeginWrite().
+func (i *inode) touchCMtime() {
+ now := i.clock.Now().Nanoseconds()
+ i.mu.Lock()
+ atomic.StoreInt64(&i.mtime, now)
+ atomic.StoreInt64(&i.ctime, now)
+ i.mu.Unlock()
+}
+
+// Preconditions: The caller has called vfs.Mount.CheckBeginWrite() and holds
+// inode.mu.
+func (i *inode) touchCMtimeLocked() {
+ now := i.clock.Now().Nanoseconds()
+ atomic.StoreInt64(&i.mtime, now)
+ atomic.StoreInt64(&i.ctime, now)
+}
+
+func (i *inode) listxattr(size uint64) ([]string, error) {
+ return i.xattrs.Listxattr(size)
+}
+
+func (i *inode) getxattr(creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) {
+ if err := i.checkPermissions(creds, vfs.MayRead); err != nil {
+ return "", err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return "", syserror.EOPNOTSUPP
+ }
+ if !i.userXattrSupported() {
+ return "", syserror.ENODATA
+ }
+ return i.xattrs.Getxattr(opts)
+}
+
+func (i *inode) setxattr(creds *auth.Credentials, opts *vfs.SetxattrOptions) error {
+ if err := i.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ if !i.userXattrSupported() {
+ return syserror.EPERM
+ }
+ return i.xattrs.Setxattr(opts)
+}
+
+func (i *inode) removexattr(creds *auth.Credentials, name string) error {
+ if err := i.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ if !i.userXattrSupported() {
+ return syserror.EPERM
+ }
+ return i.xattrs.Removexattr(name)
+}
+
+// Extended attributes in the user.* namespace are only supported for regular
+// files and directories.
+func (i *inode) userXattrSupported() bool {
+ filetype := linux.S_IFMT & atomic.LoadUint32(&i.mode)
+ return filetype == linux.S_IFREG || filetype == linux.S_IFDIR
+}
+
// fileDescription is embedded by tmpfs implementations of
// vfs.FileDescriptionImpl.
type fileDescription struct {
@@ -499,3 +619,23 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
creds := auth.CredentialsFromContext(ctx)
return fd.inode().setStat(ctx, creds, &opts.Stat)
}
+
+// Listxattr implements vfs.FileDescriptionImpl.Listxattr.
+func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
+ return fd.inode().listxattr(size)
+}
+
+// Getxattr implements vfs.FileDescriptionImpl.Getxattr.
+func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) {
+ return fd.inode().getxattr(auth.CredentialsFromContext(ctx), &opts)
+}
+
+// Setxattr implements vfs.FileDescriptionImpl.Setxattr.
+func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error {
+ return fd.inode().setxattr(auth.CredentialsFromContext(ctx), &opts)
+}
+
+// Removexattr implements vfs.FileDescriptionImpl.Removexattr.
+func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
+ return fd.inode().removexattr(auth.CredentialsFromContext(ctx), name)
+}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index beba29a09..e47af66d6 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -169,6 +169,9 @@ go_library(
"//pkg/sentry/fs/lock",
"//pkg/sentry/fs/timerfd",
"//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/fsimpl/pipefs",
+ "//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/hostcpu",
"//pkg/sentry/inet",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index d09d97825..ed40b5303 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -307,6 +307,61 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
return fds, nil
}
+// NewFDsVFS2 allocates new FDs guaranteed to be the lowest number available
+// greater than or equal to the fd parameter. All files will share the set
+// flags. Success is guaranteed to be all or none.
+func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) {
+ if fd < 0 {
+ // Don't accept negative FDs.
+ return nil, syscall.EINVAL
+ }
+
+ // Default limit.
+ end := int32(math.MaxInt32)
+
+ // Ensure we don't get past the provided limit.
+ if limitSet := limits.FromContext(ctx); limitSet != nil {
+ lim := limitSet.Get(limits.NumberOfFiles)
+ if lim.Cur != limits.Infinity {
+ end = int32(lim.Cur)
+ }
+ if fd >= end {
+ return nil, syscall.EMFILE
+ }
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // From f.next to find available fd.
+ if fd < f.next {
+ fd = f.next
+ }
+
+ // Install all entries.
+ for i := fd; i < end && len(fds) < len(files); i++ {
+ if d, _, _ := f.getVFS2(i); d == nil {
+ f.setVFS2(i, files[len(fds)], flags) // Set the descriptor.
+ fds = append(fds, i) // Record the file descriptor.
+ }
+ }
+
+ // Failure? Unwind existing FDs.
+ if len(fds) < len(files) {
+ for _, i := range fds {
+ f.setVFS2(i, nil, FDFlags{}) // Zap entry.
+ }
+ return nil, syscall.EMFILE
+ }
+
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fds[len(fds)-1] + 1
+ }
+
+ return fds, nil
+}
+
// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for
// the given file description. If it succeeds, it takes a reference on file.
func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 6feda8fa1..fef60e636 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -50,6 +50,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/timerfd"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -225,6 +227,11 @@ type Kernel struct {
// by extMu.
nextSocketEntry uint64
+ // socketMount is a disconnected vfs.Mount, not included in k.vfs,
+ // representing a sockfs.filesystem. socketMount is used to back
+ // VirtualDentries representing anonymous sockets.
+ socketMount *vfs.Mount
+
// deviceRegistry is used to save/restore device.SimpleDevices.
deviceRegistry struct{} `state:".(*device.Registry)"`
@@ -248,6 +255,10 @@ type Kernel struct {
// VFS keeps the filesystem state used across the kernel.
vfs vfs.VirtualFilesystem
+ // pipeMount is the Mount used for pipes created by the pipe() and pipe2()
+ // syscalls (as opposed to named pipes created by mknod()).
+ pipeMount *vfs.Mount
+
// If set to true, report address space activation waits as if the task is in
// external wait so that the watchdog doesn't report the task stuck.
SleepForAddressSpaceActivation bool
@@ -348,6 +359,29 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic}
k.futexes = futex.NewManager()
k.netlinkPorts = port.New()
+
+ if VFS2Enabled {
+ if err := k.vfs.Init(); err != nil {
+ return fmt.Errorf("failed to initialize VFS: %v", err)
+ }
+
+ pipeFilesystem := pipefs.NewFilesystem(&k.vfs)
+ defer pipeFilesystem.DecRef()
+ pipeMount, err := k.vfs.NewDisconnectedMount(pipeFilesystem, nil, &vfs.MountOptions{})
+ if err != nil {
+ return fmt.Errorf("failed to create pipefs mount: %v", err)
+ }
+ k.pipeMount = pipeMount
+
+ socketFilesystem := sockfs.NewFilesystem(&k.vfs)
+ defer socketFilesystem.DecRef()
+ socketMount, err := k.vfs.NewDisconnectedMount(socketFilesystem, nil, &vfs.MountOptions{})
+ if err != nil {
+ return fmt.Errorf("failed to initialize socket mount: %v", err)
+ }
+ k.socketMount = socketMount
+ }
+
return nil
}
@@ -545,15 +579,25 @@ func (ts *TaskSet) unregisterEpollWaiters() {
ts.mu.RLock()
defer ts.mu.RUnlock()
+
+ // Tasks that belong to the same process could potentially point to the
+ // same FDTable. So we retain a map of processed ones to avoid
+ // processing the same FDTable multiple times.
+ processed := make(map[*FDTable]struct{})
for t := range ts.Root.tids {
// We can skip locking Task.mu here since the kernel is paused.
- if t.fdTable != nil {
- t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
- if e, ok := file.FileOperations.(*epoll.EventPoll); ok {
- e.UnregisterEpollWaiters()
- }
- })
+ if t.fdTable == nil {
+ continue
}
+ if _, ok := processed[t.fdTable]; ok {
+ continue
+ }
+ t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
+ if e, ok := file.FileOperations.(*epoll.EventPoll); ok {
+ e.UnregisterEpollWaiters()
+ }
+ })
+ processed[t.fdTable] = struct{}{}
}
}
@@ -1015,14 +1059,17 @@ func (k *Kernel) pauseTimeLocked() {
// This means we'll iterate FDTables shared by multiple tasks repeatedly,
// but ktime.Timer.Pause is idempotent so this is harmless.
if t.fdTable != nil {
- // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
- if !VFS2Enabled {
- t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
+ t.fdTable.forEach(func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) {
+ if VFS2Enabled {
+ if tfd, ok := fd.Impl().(*vfs.TimerFileDescription); ok {
+ tfd.PauseTimer()
+ }
+ } else {
if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok {
tfd.PauseTimer()
}
- })
- }
+ }
+ })
}
}
k.timekeeper.PauseUpdates()
@@ -1047,15 +1094,18 @@ func (k *Kernel) resumeTimeLocked() {
it.ResumeTimer()
}
}
- // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
- if !VFS2Enabled {
- if t.fdTable != nil {
- t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
+ if t.fdTable != nil {
+ t.fdTable.forEach(func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) {
+ if VFS2Enabled {
+ if tfd, ok := fd.Impl().(*vfs.TimerFileDescription); ok {
+ tfd.ResumeTimer()
+ }
+ } else {
if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok {
tfd.ResumeTimer()
}
- })
- }
+ }
+ })
}
}
}
@@ -1416,9 +1466,10 @@ func (k *Kernel) SupervisorContext() context.Context {
// +stateify savable
type SocketEntry struct {
socketEntry
- k *Kernel
- Sock *refs.WeakRef
- ID uint64 // Socket table entry number.
+ k *Kernel
+ Sock *refs.WeakRef
+ SockVFS2 *vfs.FileDescription
+ ID uint64 // Socket table entry number.
}
// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
@@ -1441,7 +1492,30 @@ func (k *Kernel) RecordSocket(sock *fs.File) {
k.extMu.Unlock()
}
+// RecordSocketVFS2 adds a VFS2 socket to the system-wide socket table for
+// tracking.
+//
+// Precondition: Caller must hold a reference to sock.
+//
+// Note that the socket table will not hold a reference on the
+// vfs.FileDescription, because we do not support weak refs on VFS2 files.
+func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) {
+ k.extMu.Lock()
+ id := k.nextSocketEntry
+ k.nextSocketEntry++
+ s := &SocketEntry{
+ k: k,
+ ID: id,
+ SockVFS2: sock,
+ }
+ k.sockets.PushBack(s)
+ k.extMu.Unlock()
+}
+
// ListSockets returns a snapshot of all sockets.
+//
+// Callers of ListSockets() in VFS2 should use SocketEntry.SockVFS2.TryIncRef()
+// to get a reference on a socket in the table.
func (k *Kernel) ListSockets() []*SocketEntry {
k.extMu.Lock()
var socks []*SocketEntry
@@ -1452,6 +1526,11 @@ func (k *Kernel) ListSockets() []*SocketEntry {
return socks
}
+// SocketMount returns the global socket mount.
+func (k *Kernel) SocketMount() *vfs.Mount {
+ return k.socketMount
+}
+
// supervisorContext is a privileged context.
type supervisorContext struct {
context.NoopSleeper
@@ -1549,3 +1628,8 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) {
func (k *Kernel) VFS() *vfs.VirtualFilesystem {
return &k.vfs
}
+
+// PipeMount returns the pipefs mount.
+func (k *Kernel) PipeMount() *vfs.Mount {
+ return k.pipeMount
+}
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 725e9db7d..62c8691f1 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -255,7 +255,8 @@ func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
// POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be
// atomic, but requires no atomicity for writes larger than this.
wanted := ops.left()
- if avail := p.max - p.view.Size(); wanted > avail {
+ avail := p.max - p.view.Size()
+ if wanted > avail {
if wanted <= p.atomicIOBytes {
return 0, syserror.ErrWouldBlock
}
@@ -268,8 +269,14 @@ func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
return done, err
}
- if wanted > done {
- // Partial write due to full pipe.
+ if done < avail {
+ // Non-failure, but short write.
+ return done, nil
+ }
+ if done < wanted {
+ // Partial write due to full pipe. Note that this could also be
+ // the short write case above, we would expect a second call
+ // and the write to return zero bytes in this case.
return done, syserror.ErrWouldBlock
}
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index a5675bd70..b54f08a30 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -49,38 +49,42 @@ type VFSPipe struct {
}
// NewVFSPipe returns an initialized VFSPipe.
-func NewVFSPipe(sizeBytes, atomicIOBytes int64) *VFSPipe {
+func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe {
var vp VFSPipe
- initPipe(&vp.pipe, true /* isNamed */, sizeBytes, atomicIOBytes)
+ initPipe(&vp.pipe, isNamed, sizeBytes, atomicIOBytes)
return &vp
}
-// NewVFSPipeFD opens a named pipe. Named pipes have special blocking semantics
-// during open:
+// ReaderWriterPair returns read-only and write-only FDs for vp.
//
-// "Normally, opening the FIFO blocks until the other end is opened also. A
-// process can open a FIFO in nonblocking mode. In this case, opening for
-// read-only will succeed even if no-one has opened on the write side yet,
-// opening for write-only will fail with ENXIO (no such device or address)
-// unless the other end has already been opened. Under Linux, opening a FIFO
-// for read and write will succeed both in blocking and nonblocking mode. POSIX
-// leaves this behavior undefined. This can be used to open a FIFO for writing
-// while there are no readers available." - fifo(7)
-func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) {
+// Preconditions: statusFlags should not contain an open access mode.
+func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) {
+ return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags)
+}
+
+// Open opens the pipe represented by vp.
+func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, error) {
vp.mu.Lock()
defer vp.mu.Unlock()
- readable := vfs.MayReadFileWithOpenFlags(flags)
- writable := vfs.MayWriteFileWithOpenFlags(flags)
+ readable := vfs.MayReadFileWithOpenFlags(statusFlags)
+ writable := vfs.MayWriteFileWithOpenFlags(statusFlags)
if !readable && !writable {
return nil, syserror.EINVAL
}
- vfd, err := vp.open(vfsd, vfsfd, flags)
- if err != nil {
- return nil, err
- }
+ fd := vp.newFD(mnt, vfsd, statusFlags)
+ // Named pipes have special blocking semantics during open:
+ //
+ // "Normally, opening the FIFO blocks until the other end is opened also. A
+ // process can open a FIFO in nonblocking mode. In this case, opening for
+ // read-only will succeed even if no-one has opened on the write side yet,
+ // opening for write-only will fail with ENXIO (no such device or address)
+ // unless the other end has already been opened. Under Linux, opening a
+ // FIFO for read and write will succeed both in blocking and nonblocking
+ // mode. POSIX leaves this behavior undefined. This can be used to open a
+ // FIFO for writing while there are no readers available." - fifo(7)
switch {
case readable && writable:
// Pipes opened for read-write always succeed without blocking.
@@ -89,23 +93,26 @@ func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, vfsd *vfs.Dentry, vfsfd *vf
case readable:
newHandleLocked(&vp.rWakeup)
- // If this pipe is being opened as nonblocking and there's no
+ // If this pipe is being opened as blocking and there's no
// writer, we have to wait for a writer to open the other end.
- if flags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) {
+ if vp.pipe.isNamed && statusFlags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) {
+ fd.DecRef()
return nil, syserror.EINTR
}
case writable:
newHandleLocked(&vp.wWakeup)
- if !vp.pipe.HasReaders() {
- // Nonblocking, write-only opens fail with ENXIO when
- // the read side isn't open yet.
- if flags&linux.O_NONBLOCK != 0 {
+ if vp.pipe.isNamed && !vp.pipe.HasReaders() {
+ // Non-blocking, write-only opens fail with ENXIO when the read
+ // side isn't open yet.
+ if statusFlags&linux.O_NONBLOCK != 0 {
+ fd.DecRef()
return nil, syserror.ENXIO
}
// Wait for a reader to open the other end.
if !waitFor(&vp.mu, &vp.rWakeup, ctx) {
+ fd.DecRef()
return nil, syserror.EINTR
}
}
@@ -114,96 +121,93 @@ func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, vfsd *vfs.Dentry, vfsfd *vf
panic("invalid pipe flags: must be readable, writable, or both")
}
- return vfd, nil
+ return fd, nil
}
// Preconditions: vp.mu must be held.
-func (vp *VFSPipe) open(vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) {
- var fd VFSPipeFD
- fd.flags = flags
- fd.readable = vfs.MayReadFileWithOpenFlags(flags)
- fd.writable = vfs.MayWriteFileWithOpenFlags(flags)
- fd.vfsfd = vfsfd
- fd.pipe = &vp.pipe
+func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) *vfs.FileDescription {
+ fd := &VFSPipeFD{
+ pipe: &vp.pipe,
+ }
+ fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ })
switch {
- case fd.readable && fd.writable:
+ case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable():
vp.pipe.rOpen()
vp.pipe.wOpen()
- case fd.readable:
+ case fd.vfsfd.IsReadable():
vp.pipe.rOpen()
- case fd.writable:
+ case fd.vfsfd.IsWritable():
vp.pipe.wOpen()
default:
panic("invalid pipe flags: must be readable, writable, or both")
}
- return &fd, nil
+ return &fd.vfsfd
}
-// VFSPipeFD implements a subset of vfs.FileDescriptionImpl for pipes. It is
-// expected that filesystesm will use this in a struct implementing
-// vfs.FileDescriptionImpl.
+// VFSPipeFD implements vfs.FileDescriptionImpl for pipes.
type VFSPipeFD struct {
- pipe *Pipe
- flags uint32
- readable bool
- writable bool
- vfsfd *vfs.FileDescription
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+
+ pipe *Pipe
}
// Release implements vfs.FileDescriptionImpl.Release.
func (fd *VFSPipeFD) Release() {
var event waiter.EventMask
- if fd.readable {
+ if fd.vfsfd.IsReadable() {
fd.pipe.rClose()
- event |= waiter.EventIn
+ event |= waiter.EventOut
}
- if fd.writable {
+ if fd.vfsfd.IsWritable() {
fd.pipe.wClose()
- event |= waiter.EventOut
+ event |= waiter.EventIn | waiter.EventHUp
}
if event == 0 {
panic("invalid pipe flags: must be readable, writable, or both")
}
- if fd.writable {
- fd.vfsfd.VirtualDentry().Mount().EndWrite()
- }
-
fd.pipe.Notify(event)
}
-// OnClose implements vfs.FileDescriptionImpl.OnClose.
-func (fd *VFSPipeFD) OnClose(_ context.Context) error {
- return nil
+// Readiness implements waiter.Waitable.Readiness.
+func (fd *VFSPipeFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ switch {
+ case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable():
+ return fd.pipe.rwReadiness()
+ case fd.vfsfd.IsReadable():
+ return fd.pipe.rReadiness()
+ case fd.vfsfd.IsWritable():
+ return fd.pipe.wReadiness()
+ default:
+ panic("pipe FD is neither readable nor writable")
+ }
}
-// PRead implements vfs.FileDescriptionImpl.PRead.
-func (fd *VFSPipeFD) PRead(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.ReadOptions) (int64, error) {
- return 0, syserror.ESPIPE
+// EventRegister implements waiter.Waitable.EventRegister.
+func (fd *VFSPipeFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fd.pipe.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (fd *VFSPipeFD) EventUnregister(e *waiter.Entry) {
+ fd.pipe.EventUnregister(e)
}
// Read implements vfs.FileDescriptionImpl.Read.
func (fd *VFSPipeFD) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
- if !fd.readable {
- return 0, syserror.EINVAL
- }
-
return fd.pipe.Read(ctx, dst)
}
-// PWrite implements vfs.FileDescriptionImpl.PWrite.
-func (fd *VFSPipeFD) PWrite(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.WriteOptions) (int64, error) {
- return 0, syserror.ESPIPE
-}
-
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
- if !fd.writable {
- return 0, syserror.EINVAL
- }
-
return fd.pipe.Write(ctx, src)
}
@@ -211,3 +215,17 @@ func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.Wr
func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
return fd.pipe.Ioctl(ctx, uio, args)
}
+
+// PipeSize implements fcntl(F_GETPIPE_SZ).
+func (fd *VFSPipeFD) PipeSize() int64 {
+ // Inline Pipe.FifoSize() rather than calling it with nil Context and
+ // fs.File and ignoring the returned error (which is always nil).
+ fd.pipe.mu.Lock()
+ defer fd.pipe.mu.Unlock()
+ return fd.pipe.max
+}
+
+// SetPipeSize implements fcntl(F_SETPIPE_SZ).
+func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) {
+ return fd.pipe.SetFifoSize(size)
+}
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 35ad97d5d..e23e796ef 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -184,7 +184,6 @@ func (t *Task) CanTrace(target *Task, attach bool) bool {
if targetCreds.PermittedCaps&^callerCreds.PermittedCaps != 0 {
return false
}
- // TODO: Yama LSM
return true
}
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
index ded95f532..18416643b 100644
--- a/pkg/sentry/kernel/rseq.go
+++ b/pkg/sentry/kernel/rseq.go
@@ -304,7 +304,7 @@ func (t *Task) rseqAddrInterrupt() {
}
var cs linux.RSeqCriticalSection
- if err := cs.CopyIn(t, critAddr); err != nil {
+ if _, err := cs.CopyIn(t, critAddr); err != nil {
t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err)
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index 208569057..f66cfcc7f 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -461,7 +461,7 @@ func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.A
func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.AddrRange, _ uint64, _ bool) {
s.mu.Lock()
defer s.mu.Unlock()
- // TODO(b/38173783): RemoveMapping may be called during task exit, when ctx
+ // RemoveMapping may be called during task exit, when ctx
// is context.Background. Gracefully handle missing clocks. Failing to
// update the detach time in these cases is ok, since no one can observe the
// omission.
diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go
index 93c4fe969..84156d5a1 100644
--- a/pkg/sentry/kernel/syscalls.go
+++ b/pkg/sentry/kernel/syscalls.go
@@ -209,65 +209,61 @@ type Stracer interface {
// SyscallEnter is called on syscall entry.
//
// The returned private data is passed to SyscallExit.
- //
- // TODO(gvisor.dev/issue/155): remove kernel imports from the strace
- // package so that the type can be used directly.
SyscallEnter(t *Task, sysno uintptr, args arch.SyscallArguments, flags uint32) interface{}
// SyscallExit is called on syscall exit.
SyscallExit(context interface{}, t *Task, sysno, rval uintptr, err error)
}
-// SyscallTable is a lookup table of system calls. Critically, a SyscallTable
-// is *immutable*. In order to make supporting suspend and resume sane, they
-// must be uniquely registered and may not change during operation.
+// SyscallTable is a lookup table of system calls.
//
-// +stateify savable
+// Note that a SyscallTable is not savable directly. Instead, they are saved as
+// an OS/Arch pair and lookup happens again on restore.
type SyscallTable struct {
// OS is the operating system that this syscall table implements.
- OS abi.OS `state:"wait"`
+ OS abi.OS
// Arch is the architecture that this syscall table targets.
- Arch arch.Arch `state:"wait"`
+ Arch arch.Arch
// The OS version that this syscall table implements.
- Version Version `state:"manual"`
+ Version Version
// AuditNumber is a numeric constant that represents the syscall table. If
// non-zero, auditNumber must be one of the AUDIT_ARCH_* values defined by
// linux/audit.h.
- AuditNumber uint32 `state:"manual"`
+ AuditNumber uint32
// Table is the collection of functions.
- Table map[uintptr]Syscall `state:"manual"`
+ Table map[uintptr]Syscall
// lookup is a fixed-size array that holds the syscalls (indexed by
// their numbers). It is used for fast look ups.
- lookup []SyscallFn `state:"manual"`
+ lookup []SyscallFn
// Emulate is a collection of instruction addresses to emulate. The
// keys are addresses, and the values are system call numbers.
- Emulate map[usermem.Addr]uintptr `state:"manual"`
+ Emulate map[usermem.Addr]uintptr
// The function to call in case of a missing system call.
- Missing MissingFn `state:"manual"`
+ Missing MissingFn
// Stracer traces this syscall table.
- Stracer Stracer `state:"manual"`
+ Stracer Stracer
// External is used to handle an external callback.
- External func(*Kernel) `state:"manual"`
+ External func(*Kernel)
// ExternalFilterBefore is called before External is called before the syscall is executed.
// External is not called if it returns false.
- ExternalFilterBefore func(*Task, uintptr, arch.SyscallArguments) bool `state:"manual"`
+ ExternalFilterBefore func(*Task, uintptr, arch.SyscallArguments) bool
// ExternalFilterAfter is called before External is called after the syscall is executed.
// External is not called if it returns false.
- ExternalFilterAfter func(*Task, uintptr, arch.SyscallArguments) bool `state:"manual"`
+ ExternalFilterAfter func(*Task, uintptr, arch.SyscallArguments) bool
// FeatureEnable stores the strace and one-shot enable bits.
- FeatureEnable SyscallFlagsTable `state:"manual"`
+ FeatureEnable SyscallFlagsTable
}
// allSyscallTables contains all known tables.
@@ -330,6 +326,13 @@ func RegisterSyscallTable(s *SyscallTable) {
allSyscallTables = append(allSyscallTables, s)
}
+// FlushSyscallTablesTestOnly flushes the syscall tables for tests. Used for
+// parameterized VFSv2 tests.
+// TODO(gvisor.dv/issue/1624): Remove when VFS1 is no longer supported.
+func FlushSyscallTablesTestOnly() {
+ allSyscallTables = nil
+}
+
// Lookup returns the syscall implementation, if one exists.
func (s *SyscallTable) Lookup(sysno uintptr) SyscallFn {
if sysno < uintptr(len(s.lookup)) {
diff --git a/pkg/sentry/kernel/syscalls_state.go b/pkg/sentry/kernel/syscalls_state.go
index 00358326b..90f890495 100644
--- a/pkg/sentry/kernel/syscalls_state.go
+++ b/pkg/sentry/kernel/syscalls_state.go
@@ -14,16 +14,34 @@
package kernel
-import "fmt"
+import (
+ "fmt"
-// afterLoad is invoked by stateify.
-func (s *SyscallTable) afterLoad() {
- otherTable, ok := LookupSyscallTable(s.OS, s.Arch)
- if !ok {
- // Couldn't find a reference?
- panic(fmt.Sprintf("syscall table not found for OS %v Arch %v", s.OS, s.Arch))
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// syscallTableInfo is used to reload the SyscallTable.
+//
+// +stateify savable
+type syscallTableInfo struct {
+ OS abi.OS
+ Arch arch.Arch
+}
+
+// saveSt saves the SyscallTable.
+func (tc *TaskContext) saveSt() syscallTableInfo {
+ return syscallTableInfo{
+ OS: tc.st.OS,
+ Arch: tc.st.Arch,
}
+}
- // Copy the table.
- *s = *otherTable
+// loadSt loads the SyscallTable.
+func (tc *TaskContext) loadSt(sti syscallTableInfo) {
+ st, ok := LookupSyscallTable(sti.OS, sti.Arch)
+ if !ok {
+ panic(fmt.Sprintf("syscall table not found for OS %v, Arch %v", sti.OS, sti.Arch))
+ }
+ tc.st = st // Save the table reference.
}
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index d6546735e..e5d133d6c 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -777,6 +777,15 @@ func (t *Task) NewFDs(fd int32, files []*fs.File, flags FDFlags) ([]int32, error
return t.fdTable.NewFDs(t, fd, files, flags)
}
+// NewFDsVFS2 is a convenience wrapper for t.FDTable().NewFDsVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.
+func (t *Task) NewFDsVFS2(fd int32, files []*vfs.FileDescription, flags FDFlags) ([]int32, error) {
+ return t.fdTable.NewFDsVFS2(t, fd, files, flags)
+}
+
// NewFDFrom is a convenience wrapper for t.FDTable().NewFDs with a single file.
//
// This automatically passes the task as the context.
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index 0158b1788..9fa528384 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -49,7 +49,7 @@ type TaskContext struct {
fu *futex.Manager
// st is the task's syscall table.
- st *SyscallTable
+ st *SyscallTable `state:".(syscallTableInfo)"`
}
// release releases all resources held by the TaskContext. release is called by
@@ -58,7 +58,6 @@ func (tc *TaskContext) release() {
// Nil out pointers so that if the task is saved after release, it doesn't
// follow the pointers to possibly now-invalid objects.
if tc.MemoryManager != nil {
- // TODO(b/38173783)
tc.MemoryManager.DecUsers(context.Background())
tc.MemoryManager = nil
}
diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go
index ce3e6ef28..0325967e4 100644
--- a/pkg/sentry/kernel/task_identity.go
+++ b/pkg/sentry/kernel/task_identity.go
@@ -455,7 +455,7 @@ func (t *Task) SetKeepCaps(k bool) {
t.creds.Store(creds)
}
-// updateCredsForExec updates t.creds to reflect an execve().
+// updateCredsForExecLocked updates t.creds to reflect an execve().
//
// NOTE(b/30815691): We currently do not implement privileged executables
// (set-user/group-ID bits and file capabilities). This allows us to make a lot
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index 799cbcd93..2ba8d7e63 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -353,7 +353,7 @@ func (app *runApp) execute(t *Task) taskRunState {
default:
// What happened? Can't continue.
t.Warningf("Unexpected SwitchToApp error: %v", err)
- t.PrepareExit(ExitStatus{Code: t.ExtractErrno(err, -1)})
+ t.PrepareExit(ExitStatus{Code: ExtractErrno(err, -1)})
return (*runExit)(nil)
}
}
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index 8802db142..f07de2089 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -174,7 +174,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS
fallthrough
case (sre == ERESTARTSYS && !act.IsRestart()):
t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo)
- t.Arch().SetReturn(uintptr(-t.ExtractErrno(syserror.EINTR, -1)))
+ t.Arch().SetReturn(uintptr(-ExtractErrno(syserror.EINTR, -1)))
default:
t.Debugf("Restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo)
t.Arch().RestartSyscall()
@@ -513,8 +513,6 @@ func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool {
if t.stop != nil {
return false
}
- // - TODO(b/38173783): No special case for when t is also the sending task,
- // because the identity of the sender is unknown.
// - Do not choose tasks that have already been interrupted, as they may be
// busy handling another signal.
if len(t.interruptChan) != 0 {
diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go
index d555d69a8..c9db78e06 100644
--- a/pkg/sentry/kernel/task_syscall.go
+++ b/pkg/sentry/kernel/task_syscall.go
@@ -194,6 +194,19 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
//
// The syscall path is very hot; avoid defer.
func (t *Task) doSyscall() taskRunState {
+ // Save value of the register which is clobbered in the following
+ // t.Arch().SetReturn(-ENOSYS) operation. This is dedicated to arm64.
+ //
+ // On x86, register rax was shared by syscall number and return
+ // value, and at the entry of the syscall handler, the rax was
+ // saved to regs.orig_rax which was exposed to user space.
+ // But on arm64, syscall number was passed through X8, and the X0
+ // was shared by the first syscall argument and return value. The
+ // X0 was saved to regs.orig_x0 which was not exposed to user space.
+ // So we have to do the same operation here to save the X0 value
+ // into the task context.
+ t.Arch().SyscallSaveOrig()
+
sysno := t.Arch().SyscallNo()
args := t.Arch().SyscallArgs()
@@ -269,6 +282,7 @@ func (*runSyscallAfterSyscallEnterStop) execute(t *Task) taskRunState {
return (*runSyscallExit)(nil)
}
args := t.Arch().SyscallArgs()
+
return t.doSyscallInvoke(sysno, args)
}
@@ -298,7 +312,7 @@ func (t *Task) doSyscallInvoke(sysno uintptr, args arch.SyscallArguments) taskRu
return ctrl.next
}
} else if err != nil {
- t.Arch().SetReturn(uintptr(-t.ExtractErrno(err, int(sysno))))
+ t.Arch().SetReturn(uintptr(-ExtractErrno(err, int(sysno))))
t.haveSyscallReturn = true
} else {
t.Arch().SetReturn(rval)
@@ -417,7 +431,7 @@ func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, calle
// A return is not emulated in this case.
return (*runApp)(nil)
}
- t.Arch().SetReturn(uintptr(-t.ExtractErrno(err, int(sysno))))
+ t.Arch().SetReturn(uintptr(-ExtractErrno(err, int(sysno))))
}
t.Arch().SetIP(t.Arch().Value(caller))
t.Arch().SetStack(t.Arch().Stack() + uintptr(t.Arch().Width()))
@@ -427,7 +441,7 @@ func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, calle
// ExtractErrno extracts an integer error number from the error.
// The syscall number is purely for context in the error case. Use -1 if
// syscall number is unknown.
-func (t *Task) ExtractErrno(err error, sysno int) int {
+func ExtractErrno(err error, sysno int) int {
switch err := err.(type) {
case nil:
return 0
@@ -441,11 +455,11 @@ func (t *Task) ExtractErrno(err error, sysno int) int {
// handled (and the SIGBUS is delivered).
return int(syscall.EFAULT)
case *os.PathError:
- return t.ExtractErrno(err.Err, sysno)
+ return ExtractErrno(err.Err, sysno)
case *os.LinkError:
- return t.ExtractErrno(err.Err, sysno)
+ return ExtractErrno(err.Err, sysno)
case *os.SyscallError:
- return t.ExtractErrno(err.Err, sysno)
+ return ExtractErrno(err.Err, sysno)
default:
if errno, ok := syserror.TranslateError(err); ok {
return int(errno)
diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go
index 706de83ef..e959700f2 100644
--- a/pkg/sentry/kernel/time/time.go
+++ b/pkg/sentry/kernel/time/time.go
@@ -245,7 +245,7 @@ type Clock interface {
type WallRateClock struct{}
// WallTimeUntil implements Clock.WallTimeUntil.
-func (WallRateClock) WallTimeUntil(t, now Time) time.Duration {
+func (*WallRateClock) WallTimeUntil(t, now Time) time.Duration {
return t.Sub(now)
}
@@ -254,16 +254,16 @@ func (WallRateClock) WallTimeUntil(t, now Time) time.Duration {
type NoClockEvents struct{}
// Readiness implements waiter.Waitable.Readiness.
-func (NoClockEvents) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (*NoClockEvents) Readiness(mask waiter.EventMask) waiter.EventMask {
return 0
}
// EventRegister implements waiter.Waitable.EventRegister.
-func (NoClockEvents) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+func (*NoClockEvents) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
}
// EventUnregister implements waiter.Waitable.EventUnregister.
-func (NoClockEvents) EventUnregister(e *waiter.Entry) {
+func (*NoClockEvents) EventUnregister(e *waiter.Entry) {
}
// ClockEventsQueue implements waiter.Waitable by wrapping waiter.Queue and
@@ -273,7 +273,7 @@ type ClockEventsQueue struct {
}
// Readiness implements waiter.Waitable.Readiness.
-func (ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (*ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask {
return 0
}
diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go
index 0332fc71c..5c667117c 100644
--- a/pkg/sentry/mm/address_space.go
+++ b/pkg/sentry/mm/address_space.go
@@ -201,8 +201,10 @@ func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar usermem.AddrRange, pre
if pma.needCOW {
perms.Write = false
}
- if err := mm.as.MapFile(pmaMapAR.Start, pma.file, pseg.fileRangeOf(pmaMapAR), perms, precommit); err != nil {
- return err
+ if perms.Any() { // MapFile precondition
+ if err := mm.as.MapFile(pmaMapAR.Start, pma.file, pseg.fileRangeOf(pmaMapAR), perms, precommit); err != nil {
+ return err
+ }
}
pseg = pseg.NextSegment()
}
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index cb29d94b0..379148903 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -59,25 +59,27 @@ func (a *aioManager) newAIOContext(events uint32, id uint64) bool {
}
a.contexts[id] = &AIOContext{
- done: make(chan struct{}, 1),
+ requestReady: make(chan struct{}, 1),
maxOutstanding: events,
}
return true
}
-// destroyAIOContext destroys an asynchronous I/O context.
+// destroyAIOContext destroys an asynchronous I/O context. It doesn't wait for
+// for pending requests to complete. Returns the destroyed AIOContext so it can
+// be drained.
//
-// False is returned if the context does not exist.
-func (a *aioManager) destroyAIOContext(id uint64) bool {
+// Nil is returned if the context does not exist.
+func (a *aioManager) destroyAIOContext(id uint64) *AIOContext {
a.mu.Lock()
defer a.mu.Unlock()
ctx, ok := a.contexts[id]
if !ok {
- return false
+ return nil
}
delete(a.contexts, id)
ctx.destroy()
- return true
+ return ctx
}
// lookupAIOContext looks up the given context.
@@ -102,8 +104,8 @@ type ioResult struct {
//
// +stateify savable
type AIOContext struct {
- // done is the notification channel used for all requests.
- done chan struct{} `state:"nosave"`
+ // requestReady is the notification channel used for all requests.
+ requestReady chan struct{} `state:"nosave"`
// mu protects below.
mu sync.Mutex `state:"nosave"`
@@ -129,8 +131,14 @@ func (ctx *AIOContext) destroy() {
ctx.mu.Lock()
defer ctx.mu.Unlock()
ctx.dead = true
- if ctx.outstanding == 0 {
- close(ctx.done)
+ ctx.checkForDone()
+}
+
+// Preconditions: ctx.mu must be held by caller.
+func (ctx *AIOContext) checkForDone() {
+ if ctx.dead && ctx.outstanding == 0 {
+ close(ctx.requestReady)
+ ctx.requestReady = nil
}
}
@@ -154,11 +162,12 @@ func (ctx *AIOContext) PopRequest() (interface{}, bool) {
// Is there anything ready?
if e := ctx.results.Front(); e != nil {
- ctx.results.Remove(e)
- ctx.outstanding--
- if ctx.outstanding == 0 && ctx.dead {
- close(ctx.done)
+ if ctx.outstanding == 0 {
+ panic("AIOContext outstanding is going negative")
}
+ ctx.outstanding--
+ ctx.results.Remove(e)
+ ctx.checkForDone()
return e.data, true
}
return nil, false
@@ -172,26 +181,58 @@ func (ctx *AIOContext) FinishRequest(data interface{}) {
// Push to the list and notify opportunistically. The channel notify
// here is guaranteed to be safe because outstanding must be non-zero.
- // The done channel is only closed when outstanding reaches zero.
+ // The requestReady channel is only closed when outstanding reaches zero.
ctx.results.PushBack(&ioResult{data: data})
select {
- case ctx.done <- struct{}{}:
+ case ctx.requestReady <- struct{}{}:
default:
}
}
// WaitChannel returns a channel that is notified when an AIO request is
-// completed.
-//
-// The boolean return value indicates whether or not the context is active.
-func (ctx *AIOContext) WaitChannel() (chan struct{}, bool) {
+// completed. Returns nil if the context is destroyed and there are no more
+// outstanding requests.
+func (ctx *AIOContext) WaitChannel() chan struct{} {
ctx.mu.Lock()
defer ctx.mu.Unlock()
- if ctx.outstanding == 0 && ctx.dead {
- return nil, false
+ return ctx.requestReady
+}
+
+// Dead returns true if the context has been destroyed.
+func (ctx *AIOContext) Dead() bool {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+ return ctx.dead
+}
+
+// CancelPendingRequest forgets about a request that hasn't yet completed.
+func (ctx *AIOContext) CancelPendingRequest() {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+
+ if ctx.outstanding == 0 {
+ panic("AIOContext outstanding is going negative")
}
- return ctx.done, true
+ ctx.outstanding--
+ ctx.checkForDone()
+}
+
+// Drain drops all completed requests. Pending requests remain untouched.
+func (ctx *AIOContext) Drain() {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+
+ if ctx.outstanding == 0 {
+ return
+ }
+ size := uint32(ctx.results.Len())
+ if ctx.outstanding < size {
+ panic("AIOContext outstanding is going negative")
+ }
+ ctx.outstanding -= size
+ ctx.results.Reset()
+ ctx.checkForDone()
}
// aioMappable implements memmap.MappingIdentity and memmap.Mappable for AIO
@@ -332,9 +373,9 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint
Length: aioRingBufferSize,
MappingIdentity: m,
Mappable: m,
- // TODO(fvoznika): Linux does "do_mmap_pgoff(..., PROT_READ |
- // PROT_WRITE, ...)" in fs/aio.c:aio_setup_ring(); why do we make this
- // mapping read-only?
+ // Linux uses "do_mmap_pgoff(..., PROT_READ | PROT_WRITE, ...)" in
+ // fs/aio.c:aio_setup_ring(). Since we don't implement AIO_RING_MAGIC,
+ // user mode should not write to this page.
Perms: usermem.Read,
MaxPerms: usermem.Read,
})
@@ -349,11 +390,11 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint
return id, nil
}
-// DestroyAIOContext destroys an asynchronous I/O context. It returns false if
-// the context does not exist.
-func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) bool {
+// DestroyAIOContext destroys an asynchronous I/O context. It returns the
+// destroyed context. nil if the context does not exist.
+func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext {
if _, ok := mm.LookupAIOContext(ctx, id); !ok {
- return false
+ return nil
}
// Only unmaps after it assured that the address is a valid aio context to
diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go
index c37fc9f7b..3dabac1af 100644
--- a/pkg/sentry/mm/aio_context_state.go
+++ b/pkg/sentry/mm/aio_context_state.go
@@ -16,5 +16,5 @@ package mm
// afterLoad is invoked by stateify.
func (a *AIOContext) afterLoad() {
- a.done = make(chan struct{}, 1)
+ a.requestReady = make(chan struct{}, 1)
}
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index d8a5b9d29..aac56679b 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -84,6 +84,7 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
dumpability: mm.dumpability,
aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
sleepForActivation: mm.sleepForActivation,
+ vdsoSigReturnAddr: mm.vdsoSigReturnAddr,
}
// Copy vmas.
diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go
index 6a49334f4..28e5057f7 100644
--- a/pkg/sentry/mm/metadata.go
+++ b/pkg/sentry/mm/metadata.go
@@ -167,3 +167,17 @@ func (mm *MemoryManager) SetExecutable(file fsbridge.File) {
orig.DecRef()
}
}
+
+// VDSOSigReturn returns the address of vdso_sigreturn.
+func (mm *MemoryManager) VDSOSigReturn() uint64 {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return mm.vdsoSigReturnAddr
+}
+
+// SetVDSOSigReturn sets the address of vdso_sigreturn.
+func (mm *MemoryManager) SetVDSOSigReturn(addr uint64) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.vdsoSigReturnAddr = addr
+}
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index c2195ae11..34d3bde7a 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -231,6 +231,9 @@ type MemoryManager struct {
// before trying to activate the address space. When set to true, delays in
// activation are not reported as stuck tasks by the watchdog.
sleepForActivation bool
+
+ // vdsoSigReturnAddr is the address of 'vdso_sigreturn'.
+ vdsoSigReturnAddr uint64
}
// vma represents a virtual memory area.
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index e27f57536..159f7eafd 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -70,7 +70,6 @@ go_test(
"requires-kvm",
],
deps = [
- "//pkg/procid",
"//pkg/sentry/arch",
"//pkg/sentry/platform",
"//pkg/sentry/platform/kvm/testutil",
diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go
index 555b5fa96..4b23f7803 100644
--- a/pkg/sentry/platform/kvm/bluepill.go
+++ b/pkg/sentry/platform/kvm/bluepill.go
@@ -46,14 +46,6 @@ var (
// bounceSignalMask has only bounceSignal set.
bounceSignalMask = uint64(1 << (uint64(bounceSignal) - 1))
- // otherSignalsMask includes all other signals that will be cause the
- // vCPU to exit during execution.
- //
- // Currently, this includes the preemption signal and the profiling
- // signal. In general, these should be signals whose delivery actually
- // influences the way the program executes as the switch can be costly.
- otherSignalsMask = uint64(1<<(uint64(syscall.SIGURG)-1)) | uint64(1<<(uint64(syscall.SIGPROF)-1))
-
// bounce is the interrupt vector used to return to the kernel.
bounce = uint32(ring0.VirtualizationException)
@@ -94,8 +86,8 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) {
}
func init() {
- // Install the handler, masking all signals.
- if err := safecopy.ReplaceSignalHandler(bluepillSignal, reflect.ValueOf(sighandler).Pointer(), &savedHandler, ^uint64(0)); err != nil {
+ // Install the handler.
+ if err := safecopy.ReplaceSignalHandler(bluepillSignal, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err))
}
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 4e9d80765..9add7c944 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -24,7 +24,6 @@ import (
"syscall"
"unsafe"
- "gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/sentry/arch"
)
@@ -59,19 +58,6 @@ func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 {
return &((*arch.UContext64)(context).MContext)
}
-// injectInterrupt is a helper to inject an interrupt.
-//
-//go:nosplit
-func injectInterrupt(c *vCPU) {
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_INTERRUPT,
- uintptr(unsafe.Pointer(&bounce))); errno != 0 {
- throw("interrupt injection failed")
- }
-}
-
// bluepillHandler is called from the signal stub.
//
// The world may be stopped while this is executing, and it executes on the
@@ -83,9 +69,6 @@ func bluepillHandler(context unsafe.Pointer) {
// Sanitize the registers; interrupts must always be disabled.
c := bluepillArchEnter(bluepillArchContext(context))
- // Enable preemption.
- c.setSignalMask(true)
-
// Increment the number of switches.
atomic.AddUint32(&c.switches, 1)
@@ -106,9 +89,6 @@ func bluepillHandler(context unsafe.Pointer) {
// interrupted KVM. Since we're in a signal handler
// currently, all signals are masked and the signal
// must have been delivered directly to this thread.
- //
- // We will not be able to actually do subsequent
- // KVM_RUNs until this signal is processed.
timeout := syscall.Timespec{}
sig, _, errno := syscall.RawSyscall6(
syscall.SYS_RT_SIGTIMEDWAIT,
@@ -118,24 +98,12 @@ func bluepillHandler(context unsafe.Pointer) {
8, // sigset size.
0, 0)
if errno == syscall.EAGAIN {
- // If weren't able to process this signal, then
- // it must not have been in the bounceMask. By
- // elimination, it must have been the
- // preemption signal. We can't process this
- // signal right now, so we need to disable
- // preemption until the interrupt is actually
- // handled.
- c.setSignalMask(false)
- // Note that there is a waiter for this vCPU.
- // This will cause the vCPU to exit at some
- // point in the future (releasing the user lock
- // and guest mode).
- atomicbitops.OrUint32(&c.state, vCPUWaiter)
- } else if errno != 0 {
- // We only expect success or a timeout.
+ continue
+ }
+ if errno != 0 {
throw("error waiting for pending signal")
- } else if sig != uintptr(bounceSignal) {
- // Only the bounce should be processed.
+ }
+ if sig != uintptr(bounceSignal) {
throw("unexpected signal")
}
@@ -146,10 +114,11 @@ func bluepillHandler(context unsafe.Pointer) {
// ready.
if c.runData.readyForInterruptInjection == 0 {
c.runData.requestInterruptWindow = 1
+ continue // Rerun vCPU.
} else {
- injectInterrupt(c)
+ // Force injection below; the vCPU is ready.
+ c.runData.exitReason = _KVM_EXIT_IRQ_WINDOW_OPEN
}
- continue // Rerun vCPU.
case syscall.EFAULT:
// If a fault is not serviceable due to the host
// backing pages having page permissions, instead of an
@@ -168,30 +137,6 @@ func bluepillHandler(context unsafe.Pointer) {
}
switch c.runData.exitReason {
- case _KVM_EXIT_HLT:
- // Copy out registers.
- bluepillArchExit(c, bluepillArchContext(context))
-
- // Return to the vCPUReady state; notify any waiters.
- user := atomic.LoadUint32(&c.state) & vCPUUser
- switch atomic.SwapUint32(&c.state, user) {
- case user | vCPUGuest: // Expected case.
- case user | vCPUGuest | vCPUWaiter:
- c.notify()
- default:
- throw("invalid state")
- }
- return
- case _KVM_EXIT_IRQ_WINDOW_OPEN:
- // Inject an interrupt now.
- injectInterrupt(c)
- // Clear previous injection request.
- c.runData.requestInterruptWindow = 0
- case _KVM_EXIT_INTR:
- // This is fine, it is the normal exit reason during
- // signal delivery. However, we still need to handle
- // other potential exit reasons *combined* with EINTR,
- // so this switch must be hit even after the above.
case _KVM_EXIT_EXCEPTION:
c.die(bluepillArchContext(context), "exception")
return
@@ -210,6 +155,20 @@ func bluepillHandler(context unsafe.Pointer) {
case _KVM_EXIT_DEBUG:
c.die(bluepillArchContext(context), "debug")
return
+ case _KVM_EXIT_HLT:
+ // Copy out registers.
+ bluepillArchExit(c, bluepillArchContext(context))
+
+ // Return to the vCPUReady state; notify any waiters.
+ user := atomic.LoadUint32(&c.state) & vCPUUser
+ switch atomic.SwapUint32(&c.state, user) {
+ case user | vCPUGuest: // Expected case.
+ case user | vCPUGuest | vCPUWaiter:
+ c.notify()
+ default:
+ throw("invalid state")
+ }
+ return
case _KVM_EXIT_MMIO:
// Increment the fault count.
atomic.AddUint32(&c.faults, 1)
@@ -241,6 +200,18 @@ func bluepillHandler(context unsafe.Pointer) {
data[i] = *b
}
}
+ case _KVM_EXIT_IRQ_WINDOW_OPEN:
+ // Interrupt: we must have requested an interrupt
+ // window; set the interrupt line.
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_INTERRUPT,
+ uintptr(unsafe.Pointer(&bounce))); errno != 0 {
+ throw("interrupt injection failed")
+ }
+ // Clear previous injection request.
+ c.runData.requestInterruptWindow = 0
case _KVM_EXIT_SHUTDOWN:
c.die(bluepillArchContext(context), "shutdown")
return
diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go
index 79045651e..716198712 100644
--- a/pkg/sentry/platform/kvm/kvm_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_arm64.go
@@ -18,6 +18,8 @@ package kvm
import (
"syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
)
type kvmOneReg struct {
@@ -46,6 +48,6 @@ type userRegs struct {
func updateGlobalOnce(fd int) error {
physicalInit()
err := updateSystemValues(int(fd))
- updateVectorTable()
+ ring0.Init()
return err
}
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
index 07d9c9a98..1d5c77ff4 100644
--- a/pkg/sentry/platform/kvm/kvm_const.go
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -48,7 +48,6 @@ const (
_KVM_EXIT_IRQ_WINDOW_OPEN = 0x7
_KVM_EXIT_SHUTDOWN = 0x8
_KVM_EXIT_FAIL_ENTRY = 0x9
- _KVM_EXIT_INTR = 0xa
_KVM_EXIT_INTERNAL_ERROR = 0x11
_KVM_EXIT_SYSTEM_EVENT = 0x18
)
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
index d42ba3f24..c42752d50 100644
--- a/pkg/sentry/platform/kvm/kvm_test.go
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -16,15 +16,12 @@ package kvm
import (
"math/rand"
- "os"
"reflect"
- "runtime"
"sync/atomic"
"syscall"
"testing"
"time"
- "gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil"
@@ -323,18 +320,15 @@ func TestBounce(t *testing.T) {
})
}
-// randomSleep is used by some race tests below.
-//
-// O(hundreds of microseconds) is appropriate to ensure different overlaps and
-// different schedules.
-func randomSleep() {
- if n := rand.Intn(1000); n > 100 {
- time.Sleep(time.Duration(n) * time.Microsecond)
- }
-}
-
func TestBounceStress(t *testing.T) {
applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ randomSleep := func() {
+ // O(hundreds of microseconds) is appropriate to ensure
+ // different overlaps and different schedules.
+ if n := rand.Intn(1000); n > 100 {
+ time.Sleep(time.Duration(n) * time.Microsecond)
+ }
+ }
for i := 0; i < 1000; i++ {
// Start an asynchronously executing goroutine that
// calls Bounce at pseudo-random point in time.
@@ -361,50 +355,6 @@ func TestBounceStress(t *testing.T) {
})
}
-func TestPreemption(t *testing.T) {
- applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
- // Lock the main vCPU thread.
- runtime.LockOSThread()
- pid := os.Getpid()
- tid := procid.Current()
- running := uint32(1)
- defer atomic.StoreUint32(&running, 0)
-
- // Start generating "preemptions".
- go func() {
- for atomic.LoadUint32(&running) != 0 {
- // Kick via a preemption: best effort.
- syscall.Tgkill(pid, int(tid), syscall.SIGURG)
- randomSleep()
- }
- }()
-
- for i := 0; i < 1000; i++ {
- randomSleep()
- var si arch.SignalInfo
- if _, err := c.SwitchToUser(ring0.SwitchOpts{
- Registers: regs,
- FloatingPointState: dummyFPState,
- PageTables: pt,
- }, &si); err != platform.ErrContextInterrupt {
- t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt)
- }
- // Was this caused by a preemption signal?
- if got := atomic.LoadUint32(&c.state); got&vCPUGuest != 0 && got&vCPUWaiter == 0 {
- continue
- }
- c.unlock()
- // Should have dropped from guest mode, processed preemption.
- if got := atomic.LoadUint32(&c.state); got != vCPUReady {
- t.Errorf("vCPU not in ready state: got %v", got)
- }
- randomSleep()
- c.lock()
- }
- return false
- })
-}
-
func TestInvalidate(t *testing.T) {
var data uintptr // Used below.
applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 345b71e8f..f1afc74dc 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -108,9 +108,6 @@ type vCPU struct {
// This is a bitmask of the three fields (vCPU*) described above.
state uint32
- // signalMask is the vCPU signal mask.
- signalMask uint64
-
// runData for this vCPU.
runData *runData
@@ -124,7 +121,6 @@ type vCPU struct {
// vCPUArchState is the architecture-specific state.
vCPUArchState
- // dieState is the temporary state associated with throwing exceptions.
dieState dieState
}
@@ -157,6 +153,11 @@ func (m *machine) newVCPU() *vCPU {
c.CPU.Init(&m.kernel, c)
m.vCPUsByID[c.id] = c
+ // Ensure the signal mask is correct.
+ if err := c.setSignalMask(); err != nil {
+ panic(fmt.Sprintf("error setting signal mask: %v", err))
+ }
+
// Map the run data.
runData, err := mapRunData(int(fd))
if err != nil {
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
index 52286e56d..7156c245f 100644
--- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -111,6 +111,31 @@ func (c *vCPU) setSystemTime() error {
return nil
}
+// setSignalMask sets the vCPU signal mask.
+//
+// This must be called prior to running the vCPU.
+func (c *vCPU) setSignalMask() error {
+ // The layout of this structure implies that it will not necessarily be
+ // the same layout chosen by the Go compiler. It gets fudged here.
+ var data struct {
+ length uint32
+ mask1 uint32
+ mask2 uint32
+ _ uint32
+ }
+ data.length = 8 // Fixed sigset size.
+ data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
+ data.mask2 = ^uint32(bounceSignalMask >> 32)
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_SIGNAL_MASK,
+ uintptr(unsafe.Pointer(&data))); errno != 0 {
+ return fmt.Errorf("error setting signal mask: %v", errno)
+ }
+ return nil
+}
+
// setUserRegisters sets user registers in the vCPU.
func (c *vCPU) setUserRegisters(uregs *userRegs) error {
if _, _, errno := syscall.RawSyscall(
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 185eeb4f0..3b35858ae 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -48,69 +48,6 @@ func (m *machine) initArchState() error {
return nil
}
-func getPageWithReflect(p uintptr) []byte {
- return (*(*[0xFFFFFF]byte)(unsafe.Pointer(p & ^uintptr(syscall.Getpagesize()-1))))[:syscall.Getpagesize()]
-}
-
-// Work around: move ring0.Vectors() into a specific address with 11-bits alignment.
-//
-// According to the design documentation of Arm64,
-// the start address of exception vector table should be 11-bits aligned.
-// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S
-// But, we can't align a function's start address to a specific address by using golang.
-// We have raised this question in golang community:
-// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I
-// This function will be removed when golang supports this feature.
-//
-// There are 2 jobs were implemented in this function:
-// 1, move the start address of exception vector table into the specific address.
-// 2, modify the offset of each instruction.
-func updateVectorTable() {
- fromLocation := reflect.ValueOf(ring0.Vectors).Pointer()
- offset := fromLocation & (1<<11 - 1)
- if offset != 0 {
- offset = 1<<11 - offset
- }
-
- toLocation := fromLocation + offset
- page := getPageWithReflect(toLocation)
- if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC); err != nil {
- panic(err)
- }
-
- page = getPageWithReflect(toLocation + 4096)
- if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC); err != nil {
- panic(err)
- }
-
- // Move exception-vector-table into the specific address.
- var entry *uint32
- var entryFrom *uint32
- for i := 1; i <= 0x800; i++ {
- entry = (*uint32)(unsafe.Pointer(toLocation + 0x800 - uintptr(i)))
- entryFrom = (*uint32)(unsafe.Pointer(fromLocation + 0x800 - uintptr(i)))
- *entry = *entryFrom
- }
-
- // The offset from the address of each unconditionally branch is changed.
- // We should modify the offset of each instruction.
- nums := []uint32{0x0, 0x80, 0x100, 0x180, 0x200, 0x280, 0x300, 0x380, 0x400, 0x480, 0x500, 0x580, 0x600, 0x680, 0x700, 0x780}
- for _, num := range nums {
- entry = (*uint32)(unsafe.Pointer(toLocation + uintptr(num)))
- *entry = *entry - (uint32)(offset/4)
- }
-
- page = getPageWithReflect(toLocation)
- if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC); err != nil {
- panic(err)
- }
-
- page = getPageWithReflect(toLocation + 4096)
- if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC); err != nil {
- panic(err)
- }
-}
-
// initArchState initializes architecture-specific state.
func (c *vCPU) initArchState() error {
var (
@@ -268,6 +205,32 @@ func (c *vCPU) setSystemTime() error {
return nil
}
+// setSignalMask sets the vCPU signal mask.
+//
+// This must be called prior to running the vCPU.
+func (c *vCPU) setSignalMask() error {
+ // The layout of this structure implies that it will not necessarily be
+ // the same layout chosen by the Go compiler. It gets fudged here.
+ var data struct {
+ length uint32
+ mask1 uint32
+ mask2 uint32
+ _ uint32
+ }
+ data.length = 8 // Fixed sigset size.
+ data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
+ data.mask2 = ^uint32(bounceSignalMask >> 32)
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_SIGNAL_MASK,
+ uintptr(unsafe.Pointer(&data))); errno != 0 {
+ return fmt.Errorf("error setting signal mask: %v", errno)
+ }
+
+ return nil
+}
+
// SwitchToUser unpacks architectural-details.
func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) {
// Check for canonical addresses.
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index e4de0a889..f04be2ab5 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -87,47 +87,6 @@ func unmapRunData(r *runData) error {
return nil
}
-// setSignalMask sets the vCPU signal mask.
-//
-// This will be called from the bluepill handler, and therefore must not
-// perform any allocation.
-//
-//go:nosplit
-func (c *vCPU) setSignalMask(enableOthers bool) {
- // The signal mask is either:
- // *) Only the bounce signal, which we need to use to execute the
- // machine state up until the bounce interrupt can be processed.
- // or
- // *) All signals, which is the default state unless we need to
- // continue execution to exit guest mode (the case above).
- mask := bounceSignalMask
- if enableOthers {
- mask |= otherSignalsMask
- }
- if c.signalMask == mask {
- return // Already set.
- }
-
- // The layout of this structure implies that it will not necessarily be
- // the same layout chosen by the Go compiler. It gets fudged here.
- var data struct {
- length uint32
- mask1 uint32
- mask2 uint32
- _ uint32
- }
- data.length = 8 // Fixed sigset size.
- data.mask1 = ^uint32(mask & 0xffffffff)
- data.mask2 = ^uint32(mask >> 32)
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_SET_SIGNAL_MASK,
- uintptr(unsafe.Pointer(&data))); errno != 0 {
- throw("setSignal mask failed")
- }
-}
-
// atomicAddressSpace is an atomic address space pointer.
type atomicAddressSpace struct {
pointer unsafe.Pointer
diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD
index 934b6fbcd..b69520030 100644
--- a/pkg/sentry/platform/ring0/BUILD
+++ b/pkg/sentry/platform/ring0/BUILD
@@ -72,11 +72,13 @@ go_library(
"lib_amd64.s",
"lib_arm64.go",
"lib_arm64.s",
+ "lib_arm64_unsafe.go",
"ring0.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/cpuid",
+ "//pkg/safecopy",
"//pkg/sentry/platform/ring0/pagetables",
"//pkg/usermem",
],
diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go
index af075aae4..242b9305c 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.go
+++ b/pkg/sentry/platform/ring0/lib_arm64.go
@@ -37,3 +37,10 @@ func SaveVRegs(*byte)
// LoadVRegs loads V0-V31 registers.
func LoadVRegs(*byte)
+
+// Init sets function pointers based on architectural features.
+//
+// This must be called prior to using ring0.
+func Init() {
+ rewriteVectors()
+}
diff --git a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
new file mode 100644
index 000000000..c05166fea
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
@@ -0,0 +1,108 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+import (
+ "reflect"
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/safecopy"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ nopInstruction = 0xd503201f
+ instSize = unsafe.Sizeof(uint32(0))
+ vectorsRawLen = 0x800
+)
+
+func unsafeSlice(addr uintptr, length int) (slice []uint32) {
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
+ hdr.Data = addr
+ hdr.Len = length / int(instSize)
+ hdr.Cap = length / int(instSize)
+ return slice
+}
+
+// Work around: move ring0.Vectors() into a specific address with 11-bits alignment.
+//
+// According to the design documentation of Arm64,
+// the start address of exception vector table should be 11-bits aligned.
+// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S
+// But, we can't align a function's start address to a specific address by using golang.
+// We have raised this question in golang community:
+// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I
+// This function will be removed when golang supports this feature.
+//
+// There are 2 jobs were implemented in this function:
+// 1, move the start address of exception vector table into the specific address.
+// 2, modify the offset of each instruction.
+func rewriteVectors() {
+ vectorsBegin := reflect.ValueOf(Vectors).Pointer()
+
+ // The exception-vector-table is required to be 11-bits aligned.
+ // And the size is 0x800.
+ // Please see the documentation as reference:
+ // https://developer.arm.com/docs/100933/0100/aarch64-exception-vector-table
+ //
+ // But, golang does not allow to set a function's address to a specific value.
+ // So, for gvisor, I defined the size of exception-vector-table as 4K,
+ // filled the 2nd 2K part with NOP-s.
+ // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
+ //
+ // So, the prerequisite for this function to work correctly is:
+ // vectorsSafeLen >= 0x1000
+ // vectorsRawLen = 0x800
+ vectorsSafeLen := int(safecopy.FindEndAddress(vectorsBegin) - vectorsBegin)
+ if vectorsSafeLen < 2*vectorsRawLen {
+ panic("Can't update vectors")
+ }
+
+ vectorsSafeTable := unsafeSlice(vectorsBegin, vectorsSafeLen) // Now a []uint32
+ vectorsRawLen32 := vectorsRawLen / int(instSize)
+
+ offset := vectorsBegin & (1<<11 - 1)
+ if offset != 0 {
+ offset = 1<<11 - offset
+ }
+
+ pageBegin := (vectorsBegin + offset) & ^uintptr(usermem.PageSize-1)
+
+ _, _, errno := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC))
+ if errno != 0 {
+ panic(errno.Error())
+ }
+
+ offset = offset / instSize // By index, not bytes.
+ // Move exception-vector-table into the specific address, should uses memmove here.
+ for i := 1; i <= vectorsRawLen32; i++ {
+ vectorsSafeTable[int(offset)+vectorsRawLen32-i] = vectorsSafeTable[vectorsRawLen32-i]
+ }
+
+ // Adjust branch since instruction was moved forward.
+ for i := 0; i < vectorsRawLen32; i++ {
+ if vectorsSafeTable[int(offset)+i] != nopInstruction {
+ vectorsSafeTable[int(offset)+i] -= uint32(offset)
+ }
+ }
+
+ _, _, errno = syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_EXEC))
+ if errno != 0 {
+ panic(errno.Error())
+ }
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index 581841555..16d5f478b 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -81,6 +81,9 @@ go_library(
"pagetables_arm64.go",
"pagetables_x86.go",
"pcids.go",
+ "pcids_aarch64.go",
+ "pcids_aarch64.s",
+ "pcids_x86.go",
"walker_amd64.go",
"walker_arm64.go",
"walker_empty.go",
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go b/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go
index dcf061df9..157438d9b 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build 386 amd64
package pagetables
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids.go b/pkg/sentry/platform/ring0/pagetables/pcids.go
index 9206030bf..964496aac 100644
--- a/pkg/sentry/platform/ring0/pagetables/pcids.go
+++ b/pkg/sentry/platform/ring0/pagetables/pcids.go
@@ -18,9 +18,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
)
-// limitPCID is the number of valid PCIDs.
-const limitPCID = 4096
-
// PCIDs is a simple PCID database.
//
// This is not protected by locks and is thus suitable for use only with a
@@ -44,7 +41,7 @@ type PCIDs struct {
//
// Nil is returned iff the start and size are out of range.
func NewPCIDs(start, size uint16) *PCIDs {
- if start+uint16(size) >= limitPCID {
+ if start+uint16(size) > limitPCID {
return nil // See comment.
}
p := &PCIDs{
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go
new file mode 100644
index 000000000..fbfd41d83
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package pagetables
+
+// limitPCID is the maximum value of PCIDs.
+//
+// In VMSAv8-64, the PCID(ASID) size is an IMPLEMENTATION DEFINED choice
+// of 8 bits or 16 bits, and ID_AA64MMFR0_EL1.ASIDBits identifies the
+// supported size. When an implementation supports a 16-bit ASID, TCR_ELx.AS
+// selects whether the top 8 bits of the ASID are used.
+var limitPCID uint16
+
+// GetASIDBits return the system ASID bits, 8 or 16 bits.
+func GetASIDBits() uint8
+
+func init() {
+ limitPCID = uint16(1)<<GetASIDBits() - 1
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s b/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s
new file mode 100644
index 000000000..e9d62d768
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s
@@ -0,0 +1,45 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+#include "funcdata.h"
+#include "textflag.h"
+
+#define ID_AA64MMFR0_ASIDBITS_SHIFT 4
+#define ID_AA64MMFR0_ASIDBITS_16 2
+#define TCR_EL1_AS_BIT 36
+
+// GetASIDBits return the system ASID bits, 8 or 16 bits.
+//
+// func GetASIDBits() uint8
+TEXT ·GetASIDBits(SB),NOSPLIT,$0-1
+ // First, check whether 16bits ASID is supported.
+ // ID_AA64MMFR0_EL1.ASIDBITS[7:4] == 0010.
+ WORD $0xd5380700 // MRS ID_AA64MMFR0_EL1, R0
+ UBFX $ID_AA64MMFR0_ASIDBITS_SHIFT, R0, $4, R0
+ CMPW $ID_AA64MMFR0_ASIDBITS_16, R0
+ BNE bits_8
+
+ // Second, check whether 16bits ASID is enabled.
+ // TCR_EL1.AS[36] == 1.
+ WORD $0xd5382040 // MRS TCR_EL1, R0
+ TBZ $TCR_EL1_AS_BIT, R0, bits_8
+ MOVD $16, R0
+ B done
+bits_8:
+ MOVD $8, R0
+done:
+ MOVB R0, ret+0(FP)
+ RET
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
new file mode 100644
index 000000000..91fc5e8dd
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
@@ -0,0 +1,20 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build i386 amd64
+
+package pagetables
+
+// limitPCID is the maximum value of valid PCIDs.
+const limitPCID = 4095
diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go
index 5f80d64e8..9da0ea685 100644
--- a/pkg/sentry/platform/ring0/x86.go
+++ b/pkg/sentry/platform/ring0/x86.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build 386 amd64
package ring0
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 611fa22c3..c40c6d673 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
"//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/usermem",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index f14c336b9..7ac38764d 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -300,7 +300,7 @@ type SocketOperations struct {
// New creates a new endpoint socket.
func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil {
+ if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
}
@@ -535,7 +535,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
}
if resCh != nil {
- t := ctx.(*kernel.Task)
+ t := kernel.TaskFromContext(ctx)
if err := t.Block(resCh); err != nil {
return 0, syserr.FromError(err).ToError()
}
@@ -608,7 +608,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader
}
if resCh != nil {
- t := ctx.(*kernel.Task)
+ t := kernel.TaskFromContext(ctx)
if err := t.Block(resCh); err != nil {
return 0, syserr.FromError(err).ToError()
}
@@ -663,7 +663,7 @@ func (s *SocketOperations) checkFamily(family uint16, exact bool) *syserr.Error
// This is a hack to work around the fact that both IPv4 and IPv6 ANY are
// represented by the empty string.
//
-// TODO(gvisor.dev/issues/1556): remove this function.
+// TODO(gvisor.dev/issue/1556): remove this function.
func (s *SocketOperations) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress {
if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET {
addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
@@ -940,7 +940,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -965,8 +965,15 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int,
return nil, syserr.ErrProtocolNotAvailable
}
+func boolToInt32(v bool) int32 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -998,12 +1005,11 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.PasscredOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.PasscredOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_SNDBUF:
if outLen < sizeOfInt32 {
@@ -1042,24 +1048,22 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.ReuseAddressOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.ReusePortOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.ReusePortOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
@@ -1089,24 +1093,22 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.BroadcastOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.BroadcastOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_KEEPALIVE:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.KeepaliveEnabledOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_LINGER:
if outLen < linux.SizeOfLinger {
@@ -1156,47 +1158,41 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptInt(tcpip.DelayOption)
+ v, err := ep.GetSockOptBool(tcpip.DelayOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- if v == 0 {
- return int32(1), nil
- }
- return int32(0), nil
+ return boolToInt32(!v), nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.CorkOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.CorkOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.QuickAckOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.QuickAckOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MaxSegOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.MaxSegOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1328,11 +1324,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- var o int32
- if v {
- o = 1
- }
- return o, nil
+ return boolToInt32(v), nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1342,8 +1334,8 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if outLen == 0 {
return make([]byte, 0), nil
}
- var v tcpip.IPv6TrafficClassOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1365,12 +1357,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- var o int32
- if v {
- o = 1
- }
- return o, nil
+ return boolToInt32(v), nil
default:
emitUnimplementedEventIPv6(t, name)
@@ -1386,8 +1373,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.TTLOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.TTLOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1403,8 +1390,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MulticastTTLOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.MulticastTTLOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1429,23 +1416,19 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MulticastLoopOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- if v {
- return int32(1), nil
- }
- return int32(0), nil
+ return boolToInt32(v), nil
case linux.IP_TOS:
// Length handling for parity with Linux.
if outLen == 0 {
return []byte(nil), nil
}
- var v tcpip.IPv4TOSOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
if outLen < sizeOfInt32 {
@@ -1462,11 +1445,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- var o int32
- if v {
- o = 1
- }
- return o, nil
+ return boolToInt32(v), nil
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
@@ -1477,11 +1456,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- var o int32
- if v {
- o = 1
- }
- return o, nil
+ return boolToInt32(v), nil
default:
emitUnimplementedEventIP(t, name)
@@ -1541,7 +1516,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
// SetSockOpt can be used to implement the linux syscall setsockopt(2) for
// sockets backed by a commonEndpoint.
-func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error {
+func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error {
switch level {
case linux.SOL_SOCKET:
return setSockOptSocket(t, s, ep, name, optVal)
@@ -1568,7 +1543,7 @@ func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, n
}
// setSockOptSocket implements SetSockOpt when level is SOL_SOCKET.
-func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
switch name {
case linux.SO_SNDBUF:
if len(optVal) < sizeOfInt32 {
@@ -1592,7 +1567,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0))
case linux.SO_REUSEPORT:
if len(optVal) < sizeOfInt32 {
@@ -1600,7 +1575,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0))
case linux.SO_BINDTODEVICE:
n := bytes.IndexByte(optVal, 0)
@@ -1628,7 +1603,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BroadcastOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0))
case linux.SO_PASSCRED:
if len(optVal) < sizeOfInt32 {
@@ -1636,7 +1611,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0))
case linux.SO_KEEPALIVE:
if len(optVal) < sizeOfInt32 {
@@ -1644,7 +1619,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveEnabledOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0))
case linux.SO_SNDTIMEO:
if len(optVal) < linux.SizeOfTimeval {
@@ -1716,11 +1691,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- var o int
- if v == 0 {
- o = 1
- }
- return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0))
case linux.TCP_CORK:
if len(optVal) < sizeOfInt32 {
@@ -1728,7 +1699,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.CorkOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0))
case linux.TCP_QUICKACK:
if len(optVal) < sizeOfInt32 {
@@ -1736,7 +1707,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0))
case linux.TCP_MAXSEG:
if len(optVal) < sizeOfInt32 {
@@ -1744,7 +1715,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v)))
case linux.TCP_KEEPIDLE:
if len(optVal) < sizeOfInt32 {
@@ -1855,7 +1826,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
if v == -1 {
v = 0
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, int(v)))
case linux.IPV6_RECVTCLASS:
v, err := parseIntOrChar(optVal)
@@ -1940,7 +1911,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
if v < 0 || v > 255 {
return syserr.ErrInvalidArgument
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastTTLOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MulticastTTLOption, int(v)))
case linux.IP_ADD_MEMBERSHIP:
req, err := copyInMulticastRequest(optVal, false /* allowAddr */)
@@ -1987,9 +1958,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(
- tcpip.MulticastLoopOption(v != 0),
- ))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0))
case linux.MCAST_JOIN_GROUP:
// FIXME(b/124219304): Implement MCAST_JOIN_GROUP.
@@ -2008,7 +1977,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
} else if v < 1 || v > 255 {
return syserr.ErrInvalidArgument
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TTLOption, int(v)))
case linux.IP_TOS:
if len(optVal) == 0 {
@@ -2018,7 +1987,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv4TOSOption, int(v)))
case linux.IP_RECVTOS:
v, err := parseIntOrChar(optVal)
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index eb090e79b..c3f04b613 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -62,10 +62,6 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
}
case linux.SOCK_RAW:
- // TODO(b/142504697): "In order to create a raw socket, a
- // process must have the CAP_NET_RAW capability in the user
- // namespace that governs its network namespace." - raw(7)
-
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
@@ -141,10 +137,6 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
}
func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
- // TODO(b/142504697): "In order to create a packet socket, a process
- // must have the CAP_NET_RAW capability in the user namespace that
- // governs its network namespace." - packet(7)
-
// Packet sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(t)
if !creds.HasCapability(linux.CAP_NET_RAW) {
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 50d9744e6..6580bd6e9 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
@@ -48,11 +49,25 @@ func (c *ControlMessages) Release() {
c.Unix.Release()
}
-// Socket is the interface containing socket syscalls used by the syscall layer
-// to redirect them to the appropriate implementation.
+// Socket is an interface combining fs.FileOperations and SocketOps,
+// representing a VFS1 socket file.
type Socket interface {
fs.FileOperations
+ SocketOps
+}
+
+// SocketVFS2 is an interface combining vfs.FileDescription and SocketOps,
+// representing a VFS2 socket file.
+type SocketVFS2 interface {
+ vfs.FileDescriptionImpl
+ SocketOps
+}
+// SocketOps is the interface containing socket syscalls used by the syscall
+// layer to redirect them to the appropriate implementation.
+//
+// It is implemented by both Socket and SocketVFS2.
+type SocketOps interface {
// Connect implements the connect(2) linux syscall.
Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error
@@ -153,6 +168,8 @@ var families = make(map[int][]Provider)
// RegisterProvider registers the provider of a given address family so that
// sockets of that type can be created via socket() and/or socketpair()
// syscalls.
+//
+// This should only be called during the initialization of the address family.
func RegisterProvider(family int, provider Provider) {
families[family] = append(families[family], provider)
}
@@ -216,6 +233,74 @@ func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent {
return fs.NewDirent(ctx, inode, fmt.Sprintf("socket:[%d]", ino))
}
+// ProviderVFS2 is the vfs2 interface implemented by providers of sockets for
+// specific address families (e.g., AF_INET).
+type ProviderVFS2 interface {
+ // Socket creates a new socket.
+ //
+ // If a nil Socket _and_ a nil error is returned, it means that the
+ // protocol is not supported. A non-nil error should only be returned
+ // if the protocol is supported, but an error occurs during creation.
+ Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error)
+
+ // Pair creates a pair of connected sockets.
+ //
+ // See Socket for error information.
+ Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error)
+}
+
+// familiesVFS2 holds a map of all known address families and their providers.
+var familiesVFS2 = make(map[int][]ProviderVFS2)
+
+// RegisterProviderVFS2 registers the provider of a given address family so that
+// sockets of that type can be created via socket() and/or socketpair()
+// syscalls.
+//
+// This should only be called during the initialization of the address family.
+func RegisterProviderVFS2(family int, provider ProviderVFS2) {
+ familiesVFS2[family] = append(familiesVFS2[family], provider)
+}
+
+// NewVFS2 creates a new socket with the given family, type and protocol.
+func NewVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ for _, p := range familiesVFS2[family] {
+ s, err := p.Socket(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+ if s != nil {
+ t.Kernel().RecordSocketVFS2(s)
+ return s, nil
+ }
+ }
+
+ return nil, syserr.ErrAddressFamilyNotSupported
+}
+
+// PairVFS2 creates a new connected socket pair with the given family, type and
+// protocol.
+func PairVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ providers, ok := familiesVFS2[family]
+ if !ok {
+ return nil, nil, syserr.ErrAddressFamilyNotSupported
+ }
+
+ for _, p := range providers {
+ s1, s2, err := p.Pair(t, stype, protocol)
+ if err != nil {
+ return nil, nil, err
+ }
+ if s1 != nil && s2 != nil {
+ k := t.Kernel()
+ k.RecordSocketVFS2(s1)
+ k.RecordSocketVFS2(s2)
+ return s1, s2, nil
+ }
+ }
+
+ return nil, nil, syserr.ErrSocketNotSupported
+}
+
// SendReceiveTimeout stores timeouts for send and receive calls.
//
// It is meant to be embedded into Socket implementations to help satisfy the
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index 08743deba..de2cc4bdf 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -8,23 +8,27 @@ go_library(
"device.go",
"io.go",
"unix.go",
+ "unix_vfs2.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/fspath",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
"//pkg/sentry/socket",
"//pkg/sentry/socket/control",
"//pkg/sentry/socket/netstack",
"//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 74bcd6300..c708b6030 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/ilist",
+ "//pkg/log",
"//pkg/refs",
"//pkg/sync",
"//pkg/syserr",
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 2ef654235..2f1b127df 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -838,24 +839,43 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
// SetSockOpt sets a socket option. Currently not supported.
func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
- case tcpip.PasscredOption:
- e.setPasscred(v != 0)
- return nil
- }
return nil
}
func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.BroadcastOption:
+ case tcpip.PasscredOption:
+ e.setPasscred(v)
+ case tcpip.ReuseAddressOption:
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ }
return nil
}
func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ }
return nil
}
func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ case tcpip.PasscredOption:
+ return e.Passcred(), nil
+
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ return false, tcpip.ErrUnknownProtocolOption
+ }
}
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
@@ -914,29 +934,19 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return int(v), nil
default:
+ log.Warningf("Unsupported socket option: %d", opt)
return -1, tcpip.ErrUnknownProtocolOption
}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
+ switch opt.(type) {
case tcpip.ErrorOption:
return nil
- case *tcpip.PasscredOption:
- if e.Passcred() {
- *o = tcpip.PasscredOption(1)
- } else {
- *o = tcpip.PasscredOption(0)
- }
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
default:
+ log.Warningf("Unsupported socket option: %T", opt)
return tcpip.ErrUnknownProtocolOption
}
}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 4d30aa714..7c64f30fa 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -33,6 +34,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -52,11 +54,8 @@ type SocketOperations struct {
fsutil.FileNoSplice `state:"nosave"`
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- refs.AtomicRefCount
- socket.SendReceiveTimeout
- ep transport.Endpoint
- stype linux.SockType
+ socketOpsCommon
}
// New creates a new unix socket.
@@ -75,16 +74,29 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
}
s := SocketOperations{
- ep: ep,
- stype: stype,
+ socketOpsCommon: socketOpsCommon{
+ ep: ep,
+ stype: stype,
+ },
}
s.EnableLeakCheck("unix.SocketOperations")
return fs.NewFile(ctx, d, flags, &s)
}
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
+ refs.AtomicRefCount
+ socket.SendReceiveTimeout
+
+ ep transport.Endpoint
+ stype linux.SockType
+}
+
// DecRef implements RefCounter.DecRef.
-func (s *SocketOperations) DecRef() {
+func (s *socketOpsCommon) DecRef() {
s.DecRefWithDestructor(func() {
s.ep.Close()
})
@@ -97,7 +109,7 @@ func (s *SocketOperations) Release() {
s.DecRef()
}
-func (s *SocketOperations) isPacket() bool {
+func (s *socketOpsCommon) isPacket() bool {
switch s.stype {
case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
return true
@@ -110,7 +122,7 @@ func (s *SocketOperations) isPacket() bool {
}
// Endpoint extracts the transport.Endpoint.
-func (s *SocketOperations) Endpoint() transport.Endpoint {
+func (s *socketOpsCommon) Endpoint() transport.Endpoint {
return s.ep
}
@@ -143,7 +155,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) {
// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.ep.GetRemoteAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -155,7 +167,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32,
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.ep.GetLocalAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -178,7 +190,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// Listen implements the linux syscall listen(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
return s.ep.Listen(backlog)
}
@@ -310,6 +322,8 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
// Create the socket.
+ //
+ // TODO(gvisor.dev/issue/2324): Correctly set file permissions.
childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}})
if err != nil {
return syserr.ErrPortInUse
@@ -345,6 +359,31 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint,
return ep, nil
}
+ if kernel.VFS2Enabled {
+ p := fspath.Parse(path)
+ root := t.FSContext().RootDirectoryVFS2()
+ start := root
+ relPath := !p.Absolute
+ if relPath {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ }
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: p,
+ FollowFinalSymlink: true,
+ }
+ ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop)
+ root.DecRef()
+ if relPath {
+ start.DecRef()
+ }
+ if e != nil {
+ return nil, syserr.FromError(e)
+ }
+ return ep, nil
+ }
+
// Find the node in the filesystem.
root := t.FSContext().RootDirectory()
cwd := t.FSContext().WorkingDirectory()
@@ -363,12 +402,11 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint,
// No socket!
return nil, syserr.ErrConnectionRefused
}
-
return ep, nil
}
// Connect implements the linux syscall connect(2) for unix sockets.
-func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
ep, err := extractEndpoint(t, sockaddr)
if err != nil {
return err
@@ -379,7 +417,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
return s.ep.Connect(t, ep)
}
-// Writev implements fs.FileOperations.Write.
+// Write implements fs.FileOperations.Write.
func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
t := kernel.TaskFromContext(ctx)
ctrl := control.New(t, s.ep, nil)
@@ -399,7 +437,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
// SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
w := EndpointWriter{
Ctx: t,
Endpoint: s.ep,
@@ -453,27 +491,27 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
}
// Passcred implements transport.Credentialer.Passcred.
-func (s *SocketOperations) Passcred() bool {
+func (s *socketOpsCommon) Passcred() bool {
return s.ep.Passcred()
}
// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
-func (s *SocketOperations) ConnectedPasscred() bool {
+func (s *socketOpsCommon) ConnectedPasscred() bool {
return s.ep.ConnectedPasscred()
}
// Readiness implements waiter.Waitable.Readiness.
-func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.ep.Readiness(mask)
}
// EventRegister implements waiter.Waitable.EventRegister.
-func (s *SocketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
s.ep.EventRegister(e, mask)
}
// EventUnregister implements waiter.Waitable.EventUnregister.
-func (s *SocketOperations) EventUnregister(e *waiter.Entry) {
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
s.ep.EventUnregister(e)
}
@@ -485,7 +523,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
// Shutdown implements the linux syscall shutdown(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
f, err := netstack.ConvertShutdown(how)
if err != nil {
return err
@@ -511,7 +549,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -648,12 +686,12 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
// State implements socket.Socket.State.
-func (s *SocketOperations) State() uint32 {
+func (s *socketOpsCommon) State() uint32 {
return s.ep.State()
}
// Type implements socket.Socket.Type.
-func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) {
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
// Unix domain sockets always have a protocol of 0.
return linux.AF_UNIX, s.stype, 0
}
@@ -706,4 +744,5 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F
func init() {
socket.RegisterProvider(linux.AF_UNIX, &provider{})
+ socket.RegisterProviderVFS2(linux.AF_UNIX, &providerVFS2{})
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
new file mode 100644
index 000000000..3e54d49c4
--- /dev/null
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -0,0 +1,348 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SocketVFS2 implements socket.SocketVFS2 (and by extension,
+// vfs.FileDescriptionImpl) for Unix sockets.
+type SocketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+
+ socketOpsCommon
+}
+
+// NewVFS2File creates and returns a new vfs.FileDescription for a unix socket.
+func NewVFS2File(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
+ sock := NewFDImpl(ep, stype)
+ vfsfd := &sock.vfsfd
+ if err := sockfs.InitSocket(sock, vfsfd, t.Kernel().SocketMount(), t.Credentials()); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return vfsfd, nil
+}
+
+// NewFDImpl creates and returns a new SocketVFS2.
+func NewFDImpl(ep transport.Endpoint, stype linux.SockType) *SocketVFS2 {
+ // You can create AF_UNIX, SOCK_RAW sockets. They're the same as
+ // SOCK_DGRAM and don't require CAP_NET_RAW.
+ if stype == linux.SOCK_RAW {
+ stype = linux.SOCK_DGRAM
+ }
+
+ return &SocketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ ep: ep,
+ stype: stype,
+ },
+ }
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.socketOpsCommon.EventRegister(&e, waiter.EventIn)
+ defer s.socketOpsCommon.EventUnregister(&e)
+
+ // Try to accept the connection; if it fails, then wait until we get a
+ // notification.
+ for {
+ if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ return ep, err
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, err := s.ep.Accept()
+ if err != nil {
+ if err != syserr.ErrWouldBlock || !blocking {
+ return 0, nil, 0, err
+ }
+
+ var err *syserr.Error
+ ep, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ // We expect this to be a FileDescription here.
+ ns, err := NewVFS2File(t, ep, s.stype)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK)
+ }
+
+ var addr linux.SockAddr
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer.
+ var err *syserr.Error
+ addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ })
+ if e != nil {
+ return 0, nil, 0, syserr.FromError(e)
+ }
+
+ t.Kernel().RecordSocketVFS2(ns)
+ return fd, addr, addrLen, nil
+}
+
+// Bind implements the linux syscall bind(2) for unix sockets.
+func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ p, e := extractPath(sockaddr)
+ if e != nil {
+ return e
+ }
+
+ bep, ok := s.ep.(transport.BoundEndpoint)
+ if !ok {
+ // This socket can't be bound.
+ return syserr.ErrInvalidArgument
+ }
+
+ return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error {
+ // Is it abstract?
+ if p[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return syserr.ErrInvalidEndpointState
+ }
+ if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil {
+ // syserr.ErrPortInUse corresponds to EADDRINUSE.
+ return syserr.ErrPortInUse
+ }
+ } else {
+ path := fspath.Parse(p)
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ relPath := !path.Absolute
+ if relPath {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ }
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ }
+ err := t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{
+ // TODO(gvisor.dev/issue/2324): The file permissions should be taken
+ // from s and t.FSContext().Umask() (see net/unix/af_unix.c:unix_bind),
+ // but VFS1 just always uses 0400. Resolve this inconsistency.
+ Mode: linux.S_IFSOCK | 0400,
+ Endpoint: bep,
+ })
+ if err == syserror.EEXIST {
+ return syserr.ErrAddressInUse
+ }
+ return syserr.FromError(err)
+ }
+
+ return nil
+ })
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return netstack.Ioctl(ctx, s.ep, uio, args)
+}
+
+// PRead implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/1476): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &EndpointReader{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ NumRights: 0,
+ Peek: false,
+ From: nil,
+ })
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/1476): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ t := kernel.TaskFromContext(ctx)
+ ctrl := control.New(t, s.ep, nil)
+
+ if src.NumBytes() == 0 {
+ nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil)
+ return int64(nInt), err.ToError()
+ }
+
+ return src.CopyInTo(ctx, &EndpointWriter{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ Control: ctrl,
+ To: nil,
+ })
+}
+
+// Release implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Release() {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef()
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
+}
+
+// providerVFS2 is a unix domain socket provider for VFS2.
+type providerVFS2 struct{}
+
+func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // Create the endpoint and socket.
+ var ep transport.Endpoint
+ switch stype {
+ case linux.SOCK_DGRAM, linux.SOCK_RAW:
+ ep = transport.NewConnectionless(t)
+ case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
+ ep = transport.NewConnectioned(t, stype, t.Kernel())
+ default:
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ f, err := NewVFS2File(t, ep, stype)
+ if err != nil {
+ ep.Close()
+ return nil, err
+ }
+ return f, nil
+}
+
+// Pair creates a new pair of AF_UNIX connected sockets.
+func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, nil, syserr.ErrProtocolNotSupported
+ }
+
+ switch stype {
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW:
+ // Ok
+ default:
+ return nil, nil, syserr.ErrInvalidArgument
+ }
+
+ // Create the endpoints and sockets.
+ ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
+ s1, err := NewVFS2File(t, ep1, stype)
+ if err != nil {
+ ep1.Close()
+ ep2.Close()
+ return nil, nil, err
+ }
+ s2, err := NewVFS2File(t, ep2, stype)
+ if err != nil {
+ s1.DecRef()
+ ep2.Close()
+ return nil, nil, err
+ }
+
+ return s1, s2, nil
+}
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 77655558e..68ca537c8 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -719,7 +719,7 @@ func (s SyscallMap) SyscallEnter(t *kernel.Task, sysno uintptr, args arch.Syscal
// SyscallExit implements kernel.Stracer.SyscallExit. It logs the syscall
// exit trace.
func (s SyscallMap) SyscallExit(context interface{}, t *kernel.Task, sysno, rval uintptr, err error) {
- errno := t.ExtractErrno(err, int(sysno))
+ errno := kernel.ExtractErrno(err, int(sysno))
c := context.(*syscallContext)
elapsed := time.Since(c.start)
@@ -778,9 +778,6 @@ func (s SyscallMap) Name(sysno uintptr) string {
//
// N.B. This is not in an init function because we can't be sure all syscall
// tables are registered with the kernel when init runs.
-//
-// TODO(gvisor.dev/issue/155): remove kernel package dependencies from this
-// package and have the kernel package self-initialize all syscall tables.
func Initialize() {
for _, table := range kernel.SyscallTables() {
// Is this known?
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
index b401978db..d781d6a04 100644
--- a/pkg/sentry/syscalls/linux/sys_aio.go
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -114,14 +114,28 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
func IoDestroy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
id := args[0].Uint64()
- // Destroy the given context.
- if !t.MemoryManager().DestroyAIOContext(t, id) {
+ ctx := t.MemoryManager().DestroyAIOContext(t, id)
+ if ctx == nil {
// Does not exist.
return 0, nil, syserror.EINVAL
}
- // FIXME(fvoznika): Linux blocks until all AIO to the destroyed context is
- // done.
- return 0, nil, nil
+
+ // Drain completed requests amd wait for pending requests until there are no
+ // more.
+ for {
+ ctx.Drain()
+
+ ch := ctx.WaitChannel()
+ if ch == nil {
+ // No more requests, we're done.
+ return 0, nil, nil
+ }
+ // The task cannot be interrupted during the wait. Equivalent to
+ // TASK_UNINTERRUPTIBLE in Linux.
+ t.UninterruptibleSleepStart(true /* deactivate */)
+ <-ch
+ t.UninterruptibleSleepFinish(true /* activate */)
+ }
}
// IoGetevents implements linux syscall io_getevents(2).
@@ -200,13 +214,13 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
func waitForRequest(ctx *mm.AIOContext, t *kernel.Task, haveDeadline bool, deadline ktime.Time) (interface{}, error) {
for {
if v, ok := ctx.PopRequest(); ok {
- // Request was readly available. Just return it.
+ // Request was readily available. Just return it.
return v, nil
}
// Need to wait for request completion.
- done, active := ctx.WaitChannel()
- if !active {
+ done := ctx.WaitChannel()
+ if done == nil {
// Context has been destroyed.
return nil, syserror.EINVAL
}
@@ -248,6 +262,10 @@ func memoryFor(t *kernel.Task, cb *ioCallback) (usermem.IOSequence, error) {
}
func performCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *ioCallback, ioseq usermem.IOSequence, ctx *mm.AIOContext, eventFile *fs.File) {
+ if ctx.Dead() {
+ ctx.CancelPendingRequest()
+ return
+ }
ev := &ioEvent{
Data: cb.Data,
Obj: uint64(cbAddr),
@@ -272,7 +290,7 @@ func performCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *ioC
// Update the result.
if err != nil {
err = handleIOError(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", file)
- ev.Result = -int64(t.ExtractErrno(err, 0))
+ ev.Result = -int64(kernel.ExtractErrno(err, 0))
}
file.DecRef()
diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go
index 798344042..43c510930 100644
--- a/pkg/sentry/syscalls/linux/sys_pipe.go
+++ b/pkg/sentry/syscalls/linux/sys_pipe.go
@@ -24,6 +24,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// pipe2 implements the actual system call with flags.
func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
@@ -45,10 +47,12 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
}
if _, err := t.CopyOut(addr, fds); err != nil {
- // The files are not closed in this case, the exact semantics
- // of this error case are not well defined, but they could have
- // already been observed by user space.
- return 0, syserror.EFAULT
+ for _, fd := range fds {
+ if file, _ := t.FDTable().Remove(fd); file != nil {
+ file.DecRef()
+ }
+ }
+ return 0, err
}
return 0, nil
}
@@ -69,3 +73,5 @@ func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
n, err := pipe2(t, addr, flags)
return n, nil, err
}
+
+// LINT.ThenChange(vfs2/pipe.go)
diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go
index 9c6728530..f92bf8096 100644
--- a/pkg/sentry/syscalls/linux/sys_prctl.go
+++ b/pkg/sentry/syscalls/linux/sys_prctl.go
@@ -161,8 +161,8 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if args[1].Int() != 1 || args[2].Int() != 0 || args[3].Int() != 0 || args[4].Int() != 0 {
return 0, nil, syserror.EINVAL
}
- // no_new_privs is assumed to always be set. See
- // kernel.Task.updateCredsForExec.
+ // PR_SET_NO_NEW_PRIVS is assumed to always be set.
+ // See kernel.Task.updateCredsForExecLocked.
return 0, nil, nil
case linux.PR_GET_NO_NEW_PRIVS:
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index 78a2cb750..071b4bacc 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -96,8 +96,8 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, syserror.EINVAL
}
- // Check that the offset is legitimate.
- if offset < 0 {
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
return 0, nil, syserror.EINVAL
}
@@ -120,8 +120,8 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
defer file.DecRef()
- // Check that the offset is legitimate.
- if offset < 0 {
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
return 0, nil, syserror.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go
index e08c333d6..d5d5b6959 100644
--- a/pkg/sentry/syscalls/linux/sys_rlimit.go
+++ b/pkg/sentry/syscalls/linux/sys_rlimit.go
@@ -197,7 +197,7 @@ func Prlimit64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
// saved set user IDs of the target process must match the real user ID of
// the caller and the real, effective, and saved set group IDs of the
// target process must match the real group ID of the caller."
- if !t.HasCapabilityIn(linux.CAP_SYS_RESOURCE, t.PIDNamespace().UserNamespace()) {
+ if ot != t && !t.HasCapabilityIn(linux.CAP_SYS_RESOURCE, t.PIDNamespace().UserNamespace()) {
cred, tcred := t.Credentials(), ot.Credentials()
if cred.RealKUID != tcred.RealKUID ||
cred.RealKUID != tcred.EffectiveKUID ||
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 2919228d0..0760af77b 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -31,6 +31,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// minListenBacklog is the minimum reasonable backlog for listening sockets.
const minListenBacklog = 8
@@ -244,7 +246,11 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// Copy the file descriptors out.
if _, err := t.CopyOut(socks, fds); err != nil {
- // Note that we don't close files here; see pipe(2) also.
+ for _, fd := range fds {
+ if file, _ := t.FDTable().Remove(fd); file != nil {
+ file.DecRef()
+ }
+ }
return 0, nil, err
}
@@ -1128,3 +1134,5 @@ func SendTo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
n, err := sendTo(t, fd, bufPtr, bufLen, flags, namePtr, nameLen)
return n, nil, err
}
+
+// LINT.ThenChange(./vfs2/socket.go)
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index fd642834b..df0d0f461 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -25,10 +26,15 @@ import (
// doSplice implements a blocking splice operation.
func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonBlocking bool) (int64, error) {
- if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 {
+ log.Infof("NLAC: doSplice opts: %+v", opts)
+ if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 || (opts.SrcStart+opts.Length < 0) {
return 0, syserror.EINVAL
}
+ if opts.Length > int64(kernel.MAX_RW_COUNT) {
+ opts.Length = int64(kernel.MAX_RW_COUNT)
+ }
+
var (
total int64
n int64
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
index a11a87cd1..46ebf27a2 100644
--- a/pkg/sentry/syscalls/linux/sys_stat.go
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -115,7 +115,8 @@ func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) err
return err
}
s := statFromAttrs(t, d.Inode.StableAttr, uattr)
- return s.CopyOut(t, statAddr)
+ _, err = s.CopyOut(t, statAddr)
+ return err
}
// fstat implements fstat for the given *fs.File.
@@ -125,7 +126,8 @@ func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error {
return err
}
s := statFromAttrs(t, f.Dirent.Inode.StableAttr, uattr)
- return s.CopyOut(t, statAddr)
+ _, err = s.CopyOut(t, statAddr)
+ return err
}
// Statx implements linux syscall statx(2).
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index 506ee54ce..6ec0de96e 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -87,8 +87,8 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
defer file.DecRef()
- // Check that the offset is legitimate.
- if offset < 0 {
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
return 0, nil, syserror.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 2eb210014..6ff2d84d2 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -18,34 +18,44 @@ go_library(
"linux64_override_arm64.go",
"mmap.go",
"path.go",
+ "pipe.go",
"poll.go",
"read_write.go",
"setstat.go",
+ "socket.go",
"stat.go",
"stat_amd64.go",
"stat_arm64.go",
"sync.go",
+ "sys_timerfd.go",
"xattr.go",
],
marshal = True,
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/binary",
"//pkg/bits",
"//pkg/fspath",
"//pkg/gohacks",
"//pkg/sentry/arch",
"//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/pipefs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
"//pkg/sentry/loader",
"//pkg/sentry/memmap",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/unix/transport",
"//pkg/sentry/syscalls",
"//pkg/sentry/syscalls/linux",
"//pkg/sentry/vfs",
"//pkg/sync",
+ "//pkg/syserr",
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
index d6cb0e79a..5a938cee2 100644
--- a/pkg/sentry/syscalls/linux/vfs2/epoll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -101,14 +101,14 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
var event linux.EpollEvent
switch op {
case linux.EPOLL_CTL_ADD:
- if err := event.CopyIn(t, eventAddr); err != nil {
+ if _, err := event.CopyIn(t, eventAddr); err != nil {
return 0, nil, err
}
return 0, nil, ep.AddInterest(file, fd, event)
case linux.EPOLL_CTL_DEL:
return 0, nil, ep.DeleteInterest(file, fd)
case linux.EPOLL_CTL_MOD:
- if err := event.CopyIn(t, eventAddr); err != nil {
+ if _, err := event.CopyIn(t, eventAddr); err != nil {
return 0, nil, err
}
return 0, nil, ep.ModifyInterest(file, fd, event)
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index 3afcea665..8181d80f4 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -140,6 +141,22 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return uintptr(file.StatusFlags()), nil, nil
case linux.F_SETFL:
return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint())
+ case linux.F_SETPIPE_SZ:
+ pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
+ if !ok {
+ return 0, nil, syserror.EBADF
+ }
+ n, err := pipefile.SetPipeSize(int64(args[2].Int()))
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+ case linux.F_GETPIPE_SZ:
+ pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
+ if !ok {
+ return 0, nil, syserror.EBADF
+ }
+ return uintptr(pipefile.PipeSize()), nil, nil
default:
// TODO(gvisor.dev/issue/1623): Everything else is not yet supported.
return 0, nil, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
index a859095e2..46d3e189c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -172,7 +172,7 @@ func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mo
defer tpop.Release()
file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{
- Flags: flags,
+ Flags: flags | linux.O_LARGEFILE,
Mode: linux.FileMode(mode & (0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX) &^ t.FSContext().Umask()),
})
if err != nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go
index a61cc5059..62e98817d 100644
--- a/pkg/sentry/syscalls/linux/vfs2/getdents.go
+++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go
@@ -97,6 +97,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
// char d_name[]; /* Filename (null-terminated) */
// };
size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
+ size = (size + 7) &^ 7 // round up to multiple of 8
if size > cb.remaining {
return syserror.EINVAL
}
@@ -106,7 +107,12 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
buf[18] = dirent.Type
copy(buf[19:], dirent.Name)
- buf[size-1] = 0 // NUL terminator
+ // Zero out all remaining bytes in buf, including the NUL terminator
+ // after dirent.Name.
+ bufTail := buf[19+len(dirent.Name):]
+ for i := range bufTail {
+ bufTail[i] = 0
+ }
} else {
// struct linux_dirent {
// unsigned long d_ino; /* Inode number */
@@ -125,6 +131,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width()))
}
size := 8 + 8 + 2 + 1 + 1 + 1 + len(dirent.Name)
+ size = (size + 7) &^ 7 // round up to multiple of sizeof(long)
if size > cb.remaining {
return syserror.EINVAL
}
@@ -133,9 +140,14 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
copy(buf[18:], dirent.Name)
- buf[size-3] = 0 // NUL terminator
- buf[size-2] = 0 // zero padding byte
- buf[size-1] = dirent.Type
+ // Zero out all remaining bytes in buf, including the NUL terminator
+ // after dirent.Name and the zero padding byte between the name and
+ // dirent type.
+ bufTail := buf[18+len(dirent.Name):]
+ for i := range bufTail {
+ bufTail[i] = 0
+ }
+ bufTail[2] = dirent.Type
}
n, err := cb.t.CopyOutBytes(cb.addr, buf)
if err != nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
index 7d220bc20..21eb98444 100644
--- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
+++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
@@ -39,26 +39,27 @@ func Override(table map[uintptr]kernel.Syscall) {
table[19] = syscalls.Supported("readv", Readv)
table[20] = syscalls.Supported("writev", Writev)
table[21] = syscalls.Supported("access", Access)
- delete(table, 22) // pipe
+ table[22] = syscalls.Supported("pipe", Pipe)
table[23] = syscalls.Supported("select", Select)
table[32] = syscalls.Supported("dup", Dup)
table[33] = syscalls.Supported("dup2", Dup2)
delete(table, 40) // sendfile
- delete(table, 41) // socket
- delete(table, 42) // connect
- delete(table, 43) // accept
- delete(table, 44) // sendto
- delete(table, 45) // recvfrom
- delete(table, 46) // sendmsg
- delete(table, 47) // recvmsg
- delete(table, 48) // shutdown
- delete(table, 49) // bind
- delete(table, 50) // listen
- delete(table, 51) // getsockname
- delete(table, 52) // getpeername
- delete(table, 53) // socketpair
- delete(table, 54) // setsockopt
- delete(table, 55) // getsockopt
+ // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2.
+ table[41] = syscalls.PartiallySupported("socket", Socket, "In process of porting socket syscalls to VFS2.", nil)
+ table[42] = syscalls.PartiallySupported("connect", Connect, "In process of porting socket syscalls to VFS2.", nil)
+ table[43] = syscalls.PartiallySupported("accept", Accept, "In process of porting socket syscalls to VFS2.", nil)
+ table[44] = syscalls.PartiallySupported("sendto", SendTo, "In process of porting socket syscalls to VFS2.", nil)
+ table[45] = syscalls.PartiallySupported("recvfrom", RecvFrom, "In process of porting socket syscalls to VFS2.", nil)
+ table[46] = syscalls.PartiallySupported("sendmsg", SendMsg, "In process of porting socket syscalls to VFS2.", nil)
+ table[47] = syscalls.PartiallySupported("recvmsg", RecvMsg, "In process of porting socket syscalls to VFS2.", nil)
+ table[48] = syscalls.PartiallySupported("shutdown", Shutdown, "In process of porting socket syscalls to VFS2.", nil)
+ table[49] = syscalls.PartiallySupported("bind", Bind, "In process of porting socket syscalls to VFS2.", nil)
+ table[50] = syscalls.PartiallySupported("listen", Listen, "In process of porting socket syscalls to VFS2.", nil)
+ table[51] = syscalls.PartiallySupported("getsockname", GetSockName, "In process of porting socket syscalls to VFS2.", nil)
+ table[52] = syscalls.PartiallySupported("getpeername", GetPeerName, "In process of porting socket syscalls to VFS2.", nil)
+ table[53] = syscalls.PartiallySupported("socketpair", SocketPair, "In process of porting socket syscalls to VFS2.", nil)
+ table[54] = syscalls.PartiallySupported("getsockopt", GetSockOpt, "In process of porting socket syscalls to VFS2.", nil)
+ table[55] = syscalls.PartiallySupported("setsockopt", SetSockOpt, "In process of porting socket syscalls to VFS2.", nil)
table[59] = syscalls.Supported("execve", Execve)
table[72] = syscalls.Supported("fcntl", Fcntl)
delete(table, 73) // flock
@@ -139,23 +140,26 @@ func Override(table map[uintptr]kernel.Syscall) {
table[280] = syscalls.Supported("utimensat", Utimensat)
table[281] = syscalls.Supported("epoll_pwait", EpollPwait)
delete(table, 282) // signalfd
- delete(table, 283) // timerfd_create
+ table[283] = syscalls.Supported("timerfd_create", TimerfdCreate)
delete(table, 284) // eventfd
delete(table, 285) // fallocate
- delete(table, 286) // timerfd_settime
- delete(table, 287) // timerfd_gettime
- delete(table, 288) // accept4
+ table[286] = syscalls.Supported("timerfd_settime", TimerfdSettime)
+ table[287] = syscalls.Supported("timerfd_gettime", TimerfdGettime)
+ // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2.
+ table[288] = syscalls.PartiallySupported("accept4", Accept4, "In process of porting socket syscalls to VFS2.", nil)
delete(table, 289) // signalfd4
delete(table, 290) // eventfd2
table[291] = syscalls.Supported("epoll_create1", EpollCreate1)
table[292] = syscalls.Supported("dup3", Dup3)
- delete(table, 293) // pipe2
+ table[293] = syscalls.Supported("pipe2", Pipe2)
delete(table, 294) // inotify_init1
table[295] = syscalls.Supported("preadv", Preadv)
table[296] = syscalls.Supported("pwritev", Pwritev)
- delete(table, 299) // recvmmsg
+ // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2.
+ table[299] = syscalls.PartiallySupported("recvmmsg", RecvMMsg, "In process of porting socket syscalls to VFS2.", nil)
table[306] = syscalls.Supported("syncfs", Syncfs)
- delete(table, 307) // sendmmsg
+ // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2.
+ table[307] = syscalls.PartiallySupported("sendmmsg", SendMMsg, "In process of porting socket syscalls to VFS2.", nil)
table[316] = syscalls.Supported("renameat2", Renameat2)
delete(table, 319) // memfd_create
table[322] = syscalls.Supported("execveat", Execveat)
diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go
new file mode 100644
index 000000000..4a01e4209
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go
@@ -0,0 +1,63 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Pipe implements Linux syscall pipe(2).
+func Pipe(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ return 0, nil, pipe2(t, addr, 0)
+}
+
+// Pipe2 implements Linux syscall pipe2(2).
+func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Int()
+ return 0, nil, pipe2(t, addr, flags)
+}
+
+func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error {
+ if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
+ return syserror.EINVAL
+ }
+ r, w := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK))
+ defer r.DecRef()
+ defer w.DecRef()
+
+ fds, err := t.NewFDsVFS2(0, []*vfs.FileDescription{r, w}, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ if err != nil {
+ return err
+ }
+ if _, err := t.CopyOut(addr, fds); err != nil {
+ for _, fd := range fds {
+ if _, file := t.FDTable().Remove(fd); file != nil {
+ file.DecRef()
+ }
+ }
+ return err
+ }
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
index dbf4882da..ff1b25d7b 100644
--- a/pkg/sentry/syscalls/linux/vfs2/poll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -374,7 +374,8 @@ func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.D
}
remaining := timeoutRemaining(t, startNs, timeout)
tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds())
- return tsRemaining.CopyOut(t, timespecAddr)
+ _, err := tsRemaining.CopyOut(t, timespecAddr)
+ return err
}
// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
@@ -386,7 +387,8 @@ func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Du
}
remaining := timeoutRemaining(t, startNs, timeout)
tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds())
- return tvRemaining.CopyOut(t, timevalAddr)
+ _, err := tvRemaining.CopyOut(t, timevalAddr)
+ return err
}
// pollRestartBlock encapsulates the state required to restart poll(2) via
@@ -477,7 +479,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
timeout := time.Duration(-1)
if timevalAddr != 0 {
var timeval linux.Timeval
- if err := timeval.CopyIn(t, timevalAddr); err != nil {
+ if _, err := timeval.CopyIn(t, timevalAddr); err != nil {
return 0, nil, err
}
if timeval.Sec < 0 || timeval.Usec < 0 {
@@ -519,7 +521,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
panic(fmt.Sprintf("unsupported sizeof(void*): %d", t.Arch().Width()))
}
var maskStruct sigSetWithSize
- if err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil {
+ if _, err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil {
return 0, nil, err
}
if err := setTempSignalSet(t, usermem.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil {
@@ -554,7 +556,7 @@ func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.D
timeout := time.Duration(-1)
if timespecAddr != 0 {
var timespec linux.Timespec
- if err := timespec.CopyIn(t, timespecAddr); err != nil {
+ if _, err := timespec.CopyIn(t, timespecAddr); err != nil {
return 0, err
}
if !timespec.Valid() {
@@ -573,7 +575,7 @@ func setTempSignalSet(t *kernel.Task, maskAddr usermem.Addr, maskSize uint) erro
return syserror.EINVAL
}
var mask linux.SignalSet
- if err := mask.CopyIn(t, maskAddr); err != nil {
+ if _, err := mask.CopyIn(t, maskAddr); err != nil {
return err
}
mask &^= kernel.UnblockableSignals
diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go
index 35f6308d6..6c6998f45 100644
--- a/pkg/sentry/syscalls/linux/vfs2/read_write.go
+++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go
@@ -103,7 +103,7 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt
// Issue the request and break out if it completes with anything other than
// "would block".
- n, err := file.Read(t, dst, opts)
+ n, err = file.Read(t, dst, opts)
total += n
if err != syserror.ErrWouldBlock {
break
@@ -130,8 +130,8 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
defer file.DecRef()
- // Check that the offset is legitimate.
- if offset < 0 {
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
return 0, nil, syserror.EINVAL
}
@@ -248,7 +248,7 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of
// Issue the request and break out if it completes with anything other than
// "would block".
- n, err := file.PRead(t, dst, offset+total, opts)
+ n, err = file.PRead(t, dst, offset+total, opts)
total += n
if err != syserror.ErrWouldBlock {
break
@@ -335,7 +335,7 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op
// Issue the request and break out if it completes with anything other than
// "would block".
- n, err := file.Write(t, src, opts)
+ n, err = file.Write(t, src, opts)
total += n
if err != syserror.ErrWouldBlock {
break
@@ -362,8 +362,8 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
defer file.DecRef()
- // Check that the offset is legitimate.
- if offset < 0 {
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
return 0, nil, syserror.EINVAL
}
@@ -480,7 +480,7 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o
// Issue the request and break out if it completes with anything other than
// "would block".
- n, err := file.PWrite(t, src, offset+total, opts)
+ n, err = file.PWrite(t, src, offset+total, opts)
total += n
if err != syserror.ErrWouldBlock {
break
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index 136453ccc..4e61f1452 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -226,7 +226,7 @@ func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
opts.Stat.Mtime.Nsec = linux.UTIME_NOW
} else {
var times linux.Utime
- if err := times.CopyIn(t, timesAddr); err != nil {
+ if _, err := times.CopyIn(t, timesAddr); err != nil {
return 0, nil, err
}
opts.Stat.Atime.Sec = times.Actime
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
new file mode 100644
index 000000000..b1ede32f0
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -0,0 +1,1139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// minListenBacklog is the minimum reasonable backlog for listening sockets.
+const minListenBacklog = 8
+
+// maxListenBacklog is the maximum allowed backlog for listening sockets.
+const maxListenBacklog = 1024
+
+// maxAddrLen is the maximum socket address length we're willing to accept.
+const maxAddrLen = 200
+
+// maxOptLen is the maximum sockopt parameter length we're willing to accept.
+const maxOptLen = 1024 * 8
+
+// maxControlLen is the maximum length of the msghdr.msg_control buffer we're
+// willing to accept. Note that this limit is smaller than Linux, which allows
+// buffers upto INT_MAX.
+const maxControlLen = 10 * 1024 * 1024
+
+// nameLenOffset is the offset from the start of the MessageHeader64 struct to
+// the NameLen field.
+const nameLenOffset = 8
+
+// controlLenOffset is the offset form the start of the MessageHeader64 struct
+// to the ControlLen field.
+const controlLenOffset = 40
+
+// flagsOffset is the offset form the start of the MessageHeader64 struct
+// to the Flags field.
+const flagsOffset = 48
+
+const sizeOfInt32 = 4
+
+// messageHeader64Len is the length of a MessageHeader64 struct.
+var messageHeader64Len = uint64(binary.Size(MessageHeader64{}))
+
+// multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct.
+var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{}))
+
+// baseRecvFlags are the flags that are accepted across recvmsg(2),
+// recvmmsg(2), and recvfrom(2).
+const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | linux.MSG_NOSIGNAL | linux.MSG_WAITALL | linux.MSG_TRUNC | linux.MSG_CTRUNC
+
+// MessageHeader64 is the 64-bit representation of the msghdr struct used in
+// the recvmsg and sendmsg syscalls.
+type MessageHeader64 struct {
+ // Name is the optional pointer to a network address buffer.
+ Name uint64
+
+ // NameLen is the length of the buffer pointed to by Name.
+ NameLen uint32
+ _ uint32
+
+ // Iov is a pointer to an array of io vectors that describe the memory
+ // locations involved in the io operation.
+ Iov uint64
+
+ // IovLen is the length of the array pointed to by Iov.
+ IovLen uint64
+
+ // Control is the optional pointer to ancillary control data.
+ Control uint64
+
+ // ControlLen is the length of the data pointed to by Control.
+ ControlLen uint64
+
+ // Flags on the sent/received message.
+ Flags int32
+ _ int32
+}
+
+// multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in
+// the recvmmsg and sendmmsg syscalls.
+type multipleMessageHeader64 struct {
+ msgHdr MessageHeader64
+ msgLen uint32
+ _ int32
+}
+
+// CopyInMessageHeader64 copies a message header from user to kernel memory.
+func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error {
+ b := t.CopyScratchBuffer(52)
+ if _, err := t.CopyInBytes(addr, b); err != nil {
+ return err
+ }
+
+ msg.Name = usermem.ByteOrder.Uint64(b[0:])
+ msg.NameLen = usermem.ByteOrder.Uint32(b[8:])
+ msg.Iov = usermem.ByteOrder.Uint64(b[16:])
+ msg.IovLen = usermem.ByteOrder.Uint64(b[24:])
+ msg.Control = usermem.ByteOrder.Uint64(b[32:])
+ msg.ControlLen = usermem.ByteOrder.Uint64(b[40:])
+ msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:]))
+
+ return nil
+}
+
+// CaptureAddress allocates memory for and copies a socket address structure
+// from the untrusted address space range.
+func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) {
+ if addrlen > maxAddrLen {
+ return nil, syserror.EINVAL
+ }
+
+ addrBuf := make([]byte, addrlen)
+ if _, err := t.CopyInBytes(addr, addrBuf); err != nil {
+ return nil, err
+ }
+
+ return addrBuf, nil
+}
+
+// writeAddress writes a sockaddr structure and its length to an output buffer
+// in the unstrusted address space range. If the address is bigger than the
+// buffer, it is truncated.
+func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
+ // Get the buffer length.
+ var bufLen uint32
+ if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil {
+ return err
+ }
+
+ if int32(bufLen) < 0 {
+ return syserror.EINVAL
+ }
+
+ // Write the length unconditionally.
+ if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil {
+ return err
+ }
+
+ if addr == nil {
+ return nil
+ }
+
+ if bufLen > addrLen {
+ bufLen = addrLen
+ }
+
+ // Copy as much of the address as will fit in the buffer.
+ encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr)
+ if bufLen > uint32(len(encodedAddr)) {
+ bufLen = uint32(len(encodedAddr))
+ }
+ _, err := t.CopyOutBytes(addrPtr, encodedAddr[:int(bufLen)])
+ return err
+}
+
+// Socket implements the linux syscall socket(2).
+func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ domain := int(args[0].Int())
+ stype := args[1].Int()
+ protocol := int(args[2].Int())
+
+ // Check and initialize the flags.
+ if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Create the new socket.
+ s, e := socket.NewVFS2(t, domain, linux.SockType(stype&0xf), protocol)
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+ defer s.DecRef()
+
+ if err := s.SetStatusFlags(t, t.Credentials(), uint32(stype&linux.SOCK_NONBLOCK)); err != nil {
+ return 0, nil, err
+ }
+
+ fd, err := t.NewFDFromVFS2(0, s, kernel.FDFlags{
+ CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// SocketPair implements the linux syscall socketpair(2).
+func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ domain := int(args[0].Int())
+ stype := args[1].Int()
+ protocol := int(args[2].Int())
+ addr := args[3].Pointer()
+
+ // Check and initialize the flags.
+ if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Create the socket pair.
+ s1, s2, e := socket.PairVFS2(t, domain, linux.SockType(stype&0xf), protocol)
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+ // Adding to the FD table will cause an extra reference to be acquired.
+ defer s1.DecRef()
+ defer s2.DecRef()
+
+ nonblocking := uint32(stype & linux.SOCK_NONBLOCK)
+ if err := s1.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil {
+ return 0, nil, err
+ }
+ if err := s2.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil {
+ return 0, nil, err
+ }
+
+ // Create the FDs for the sockets.
+ flags := kernel.FDFlags{
+ CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
+ }
+ fds, err := t.NewFDsVFS2(0, []*vfs.FileDescription{s1, s2}, flags)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if _, err := t.CopyOut(addr, fds); err != nil {
+ for _, fd := range fds {
+ if _, file := t.FDTable().Remove(fd); file != nil {
+ file.DecRef()
+ }
+ }
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
+
+// Connect implements the linux syscall connect(2).
+func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Uint()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Capture address and call syscall implementation.
+ a, err := CaptureAddress(t, addr, addrlen)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0
+ return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), kernel.ERESTARTSYS)
+}
+
+// accept is the implementation of the accept syscall. It is called by accept
+// and accept4 syscall handlers.
+func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, flags int) (uintptr, error) {
+ // Check that no unsupported flags are passed in.
+ if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ // Call the syscall implementation for this socket, then copy the
+ // output address if one is specified.
+ blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0
+
+ peerRequested := addrLen != 0
+ nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking)
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+ if peerRequested {
+ // NOTE(magi): Linux does not give you an error if it can't
+ // write the data back out so neither do we.
+ if err := writeAddress(t, peer, peerLen, addr, addrLen); err == syserror.EINVAL {
+ return 0, err
+ }
+ }
+ return uintptr(nfd), nil
+}
+
+// Accept4 implements the linux syscall accept4(2).
+func Accept4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+ flags := int(args[3].Int())
+
+ n, err := accept(t, fd, addr, addrlen, flags)
+ return n, nil, err
+}
+
+// Accept implements the linux syscall accept(2).
+func Accept(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ n, err := accept(t, fd, addr, addrlen, 0)
+ return n, nil, err
+}
+
+// Bind implements the linux syscall bind(2).
+func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Uint()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Capture address and call syscall implementation.
+ a, err := CaptureAddress(t, addr, addrlen)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, s.Bind(t, a).ToError()
+}
+
+// Listen implements the linux syscall listen(2).
+func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ backlog := args[1].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Per Linux, the backlog is silently capped to reasonable values.
+ if backlog <= 0 {
+ backlog = minListenBacklog
+ }
+ if backlog > maxListenBacklog {
+ backlog = maxListenBacklog
+ }
+
+ return 0, nil, s.Listen(t, int(backlog)).ToError()
+}
+
+// Shutdown implements the linux syscall shutdown(2).
+func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ how := args[1].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Validate how, then call syscall implementation.
+ switch how {
+ case linux.SHUT_RD, linux.SHUT_WR, linux.SHUT_RDWR:
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, s.Shutdown(t, int(how)).ToError()
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2).
+func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ level := args[1].Int()
+ name := args[2].Int()
+ optValAddr := args[3].Pointer()
+ optLenAddr := args[4].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Read the length. Reject negative values.
+ optLen := int32(0)
+ if _, err := t.CopyIn(optLenAddr, &optLen); err != nil {
+ return 0, nil, err
+ }
+ if optLen < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Call syscall implementation then copy both value and value len out.
+ v, e := getSockOpt(t, s, int(level), int(name), optValAddr, int(optLen))
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+
+ vLen := int32(binary.Size(v))
+ if _, err := t.CopyOut(optLenAddr, vLen); err != nil {
+ return 0, nil, err
+ }
+
+ if v != nil {
+ if _, err := t.CopyOut(optValAddr, v); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ return 0, nil, nil
+}
+
+// getSockOpt tries to handle common socket options, or dispatches to a specific
+// socket implementation.
+func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+ if level == linux.SOL_SOCKET {
+ switch name {
+ case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
+ if len < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ }
+
+ switch name {
+ case linux.SO_TYPE:
+ _, skType, _ := s.Type()
+ return int32(skType), nil
+ case linux.SO_DOMAIN:
+ family, _, _ := s.Type()
+ return int32(family), nil
+ case linux.SO_PROTOCOL:
+ _, _, protocol := s.Type()
+ return int32(protocol), nil
+ }
+ }
+
+ return s.GetSockOpt(t, level, name, optValAddr, len)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2).
+//
+// Note that unlike Linux, enabling SO_PASSCRED does not autobind the socket.
+func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ level := args[1].Int()
+ name := args[2].Int()
+ optValAddr := args[3].Pointer()
+ optLen := args[4].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ if optLen < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if optLen > maxOptLen {
+ return 0, nil, syserror.EINVAL
+ }
+ buf := t.CopyScratchBuffer(int(optLen))
+ if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ return 0, nil, err
+ }
+
+ // Call syscall implementation.
+ if err := s.SetSockOpt(t, int(level), int(name), buf); err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, nil
+}
+
+// GetSockName implements the linux syscall getsockname(2).
+func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Get the socket name and copy it to the caller.
+ v, vl, err := s.GetSockName(t)
+ if err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, writeAddress(t, v, vl, addr, addrlen)
+}
+
+// GetPeerName implements the linux syscall getpeername(2).
+func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Get the socket peer name and copy it to the caller.
+ v, vl, err := s.GetPeerName(t)
+ if err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, writeAddress(t, v, vl, addr, addrlen)
+}
+
+// RecvMsg implements the linux syscall recvmsg(2).
+func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, err := recvSingleMsg(t, s, msgPtr, flags, haveDeadline, deadline)
+ return n, nil, err
+}
+
+// RecvMMsg implements the linux syscall recvmmsg(2).
+func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ vlen := args[2].Uint()
+ flags := args[3].Int()
+ toPtr := args[4].Pointer()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if toPtr != 0 {
+ var ts linux.Timespec
+ if _, err := ts.CopyIn(t, toPtr); err != nil {
+ return 0, nil, err
+ }
+ if !ts.Valid() {
+ return 0, nil, syserror.EINVAL
+ }
+ deadline = t.Kernel().MonotonicClock().Now().Add(ts.ToDuration())
+ haveDeadline = true
+ }
+
+ if !haveDeadline {
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+ }
+
+ var count uint32
+ var err error
+ for i := uint64(0); i < uint64(vlen); i++ {
+ mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ var n uintptr
+ if n, err = recvSingleMsg(t, s, mp, flags, haveDeadline, deadline); err != nil {
+ break
+ }
+
+ // Copy the received length to the caller.
+ lp, ok := mp.AddLength(messageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ break
+ }
+ count++
+ }
+
+ if count == 0 {
+ return 0, nil, err
+ }
+ return uintptr(count), nil, nil
+}
+
+func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
+ // Capture the message header and io vectors.
+ var msg MessageHeader64
+ if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ return 0, err
+ }
+
+ if msg.IovLen > linux.UIO_MAXIOV {
+ return 0, syserror.EMSGSIZE
+ }
+ dst, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ // FIXME(b/63594852): Pretend we have an empty error queue.
+ if flags&linux.MSG_ERRQUEUE != 0 {
+ return 0, syserror.EAGAIN
+ }
+
+ // Fast path when no control message nor name buffers are provided.
+ if msg.ControlLen == 0 && msg.NameLen == 0 {
+ n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
+ if err != nil {
+ return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS)
+ }
+ if !cms.Unix.Empty() {
+ mflags |= linux.MSG_CTRUNC
+ cms.Release()
+ }
+
+ if int(msg.Flags) != mflags {
+ // Copy out the flags to the caller.
+ if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ return 0, err
+ }
+ }
+
+ return uintptr(n), nil
+ }
+
+ if msg.ControlLen > maxControlLen {
+ return 0, syserror.ENOBUFS
+ }
+ n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen)
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+ defer cms.Release()
+
+ controlData := make([]byte, 0, msg.ControlLen)
+ controlData = control.PackControlMessages(t, cms, controlData)
+
+ if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
+ creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
+ controlData, mflags = control.PackCredentials(t, creds, controlData, mflags)
+ }
+
+ if cms.Unix.Rights != nil {
+ controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags)
+ }
+
+ // Copy the address to the caller.
+ if msg.NameLen != 0 {
+ if err := writeAddress(t, sender, senderLen, usermem.Addr(msg.Name), usermem.Addr(msgPtr+nameLenOffset)); err != nil {
+ return 0, err
+ }
+ }
+
+ // Copy the control data to the caller.
+ if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
+ return 0, err
+ }
+ if len(controlData) > 0 {
+ if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil {
+ return 0, err
+ }
+ }
+
+ // Copy out the flags to the caller.
+ if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ return 0, err
+ }
+
+ return uintptr(n), nil
+}
+
+// recvFrom is the implementation of the recvfrom syscall. It is called by
+// recvfrom and recv syscall handlers.
+func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLenPtr usermem.Addr) (uintptr, error) {
+ if int(bufLen) < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CONFIRM) != 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ dst, err := t.SingleIOSequence(bufPtr, int(bufLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
+ cm.Release()
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+
+ // Copy the address to the caller.
+ if nameLenPtr != 0 {
+ if err := writeAddress(t, sender, senderLen, namePtr, nameLenPtr); err != nil {
+ return 0, err
+ }
+ }
+
+ return uintptr(n), nil
+}
+
+// RecvFrom implements the linux syscall recvfrom(2).
+func RecvFrom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufPtr := args[1].Pointer()
+ bufLen := args[2].Uint64()
+ flags := args[3].Int()
+ namePtr := args[4].Pointer()
+ nameLenPtr := args[5].Pointer()
+
+ n, err := recvFrom(t, fd, bufPtr, bufLen, flags, namePtr, nameLenPtr)
+ return n, nil, err
+}
+
+// SendMsg implements the linux syscall sendmsg(2).
+func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, err := sendSingleMsg(t, s, file, msgPtr, flags)
+ return n, nil, err
+}
+
+// SendMMsg implements the linux syscall sendmmsg(2).
+func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ vlen := args[2].Uint()
+ flags := args[3].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var count uint32
+ var err error
+ for i := uint64(0); i < uint64(vlen); i++ {
+ mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ var n uintptr
+ if n, err = sendSingleMsg(t, s, file, mp, flags); err != nil {
+ break
+ }
+
+ // Copy the received length to the caller.
+ lp, ok := mp.AddLength(messageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ break
+ }
+ count++
+ }
+
+ if count == 0 {
+ return 0, nil, err
+ }
+ return uintptr(count), nil, nil
+}
+
+func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr usermem.Addr, flags int32) (uintptr, error) {
+ // Capture the message header.
+ var msg MessageHeader64
+ if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ return 0, err
+ }
+
+ var controlData []byte
+ if msg.ControlLen > 0 {
+ // Put an upper bound to prevent large allocations.
+ if msg.ControlLen > maxControlLen {
+ return 0, syserror.ENOBUFS
+ }
+ controlData = make([]byte, msg.ControlLen)
+ if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil {
+ return 0, err
+ }
+ }
+
+ // Read the destination address if one is specified.
+ var to []byte
+ if msg.NameLen != 0 {
+ var err error
+ to, err = CaptureAddress(t, usermem.Addr(msg.Name), msg.NameLen)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ // Read data then call the sendmsg implementation.
+ if msg.IovLen > linux.UIO_MAXIOV {
+ return 0, syserror.EMSGSIZE
+ }
+ src, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ controlMessages, err := control.Parse(t, s, controlData)
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.SendTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Call the syscall implementation.
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages)
+ err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file)
+ if err != nil {
+ controlMessages.Release()
+ }
+ return uintptr(n), err
+}
+
+// sendTo is the implementation of the sendto syscall. It is called by sendto
+// and send syscall handlers.
+func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLen uint32) (uintptr, error) {
+ bl := int(bufLen)
+ if bl < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Read the destination address if one is specified.
+ var to []byte
+ var err error
+ if namePtr != 0 {
+ to, err = CaptureAddress(t, namePtr, nameLen)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ src, err := t.SingleIOSequence(bufPtr, bl, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.SendTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Call the syscall implementation.
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)})
+ return uintptr(n), slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendto", file)
+}
+
+// SendTo implements the linux syscall sendto(2).
+func SendTo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufPtr := args[1].Pointer()
+ bufLen := args[2].Uint64()
+ flags := args[3].Int()
+ namePtr := args[4].Pointer()
+ nameLen := args[5].Uint()
+
+ n, err := sendTo(t, fd, bufPtr, bufLen, flags, namePtr, nameLen)
+ return n, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go
index fdfe49243..bb1d5cac4 100644
--- a/pkg/sentry/syscalls/linux/vfs2/stat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/stat.go
@@ -91,7 +91,8 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags
}
var stat linux.Stat
convertStatxToUserStat(t, &statx, &stat)
- return stat.CopyOut(t, statAddr)
+ _, err = stat.CopyOut(t, statAddr)
+ return err
}
start = dirfile.VirtualDentry()
start.IncRef()
@@ -111,7 +112,8 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags
}
var stat linux.Stat
convertStatxToUserStat(t, &statx, &stat)
- return stat.CopyOut(t, statAddr)
+ _, err = stat.CopyOut(t, statAddr)
+ return err
}
func timespecFromStatxTimestamp(sxts linux.StatxTimestamp) linux.Timespec {
@@ -140,7 +142,8 @@ func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
var stat linux.Stat
convertStatxToUserStat(t, &statx, &stat)
- return 0, nil, stat.CopyOut(t, statAddr)
+ _, err = stat.CopyOut(t, statAddr)
+ return 0, nil, err
}
// Statx implements Linux syscall statx(2).
@@ -199,7 +202,8 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
}
userifyStatx(t, &statx)
- return 0, nil, statx.CopyOut(t, statxAddr)
+ _, err = statx.CopyOut(t, statxAddr)
+ return 0, nil, err
}
start = dirfile.VirtualDentry()
start.IncRef()
@@ -218,7 +222,8 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
}
userifyStatx(t, &statx)
- return 0, nil, statx.CopyOut(t, statxAddr)
+ _, err = statx.CopyOut(t, statxAddr)
+ return 0, nil, err
}
func userifyStatx(t *kernel.Task, statx *linux.Statx) {
@@ -359,8 +364,8 @@ func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, err
}
-
- return 0, nil, statfs.CopyOut(t, bufAddr)
+ _, err = statfs.CopyOut(t, bufAddr)
+ return 0, nil, err
}
// Fstatfs implements Linux syscall fstatfs(2).
@@ -378,6 +383,6 @@ func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if err != nil {
return 0, nil, err
}
-
- return 0, nil, statfs.CopyOut(t, bufAddr)
+ _, err = statfs.CopyOut(t, bufAddr)
+ return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sys_timerfd.go b/pkg/sentry/syscalls/linux/vfs2/sys_timerfd.go
new file mode 100644
index 000000000..7938a5249
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/sys_timerfd.go
@@ -0,0 +1,123 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// TimerfdCreate implements Linux syscall timerfd_create(2).
+func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ clockID := args[0].Int()
+ flags := args[1].Int()
+
+ if flags&^(linux.TFD_CLOEXEC|linux.TFD_NONBLOCK) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var fileFlags uint32
+ if flags&linux.TFD_NONBLOCK != 0 {
+ fileFlags = linux.O_NONBLOCK
+ }
+
+ var clock ktime.Clock
+ switch clockID {
+ case linux.CLOCK_REALTIME:
+ clock = t.Kernel().RealtimeClock()
+ case linux.CLOCK_MONOTONIC, linux.CLOCK_BOOTTIME:
+ clock = t.Kernel().MonotonicClock()
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+ file, err := t.Kernel().VFS().NewTimerFD(clock, fileFlags)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.TFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// TimerfdSettime implements Linux syscall timerfd_settime(2).
+func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ flags := args[1].Int()
+ newValAddr := args[2].Pointer()
+ oldValAddr := args[3].Pointer()
+
+ if flags&^(linux.TFD_TIMER_ABSTIME) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ tfd, ok := file.Impl().(*vfs.TimerFileDescription)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var newVal linux.Itimerspec
+ if _, err := t.CopyIn(newValAddr, &newVal); err != nil {
+ return 0, nil, err
+ }
+ newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tfd.Clock())
+ if err != nil {
+ return 0, nil, err
+ }
+ tm, oldS := tfd.SetTime(newS)
+ if oldValAddr != 0 {
+ oldVal := ktime.ItimerspecFromSetting(tm, oldS)
+ if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+}
+
+// TimerfdGettime implements Linux syscall timerfd_gettime(2).
+func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ curValAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ tfd, ok := file.Impl().(*vfs.TimerFileDescription)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ tm, s := tfd.GetTime()
+ curVal := ktime.ItimerspecFromSetting(tm, s)
+ _, err := t.CopyOut(curValAddr, &curVal)
+ return 0, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go
index 89e9ff4d7..af455d5c1 100644
--- a/pkg/sentry/syscalls/linux/vfs2/xattr.go
+++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go
@@ -51,7 +51,7 @@ func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSyml
}
defer tpop.Release()
- names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop)
+ names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop, uint64(size))
if err != nil {
return 0, nil, err
}
@@ -74,7 +74,7 @@ func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
defer file.DecRef()
- names, err := file.Listxattr(t)
+ names, err := file.Listxattr(t, uint64(size))
if err != nil {
return 0, nil, err
}
@@ -116,7 +116,10 @@ func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli
return 0, nil, err
}
- value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, name)
+ value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.GetxattrOptions{
+ Name: name,
+ Size: uint64(size),
+ })
if err != nil {
return 0, nil, err
}
@@ -145,7 +148,7 @@ func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, err
}
- value, err := file.Getxattr(t, name)
+ value, err := file.Getxattr(t, &vfs.GetxattrOptions{Name: name, Size: uint64(size)})
if err != nil {
return 0, nil, err
}
@@ -230,7 +233,7 @@ func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, err
}
- return 0, nil, file.Setxattr(t, vfs.SetxattrOptions{
+ return 0, nil, file.Setxattr(t, &vfs.SetxattrOptions{
Name: name,
Value: value,
Flags: uint32(flags),
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index bf4d27c7d..9aeb83fb0 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -36,6 +36,7 @@ go_library(
"pathname.go",
"permissions.go",
"resolving_path.go",
+ "timerfd.go",
"vfs.go",
],
visibility = ["//pkg/sentry:internal"],
@@ -51,6 +52,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
"//pkg/sentry/memmap",
"//pkg/sentry/socket/unix/transport",
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index d1f6dfb45..a64d86122 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -245,7 +245,7 @@ func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath
}
// ListxattrAt implements FilesystemImpl.ListxattrAt.
-func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) {
+func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) {
if !rp.Done() {
return nil, syserror.ENOTDIR
}
@@ -253,7 +253,7 @@ func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([
}
// GetxattrAt implements FilesystemImpl.GetxattrAt.
-func (fs *anonFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) {
+func (fs *anonFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) {
if !rp.Done() {
return "", syserror.ENOTDIR
}
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 3da45d744..8e0b40841 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -99,6 +99,8 @@ func (vfs *VirtualFilesystem) NewEpollInstanceFD() (*FileDescription, error) {
interest: make(map[epollInterestKey]*epollInterest),
}
if err := ep.vfsfd.Init(ep, linux.O_RDWR, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
UseDentryMetadata: true,
}); err != nil {
return nil, err
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 8ee549dc2..5976b5ccd 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -84,6 +84,17 @@ type FileDescriptionOptions struct {
// usually only the case if O_DIRECT would actually have an effect.
AllowDirectIO bool
+ // If DenyPRead is true, calls to FileDescription.PRead() return ESPIPE.
+ DenyPRead bool
+
+ // If DenyPWrite is true, calls to FileDescription.PWrite() return
+ // ESPIPE.
+ DenyPWrite bool
+
+ // if InvalidWrite is true, calls to FileDescription.Write() return
+ // EINVAL.
+ InvalidWrite bool
+
// If UseDentryMetadata is true, calls to FileDescription methods that
// interact with file and filesystem metadata (Stat, SetStat, StatFS,
// Listxattr, Getxattr, Setxattr, Removexattr) are implemented by calling
@@ -111,7 +122,7 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, statusFlags uint32, mn
}
fd.refs = 1
- fd.statusFlags = statusFlags | linux.O_LARGEFILE
+ fd.statusFlags = statusFlags
fd.vd = VirtualDentry{
mount: mnt,
dentry: d,
@@ -175,6 +186,12 @@ func (fd *FileDescription) DecRef() {
}
}
+// Refs returns the current number of references. The returned count
+// is inherently racy and is unsafe to use without external synchronization.
+func (fd *FileDescription) Refs() int64 {
+ return atomic.LoadInt64(&fd.refs)
+}
+
// Mount returns the mount on which fd was opened. It does not take a reference
// on the returned Mount.
func (fd *FileDescription) Mount() *Mount {
@@ -306,6 +323,7 @@ type FileDescriptionImpl interface {
// - If opts.Flags specifies unsupported options, PRead returns EOPNOTSUPP.
//
// Preconditions: The FileDescription was opened for reading.
+ // FileDescriptionOptions.DenyPRead == false.
PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error)
// Read is similar to PRead, but does not specify an offset.
@@ -337,6 +355,7 @@ type FileDescriptionImpl interface {
// EOPNOTSUPP.
//
// Preconditions: The FileDescription was opened for writing.
+ // FileDescriptionOptions.DenyPWrite == false.
PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error)
// Write is similar to PWrite, but does not specify an offset, which is
@@ -382,11 +401,11 @@ type FileDescriptionImpl interface {
Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error)
// Listxattr returns all extended attribute names for the file.
- Listxattr(ctx context.Context) ([]string, error)
+ Listxattr(ctx context.Context, size uint64) ([]string, error)
// Getxattr returns the value associated with the given extended attribute
// for the file.
- Getxattr(ctx context.Context, name string) (string, error)
+ Getxattr(ctx context.Context, opts GetxattrOptions) (string, error)
// Setxattr changes the value associated with the given extended attribute
// for the file.
@@ -515,6 +534,9 @@ func (fd *FileDescription) EventUnregister(e *waiter.Entry) {
// offset, and returns the number of bytes read. PRead is permitted to return
// partial reads with a nil error.
func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
+ if fd.opts.DenyPRead {
+ return 0, syserror.ESPIPE
+ }
if !fd.readable {
return 0, syserror.EBADF
}
@@ -533,6 +555,9 @@ func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opt
// offset, and returns the number of bytes written. PWrite is permitted to
// return partial writes with a nil error.
func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ if fd.opts.DenyPWrite {
+ return 0, syserror.ESPIPE
+ }
if !fd.writable {
return 0, syserror.EBADF
}
@@ -541,6 +566,9 @@ func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, o
// Write is similar to PWrite, but does not specify an offset.
func (fd *FileDescription) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
+ if fd.opts.InvalidWrite {
+ return 0, syserror.EINVAL
+ }
if !fd.writable {
return 0, syserror.EBADF
}
@@ -577,18 +605,23 @@ func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
// Listxattr returns all extended attribute names for the file represented by
// fd.
-func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) {
+//
+// If the size of the list (including a NUL terminating byte after every entry)
+// would exceed size, ERANGE may be returned. Note that implementations
+// are free to ignore size entirely and return without error). In all cases,
+// if size is 0, the list should be returned without error, regardless of size.
+func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
if fd.opts.UseDentryMetadata {
vfsObj := fd.vd.mount.vfs
rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
Root: fd.vd,
Start: fd.vd,
})
- names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp)
+ names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp, size)
vfsObj.putResolvingPath(rp)
return names, err
}
- names, err := fd.impl.Listxattr(ctx)
+ names, err := fd.impl.Listxattr(ctx, size)
if err == syserror.ENOTSUP {
// Linux doesn't actually return ENOTSUP in this case; instead,
// fs/xattr.c:vfs_listxattr() falls back to allowing the security
@@ -601,34 +634,39 @@ func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) {
// Getxattr returns the value associated with the given extended attribute for
// the file represented by fd.
-func (fd *FileDescription) Getxattr(ctx context.Context, name string) (string, error) {
+//
+// If the size of the return value exceeds opts.Size, ERANGE may be returned
+// (note that implementations are free to ignore opts.Size entirely and return
+// without error). In all cases, if opts.Size is 0, the value should be
+// returned without error, regardless of size.
+func (fd *FileDescription) Getxattr(ctx context.Context, opts *GetxattrOptions) (string, error) {
if fd.opts.UseDentryMetadata {
vfsObj := fd.vd.mount.vfs
rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
Root: fd.vd,
Start: fd.vd,
})
- val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, name)
+ val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, *opts)
vfsObj.putResolvingPath(rp)
return val, err
}
- return fd.impl.Getxattr(ctx, name)
+ return fd.impl.Getxattr(ctx, *opts)
}
// Setxattr changes the value associated with the given extended attribute for
// the file represented by fd.
-func (fd *FileDescription) Setxattr(ctx context.Context, opts SetxattrOptions) error {
+func (fd *FileDescription) Setxattr(ctx context.Context, opts *SetxattrOptions) error {
if fd.opts.UseDentryMetadata {
vfsObj := fd.vd.mount.vfs
rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
Root: fd.vd,
Start: fd.vd,
})
- err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, opts)
+ err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, *opts)
vfsObj.putResolvingPath(rp)
return err
}
- return fd.impl.Setxattr(ctx, opts)
+ return fd.impl.Setxattr(ctx, *opts)
}
// Removexattr removes the given extended attribute from the file represented
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index d45e602ce..f4c111926 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -130,14 +130,14 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg
// Listxattr implements FileDescriptionImpl.Listxattr analogously to
// inode_operations::listxattr == NULL in Linux.
-func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context) ([]string, error) {
+func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context, size uint64) ([]string, error) {
// This isn't exactly accurate; see FileDescription.Listxattr.
return nil, syserror.ENOTSUP
}
// Getxattr implements FileDescriptionImpl.Getxattr analogously to
// inode::i_opflags & IOP_XATTR == 0 in Linux.
-func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, name string) (string, error) {
+func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) {
return "", syserror.ENOTSUP
}
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index cd34782ff..a537a29d1 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -442,7 +442,13 @@ type FilesystemImpl interface {
// - If extended attributes are not supported by the filesystem,
// ListxattrAt returns nil. (See FileDescription.Listxattr for an
// explanation.)
- ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error)
+ //
+ // - If the size of the list (including a NUL terminating byte after every
+ // entry) would exceed size, ERANGE may be returned. Note that
+ // implementations are free to ignore size entirely and return without
+ // error). In all cases, if size is 0, the list should be returned without
+ // error, regardless of size.
+ ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error)
// GetxattrAt returns the value associated with the given extended
// attribute for the file at rp.
@@ -451,7 +457,15 @@ type FilesystemImpl interface {
//
// - If extended attributes are not supported by the filesystem, GetxattrAt
// returns ENOTSUP.
- GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error)
+ //
+ // - If an extended attribute named opts.Name does not exist, ENODATA is
+ // returned.
+ //
+ // - If the size of the return value exceeds opts.Size, ERANGE may be
+ // returned (note that implementations are free to ignore opts.Size entirely
+ // and return without error). In all cases, if opts.Size is 0, the value
+ // should be returned without error, regardless of size.
+ GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error)
// SetxattrAt changes the value associated with the given extended
// attribute for the file at rp.
@@ -460,6 +474,10 @@ type FilesystemImpl interface {
//
// - If extended attributes are not supported by the filesystem, SetxattrAt
// returns ENOTSUP.
+ //
+ // - If XATTR_CREATE is set in opts.Flag and opts.Name already exists,
+ // EEXIST is returned. If XATTR_REPLACE is set and opts.Name does not exist,
+ // ENODATA is returned.
SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error
// RemovexattrAt removes the given extended attribute from the file at rp.
@@ -468,6 +486,8 @@ type FilesystemImpl interface {
//
// - If extended attributes are not supported by the filesystem,
// RemovexattrAt returns ENOTSUP.
+ //
+ // - If name does not exist, ENODATA is returned.
RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error
// BoundEndpointAt returns the Unix socket endpoint bound at the path rp.
@@ -497,7 +517,7 @@ type FilesystemImpl interface {
// Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl.
PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error
- // TODO: inotify_add_watch()
+ // TODO(gvisor.dev/issue/1479): inotify_add_watch()
}
// PrependPathAtVFSRootError is returned by implementations of
diff --git a/pkg/sentry/vfs/memxattr/BUILD b/pkg/sentry/vfs/memxattr/BUILD
new file mode 100644
index 000000000..d8c4d27b9
--- /dev/null
+++ b/pkg/sentry/vfs/memxattr/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "memxattr",
+ srcs = ["xattr.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/vfs/memxattr/xattr.go b/pkg/sentry/vfs/memxattr/xattr.go
new file mode 100644
index 000000000..cc1e7d764
--- /dev/null
+++ b/pkg/sentry/vfs/memxattr/xattr.go
@@ -0,0 +1,102 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package memxattr provides a default, in-memory extended attribute
+// implementation.
+package memxattr
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// SimpleExtendedAttributes implements extended attributes using a map of
+// names to values.
+//
+// +stateify savable
+type SimpleExtendedAttributes struct {
+ // mu protects the below fields.
+ mu sync.RWMutex `state:"nosave"`
+ xattrs map[string]string
+}
+
+// Getxattr returns the value at 'name'.
+func (x *SimpleExtendedAttributes) Getxattr(opts *vfs.GetxattrOptions) (string, error) {
+ x.mu.RLock()
+ value, ok := x.xattrs[opts.Name]
+ x.mu.RUnlock()
+ if !ok {
+ return "", syserror.ENODATA
+ }
+ // Check that the size of the buffer provided in getxattr(2) is large enough
+ // to contain the value.
+ if opts.Size != 0 && uint64(len(value)) > opts.Size {
+ return "", syserror.ERANGE
+ }
+ return value, nil
+}
+
+// Setxattr sets 'value' at 'name'.
+func (x *SimpleExtendedAttributes) Setxattr(opts *vfs.SetxattrOptions) error {
+ x.mu.Lock()
+ defer x.mu.Unlock()
+ if x.xattrs == nil {
+ if opts.Flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
+ x.xattrs = make(map[string]string)
+ }
+
+ _, ok := x.xattrs[opts.Name]
+ if ok && opts.Flags&linux.XATTR_CREATE != 0 {
+ return syserror.EEXIST
+ }
+ if !ok && opts.Flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
+
+ x.xattrs[opts.Name] = opts.Value
+ return nil
+}
+
+// Listxattr returns all names in xattrs.
+func (x *SimpleExtendedAttributes) Listxattr(size uint64) ([]string, error) {
+ // Keep track of the size of the buffer needed in listxattr(2) for the list.
+ listSize := 0
+ x.mu.RLock()
+ names := make([]string, 0, len(x.xattrs))
+ for n := range x.xattrs {
+ names = append(names, n)
+ // Add one byte per null terminator.
+ listSize += len(n) + 1
+ }
+ x.mu.RUnlock()
+ if size != 0 && uint64(listSize) > size {
+ return nil, syserror.ERANGE
+ }
+ return names, nil
+}
+
+// Removexattr removes the xattr at 'name'.
+func (x *SimpleExtendedAttributes) Removexattr(name string) error {
+ x.mu.Lock()
+ defer x.mu.Unlock()
+ if _, ok := x.xattrs[name]; !ok {
+ return syserror.ENODATA
+ }
+ delete(x.xattrs, name)
+ return nil
+}
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 7792eb1a0..f06946103 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -233,9 +233,9 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia
}
vd.dentry.mu.Lock()
}
- // TODO: Linux requires that either both the mount point and the mount root
- // are directories, or neither are, and returns ENOTDIR if this is not the
- // case.
+ // TODO(gvisor.dev/issue/1035): Linux requires that either both the mount
+ // point and the mount root are directories, or neither are, and returns
+ // ENOTDIR if this is not the case.
mntns := vd.mount.ns
mnt := newMount(vfs, fs, root, mntns, opts)
vfs.mounts.seq.BeginWrite()
@@ -274,9 +274,9 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti
}
}
- // TODO(jamieliu): Linux special-cases umount of the caller's root, which
- // we don't implement yet (we'll just fail it since the caller holds a
- // reference on it).
+ // TODO(gvisor.dev/issue/1035): Linux special-cases umount of the caller's
+ // root, which we don't implement yet (we'll just fail it since the caller
+ // holds a reference on it).
vfs.mounts.seq.BeginWrite()
if opts.Flags&linux.MNT_DETACH == 0 {
@@ -835,7 +835,8 @@ func superBlockOpts(mountPath string, mnt *Mount) string {
// NOTE(b/147673608): If the mount is a cgroup, we also need to include
// the cgroup name in the options. For now we just read that from the
// path.
- // TODO(gvisor.dev/issues/190): Once gVisor has full cgroup support, we
+ //
+ // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we
// should get this value from the cgroup itself, and not rely on the
// path.
if mnt.fs.FilesystemType().Name() == "cgroup" {
diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go
index 3b933468d..3335e4057 100644
--- a/pkg/sentry/vfs/mount_test.go
+++ b/pkg/sentry/vfs/mount_test.go
@@ -55,7 +55,7 @@ func TestMountTableInsertLookup(t *testing.T) {
}
}
-// TODO: concurrent lookup/insertion/removal
+// TODO(gvisor.dev/issue/1035): concurrent lookup/insertion/removal.
// must be powers of 2
var benchNumMounts = []int{1 << 2, 1 << 5, 1 << 8}
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index 3e90dc4ed..534528ce6 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -16,6 +16,7 @@ package vfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
)
// GetDentryOptions contains options to VirtualFilesystem.GetDentryAt() and
@@ -44,6 +45,10 @@ type MknodOptions struct {
// DevMinor are the major and minor device numbers for the created device.
DevMajor uint32
DevMinor uint32
+
+ // Endpoint is the endpoint to bind to the created file, if a socket file is
+ // being created for bind(2) on a Unix domain socket.
+ Endpoint transport.BoundEndpoint
}
// MountFlags contains flags as specified for mount(2), e.g. MS_NOEXEC.
@@ -127,6 +132,20 @@ type SetStatOptions struct {
Stat linux.Statx
}
+// GetxattrOptions contains options to VirtualFilesystem.GetxattrAt(),
+// FilesystemImpl.GetxattrAt(), FileDescription.Getxattr(), and
+// FileDescriptionImpl.Getxattr().
+type GetxattrOptions struct {
+ // Name is the name of the extended attribute to retrieve.
+ Name string
+
+ // Size is the maximum value size that the caller will tolerate. If the value
+ // is larger than size, getxattr methods may return ERANGE, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ Size uint64
+}
+
// SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(),
// FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and
// FileDescriptionImpl.Setxattr().
diff --git a/pkg/sentry/vfs/timerfd.go b/pkg/sentry/vfs/timerfd.go
new file mode 100644
index 000000000..42b880656
--- /dev/null
+++ b/pkg/sentry/vfs/timerfd.go
@@ -0,0 +1,142 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// TimerFileDescription implements FileDescriptionImpl for timer fds. It also
+// implements ktime.TimerListener.
+type TimerFileDescription struct {
+ vfsfd FileDescription
+ FileDescriptionDefaultImpl
+ DentryMetadataFileDescriptionImpl
+
+ events waiter.Queue
+ timer *ktime.Timer
+
+ // val is the number of timer expirations since the last successful
+ // call to PRead, or SetTime. val must be accessed using atomic memory
+ // operations.
+ val uint64
+}
+
+var _ FileDescriptionImpl = (*TimerFileDescription)(nil)
+var _ ktime.TimerListener = (*TimerFileDescription)(nil)
+
+// NewTimerFD returns a new timer fd.
+func (vfs *VirtualFilesystem) NewTimerFD(clock ktime.Clock, flags uint32) (*FileDescription, error) {
+ vd := vfs.NewAnonVirtualDentry("[timerfd]")
+ defer vd.DecRef()
+ tfd := &TimerFileDescription{}
+ tfd.timer = ktime.NewTimer(clock, tfd)
+ if err := tfd.vfsfd.Init(tfd, flags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{
+ UseDentryMetadata: true,
+ DenyPRead: true,
+ DenyPWrite: true,
+ InvalidWrite: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &tfd.vfsfd, nil
+}
+
+// Read implements FileDescriptionImpl.Read.
+func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
+ const sizeofUint64 = 8
+ if dst.NumBytes() < sizeofUint64 {
+ return 0, syserror.EINVAL
+ }
+ if val := atomic.SwapUint64(&tfd.val, 0); val != 0 {
+ var buf [sizeofUint64]byte
+ usermem.ByteOrder.PutUint64(buf[:], val)
+ if _, err := dst.CopyOut(ctx, buf[:]); err != nil {
+ // Linux does not undo consuming the number of
+ // expirations even if writing to userspace fails.
+ return 0, err
+ }
+ return sizeofUint64, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+// Clock returns the timer fd's Clock.
+func (tfd *TimerFileDescription) Clock() ktime.Clock {
+ return tfd.timer.Clock()
+}
+
+// GetTime returns the associated Timer's setting and the time at which it was
+// observed.
+func (tfd *TimerFileDescription) GetTime() (ktime.Time, ktime.Setting) {
+ return tfd.timer.Get()
+}
+
+// SetTime atomically changes the associated Timer's setting, resets the number
+// of expirations to 0, and returns the previous setting and the time at which
+// it was observed.
+func (tfd *TimerFileDescription) SetTime(s ktime.Setting) (ktime.Time, ktime.Setting) {
+ return tfd.timer.SwapAnd(s, func() { atomic.StoreUint64(&tfd.val, 0) })
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (tfd *TimerFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ var ready waiter.EventMask
+ if atomic.LoadUint64(&tfd.val) != 0 {
+ ready |= waiter.EventIn
+ }
+ return ready
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (tfd *TimerFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ tfd.events.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (tfd *TimerFileDescription) EventUnregister(e *waiter.Entry) {
+ tfd.events.EventUnregister(e)
+}
+
+// PauseTimer pauses the associated Timer.
+func (tfd *TimerFileDescription) PauseTimer() {
+ tfd.timer.Pause()
+}
+
+// ResumeTimer resumes the associated Timer.
+func (tfd *TimerFileDescription) ResumeTimer() {
+ tfd.timer.Resume()
+}
+
+// Release implements FileDescriptionImpl.Release()
+func (tfd *TimerFileDescription) Release() {
+ tfd.timer.Destroy()
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (tfd *TimerFileDescription) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
+ atomic.AddUint64(&tfd.val, exp)
+ tfd.events.Notify(waiter.EventIn)
+ return ktime.Setting{}, false
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (tfd *TimerFileDescription) Destroy() {}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 720b90d8f..cb5bbd781 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -335,7 +335,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
rp := vfs.getResolvingPath(creds, pop)
for {
err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts)
- if err != nil {
+ if err == nil {
vfs.putResolvingPath(rp)
return nil
}
@@ -383,14 +383,11 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C
func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
// Remove:
//
- // - O_LARGEFILE, which we always report in FileDescription status flags
- // since only 64-bit architectures are supported at this time.
- //
// - O_CLOEXEC, which affects file descriptors and therefore must be
// handled outside of VFS.
//
// - Unknown flags.
- opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE
+ opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_LARGEFILE | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE
// Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC.
if opts.Flags&linux.O_SYNC != 0 {
opts.Flags |= linux.O_DSYNC
@@ -680,10 +677,10 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti
// ListxattrAt returns all extended attribute names for the file at the given
// path.
-func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) ([]string, error) {
+func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) {
rp := vfs.getResolvingPath(creds, pop)
for {
- names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp)
+ names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp, size)
if err == nil {
vfs.putResolvingPath(rp)
return names, nil
@@ -705,10 +702,10 @@ func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Crede
// GetxattrAt returns the value associated with the given extended attribute
// for the file at the given path.
-func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) (string, error) {
+func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetxattrOptions) (string, error) {
rp := vfs.getResolvingPath(creds, pop)
for {
- val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, name)
+ val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, *opts)
if err == nil {
vfs.putResolvingPath(rp)
return val, nil
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index f7d6009a0..fcc46420f 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -319,8 +319,8 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo
// Dump stack only if a new task is detected or if it sometime has
// passed since the last time a stack dump was generated.
- skipStack := newTaskFound || time.Since(w.lastStackDump) >= stackDumpSameTaskPeriod
- w.doAction(w.TaskTimeoutAction, skipStack, &buf)
+ showStack := newTaskFound || time.Since(w.lastStackDump) >= stackDumpSameTaskPeriod
+ w.doAction(w.TaskTimeoutAction, showStack, &buf)
}
func (w *Watchdog) reportStuckWatchdog() {
@@ -329,16 +329,15 @@ func (w *Watchdog) reportStuckWatchdog() {
w.doAction(w.TaskTimeoutAction, false, &buf)
}
-// doAction will take the given action. If the action is LogWarnind and
-// skipStack is true, then the stack printing will be skipped.
-func (w *Watchdog) doAction(action Action, skipStack bool, msg *bytes.Buffer) {
+// doAction will take the given action. If the action is LogWarning and
+// showStack is false, then the stack printing will be skipped.
+func (w *Watchdog) doAction(action Action, showStack bool, msg *bytes.Buffer) {
switch action {
case LogWarning:
- if skipStack {
+ if !showStack {
msg.WriteString("\n...[stack dump skipped]...")
log.Warningf(msg.String())
return
-
}
log.TracebackAll(msg.String())
w.lastStackDump = time.Now()
diff --git a/pkg/state/state.go b/pkg/state/state.go
index dbe507ab4..03ae2dbb0 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -241,10 +241,7 @@ func Register(name string, instance interface{}, fns Fns) {
//
// This function is used by the stateify tool.
func IsZeroValue(val interface{}) bool {
- if val == nil {
- return true
- }
- return reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface())
+ return val == nil || reflect.ValueOf(val).Elem().IsZero()
}
// step captures one encoding / decoding step. On each step, there is up to one
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index 5340cf0d6..0e35d7d17 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -31,13 +31,13 @@ go_library(
name = "sync",
srcs = [
"aliases.go",
- "downgradable_rwmutex_unsafe.go",
"memmove_unsafe.go",
+ "mutex_unsafe.go",
"norace_unsafe.go",
"race_unsafe.go",
+ "rwmutex_unsafe.go",
"seqcount.go",
- "syncutil.go",
- "tmutex_unsafe.go",
+ "sync.go",
],
)
@@ -45,9 +45,9 @@ go_test(
name = "sync_test",
size = "small",
srcs = [
- "downgradable_rwmutex_test.go",
+ "mutex_test.go",
+ "rwmutex_test.go",
"seqcount_test.go",
- "tmutex_test.go",
],
library = ":sync",
)
diff --git a/pkg/sync/tmutex_test.go b/pkg/sync/mutex_test.go
index 0838248b4..0838248b4 100644
--- a/pkg/sync/tmutex_test.go
+++ b/pkg/sync/mutex_test.go
diff --git a/pkg/sync/tmutex_unsafe.go b/pkg/sync/mutex_unsafe.go
index 3dd15578b..3dd15578b 100644
--- a/pkg/sync/tmutex_unsafe.go
+++ b/pkg/sync/mutex_unsafe.go
diff --git a/pkg/sync/downgradable_rwmutex_test.go b/pkg/sync/rwmutex_test.go
index ce667e825..ce667e825 100644
--- a/pkg/sync/downgradable_rwmutex_test.go
+++ b/pkg/sync/rwmutex_test.go
diff --git a/pkg/sync/downgradable_rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go
index ea6cdc447..ea6cdc447 100644
--- a/pkg/sync/downgradable_rwmutex_unsafe.go
+++ b/pkg/sync/rwmutex_unsafe.go
diff --git a/pkg/sync/syncutil.go b/pkg/sync/sync.go
index b16cf5333..b16cf5333 100644
--- a/pkg/sync/syncutil.go
+++ b/pkg/sync/sync.go
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 8d42cd066..8ec5d5d5c 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -17,6 +17,7 @@ package buffer
import (
"bytes"
+ "io"
)
// View is a slice of a buffer, with convenience methods.
@@ -89,6 +90,47 @@ func (vv *VectorisedView) TrimFront(count int) {
}
}
+// Read implements io.Reader.
+func (vv *VectorisedView) Read(v View) (copied int, err error) {
+ count := len(v)
+ for count > 0 && len(vv.views) > 0 {
+ if count < len(vv.views[0]) {
+ vv.size -= count
+ copy(v[copied:], vv.views[0][:count])
+ vv.views[0].TrimFront(count)
+ copied += count
+ return copied, nil
+ }
+ count -= len(vv.views[0])
+ copy(v[copied:], vv.views[0])
+ copied += len(vv.views[0])
+ vv.RemoveFirst()
+ }
+ if copied == 0 {
+ return 0, io.EOF
+ }
+ return copied, nil
+}
+
+// ReadToVV reads up to n bytes from vv to dstVV and removes them from vv. It
+// returns the number of bytes copied.
+func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int) {
+ for count > 0 && len(vv.views) > 0 {
+ if count < len(vv.views[0]) {
+ vv.size -= count
+ dstVV.AppendView(vv.views[0][:count])
+ vv.views[0].TrimFront(count)
+ copied += count
+ return
+ }
+ count -= len(vv.views[0])
+ dstVV.AppendView(vv.views[0])
+ copied += len(vv.views[0])
+ vv.RemoveFirst()
+ }
+ return copied
+}
+
// CapLength irreversibly reduces the length of the vectorised view.
func (vv *VectorisedView) CapLength(length int) {
if length < 0 {
@@ -116,12 +158,12 @@ func (vv *VectorisedView) CapLength(length int) {
// Clone returns a clone of this VectorisedView.
// If the buffer argument is large enough to contain all the Views of this VectorisedView,
// the method will avoid allocations and use the buffer to store the Views of the clone.
-func (vv VectorisedView) Clone(buffer []View) VectorisedView {
+func (vv *VectorisedView) Clone(buffer []View) VectorisedView {
return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
}
// First returns the first view of the vectorised view.
-func (vv VectorisedView) First() View {
+func (vv *VectorisedView) First() View {
if len(vv.views) == 0 {
return nil
}
@@ -134,11 +176,12 @@ func (vv *VectorisedView) RemoveFirst() {
return
}
vv.size -= len(vv.views[0])
+ vv.views[0] = nil
vv.views = vv.views[1:]
}
// Size returns the size in bytes of the entire content stored in the vectorised view.
-func (vv VectorisedView) Size() int {
+func (vv *VectorisedView) Size() int {
return vv.size
}
@@ -146,7 +189,7 @@ func (vv VectorisedView) Size() int {
//
// If the vectorised view contains a single view, that view will be returned
// directly.
-func (vv VectorisedView) ToView() View {
+func (vv *VectorisedView) ToView() View {
if len(vv.views) == 1 {
return vv.views[0]
}
@@ -158,7 +201,7 @@ func (vv VectorisedView) ToView() View {
}
// Views returns the slice containing the all views.
-func (vv VectorisedView) Views() []View {
+func (vv *VectorisedView) Views() []View {
return vv.views
}
diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go
index ebc3a17b7..106e1994c 100644
--- a/pkg/tcpip/buffer/view_test.go
+++ b/pkg/tcpip/buffer/view_test.go
@@ -233,3 +233,140 @@ func TestToClone(t *testing.T) {
})
}
}
+
+func TestVVReadToVV(t *testing.T) {
+ testCases := []struct {
+ comment string
+ vv VectorisedView
+ bytesToRead int
+ wantBytes string
+ leftVV VectorisedView
+ }{
+ {
+ comment: "large VV, short read",
+ vv: vv(30, "012345678901234567890123456789"),
+ bytesToRead: 10,
+ wantBytes: "0123456789",
+ leftVV: vv(20, "01234567890123456789"),
+ },
+ {
+ comment: "largeVV, multiple views, short read",
+ vv: vv(13, "123", "345", "567", "8910"),
+ bytesToRead: 6,
+ wantBytes: "123345",
+ leftVV: vv(7, "567", "8910"),
+ },
+ {
+ comment: "smallVV (multiple views), large read",
+ vv: vv(3, "1", "2", "3"),
+ bytesToRead: 10,
+ wantBytes: "123",
+ leftVV: vv(0, ""),
+ },
+ {
+ comment: "smallVV (single view), large read",
+ vv: vv(1, "1"),
+ bytesToRead: 10,
+ wantBytes: "1",
+ leftVV: vv(0, ""),
+ },
+ {
+ comment: "emptyVV, large read",
+ vv: vv(0, ""),
+ bytesToRead: 10,
+ wantBytes: "",
+ leftVV: vv(0, ""),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.comment, func(t *testing.T) {
+ var readTo VectorisedView
+ inSize := tc.vv.Size()
+ copied := tc.vv.ReadToVV(&readTo, tc.bytesToRead)
+ if got, want := copied, len(tc.wantBytes); got != want {
+ t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc: %+v", got, want, tc)
+ }
+ if got, want := string(readTo.ToView()), tc.wantBytes; got != want {
+ t.Errorf("unexpected content in readTo got: %s, want: %s", got, want)
+ }
+ if got, want := tc.vv.Size(), inSize-copied; got != want {
+ t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
+ }
+ if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want {
+ t.Errorf("unexpected data left in vv after read got: %+v, want: %+v", got, want)
+ }
+ })
+ }
+}
+
+func TestVVRead(t *testing.T) {
+ testCases := []struct {
+ comment string
+ vv VectorisedView
+ bytesToRead int
+ readBytes string
+ leftBytes string
+ wantError bool
+ }{
+ {
+ comment: "large VV, short read",
+ vv: vv(30, "012345678901234567890123456789"),
+ bytesToRead: 10,
+ readBytes: "0123456789",
+ leftBytes: "01234567890123456789",
+ },
+ {
+ comment: "largeVV, multiple buffers, short read",
+ vv: vv(13, "123", "345", "567", "8910"),
+ bytesToRead: 6,
+ readBytes: "123345",
+ leftBytes: "5678910",
+ },
+ {
+ comment: "smallVV, large read",
+ vv: vv(3, "1", "2", "3"),
+ bytesToRead: 10,
+ readBytes: "123",
+ leftBytes: "",
+ },
+ {
+ comment: "smallVV, large read",
+ vv: vv(1, "1"),
+ bytesToRead: 10,
+ readBytes: "1",
+ leftBytes: "",
+ },
+ {
+ comment: "emptyVV, large read",
+ vv: vv(0, ""),
+ bytesToRead: 10,
+ readBytes: "",
+ wantError: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.comment, func(t *testing.T) {
+ readTo := NewView(tc.bytesToRead)
+ inSize := tc.vv.Size()
+ copied, err := tc.vv.Read(readTo)
+ if !tc.wantError && err != nil {
+ t.Fatalf("unexpected error in tc.vv.Read(..) = %s", err)
+ }
+ readTo = readTo[:copied]
+ if got, want := copied, len(tc.readBytes); got != want {
+ t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
+ }
+ if got, want := string(readTo), tc.readBytes; got != want {
+ t.Errorf("unexpected data in readTo got: %s, want: %s", got, want)
+ }
+ if got, want := tc.vv.Size(), inSize-copied; got != want {
+ t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
+ }
+ if got, want := string(tc.vv.ToView()), tc.leftBytes; got != want {
+ t.Errorf("vv has incorrect data after Read got: %s, want: %s", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 8dc0f7c0e..c1745ba6a 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -107,6 +107,8 @@ func DstAddr(addr tcpip.Address) NetworkChecker {
// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
func TTL(ttl uint8) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
var v uint8
switch ip := h[0].(type) {
case header.IPv4:
@@ -310,6 +312,8 @@ func SrcPort(port uint16) TransportChecker {
// DstPort creates a checker that checks the destination port.
func DstPort(port uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
if p := h.DestinationPort(); p != port {
t.Errorf("Bad destination port, got %v, want %v", p, port)
}
@@ -336,6 +340,7 @@ func SeqNum(seq uint32) TransportChecker {
func AckNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -350,6 +355,8 @@ func AckNum(seq uint32) TransportChecker {
// Window creates a checker that checks the tcp window.
func Window(window uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -381,6 +388,8 @@ func TCPFlags(flags uint8) TransportChecker {
// given mask, match the supplied flags.
func TCPFlagsMatch(flags, mask uint8) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -398,6 +407,8 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker {
// If wndscale is negative, the window scale option must not be present.
func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -494,6 +505,8 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
// skipped.
func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -612,6 +625,8 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
// Payload creates a checker that checks the payload.
func Payload(want []byte) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
if got := h.Payload(); !reflect.DeepEqual(got, want) {
t.Errorf("Wrong payload, got %v, want %v", got, want)
}
@@ -644,6 +659,7 @@ func ICMPv4(checkers ...TransportChecker) NetworkChecker {
func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
@@ -658,6 +674,7 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
func ICMPv4Code(want byte) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
@@ -700,6 +717,7 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker {
func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv6, ok := h.(header.ICMPv6)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
@@ -714,6 +732,7 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
func ICMPv6Code(want byte) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv6, ok := h.(header.ICMPv6)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
@@ -728,7 +747,7 @@ func ICMPv6Code(want byte) TransportChecker {
// message for type of ty, with potentially additional checks specified by
// checkers.
//
-// checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
// NDP message as far as the size of the message (minSize) is concerned. The
// values within the message are up to checkers to validate.
func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
@@ -760,9 +779,9 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N
// Neighbor Solicitation message (as per the raw wire format), with potentially
// additional checks specified by checkers.
//
-// checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDPNS message as far as the size of the messages concerned. The values within
-// the message are up to checkers to validate.
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPNS message as far as the size of the message is concerned. The values
+// within the message are up to checkers to validate.
func NDPNS(checkers ...TransportChecker) NetworkChecker {
return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
}
@@ -780,7 +799,54 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
ns := header.NDPNeighborSolicit(icmp.NDPPayload())
if got := ns.TargetAddress(); got != want {
- t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want)
+ t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
+ }
+ }
+}
+
+// NDPNA creates a checker that checks that the packet contains a valid NDP
+// Neighbor Advertisement message (as per the raw wire format), with potentially
+// additional checks specified by checkers.
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPNA message as far as the size of the message is concerned. The values
+// within the message are up to checkers to validate.
+func NDPNA(checkers ...TransportChecker) NetworkChecker {
+ return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...)
+}
+
+// NDPNATargetAddress creates a checker that checks the Target Address field of
+// a header.NDPNeighborAdvert.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNATargetAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+
+ if got := na.TargetAddress(); got != want {
+ t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
+ }
+ }
+}
+
+// NDPNASolicitedFlag creates a checker that checks the Solicited field of
+// a header.NDPNeighborAdvert.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNASolicitedFlag(want bool) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+
+ if got := na.SolicitedFlag(); got != want {
+ t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
}
}
}
@@ -819,6 +885,13 @@ func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption
} else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
}
+ case header.NDPTargetLinkLayerAddressOption:
+ gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption)
+ if !ok {
+ t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
+ } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
+ t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
+ }
default:
t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt)
}
@@ -831,6 +904,21 @@ func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption
}
}
+// NDPNAOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Neighbor Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNAOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ ndpOptions(t, na.Options(), opts)
+ }
+}
+
// NDPNSOptions creates a checker that checks that the packet contains the
// provided NDP options within an NDP Neighbor Solicitation message.
//
@@ -849,7 +937,7 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker {
// NDPRS creates a checker that checks that the packet contains a valid NDP
// Router Solicitation message (as per the raw wire format).
//
-// checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
// NDPRS as far as the size of the message is concerned. The values within the
// message are up to checkers to validate.
func NDPRS(checkers ...TransportChecker) NetworkChecker {
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 7094f3f0b..0cde694dc 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -21,6 +21,7 @@ go_library(
"ndp_options.go",
"ndp_router_advert.go",
"ndp_router_solicit.go",
+ "ndpoptionidentifier_string.go",
"tcp.go",
"udp.go",
],
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index 7a0014ad9..14413f2ce 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -88,7 +88,7 @@ func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr {
- t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", got, test.expectedLinkAddr)
+ t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", test.addr, got, test.expectedLinkAddr)
}
})
}
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
index 1b6c3f328..2c4591409 100644
--- a/pkg/tcpip/header/ipv6_extension_headers.go
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -62,6 +62,10 @@ const (
// within an IPv6RoutingExtHdr.
ipv6RoutingExtHdrSegmentsLeftIdx = 1
+ // IPv6FragmentExtHdrLength is the length of an IPv6 extension header, in
+ // bytes.
+ IPv6FragmentExtHdrLength = 8
+
// ipv6FragmentExtHdrFragmentOffsetOffset is the offset to the start of the
// Fragment Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetOffset = 0
@@ -391,17 +395,24 @@ func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, pa
}
// AsRawHeader returns the remaining payload of i as a raw header and
-// completes the iterator.
+// optionally consumes the iterator.
//
-// Calls to Next after calling AsRawHeader on i will indicate that the
-// iterator is done.
-func (i *IPv6PayloadIterator) AsRawHeader() IPv6RawPayloadHeader {
- buf := i.payload
+// If consume is true, calls to Next after calling AsRawHeader on i will
+// indicate that the iterator is done.
+func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
identifier := i.nextHdrIdentifier
- // Mark i as done.
- *i = IPv6PayloadIterator{
- nextHdrIdentifier: IPv6NoNextHeaderIdentifier,
+ var buf buffer.VectorisedView
+ if consume {
+ // Since we consume the iterator, we return the payload as is.
+ buf = i.payload
+
+ // Mark i as done.
+ *i = IPv6PayloadIterator{
+ nextHdrIdentifier: IPv6NoNextHeaderIdentifier,
+ }
+ } else {
+ buf = i.payload.Clone(nil)
}
return IPv6RawPayloadHeader{Identifier: identifier, Buf: buf}
@@ -420,7 +431,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
// a fragment extension header as the data following the fragment extension
// header may not be complete.
if i.forceRaw {
- return i.AsRawHeader(), false, nil
+ return i.AsRawHeader(true /* consume */), false, nil
}
// Is the header we are parsing a known extension header?
@@ -452,10 +463,12 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
fragmentExtHdr := IPv6FragmentExtHdr(data)
- // If the packet is a fragmented packet, do not attempt to parse
- // anything after the fragment extension header as the data following
- // the extension header may not be complete.
- if fragmentExtHdr.More() || fragmentExtHdr.FragmentOffset() != 0 {
+ // If the packet is not the first fragment, do not attempt to parse anything
+ // after the fragment extension header as the payload following the fragment
+ // extension header should not contain any headers; the first fragment must
+ // hold all the headers up to and including any upper layer headers, as per
+ // RFC 8200 section 4.5.
+ if fragmentExtHdr.FragmentOffset() != 0 {
i.forceRaw = true
}
@@ -476,7 +489,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
default:
// The header we are parsing is not a known extension header. Return the
// raw payload.
- return i.AsRawHeader(), false, nil
+ return i.AsRawHeader(true /* consume */), false, nil
}
}
diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go
index 133ccc8b6..ab20c5f37 100644
--- a/pkg/tcpip/header/ipv6_extension_headers_test.go
+++ b/pkg/tcpip/header/ipv6_extension_headers_test.go
@@ -673,19 +673,26 @@ func TestIPv6ExtHdrIter(t *testing.T) {
payload buffer.VectorisedView
expected []IPv6PayloadHeader
}{
- // With a non-atomic fragment, the payload after the fragment will not be
- // parsed because the payload may not be complete.
+ // With a non-atomic fragment that is not the first fragment, the payload
+ // after the fragment will not be parsed because the payload is expected to
+ // only hold upper layer data.
{
- name: "hopbyhop - fragment - routing - upper",
+ name: "hopbyhop - fragment (not first) - routing - upper",
firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
payload: makeVectorisedViewFromByteBuffers([]byte{
// Hop By Hop extension header.
uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
// Fragment extension header.
+ //
+ // More = 1, Fragment Offset = 2117, ID = 2147746305
uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1,
// Routing extension header.
+ //
+ // Even though we have a routing ext header here, it should be
+ // be interpretted as raw bytes as only the first fragment is expected
+ // to hold headers.
255, 0, 1, 2, 3, 4, 5, 6,
// Upper layer data.
@@ -701,6 +708,34 @@ func TestIPv6ExtHdrIter(t *testing.T) {
},
},
{
+ name: "hopbyhop - fragment (first) - routing - upper",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Fragment extension header.
+ //
+ // More = 1, Fragment Offset = 0, ID = 2147746305
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 0, 1, 128, 4, 2, 1,
+
+ // Routing extension header.
+ 255, 0, 1, 2, 3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6FragmentExtHdr([6]byte{0, 1, 128, 4, 2, 1}),
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: upperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+ {
name: "fragment - routing - upper (across views)",
firstNextHdr: IPv6FragmentExtHdrIdentifier,
payload: makeVectorisedViewFromByteBuffers([]byte{
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index e6a6ad39b..5d3975c56 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -15,32 +15,47 @@
package header
import (
+ "bytes"
"encoding/binary"
"errors"
"fmt"
+ "io"
"math"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
)
+// NDPOptionIdentifier is an NDP option type identifier.
+type NDPOptionIdentifier uint8
+
const (
// NDPSourceLinkLayerAddressOptionType is the type of the Source Link Layer
// Address option, as per RFC 4861 section 4.6.1.
- NDPSourceLinkLayerAddressOptionType = 1
+ NDPSourceLinkLayerAddressOptionType NDPOptionIdentifier = 1
// NDPTargetLinkLayerAddressOptionType is the type of the Target Link Layer
// Address option, as per RFC 4861 section 4.6.1.
- NDPTargetLinkLayerAddressOptionType = 2
+ NDPTargetLinkLayerAddressOptionType NDPOptionIdentifier = 2
+
+ // NDPPrefixInformationType is the type of the Prefix Information
+ // option, as per RFC 4861 section 4.6.2.
+ NDPPrefixInformationType NDPOptionIdentifier = 3
+
+ // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS
+ // Server option, as per RFC 8106 section 5.1.
+ NDPRecursiveDNSServerOptionType NDPOptionIdentifier = 25
+ // NDPDNSSearchListOptionType is the type of the DNS Search List option,
+ // as per RFC 8106 section 5.2.
+ NDPDNSSearchListOptionType = 31
+)
+
+const (
// NDPLinkLayerAddressSize is the size of a Source or Target Link Layer
// Address option for an Ethernet address.
NDPLinkLayerAddressSize = 8
- // NDPPrefixInformationType is the type of the Prefix Information
- // option, as per RFC 4861 section 4.6.2.
- NDPPrefixInformationType = 3
-
// ndpPrefixInformationLength is the expected length, in bytes, of the
// body of an NDP Prefix Information option, as per RFC 4861 section
// 4.6.2 which specifies that the Length field is 4. Given this, the
@@ -91,10 +106,6 @@ const (
// within an NDPPrefixInformation.
ndpPrefixInformationPrefixOffset = 14
- // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS
- // Server option, as per RFC 8106 section 5.1.
- NDPRecursiveDNSServerOptionType = 25
-
// ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte
// Lifetime field within an NDPRecursiveDNSServer.
ndpRecursiveDNSServerLifetimeOffset = 2
@@ -103,10 +114,31 @@ const (
// for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer.
ndpRecursiveDNSServerAddressesOffset = 6
- // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS
- // Server option's length field value when it contains at least one
- // IPv6 address.
- minNDPRecursiveDNSServerLength = 3
+ // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS Server
+ // option's body size when it contains at least one IPv6 address, as per
+ // RFC 8106 section 5.3.1.
+ minNDPRecursiveDNSServerBodySize = 22
+
+ // ndpDNSSearchListLifetimeOffset is the start of the 4-byte
+ // Lifetime field within an NDPDNSSearchList.
+ ndpDNSSearchListLifetimeOffset = 2
+
+ // ndpDNSSearchListDomainNamesOffset is the start of the DNS search list
+ // domain names within an NDPDNSSearchList.
+ ndpDNSSearchListDomainNamesOffset = 6
+
+ // minNDPDNSSearchListBodySize is the minimum NDP DNS Search List option's
+ // body size when it contains at least one domain name, as per RFC 8106
+ // section 5.3.1.
+ minNDPDNSSearchListBodySize = 14
+
+ // maxDomainNameLabelLength is the maximum length of a domain name
+ // label, as per RFC 1035 section 3.1.
+ maxDomainNameLabelLength = 63
+
+ // maxDomainNameLength is the maximum length of a domain name, including
+ // label AND label length octet, as per RFC 1035 section 3.1.
+ maxDomainNameLength = 255
// lengthByteUnits is the multiplier factor for the Length field of an
// NDP option. That is, the length field for NDP options is in units of
@@ -132,16 +164,13 @@ var (
// few NDPOption then modify the backing NDPOptions so long as the
// NDPOptionIterator obtained before modification is no longer used.
type NDPOptionIterator struct {
- // The NDPOptions this NDPOptionIterator is iterating over.
- opts NDPOptions
+ opts *bytes.Buffer
}
// Potential errors when iterating over an NDPOptions.
var (
- ErrNDPOptBufExhausted = errors.New("Buffer unexpectedly exhausted")
- ErrNDPOptZeroLength = errors.New("NDP option has zero-valued Length field")
- ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body")
- ErrNDPInvalidLength = errors.New("NDP option's Length value is invalid as per relevant RFC")
+ ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body")
+ ErrNDPOptMalformedHeader = errors.New("NDP option has a malformed header")
)
// Next returns the next element in the backing NDPOptions, or true if we are
@@ -152,48 +181,50 @@ var (
func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
for {
// Do we still have elements to look at?
- if len(i.opts) == 0 {
+ if i.opts.Len() == 0 {
return nil, true, nil
}
- // Do we have enough bytes for an NDP option that has a Length
- // field of at least 1? Note, 0 in the Length field is invalid.
- if len(i.opts) < lengthByteUnits {
- return nil, true, ErrNDPOptBufExhausted
- }
-
// Get the Type field.
- t := i.opts[0]
-
- // Get the Length field.
- l := i.opts[1]
+ temp, err := i.opts.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ // ReadByte should only ever return nil or io.EOF.
+ panic(fmt.Sprintf("unexpected error when reading the option's Type field: %s", err))
+ }
- // This would indicate an erroneous NDP option as the Length
- // field should never be 0.
- if l == 0 {
- return nil, true, ErrNDPOptZeroLength
+ // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once
+ // we start parsing an option; we expect the buffer to contain enough
+ // bytes for the whole option.
+ return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Type field: %w", io.ErrUnexpectedEOF)
}
+ kind := NDPOptionIdentifier(temp)
- // How many bytes are in the option body?
- numBytes := int(l) * lengthByteUnits
- numBodyBytes := numBytes - 2
-
- potentialBody := i.opts[2:]
+ // Get the Length field.
+ length, err := i.opts.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ panic(fmt.Sprintf("unexpected error when reading the option's Length field for %s: %s", kind, err))
+ }
- // This would indicate an erroenous NDPOptions buffer as we ran
- // out of the buffer in the middle of an NDP option.
- if left := len(potentialBody); left < numBodyBytes {
- return nil, true, ErrNDPOptBufExhausted
+ return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Length field for %s: %w", kind, io.ErrUnexpectedEOF)
}
- // Get only the options body, leaving the rest of the options
- // buffer alone.
- body := potentialBody[:numBodyBytes]
+ // This would indicate an erroneous NDP option as the Length field should
+ // never be 0.
+ if length == 0 {
+ return nil, true, fmt.Errorf("zero valued Length field for %s: %w", kind, ErrNDPOptMalformedHeader)
+ }
- // Update opts with the remaining options body.
- i.opts = i.opts[numBytes:]
+ // Get the body.
+ numBytes := int(length) * lengthByteUnits
+ numBodyBytes := numBytes - 2
+ body := i.opts.Next(numBodyBytes)
+ if len(body) < numBodyBytes {
+ return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Body for %s: %w", kind, io.ErrUnexpectedEOF)
+ }
- switch t {
+ switch kind {
case NDPSourceLinkLayerAddressOptionType:
return NDPSourceLinkLayerAddressOption(body), false, nil
@@ -205,22 +236,23 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
// body is ndpPrefixInformationLength, as per RFC 4861
// section 4.6.2.
if numBodyBytes != ndpPrefixInformationLength {
- return nil, true, ErrNDPOptMalformedBody
+ return nil, true, fmt.Errorf("got %d bytes for NDP Prefix Information option's body, expected %d bytes: %w", numBodyBytes, ndpPrefixInformationLength, ErrNDPOptMalformedBody)
}
return NDPPrefixInformation(body), false, nil
case NDPRecursiveDNSServerOptionType:
- // RFC 8106 section 5.3.1 outlines that the RDNSS option
- // must have a minimum length of 3 so it contains at
- // least one IPv6 address.
- if l < minNDPRecursiveDNSServerLength {
- return nil, true, ErrNDPInvalidLength
+ opt := NDPRecursiveDNSServer(body)
+ if err := opt.checkAddresses(); err != nil {
+ return nil, true, err
}
- opt := NDPRecursiveDNSServer(body)
- if len(opt.Addresses()) == 0 {
- return nil, true, ErrNDPOptMalformedBody
+ return opt, false, nil
+
+ case NDPDNSSearchListOptionType:
+ opt := NDPDNSSearchList(body)
+ if err := opt.checkDomainNames(); err != nil {
+ return nil, true, err
}
return opt, false, nil
@@ -247,10 +279,16 @@ type NDPOptions []byte
//
// See NDPOptionIterator for more information.
func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) {
- it := NDPOptionIterator{opts: b}
+ it := NDPOptionIterator{
+ opts: bytes.NewBuffer(b),
+ }
if check {
- for it2 := it; true; {
+ it2 := NDPOptionIterator{
+ opts: bytes.NewBuffer(b),
+ }
+
+ for {
if _, done, err := it2.Next(); err != nil || done {
return it, err
}
@@ -278,7 +316,7 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int {
continue
}
- b[0] = o.Type()
+ b[0] = byte(o.Type())
// We know this safe because paddedLength would have returned
// 0 if o had an invalid length (> 255 * lengthByteUnits).
@@ -304,7 +342,7 @@ type NDPOption interface {
fmt.Stringer
// Type returns the type of the receiver.
- Type() uint8
+ Type() NDPOptionIdentifier
// Length returns the length of the body of the receiver, in bytes.
Length() int
@@ -386,7 +424,7 @@ func (b NDPOptionsSerializer) Length() int {
type NDPSourceLinkLayerAddressOption tcpip.LinkAddress
// Type implements NDPOption.Type.
-func (o NDPSourceLinkLayerAddressOption) Type() uint8 {
+func (o NDPSourceLinkLayerAddressOption) Type() NDPOptionIdentifier {
return NDPSourceLinkLayerAddressOptionType
}
@@ -426,7 +464,7 @@ func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
type NDPTargetLinkLayerAddressOption tcpip.LinkAddress
// Type implements NDPOption.Type.
-func (o NDPTargetLinkLayerAddressOption) Type() uint8 {
+func (o NDPTargetLinkLayerAddressOption) Type() NDPOptionIdentifier {
return NDPTargetLinkLayerAddressOptionType
}
@@ -466,7 +504,7 @@ func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
type NDPPrefixInformation []byte
// Type implements NDPOption.Type.
-func (o NDPPrefixInformation) Type() uint8 {
+func (o NDPPrefixInformation) Type() NDPOptionIdentifier {
return NDPPrefixInformationType
}
@@ -590,7 +628,7 @@ type NDPRecursiveDNSServer []byte
// Type returns the type of an NDP Recursive DNS Server option.
//
// Type implements NDPOption.Type.
-func (NDPRecursiveDNSServer) Type() uint8 {
+func (NDPRecursiveDNSServer) Type() NDPOptionIdentifier {
return NDPRecursiveDNSServerOptionType
}
@@ -613,7 +651,12 @@ func (o NDPRecursiveDNSServer) serializeInto(b []byte) int {
// String implements fmt.Stringer.String.
func (o NDPRecursiveDNSServer) String() string {
- return fmt.Sprintf("%T(%s valid for %s)", o, o.Addresses(), o.Lifetime())
+ lt := o.Lifetime()
+ addrs, err := o.Addresses()
+ if err != nil {
+ return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err)
+ }
+ return fmt.Sprintf("%T(%s valid for %s)", o, addrs, lt)
}
// Lifetime returns the length of time that the DNS server addresses
@@ -632,29 +675,225 @@ func (o NDPRecursiveDNSServer) Lifetime() time.Duration {
// Addresses returns the recursive DNS server IPv6 addresses that may be
// used for name resolution.
//
-// Note, some of the addresses returned MAY be link-local addresses.
+// Note, the addresses MAY be link-local addresses.
+func (o NDPRecursiveDNSServer) Addresses() ([]tcpip.Address, error) {
+ var addrs []tcpip.Address
+ return addrs, o.iterAddresses(func(addr tcpip.Address) { addrs = append(addrs, addr) })
+}
+
+// checkAddresses iterates over the addresses in an NDP Recursive DNS Server
+// option and returns any error it encounters.
+func (o NDPRecursiveDNSServer) checkAddresses() error {
+ return o.iterAddresses(nil)
+}
+
+// iterAddresses iterates over the addresses in an NDP Recursive DNS Server
+// option and calls a function with each valid unicast IPv6 address.
//
-// Addresses may panic if o does not hold valid IPv6 addresses.
-func (o NDPRecursiveDNSServer) Addresses() []tcpip.Address {
- l := len(o)
- if l < ndpRecursiveDNSServerAddressesOffset {
- return nil
+// Note, the addresses MAY be link-local addresses.
+func (o NDPRecursiveDNSServer) iterAddresses(fn func(tcpip.Address)) error {
+ if l := len(o); l < minNDPRecursiveDNSServerBodySize {
+ return fmt.Errorf("got %d bytes for NDP Recursive DNS Server option's body, expected at least %d bytes: %w", l, minNDPRecursiveDNSServerBodySize, io.ErrUnexpectedEOF)
}
- l -= ndpRecursiveDNSServerAddressesOffset
+ o = o[ndpRecursiveDNSServerAddressesOffset:]
+ l := len(o)
if l%IPv6AddressSize != 0 {
- return nil
+ return fmt.Errorf("NDP Recursive DNS Server option's body ends in the middle of an IPv6 address (addresses body size = %d bytes): %w", l, ErrNDPOptMalformedBody)
}
- buf := o[ndpRecursiveDNSServerAddressesOffset:]
- var addrs []tcpip.Address
- for len(buf) > 0 {
- addr := tcpip.Address(buf[:IPv6AddressSize])
+ for i := 0; len(o) != 0; i++ {
+ addr := tcpip.Address(o[:IPv6AddressSize])
if !IsV6UnicastAddress(addr) {
- return nil
+ return fmt.Errorf("%d-th address (%s) in NDP Recursive DNS Server option is not a valid unicast IPv6 address: %w", i, addr, ErrNDPOptMalformedBody)
+ }
+
+ if fn != nil {
+ fn(addr)
}
- addrs = append(addrs, addr)
- buf = buf[IPv6AddressSize:]
+
+ o = o[IPv6AddressSize:]
}
- return addrs
+
+ return nil
+}
+
+// NDPDNSSearchList is the NDP DNS Search List option, as defined by
+// RFC 8106 section 5.2.
+type NDPDNSSearchList []byte
+
+// Type implements NDPOption.Type.
+func (o NDPDNSSearchList) Type() NDPOptionIdentifier {
+ return NDPDNSSearchListOptionType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPDNSSearchList) Length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPDNSSearchList) serializeInto(b []byte) int {
+ used := copy(b, o)
+
+ // Zero out the reserved bytes that are before the Lifetime field.
+ for i := 0; i < ndpDNSSearchListLifetimeOffset; i++ {
+ b[i] = 0
+ }
+
+ return used
+}
+
+// String implements fmt.Stringer.String.
+func (o NDPDNSSearchList) String() string {
+ lt := o.Lifetime()
+ domainNames, err := o.DomainNames()
+ if err != nil {
+ return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err)
+ }
+ return fmt.Sprintf("%T(%s valid for %s)", o, domainNames, lt)
+}
+
+// Lifetime returns the length of time that the DNS search list of domain names
+// in this option may be used for name resolution.
+//
+// Note, a value of 0 implies the domain names should no longer be used,
+// and a value of infinity/forever is represented by NDPInfiniteLifetime.
+func (o NDPDNSSearchList) Lifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 8106 section 5.1.
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpDNSSearchListLifetimeOffset:]))
+}
+
+// DomainNames returns a DNS search list of domain names.
+//
+// DomainNames will parse the backing buffer as outlined by RFC 1035 section
+// 3.1 and return a list of strings, with all domain names in lower case.
+func (o NDPDNSSearchList) DomainNames() ([]string, error) {
+ var domainNames []string
+ return domainNames, o.iterDomainNames(func(domainName string) { domainNames = append(domainNames, domainName) })
+}
+
+// checkDomainNames iterates over the domain names in an NDP DNS Search List
+// option and returns any error it encounters.
+func (o NDPDNSSearchList) checkDomainNames() error {
+ return o.iterDomainNames(nil)
+}
+
+// iterDomainNames iterates over the domain names in an NDP DNS Search List
+// option and calls a function with each valid domain name.
+func (o NDPDNSSearchList) iterDomainNames(fn func(string)) error {
+ if l := len(o); l < minNDPDNSSearchListBodySize {
+ return fmt.Errorf("got %d bytes for NDP DNS Search List option's body, expected at least %d bytes: %w", l, minNDPDNSSearchListBodySize, io.ErrUnexpectedEOF)
+ }
+
+ var searchList bytes.Reader
+ searchList.Reset(o[ndpDNSSearchListDomainNamesOffset:])
+
+ var scratch [maxDomainNameLength]byte
+ domainName := bytes.NewBuffer(scratch[:])
+
+ // Parse the domain names, as per RFC 1035 section 3.1.
+ for searchList.Len() != 0 {
+ domainName.Reset()
+
+ // Parse a label within a domain name, as per RFC 1035 section 3.1.
+ for {
+ // The first byte is the label length.
+ labelLenByte, err := searchList.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ // ReadByte should only ever return nil or io.EOF.
+ panic(fmt.Sprintf("unexpected error when reading a label's length: %s", err))
+ }
+
+ // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected
+ // once we start parsing a domain name; we expect the buffer to contain
+ // enough bytes for the whole domain name.
+ return fmt.Errorf("unexpected exhausted buffer while parsing a new label for a domain from NDP Search List option: %w", io.ErrUnexpectedEOF)
+ }
+ labelLen := int(labelLenByte)
+
+ // A zero-length label implies the end of a domain name.
+ if labelLen == 0 {
+ // If the domain name is empty or we have no callback function, do
+ // nothing further with the current domain name.
+ if domainName.Len() == 0 || fn == nil {
+ break
+ }
+
+ // Ignore the trailing period in the parsed domain name.
+ domainName.Truncate(domainName.Len() - 1)
+ fn(domainName.String())
+ break
+ }
+
+ // The label's length must not exceed the maximum length for a label.
+ if labelLen > maxDomainNameLabelLength {
+ return fmt.Errorf("label length of %d bytes is greater than the max label length of %d bytes for an NDP Search List option: %w", labelLen, maxDomainNameLabelLength, ErrNDPOptMalformedBody)
+ }
+
+ // The label (and trailing period) must not make the domain name too long.
+ if labelLen+1 > domainName.Cap()-domainName.Len() {
+ return fmt.Errorf("label would make an NDP Search List option's domain name longer than the max domain name length of %d bytes: %w", maxDomainNameLength, ErrNDPOptMalformedBody)
+ }
+
+ // Copy the label and add a trailing period.
+ for i := 0; i < labelLen; i++ {
+ b, err := searchList.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ panic(fmt.Sprintf("unexpected error when reading domain name's label: %s", err))
+ }
+
+ return fmt.Errorf("read %d out of %d bytes for a domain name's label from NDP Search List option: %w", i, labelLen, io.ErrUnexpectedEOF)
+ }
+
+ // As per RFC 1035 section 2.3.1:
+ // 1) the label must only contain ASCII include letters, digits and
+ // hyphens
+ // 2) the first character in a label must be a letter
+ // 3) the last letter in a label must be a letter or digit
+
+ if !isLetter(b) {
+ if i == 0 {
+ return fmt.Errorf("first character of a domain name's label in an NDP Search List option must be a letter, got character code = %d: %w", b, ErrNDPOptMalformedBody)
+ }
+
+ if b == '-' {
+ if i == labelLen-1 {
+ return fmt.Errorf("last character of a domain name's label in an NDP Search List option must not be a hyphen (-): %w", ErrNDPOptMalformedBody)
+ }
+ } else if !isDigit(b) {
+ return fmt.Errorf("domain name's label in an NDP Search List option may only contain letters, digits and hyphens, got character code = %d: %w", b, ErrNDPOptMalformedBody)
+ }
+ }
+
+ // If b is an upper case character, make it lower case.
+ if isUpperLetter(b) {
+ b = b - 'A' + 'a'
+ }
+
+ if err := domainName.WriteByte(b); err != nil {
+ panic(fmt.Sprintf("unexpected error writing label to domain name buffer: %s", err))
+ }
+ }
+ if err := domainName.WriteByte('.'); err != nil {
+ panic(fmt.Sprintf("unexpected error writing trailing period to domain name buffer: %s", err))
+ }
+ }
+ }
+
+ return nil
+}
+
+func isLetter(b byte) bool {
+ return b >= 'a' && b <= 'z' || isUpperLetter(b)
+}
+
+func isUpperLetter(b byte) bool {
+ return b >= 'A' && b <= 'Z'
+}
+
+func isDigit(b byte) bool {
+ return b >= '0' && b <= '9'
}
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index 1cb9f5dc8..dc4591253 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -16,6 +16,10 @@ package header
import (
"bytes"
+ "errors"
+ "fmt"
+ "io"
+ "regexp"
"testing"
"time"
@@ -115,7 +119,7 @@ func TestNDPNeighborAdvert(t *testing.T) {
// Make sure flags got updated in the backing buffer.
if got := b[ndpNAFlagsOffset]; got != 64 {
- t.Errorf("got flags byte = %d, want = 64")
+ t.Errorf("got flags byte = %d, want = 64", got)
}
}
@@ -543,8 +547,12 @@ func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) {
want := []tcpip.Address{
"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
}
- if got := opt.Addresses(); !cmp.Equal(got, want) {
- t.Errorf("got Addresses = %v, want = %v", got, want)
+ addrs, err := opt.Addresses()
+ if err != nil {
+ t.Errorf("opt.Addresses() = %s", err)
+ }
+ if diff := cmp.Diff(addrs, want); diff != "" {
+ t.Errorf("mismatched addresses (-want +got):\n%s", diff)
}
// Iterator should not return anything else.
@@ -638,8 +646,12 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) {
if got := opt.Lifetime(); got != test.lifetime {
t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime)
}
- if got := opt.Addresses(); !cmp.Equal(got, test.addrs) {
- t.Errorf("got Addresses = %v, want = %v", got, test.addrs)
+ addrs, err := opt.Addresses()
+ if err != nil {
+ t.Errorf("opt.Addresses() = %s", err)
+ }
+ if diff := cmp.Diff(addrs, test.addrs); diff != "" {
+ t.Errorf("mismatched addresses (-want +got):\n%s", diff)
}
// Iterator should not return anything else.
@@ -657,42 +669,513 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) {
}
}
+// TestNDPDNSSearchListOption tests the getters of NDPDNSSearchList.
+func TestNDPDNSSearchListOption(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ lifetime time.Duration
+ domainNames []string
+ err error
+ }{
+ {
+ name: "Valid1Label",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, 'a', 'b', 'c',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: []string{
+ "abc",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid2Label",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 5,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 0,
+ 0, 0, 0, 0, 0, 0,
+ },
+ lifetime: 5 * time.Second,
+ domainNames: []string{
+ "abc.abcd",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid3Label",
+ buf: []byte{
+ 0, 0,
+ 1, 0, 0, 0,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 1, 'e',
+ 0,
+ 0, 0, 0, 0,
+ },
+ lifetime: 16777216 * time.Second,
+ domainNames: []string{
+ "abc.abcd.e",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid2Domains",
+ buf: []byte{
+ 0, 0,
+ 1, 2, 3, 4,
+ 3, 'a', 'b', 'c',
+ 0,
+ 2, 'd', 'e',
+ 3, 'x', 'y', 'z',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: 16909060 * time.Second,
+ domainNames: []string{
+ "abc",
+ "de.xyz",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid3DomainsMixedCase",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 3, 'a', 'B', 'c',
+ 0,
+ 2, 'd', 'E',
+ 3, 'X', 'y', 'z',
+ 0,
+ 1, 'J',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abc",
+ "de.xyz",
+ "j",
+ },
+ err: nil,
+ },
+ {
+ name: "ValidDomainAfterNULL",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 3, 'a', 'B', 'c',
+ 0, 0, 0, 0,
+ 2, 'd', 'E',
+ 3, 'X', 'y', 'z',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abc",
+ "de.xyz",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid0Domains",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 0,
+ 0, 0, 0, 0, 0, 0, 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: nil,
+ },
+ {
+ name: "NoTrailingNull",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 7, 'a', 'b', 'c', 'd', 'e', 'f', 'g',
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "IncorrectLength",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 8, 'a', 'b', 'c', 'd', 'e', 'f', 'g',
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "IncorrectLengthWithNULL",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 7, 'a', 'b', 'c', 'd', 'e', 'f',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "LabelOfLength63",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk",
+ },
+ err: nil,
+ },
+ {
+ name: "LabelOfLength64",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 64, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "DomainNameOfLength255",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghij",
+ },
+ err: nil,
+ },
+ {
+ name: "DomainNameOfLength256",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "StartingDigitForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, '9', 'b', 'c',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "StartingHyphenForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, '-', 'b', 'c',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "EndingHyphenForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, 'a', 'b', '-',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "EndingDigitForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, 'a', 'b', '9',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: []string{
+ "ab9",
+ },
+ err: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opt := NDPDNSSearchList(test.buf)
+
+ if got := opt.Lifetime(); got != test.lifetime {
+ t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime)
+ }
+ domainNames, err := opt.DomainNames()
+ if !errors.Is(err, test.err) {
+ t.Errorf("opt.DomainNames() = %s", err)
+ }
+ if diff := cmp.Diff(domainNames, test.domainNames); diff != "" {
+ t.Errorf("mismatched domain names (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) {
+ for r := rune(0); r <= 255; r++ {
+ t.Run(fmt.Sprintf("RuneVal=%d", r), func(t *testing.T) {
+ buf := []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 3, 'a', 0 /* will be replaced */, 'c',
+ 0,
+ 0, 0, 0,
+ }
+ buf[8] = uint8(r)
+ opt := NDPDNSSearchList(buf)
+
+ // As per RFC 1035 section 2.3.1, the label must only include ASCII
+ // letters, digits and hyphens (a-z, A-Z, 0-9, -).
+ var expectedErr error
+ re := regexp.MustCompile(`[a-zA-Z0-9-]`)
+ if !re.Match([]byte{byte(r)}) {
+ expectedErr = ErrNDPOptMalformedBody
+ }
+
+ if domainNames, err := opt.DomainNames(); !errors.Is(err, expectedErr) {
+ t.Errorf("got opt.DomainNames() = (%s, %v), want = (_, %v)", domainNames, err, ErrNDPOptMalformedBody)
+ }
+ })
+ }
+}
+
+func TestNDPDNSSearchListOptionSerialize(t *testing.T) {
+ b := []byte{
+ 9, 8,
+ 1, 0, 0, 0,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 1, 'e',
+ 0,
+ }
+ targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
+ expected := []byte{
+ 31, 3, 0, 0,
+ 1, 0, 0, 0,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 1, 'e',
+ 0,
+ 0, 0, 0, 0,
+ }
+ opts := NDPOptions(targetBuf)
+ serializer := NDPOptionsSerializer{
+ NDPDNSSearchList(b),
+ }
+ if got, want := opts.Serialize(serializer), len(expected); got != want {
+ t.Errorf("got Serialize = %d, want = %d", got, want)
+ }
+ if !bytes.Equal(targetBuf, expected) {
+ t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPDNSSearchListOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPDNSSearchListOptionType)
+ }
+
+ opt, ok := next.(NDPDNSSearchList)
+ if !ok {
+ t.Fatalf("next (type = %T) cannot be casted to an NDPDNSSearchList", next)
+ }
+ if got := opt.Type(); got != 31 {
+ t.Errorf("got Type = %d, want = 31", got)
+ }
+ if got := opt.Length(); got != 22 {
+ t.Errorf("got Length = %d, want = 22", got)
+ }
+ if got, want := opt.Lifetime(), 16777216*time.Second; got != want {
+ t.Errorf("got Lifetime = %s, want = %s", got, want)
+ }
+ domainNames, err := opt.DomainNames()
+ if err != nil {
+ t.Errorf("opt.DomainNames() = %s", err)
+ }
+ if diff := cmp.Diff(domainNames, []string{"abc.abcd.e"}); diff != "" {
+ t.Errorf("domain names mismatch (-want +got):\n%s", diff)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+}
+
// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions
// the iterator was returned for is malformed.
func TestNDPOptionsIterCheck(t *testing.T) {
tests := []struct {
- name string
- buf []byte
- expected error
+ name string
+ buf []byte
+ expectedErr error
}{
{
- "ZeroLengthField",
- []byte{0, 0, 0, 0, 0, 0, 0, 0},
- ErrNDPOptZeroLength,
+ name: "ZeroLengthField",
+ buf: []byte{0, 0, 0, 0, 0, 0, 0, 0},
+ expectedErr: ErrNDPOptMalformedHeader,
},
{
- "ValidSourceLinkLayerAddressOption",
- []byte{1, 1, 1, 2, 3, 4, 5, 6},
- nil,
+ name: "ValidSourceLinkLayerAddressOption",
+ buf: []byte{1, 1, 1, 2, 3, 4, 5, 6},
+ expectedErr: nil,
},
{
- "TooSmallSourceLinkLayerAddressOption",
- []byte{1, 1, 1, 2, 3, 4, 5},
- ErrNDPOptBufExhausted,
+ name: "TooSmallSourceLinkLayerAddressOption",
+ buf: []byte{1, 1, 1, 2, 3, 4, 5},
+ expectedErr: io.ErrUnexpectedEOF,
},
{
- "ValidTargetLinkLayerAddressOption",
- []byte{2, 1, 1, 2, 3, 4, 5, 6},
- nil,
+ name: "ValidTargetLinkLayerAddressOption",
+ buf: []byte{2, 1, 1, 2, 3, 4, 5, 6},
+ expectedErr: nil,
},
{
- "TooSmallTargetLinkLayerAddressOption",
- []byte{2, 1, 1, 2, 3, 4, 5},
- ErrNDPOptBufExhausted,
+ name: "TooSmallTargetLinkLayerAddressOption",
+ buf: []byte{2, 1, 1, 2, 3, 4, 5},
+ expectedErr: io.ErrUnexpectedEOF,
},
{
- "ValidPrefixInformation",
- []byte{
+ name: "ValidPrefixInformation",
+ buf: []byte{
3, 4, 43, 64,
1, 2, 3, 4,
5, 6, 7, 8,
@@ -702,11 +1185,11 @@ func TestNDPOptionsIterCheck(t *testing.T) {
17, 18, 19, 20,
21, 22, 23, 24,
},
- nil,
+ expectedErr: nil,
},
{
- "TooSmallPrefixInformation",
- []byte{
+ name: "TooSmallPrefixInformation",
+ buf: []byte{
3, 4, 43, 64,
1, 2, 3, 4,
5, 6, 7, 8,
@@ -716,11 +1199,11 @@ func TestNDPOptionsIterCheck(t *testing.T) {
17, 18, 19, 20,
21, 22, 23,
},
- ErrNDPOptBufExhausted,
+ expectedErr: io.ErrUnexpectedEOF,
},
{
- "InvalidPrefixInformationLength",
- []byte{
+ name: "InvalidPrefixInformationLength",
+ buf: []byte{
3, 3, 43, 64,
1, 2, 3, 4,
5, 6, 7, 8,
@@ -728,11 +1211,11 @@ func TestNDPOptionsIterCheck(t *testing.T) {
9, 10, 11, 12,
13, 14, 15, 16,
},
- ErrNDPOptMalformedBody,
+ expectedErr: ErrNDPOptMalformedBody,
},
{
- "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation",
- []byte{
+ name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation",
+ buf: []byte{
// Source Link-Layer Address.
1, 1, 1, 2, 3, 4, 5, 6,
@@ -749,11 +1232,11 @@ func TestNDPOptionsIterCheck(t *testing.T) {
17, 18, 19, 20,
21, 22, 23, 24,
},
- nil,
+ expectedErr: nil,
},
{
- "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized",
- []byte{
+ name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized",
+ buf: []byte{
// Source Link-Layer Address.
1, 1, 1, 2, 3, 4, 5, 6,
@@ -775,52 +1258,153 @@ func TestNDPOptionsIterCheck(t *testing.T) {
17, 18, 19, 20,
21, 22, 23, 24,
},
- nil,
+ expectedErr: nil,
},
{
- "InvalidRecursiveDNSServerCutsOffAddress",
- []byte{
+ name: "InvalidRecursiveDNSServerCutsOffAddress",
+ buf: []byte{
25, 4, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 1, 2, 3, 4, 5, 6, 7,
},
- ErrNDPOptMalformedBody,
+ expectedErr: ErrNDPOptMalformedBody,
},
{
- "InvalidRecursiveDNSServerInvalidLengthField",
- []byte{
+ name: "InvalidRecursiveDNSServerInvalidLengthField",
+ buf: []byte{
25, 2, 0, 0,
0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8,
},
- ErrNDPInvalidLength,
+ expectedErr: io.ErrUnexpectedEOF,
},
{
- "RecursiveDNSServerTooSmall",
- []byte{
+ name: "RecursiveDNSServerTooSmall",
+ buf: []byte{
25, 1, 0, 0,
0, 0, 0,
},
- ErrNDPOptBufExhausted,
+ expectedErr: io.ErrUnexpectedEOF,
},
{
- "RecursiveDNSServerMulticast",
- []byte{
+ name: "RecursiveDNSServerMulticast",
+ buf: []byte{
25, 3, 0, 0,
0, 0, 0, 0,
255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
},
- ErrNDPOptMalformedBody,
+ expectedErr: ErrNDPOptMalformedBody,
},
{
- "RecursiveDNSServerUnspecified",
- []byte{
+ name: "RecursiveDNSServerUnspecified",
+ buf: []byte{
25, 3, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
},
- ErrNDPOptMalformedBody,
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "DNSSearchListLargeCompliantRFC1035",
+ buf: []byte{
+ 31, 33, 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j',
+ 0,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "DNSSearchListNonCompliantRFC1035",
+ buf: []byte{
+ 31, 33, 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "DNSSearchListValidSmall",
+ buf: []byte{
+ 31, 2, 0, 0,
+ 0, 0, 0, 0,
+ 6, 'a', 'b', 'c', 'd', 'e', 'f',
+ 0,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "DNSSearchListTooSmall",
+ buf: []byte{
+ 31, 1, 0, 0,
+ 0, 0, 0,
+ },
+ expectedErr: io.ErrUnexpectedEOF,
},
}
@@ -828,8 +1412,8 @@ func TestNDPOptionsIterCheck(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
opts := NDPOptions(test.buf)
- if _, err := opts.Iter(true); err != test.expected {
- t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expected)
+ if _, err := opts.Iter(true); !errors.Is(err, test.expectedErr) {
+ t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expectedErr)
}
// test.buf may be malformed but we chose not to check
diff --git a/pkg/tcpip/header/ndpoptionidentifier_string.go b/pkg/tcpip/header/ndpoptionidentifier_string.go
new file mode 100644
index 000000000..6fe9a336b
--- /dev/null
+++ b/pkg/tcpip/header/ndpoptionidentifier_string.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type NDPOptionIdentifier ."; DO NOT EDIT.
+
+package header
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[NDPSourceLinkLayerAddressOptionType-1]
+ _ = x[NDPTargetLinkLayerAddressOptionType-2]
+ _ = x[NDPPrefixInformationType-3]
+ _ = x[NDPRecursiveDNSServerOptionType-25]
+}
+
+const (
+ _NDPOptionIdentifier_name_0 = "NDPSourceLinkLayerAddressOptionTypeNDPTargetLinkLayerAddressOptionTypeNDPPrefixInformationType"
+ _NDPOptionIdentifier_name_1 = "NDPRecursiveDNSServerOptionType"
+)
+
+var (
+ _NDPOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94}
+)
+
+func (i NDPOptionIdentifier) String() string {
+ switch {
+ case 1 <= i && i <= 3:
+ i -= 1
+ return _NDPOptionIdentifier_name_0[_NDPOptionIdentifier_index_0[i]:_NDPOptionIdentifier_index_0[i+1]]
+ case i == 25:
+ return _NDPOptionIdentifier_name_1
+ default:
+ return "NDPOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+}
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index 74412c894..9339d637f 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -99,6 +99,11 @@ func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
}
+// SetLength sets the "length" field of the udp header.
+func (b UDP) SetLength(length uint16) {
+ binary.BigEndian.PutUint16(b[udpLength:], length)
+}
+
// CalculateChecksum calculates the checksum of the udp packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index a8d6653ce..9bf67686d 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -28,7 +28,7 @@ import (
// PacketInfo holds all the information about an outbound packet.
type PacketInfo struct {
- Pkt stack.PacketBuffer
+ Pkt *stack.PacketBuffer
Proto tcpip.NetworkProtocolNumber
GSO *stack.GSO
Route stack.Route
@@ -50,13 +50,11 @@ type NotificationHandle struct {
}
type queue struct {
+ // c is the outbound packet channel.
+ c chan PacketInfo
// mu protects fields below.
- mu sync.RWMutex
- // c is the outbound packet channel. Sending to c should hold mu.
- c chan PacketInfo
- numWrite int
- numRead int
- notify []*NotificationHandle
+ mu sync.RWMutex
+ notify []*NotificationHandle
}
func (q *queue) Close() {
@@ -64,11 +62,8 @@ func (q *queue) Close() {
}
func (q *queue) Read() (PacketInfo, bool) {
- q.mu.Lock()
- defer q.mu.Unlock()
select {
case p := <-q.c:
- q.numRead++
return p, true
default:
return PacketInfo{}, false
@@ -76,15 +71,8 @@ func (q *queue) Read() (PacketInfo, bool) {
}
func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) {
- // We have to receive from channel without holding the lock, since it can
- // block indefinitely. This will cause a window that numWrite - numRead
- // produces a larger number, but won't go to negative. numWrite >= numRead
- // still holds.
select {
case pkt := <-q.c:
- q.mu.Lock()
- defer q.mu.Unlock()
- q.numRead++
return pkt, true
case <-ctx.Done():
return PacketInfo{}, false
@@ -93,16 +81,12 @@ func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) {
func (q *queue) Write(p PacketInfo) bool {
wrote := false
-
- // It's important to make sure nobody can see numWrite until we increment it,
- // so numWrite >= numRead holds.
- q.mu.Lock()
select {
case q.c <- p:
wrote = true
- q.numWrite++
default:
}
+ q.mu.Lock()
notify := q.notify
q.mu.Unlock()
@@ -116,13 +100,7 @@ func (q *queue) Write(p PacketInfo) bool {
}
func (q *queue) Num() int {
- q.mu.RLock()
- defer q.mu.RUnlock()
- n := q.numWrite - q.numRead
- if n < 0 {
- panic("numWrite < numRead")
- }
- return n
+ return len(q.c)
}
func (q *queue) AddNotify(notify Notification) *NotificationHandle {
@@ -257,7 +235,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
route := r.Clone()
route.Release()
p := PacketInfo{
- Pkt: pkt,
+ Pkt: &pkt,
Proto: protocol,
GSO: gso,
Route: route,
@@ -269,21 +247,15 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
}
// WritePackets stores outbound packets into the channel.
-func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
// Clone r then release its resource so we only get the relevant fields from
// stack.Route without holding a reference to a NIC's endpoint.
route := r.Clone()
route.Release()
- payloadView := pkts[0].Data.ToView()
n := 0
- for _, pkt := range pkts {
- off := pkt.DataOffset
- size := pkt.DataSize
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
p := PacketInfo{
- Pkt: stack.PacketBuffer{
- Header: pkt.Header,
- Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(),
- },
+ Pkt: pkt,
Proto: protocol,
GSO: gso,
Route: route,
@@ -301,7 +273,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.Pac
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
p := PacketInfo{
- Pkt: stack.PacketBuffer{Data: vv},
+ Pkt: &stack.PacketBuffer{Data: vv},
Proto: 0,
GSO: nil,
}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 3b3b6909b..b857ce9d0 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -91,7 +91,7 @@ func (p PacketDispatchMode) String() string {
case PacketMMap:
return "PacketMMap"
default:
- return fmt.Sprintf("unknown packet dispatch mode %v", p)
+ return fmt.Sprintf("unknown packet dispatch mode '%d'", p)
}
}
@@ -441,118 +441,106 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// WritePackets writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- var ethHdrBuf []byte
- // hdr + data
- iovLen := 2
- if e.hdrSize > 0 {
- // Add ethernet header if needed.
- ethHdrBuf = make([]byte, header.EthernetMinimumSize)
- eth := header.Ethernet(ethHdrBuf)
- ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
- Type: protocol,
- }
-
- // Preserve the src address if it's set in the route.
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
- } else {
- ethHdr.SrcAddr = e.addr
- }
- eth.Encode(ethHdr)
- iovLen++
- }
+//
+// NOTE: This API uses sendmmsg to batch packets. As a result the underlying FD
+// picked to write the packet out has to be the same for all packets in the
+// list. In other words all packets in the batch should belong to the same
+// flow.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ n := pkts.Len()
- n := len(pkts)
-
- views := pkts[0].Data.Views()
- /*
- * Each boundary in views can add one more iovec.
- *
- * payload | | | |
- * -----------------------------
- * packets | | | | | | |
- * -----------------------------
- * iovecs | | | | | | | | |
- */
- iovec := make([]syscall.Iovec, n*iovLen+len(views)-1)
mmsgHdrs := make([]rawfile.MMsgHdr, n)
+ i := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ var ethHdrBuf []byte
+ iovLen := 0
+ if e.hdrSize > 0 {
+ // Add ethernet header if needed.
+ ethHdrBuf = make([]byte, header.EthernetMinimumSize)
+ eth := header.Ethernet(ethHdrBuf)
+ ethHdr := &header.EthernetFields{
+ DstAddr: r.RemoteLinkAddress,
+ Type: protocol,
+ }
- iovecIdx := 0
- viewIdx := 0
- viewOff := 0
- off := 0
- nextOff := 0
- for i := range pkts {
- // TODO(b/134618279): Different packets may have different data
- // in the future. We should handle this.
- if !viewsEqual(pkts[i].Data.Views(), views) {
- panic("All packets in pkts should have the same Data.")
+ // Preserve the src address if it's set in the route.
+ if r.LocalLinkAddress != "" {
+ ethHdr.SrcAddr = r.LocalLinkAddress
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+ iovLen++
}
- prevIovecIdx := iovecIdx
- mmsgHdr := &mmsgHdrs[i]
- mmsgHdr.Msg.Iov = &iovec[iovecIdx]
- packetSize := pkts[i].DataSize
- hdr := &pkts[i].Header
-
- off = pkts[i].DataOffset
- if off != nextOff {
- // We stop in a different point last time.
- size := packetSize
- viewIdx = 0
- viewOff = 0
- for size > 0 {
- if size >= len(views[viewIdx]) {
- viewIdx++
- viewOff = 0
- size -= len(views[viewIdx])
- } else {
- viewOff = size
- size = 0
+ var vnetHdrBuf []byte
+ vnetHdr := virtioNetHdr{}
+ if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if gso != nil {
+ vnetHdr.hdrLen = uint16(pkt.Header.UsedLength())
+ if gso.NeedsCsum {
+ vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
+ vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen
+ vnetHdr.csumOffset = gso.CsumOffset
+ }
+ if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS {
+ switch gso.Type {
+ case stack.GSOTCPv4:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
+ case stack.GSOTCPv6:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
+ default:
+ panic(fmt.Sprintf("Unknown gso type: %v", gso.Type))
+ }
+ vnetHdr.gsoSize = gso.MSS
}
}
+ vnetHdrBuf = vnetHdrToByteSlice(&vnetHdr)
+ iovLen++
}
- nextOff = off + packetSize
+ iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views()))
+ mmsgHdr := &mmsgHdrs[i]
+ mmsgHdr.Msg.Iov = &iovecs[0]
+ iovecIdx := 0
+ if vnetHdrBuf != nil {
+ v := &iovecs[iovecIdx]
+ v.Base = &vnetHdrBuf[0]
+ v.Len = uint64(len(vnetHdrBuf))
+ iovecIdx++
+ }
if ethHdrBuf != nil {
- v := &iovec[iovecIdx]
+ v := &iovecs[iovecIdx]
v.Base = &ethHdrBuf[0]
v.Len = uint64(len(ethHdrBuf))
iovecIdx++
}
-
- v := &iovec[iovecIdx]
+ pktSize := uint64(0)
+ // Encode L3 Header
+ v := &iovecs[iovecIdx]
+ hdr := &pkt.Header
hdrView := hdr.View()
v.Base = &hdrView[0]
v.Len = uint64(len(hdrView))
+ pktSize += v.Len
iovecIdx++
- for packetSize > 0 {
- vec := &iovec[iovecIdx]
+ // Now encode the Transport Payload.
+ pktViews := pkt.Data.Views()
+ for i := range pktViews {
+ vec := &iovecs[iovecIdx]
iovecIdx++
-
- v := views[viewIdx]
- vec.Base = &v[viewOff]
- s := len(v) - viewOff
- if s <= packetSize {
- viewIdx++
- viewOff = 0
- } else {
- s = packetSize
- viewOff += s
- }
- vec.Len = uint64(s)
- packetSize -= s
+ vec.Base = &pktViews[i][0]
+ vec.Len = uint64(len(pktViews[i]))
+ pktSize += vec.Len
}
-
- mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx)
+ mmsgHdr.Msg.Iovlen = uint64(iovecIdx)
+ i++
}
packets := 0
for packets < n {
- fd := e.fds[pkts[packets].Hash%uint32(len(e.fds))]
+ fd := e.fds[pkts.Front().Hash%uint32(len(e.fds))]
sent, err := rawfile.NonBlockingSendMMsg(fd, mmsgHdrs)
if err != nil {
return packets, err
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 4039753b7..1e2255bfa 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index f5973066d..a5478ce17 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -87,7 +87,7 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber,
// WritePackets writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if
// r.RemoteAddress has a route registered in this endpoint.
-func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
endpoint, ok := m.routes[r.RemoteAddress]
if !ok {
return 0, tcpip.ErrNoRoute
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 6461d0108..0796d717e 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -214,7 +214,7 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.Netw
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 0a6b8945c..be2537a82 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -21,11 +21,9 @@
package sniffer
import (
- "bytes"
"encoding/binary"
"fmt"
"io"
- "os"
"sync/atomic"
"time"
@@ -42,12 +40,12 @@ import (
// LogPackets must be accessed atomically.
var LogPackets uint32 = 1
-// LogPacketsToFile is a flag used to enable or disable logging packets to a
-// pcap file. Valid values are 0 or 1. A file must have been specified when the
+// LogPacketsToPCAP is a flag used to enable or disable logging packets to a
+// pcap writer. Valid values are 0 or 1. A writer must have been specified when the
// sniffer was created for this flag to have effect.
//
-// LogPacketsToFile must be accessed atomically.
-var LogPacketsToFile uint32 = 1
+// LogPacketsToPCAP must be accessed atomically.
+var LogPacketsToPCAP uint32 = 1
var transportProtocolMinSizes map[tcpip.TransportProtocolNumber]int = map[tcpip.TransportProtocolNumber]int{
header.ICMPv4ProtocolNumber: header.IPv4MinimumSize,
@@ -59,7 +57,7 @@ var transportProtocolMinSizes map[tcpip.TransportProtocolNumber]int = map[tcpip.
type endpoint struct {
dispatcher stack.NetworkDispatcher
lower stack.LinkEndpoint
- file *os.File
+ writer io.Writer
maxPCAPLen uint32
}
@@ -99,23 +97,22 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error {
})
}
-// NewWithFile creates a new sniffer link-layer endpoint. It wraps around
-// another endpoint and logs packets and they traverse the endpoint.
+// NewWithWriter creates a new sniffer link-layer endpoint. It wraps around
+// another endpoint and logs packets as they traverse the endpoint.
//
-// Packets can be logged to file in the pcap format. A sniffer created
-// with this function will not emit packets using the standard log
-// package.
+// Packets are logged to writer in the pcap format. A sniffer created with this
+// function will not emit packets using the standard log package.
//
// snapLen is the maximum amount of a packet to be saved. Packets with a length
-// less than or equal too snapLen will be saved in their entirety. Longer
+// less than or equal to snapLen will be saved in their entirety. Longer
// packets will be truncated to snapLen.
-func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) {
- if err := writePCAPHeader(file, snapLen); err != nil {
+func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (stack.LinkEndpoint, error) {
+ if err := writePCAPHeader(writer, snapLen); err != nil {
return nil, err
}
return &endpoint{
lower: lower,
- file: file,
+ writer: writer,
maxPCAPLen: snapLen,
}, nil
}
@@ -124,36 +121,7 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
- if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("recv", protocol, pkt.Data.First(), nil)
- }
- if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
- vs := pkt.Data.Views()
- length := pkt.Data.Size()
- if length > int(e.maxPCAPLen) {
- length = int(e.maxPCAPLen)
- }
-
- buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
- if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(pkt.Data.Size()))); err != nil {
- panic(err)
- }
- for _, v := range vs {
- if length == 0 {
- break
- }
- if len(v) > length {
- v = v[:length]
- }
- if _, err := buf.Write([]byte(v)); err != nil {
- panic(err)
- }
- length -= len(v)
- }
- if _, err := e.file.Write(buf.Bytes()); err != nil {
- panic(err)
- }
- }
+ e.dumpPacket("recv", nil, protocol, &pkt)
e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt)
}
@@ -200,31 +168,43 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
-func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
- if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("send", protocol, pkt.Header.View(), gso)
+func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ writer := e.writer
+ if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
+ first := pkt.Header.View()
+ if len(first) == 0 {
+ first = pkt.Data.First()
+ }
+ logPacket(prefix, protocol, first, gso)
}
- if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
- hdrBuf := pkt.Header.View()
- length := len(hdrBuf) + pkt.Data.Size()
- if length > int(e.maxPCAPLen) {
- length = int(e.maxPCAPLen)
+ if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
+ totalLength := pkt.Header.UsedLength() + pkt.Data.Size()
+ length := totalLength
+ if max := int(e.maxPCAPLen); length > max {
+ length = max
}
-
- buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
- if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+pkt.Data.Size()))); err != nil {
+ if err := binary.Write(writer, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(totalLength))); err != nil {
panic(err)
}
- if len(hdrBuf) > length {
- hdrBuf = hdrBuf[:length]
- }
- if _, err := buf.Write(hdrBuf); err != nil {
- panic(err)
+ write := func(b []byte) {
+ if len(b) > length {
+ b = b[:length]
+ }
+ for len(b) != 0 {
+ n, err := writer.Write(b)
+ if err != nil {
+ panic(err)
+ }
+ b = b[n:]
+ length -= n
+ }
}
- length -= len(hdrBuf)
- logVectorisedView(pkt.Data, length, buf)
- if _, err := e.file.Write(buf.Bytes()); err != nil {
- panic(err)
+ write(pkt.Header.View())
+ for _, view := range pkt.Data.Views() {
+ if length == 0 {
+ break
+ }
+ write(view)
}
}
}
@@ -233,68 +213,30 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumb
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
- e.dumpPacket(gso, protocol, pkt)
+ e.dumpPacket("send", gso, protocol, &pkt)
return e.lower.WritePacket(r, gso, protocol, pkt)
}
// WritePackets implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- view := pkts[0].Data.ToView()
- for _, pkt := range pkts {
- e.dumpPacket(gso, protocol, stack.PacketBuffer{
- Header: pkt.Header,
- Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(),
- })
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.dumpPacket("send", gso, protocol, pkt)
}
return e.lower.WritePackets(r, gso, pkts, protocol)
}
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */)
- }
- if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
- length := vv.Size()
- if length > int(e.maxPCAPLen) {
- length = int(e.maxPCAPLen)
- }
-
- buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
- if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil {
- panic(err)
- }
- logVectorisedView(vv, length, buf)
- if _, err := e.file.Write(buf.Bytes()); err != nil {
- panic(err)
- }
- }
+ e.dumpPacket("send", nil, 0, &stack.PacketBuffer{
+ Data: vv,
+ })
return e.lower.WriteRawPacket(vv)
}
-func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) {
- if length <= 0 {
- return
- }
- for _, v := range vv.Views() {
- if len(v) > length {
- v = v[:length]
- }
- n, err := buf.Write(v)
- if err != nil {
- panic(err)
- }
- length -= n
- if length == 0 {
- return
- }
- }
-}
-
// Wait implements stack.LinkEndpoint.Wait.
-func (*endpoint) Wait() {}
+func (e *endpoint) Wait() { e.lower.Wait() }
func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
// Figure out the network layer info.
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 52fe397bf..2b3741276 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -112,9 +112,9 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
if !e.writeGate.Enter() {
- return len(pkts), nil
+ return pkts.Len(), nil
}
n, err := e.lower.WritePackets(r, gso, pkts, protocol)
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 88224e494..54eb5322b 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -71,9 +71,9 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcp
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- e.writeCount += len(pkts)
- return len(pkts), nil
+func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ e.writeCount += pkts.Len()
+ return pkts.Len(), nil
}
func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 255098372..7acbfa0a8 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderPara
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) {
return 0, tcpip.ErrNotSupported
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index b3e239ac7..1646d9cde 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -138,7 +138,8 @@ func TestDirectRequest(t *testing.T) {
// Sleep tests are gross, but this will only potentially flake
// if there's a bug. If there is no bug this will reliably
// succeed.
- ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
if pkt, ok := c.linkEP.ReadContext(ctx); ok {
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 4950d69fc..4c20301c6 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -172,7 +172,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index a7d9a8b25..104aafbed 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -280,28 +280,47 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("multiple packets in local loop")
}
if r.Loop&stack.PacketOut == 0 {
- return len(pkts), nil
+ return pkts.Len(), nil
+ }
+
+ for pkt := pkts.Front(); pkt != nil; {
+ ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
+ pkt.NetworkHeader = buffer.View(ip)
+ pkt = pkt.Next()
}
// iptables filtering. All packets that reach here are locally
// generated.
ipt := e.stack.IPTables()
- for i := range pkts {
- if ok := ipt.Check(stack.Output, pkts[i]); !ok {
- // iptables is telling us to drop the packet.
+ dropped := ipt.CheckPackets(stack.Output, pkts)
+ if len(dropped) == 0 {
+ // Fast path: If no packets are to be dropped then we can just invoke the
+ // faster WritePackets API directly.
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+ }
+
+ // Slow Path as we are dropping some packets in the batch degrade to
+ // emitting one packet at a time.
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if _, ok := dropped[pkt]; ok {
continue
}
- ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params)
- pkts[i].NetworkHeader = buffer.View(ip)
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil {
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+ }
+ n++
}
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, err
+ return n, nil
}
// WriteHeaderIncludedPacket writes a packet already containing a network
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index a93a7621a..3f71fc520 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -31,6 +31,7 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index e0dd5afd3..b68983d10 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -62,7 +62,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer) {
+func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
@@ -79,32 +79,22 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
// Only the first view in vv is accounted for by h. To account for the
// rest of vv, a shallow copy is made and the first view is removed.
// This copy is used as extra payload during the checksum calculation.
- payload := pkt.Data
+ payload := pkt.Data.Clone(nil)
payload.RemoveFirst()
if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want {
received.Invalid.Increment()
return
}
- // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
- // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
- // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not
- // set to 0.
- switch h.Type() {
- case header.ICMPv6NeighborSolicit,
- header.ICMPv6NeighborAdvert,
- header.ICMPv6RouterSolicit,
- header.ICMPv6RouterAdvert,
- header.ICMPv6RedirectMsg:
- if iph.HopLimit() != header.NDPHopLimit {
- received.Invalid.Increment()
- return
- }
-
- if h.Code() != 0 {
- received.Invalid.Increment()
- return
- }
+ isNDPValid := func() bool {
+ // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
+ // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
+ // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not
+ // set to 0.
+ //
+ // As per RFC 6980 section 5, nodes MUST silently drop NDP messages if the
+ // packet includes a fragmentation header.
+ return !hasFragmentHeader && iph.HopLimit() == header.NDPHopLimit && h.Code() == 0
}
// TODO(b/112892170): Meaningfully handle all ICMP types.
@@ -133,7 +123,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
- if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
+ if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
received.Invalid.Increment()
return
}
@@ -148,53 +138,48 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
targetAddr := ns.TargetAddress()
s := r.Stack()
- rxNICID := r.NICID()
- if isTentative, err := s.IsAddrTentative(rxNICID, targetAddr); err != nil {
- // We will only get an error if rxNICID is unrecognized,
- // which should not happen. For now short-circuit this
- // packet.
+ if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
+ // We will only get an error if the NIC is unrecognized, which should not
+ // happen. For now, drop this packet.
//
// TODO(b/141002840): Handle this better?
return
} else if isTentative {
- // If the target address is tentative and the source
- // of the packet is a unicast (specified) address, then
- // the source of the packet is attempting to perform
- // address resolution on the target. In this case, the
- // solicitation is silently ignored, as per RFC 4862
- // section 5.4.3.
+ // If the target address is tentative and the source of the packet is a
+ // unicast (specified) address, then the source of the packet is
+ // attempting to perform address resolution on the target. In this case,
+ // the solicitation is silently ignored, as per RFC 4862 section 5.4.3.
//
- // If the target address is tentative and the source of
- // the packet is the unspecified address (::), then we
- // know another node is also performing DAD for the
- // same address (since targetAddr is tentative for us,
- // we know we are also performing DAD on it). In this
- // case we let the stack know so it can handle such a
- // scenario and do nothing further with the NDP NS.
- if iph.SourceAddress() == header.IPv6Any {
- s.DupTentativeAddrDetected(rxNICID, targetAddr)
+ // If the target address is tentative and the source of the packet is the
+ // unspecified address (::), then we know another node is also performing
+ // DAD for the same address (since the target address is tentative for us,
+ // we know we are also performing DAD on it). In this case we let the
+ // stack know so it can handle such a scenario and do nothing further with
+ // the NS.
+ if r.RemoteAddress == header.IPv6Any {
+ s.DupTentativeAddrDetected(e.nicID, targetAddr)
}
- // Do not handle neighbor solicitations targeted
- // to an address that is tentative on the received
- // NIC any further.
+ // Do not handle neighbor solicitations targeted to an address that is
+ // tentative on the NIC any further.
return
}
- // At this point we know that targetAddr is not tentative on
- // rxNICID so the packet is processed as defined in RFC 4861,
- // as per RFC 4862 section 5.4.3.
+ // At this point we know that the target address is not tentative on the NIC
+ // so the packet is processed as defined in RFC 4861, as per RFC 4862
+ // section 5.4.3.
+ // Is the NS targetting us?
if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
- // We don't have a useful answer; the best we can do is ignore the request.
return
}
- // If the NS message has the source link layer option, update the link
- // address cache with the link address for the sender of the message.
+ // If the NS message contains the Source Link-Layer Address option, update
+ // the link address cache with the value of the option.
//
// TODO(b/148429853): Properly process the NS message and do Neighbor
// Unreachability Detection.
+ var sourceLinkAddr tcpip.LinkAddress
for {
opt, done, err := it.Next()
if err != nil {
@@ -207,22 +192,36 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
switch opt := opt.(type) {
case header.NDPSourceLinkLayerAddressOption:
- e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, opt.EthernetAddress())
+ // No RFCs define what to do when an NS message has multiple Source
+ // Link-Layer Address options. Since no interface can have multiple
+ // link-layer addresses, we consider such messages invalid.
+ if len(sourceLinkAddr) != 0 {
+ received.Invalid.Increment()
+ return
+ }
+
+ sourceLinkAddr = opt.EthernetAddress()
}
}
- optsSerializer := header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]),
+ unspecifiedSource := r.RemoteAddress == header.IPv6Any
+
+ // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST
+ // NOT be included when the source IP address is the unspecified address.
+ // Otherwise, on link layers that have addresses this option MUST be
+ // included in multicast solicitations and SHOULD be included in unicast
+ // solicitations.
+ if len(sourceLinkAddr) == 0 {
+ if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource {
+ received.Invalid.Increment()
+ return
+ }
+ } else if unspecifiedSource {
+ received.Invalid.Increment()
+ return
+ } else {
+ e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr)
}
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()))
- packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
- packet.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(packet.NDPPayload())
- na.SetSolicitedFlag(true)
- na.SetOverrideFlag(true)
- na.SetTargetAddress(targetAddr)
- opts := na.Options()
- opts.Serialize(optsSerializer)
// ICMPv6 Neighbor Solicit messages are always sent to
// specially crafted IPv6 multicast addresses. As a result, the
@@ -235,6 +234,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
r := r.Clone()
defer r.Release()
r.LocalAddress = targetAddr
+
+ // As per RFC 4861 section 7.2.4, if the the source of the solicitation is
+ // the unspecified address, the node MUST set the Solicited flag to zero and
+ // multicast the advertisement to the all-nodes address.
+ solicited := true
+ if unspecifiedSource {
+ solicited = false
+ r.RemoteAddress = header.IPv6AllNodesMulticastAddress
+ }
+
+ // If the NS has a source link-layer option, use the link address it
+ // specifies as the remote link address for the response instead of the
+ // source link address of the packet.
+ //
+ // TODO(#2401): As per RFC 4861 section 7.2.4 we should consult our link
+ // address cache for the right destination link address instead of manually
+ // patching the route with the remote link address if one is specified in a
+ // Source Link-Layer Address option.
+ if len(sourceLinkAddr) != 0 {
+ r.RemoteLinkAddress = sourceLinkAddr
+ }
+
+ optsSerializer := header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress),
+ }
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()))
+ packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ packet.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(packet.NDPPayload())
+ na.SetSolicitedFlag(solicited)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(targetAddr)
+ opts := na.Options()
+ opts.Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
// RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
@@ -253,7 +286,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6NeighborAdvert:
received.NeighborAdvert.Increment()
- if len(v) < header.ICMPv6NeighborAdvertSize {
+ if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
received.Invalid.Increment()
return
}
@@ -268,40 +301,38 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
targetAddr := na.TargetAddress()
stack := r.Stack()
- rxNICID := r.NICID()
- if isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr); err != nil {
- // We will only get an error if rxNICID is unrecognized,
- // which should not happen. For now short-circuit this
- // packet.
+ if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil {
+ // We will only get an error if the NIC is unrecognized, which should not
+ // happen. For now short-circuit this packet.
//
// TODO(b/141002840): Handle this better?
return
} else if isTentative {
- // We just got an NA from a node that owns an address we
- // are performing DAD on, implying the address is not
- // unique. In this case we let the stack know so it can
- // handle such a scenario and do nothing furthur with
+ // We just got an NA from a node that owns an address we are performing
+ // DAD on, implying the address is not unique. In this case we let the
+ // stack know so it can handle such a scenario and do nothing furthur with
// the NDP NA.
- stack.DupTentativeAddrDetected(rxNICID, targetAddr)
+ stack.DupTentativeAddrDetected(e.nicID, targetAddr)
return
}
- // At this point we know that the targetAddress is not tentative
- // on rxNICID. However, targetAddr may still be assigned to
- // rxNICID but not tentative (it could be permanent). Such a
- // scenario is beyond the scope of RFC 4862. As such, we simply
- // ignore such a scenario for now and proceed as normal.
+ // At this point we know that the target address is not tentative on the
+ // NIC. However, the target address may still be assigned to the NIC but not
+ // tentative (it could be permanent). Such a scenario is beyond the scope of
+ // RFC 4862. As such, we simply ignore such a scenario for now and proceed
+ // as normal.
//
+ // TODO(b/143147598): Handle the scenario described above. Also inform the
+ // netstack integration that a duplicate address was detected outside of
+ // DAD.
+
// If the NA message has the target link layer option, update the link
// address cache with the link address for the target of the message.
//
- // TODO(b/143147598): Handle the scenario described above. Also
- // inform the netstack integration that a duplicate address was
- // detected outside of DAD.
- //
// TODO(b/148429853): Properly process the NA message and do Neighbor
// Unreachability Detection.
+ var targetLinkAddr tcpip.LinkAddress
for {
opt, done, err := it.Next()
if err != nil {
@@ -314,10 +345,22 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
switch opt := opt.(type) {
case header.NDPTargetLinkLayerAddressOption:
- e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, opt.EthernetAddress())
+ // No RFCs define what to do when an NA message has multiple Target
+ // Link-Layer Address options. Since no interface can have multiple
+ // link-layer addresses, we consider such messages invalid.
+ if len(targetLinkAddr) != 0 {
+ received.Invalid.Increment()
+ return
+ }
+
+ targetLinkAddr = opt.EthernetAddress()
}
}
+ if len(targetLinkAddr) != 0 {
+ e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr)
+ }
+
case header.ICMPv6EchoRequest:
received.EchoRequest.Increment()
if len(v) < header.ICMPv6EchoMinimumSize {
@@ -355,8 +398,20 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6RouterSolicit:
received.RouterSolicit.Increment()
+ if !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
case header.ICMPv6RouterAdvert:
+ received.RouterAdvert.Increment()
+
+ p := h.NDPPayload()
+ if len(p) < header.NDPRAMinimumSize || !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
routerAddr := iph.SourceAddress()
//
@@ -370,16 +425,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
return
}
- p := h.NDPPayload()
-
- // Is the NDP payload of sufficient size to hold a Router
- // Advertisement?
- if len(p) < header.NDPRAMinimumSize {
- // ...No, silently drop the packet.
- received.Invalid.Increment()
- return
- }
-
ra := header.NDPRouterAdvert(p)
opts := ra.Options()
@@ -395,8 +440,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
// as RFC 4861 section 6.1.2 is concerned.
//
- received.RouterAdvert.Increment()
-
// Tell the NIC to handle the RA.
stack := r.Stack()
rxNICID := r.NICID()
@@ -404,6 +447,10 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6RedirectMsg:
received.RedirectMsg.Increment()
+ if !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
default:
received.Invalid.Increment()
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index bae09ed94..bd099a7f8 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -32,7 +32,8 @@ import (
const (
linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
)
var (
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 685239017..331b0817b 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -143,19 +143,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
}
if r.Loop&stack.PacketOut == 0 {
- return len(pkts), nil
+ return pkts.Len(), nil
}
- for i := range pkts {
- hdr := &pkts[i].Header
- size := pkts[i].DataSize
- ip := e.addIPHeader(r, hdr, size, params)
- pkts[i].NetworkHeader = buffer.View(ip)
+ for pb := pkts.Front(); pb != nil; pb = pb.Next() {
+ ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params)
+ pb.NetworkHeader = buffer.View(ip)
}
n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
@@ -185,6 +183,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
pkt.Data.CapLength(int(h.PayloadLength()))
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data)
+ hasFragmentHeader := false
for firstHeader := true; ; firstHeader = false {
extHdr, done, err := it.Next()
@@ -257,6 +256,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
}
case header.IPv6FragmentExtHdr:
+ hasFragmentHeader = true
+
fragmentOffset := extHdr.FragmentOffset()
more := extHdr.More()
if !more && fragmentOffset == 0 {
@@ -269,7 +270,55 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
continue
}
- rawPayload := it.AsRawHeader()
+ // Don't consume the iterator if we have the first fragment because we
+ // will use it to validate that the first fragment holds the upper layer
+ // header.
+ rawPayload := it.AsRawHeader(fragmentOffset != 0 /* consume */)
+
+ if fragmentOffset == 0 {
+ // Check that the iterator ends with a raw payload as the first fragment
+ // should include all headers up to and including any upper layer
+ // headers, as per RFC 8200 section 4.5; only upper layer data
+ // (non-headers) should follow the fragment extension header.
+ var lastHdr header.IPv6PayloadHeader
+
+ for {
+ it, done, err := it.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ lastHdr = it
+ }
+
+ // If the last header is a raw header, then the last portion of the IPv6
+ // payload is not a known IPv6 extension header. Note, this does not
+ // mean that the last portion is an upper layer header or not an
+ // extension header because:
+ // 1) we do not yet support all extension headers
+ // 2) we do not validate the upper layer header before reassembling.
+ //
+ // This check makes sure that a known IPv6 extension header is not
+ // present after the Fragment extension header in a non-initial
+ // fragment.
+ //
+ // TODO(#2196): Support IPv6 Authentication and Encapsulated
+ // Security Payload extension headers.
+ // TODO(#2333): Validate that the upper layer header is valid.
+ switch lastHdr.(type) {
+ case header.IPv6RawPayloadHeader:
+ default:
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+ }
+
fragmentPayloadLen := rawPayload.Buf.Size()
if fragmentPayloadLen == 0 {
// Drop the packet as it's marked as a fragment but has no payload.
@@ -344,7 +393,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
pkt.Data = extHdr.Buf
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
- e.handleICMP(r, headerView, pkt)
+ e.handleICMP(r, headerView, pkt, hasFragmentHeader)
} else {
r.Stats().IP.PacketsDelivered.Increment()
// TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 37f7e53ce..841a0cb7a 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -34,6 +34,7 @@ const (
// The least significant 3 bytes are the same as addr2 so both addr2 and
// addr3 will have the same solicited-node address.
addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
+ addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03"
// Tests use the extension header identifier values as uint8 instead of
// header.IPv6ExtensionHeaderIdentifier.
@@ -167,6 +168,8 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
// packets destined to the IPv6 solicited-node address of an assigned IPv6
// address.
func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
+ const nicID = 1
+
tests := []struct {
name string
protocolFactory stack.TransportProtocol
@@ -184,50 +187,61 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
})
- e := channel.New(10, 1280, linkAddr1)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ e := channel.New(1, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- // Should not receive a packet destined to the solicited
- // node address of addr2/addr3 yet as we haven't added
- // those addresses.
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // Should not receive a packet destined to the solicited node address of
+ // addr2/addr3 yet as we haven't added those addresses.
test.rxf(t, s, e, addr1, snmc, 0)
- if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err)
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
}
- // Should receive a packet destined to the solicited
- // node address of addr2/addr3 now that we have added
- // added addr2.
+ // Should receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have added added addr2.
test.rxf(t, s, e, addr1, snmc, 1)
- if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err)
+ if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err)
}
- // Should still receive a packet destined to the
- // solicited node address of addr2/addr3 now that we
- // have added addr3.
+ // Should still receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have added addr3.
test.rxf(t, s, e, addr1, snmc, 2)
- if err := s.RemoveAddress(1, addr2); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err)
+ if err := s.RemoveAddress(nicID, addr2); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr2, err)
}
- // Should still receive a packet destined to the
- // solicited node address of addr2/addr3 now that we
- // have removed addr2.
+ // Should still receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have removed addr2.
test.rxf(t, s, e, addr1, snmc, 3)
- if err := s.RemoveAddress(1, addr3); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err)
+ // Make sure addr3's endpoint does not get removed from the NIC by
+ // incrementing its reference count with a route.
+ r, err := s.FindRoute(nicID, addr3, addr4, ProtocolNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr3, addr4, ProtocolNumber, err)
+ }
+ defer r.Release()
+
+ if err := s.RemoveAddress(nicID, addr3); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr3, err)
}
- // Should not receive a packet destined to the solicited
- // node address of addr2/addr3 yet as both of them got
- // removed.
+ // Should not receive a packet destined to the solicited node address of
+ // addr2/addr3 yet as both of them got removed, even though a route using
+ // addr3 exists.
test.rxf(t, s, e, addr1, snmc, 3)
})
}
@@ -1014,7 +1028,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: nil,
},
{
name: "Two fragments with routing header with non-zero segments left across fragments",
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index f924ed9e1..12b70f7e9 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -173,6 +174,257 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
}
}
+func TestNeighorSolicitationResponse(t *testing.T) {
+ const nicID = 1
+ nicAddr := lladdr0
+ remoteAddr := lladdr1
+ nicAddrSNMC := header.SolicitedNodeAddr(nicAddr)
+ nicLinkAddr := linkAddr0
+ remoteLinkAddr0 := linkAddr1
+ remoteLinkAddr1 := linkAddr2
+
+ tests := []struct {
+ name string
+ nsOpts header.NDPOptionsSerializer
+ nsSrcLinkAddr tcpip.LinkAddress
+ nsSrc tcpip.Address
+ nsDst tcpip.Address
+ nsInvalid bool
+ naDstLinkAddr tcpip.LinkAddress
+ naSolicited bool
+ naSrc tcpip.Address
+ naDst tcpip.Address
+ }{
+ {
+ name: "Unspecified source to multicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: false,
+ naSrc: nicAddr,
+ naDst: header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ name: "Unspecified source with source ll option to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+ {
+ name: "Unspecified source to unicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: false,
+ naSrc: nicAddr,
+ naDst: header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ name: "Unspecified source with source ll option to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddr,
+ nsInvalid: true,
+ },
+
+ {
+ name: "Specified source with 1 source ll to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll different from route to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr1,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source to multicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+ {
+ name: "Specified source with 2 source ll to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+
+ {
+ name: "Specified source to unicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll different from route to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr1,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 2 source ll to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(1, 1280, nicLinkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(nicAddr)
+ opts := ns.Options()
+ opts.Serialize(test.nsOpts)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ if test.nsInvalid {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+
+ if p, got := e.Read(); got {
+ t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
+ }
+
+ // If we expected the NS to be invalid, we have nothing else to check.
+ return
+ }
+
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ p, got := e.Read()
+ if !got {
+ t.Fatal("expected an NDP NA response")
+ }
+
+ if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ }
+
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(test.naSrc),
+ checker.DstAddr(test.naDst),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNA(
+ checker.NDPNASolicitedFlag(test.naSolicited),
+ checker.NDPNATargetAddress(nicAddr),
+ checker.NDPNAOptions([]header.NDPOption{
+ header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
+ }),
+ ))
+ })
+ }
+}
+
// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a
// valid NDP NA message with the Target Link Layer Address option results in a
// new entry in the link address cache for the target of the message.
@@ -197,6 +449,13 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
name: "Invalid Length",
optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
},
+ {
+ name: "Multiple",
+ optsBuf: []byte{
+ 2, 1, 2, 3, 4, 5, 6, 7,
+ 2, 1, 2, 3, 4, 5, 6, 8,
+ },
+ },
}
for _, test := range tests {
@@ -276,9 +535,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
}
}
-// TestHopLimitValidation is a test that makes sure that NDP packets are only
-// received if their IP header's hop limit is set to 255.
-func TestHopLimitValidation(t *testing.T) {
+func TestNDPValidation(t *testing.T) {
setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
t.Helper()
@@ -294,12 +551,19 @@ func TestHopLimitValidation(t *testing.T) {
return s, ep, r
}
- handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, ep stack.NetworkEndpoint, r *stack.Route) {
+ handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ nextHdr := uint8(header.ICMPv6ProtocolNumber)
+ if atomicFragment {
+ bytes := hdr.Prepend(header.IPv6FragmentExtHdrLength)
+ bytes[0] = nextHdr
+ nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
+ }
+
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ NextHeader: nextHdr,
HopLimit: hopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
@@ -364,61 +628,93 @@ func TestHopLimitValidation(t *testing.T) {
},
}
+ subTests := []struct {
+ name string
+ atomicFragment bool
+ hopLimit uint8
+ code uint8
+ valid bool
+ }{
+ {
+ name: "Valid",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: true,
+ },
+ {
+ name: "Fragmented",
+ atomicFragment: true,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid hop limit",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit - 1,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid ICMPv6 code",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 1,
+ valid: false,
+ },
+ }
+
for _, typ := range types {
t.Run(typ.name, func(t *testing.T) {
- s, ep, r := setup(t)
- defer r.Release()
-
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- extraDataLen := len(typ.extraData)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
- extraData := buffer.View(hdr.Prepend(extraDataLen))
- copy(extraData, typ.extraData)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- // Should not have received any ICMPv6 packets with
- // type = typ.typ.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Receive the NDP packet with an invalid hop limit
- // value.
- handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r)
-
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
-
- // Rx count of NDP packet of type typ.typ should not
- // have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Receive the NDP packet with a valid hop limit value.
- handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r)
-
- // Rx count of NDP packet of type typ.typ should have
- // increased.
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
- }
-
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
+ for _, test := range subTests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep, r := setup(t)
+ defer r.Release()
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ extraDataLen := len(typ.extraData)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen + header.IPv6FragmentExtHdrLength)
+ extraData := buffer.View(hdr.Prepend(extraDataLen))
+ copy(extraData, typ.extraData)
+ pkt := header.ICMPv6(hdr.Prepend(typ.size))
+ pkt.SetType(typ.typ)
+ pkt.SetCode(test.code)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
+
+ // Rx count of the NDP message should initially be 0.
+ if got := typStat.Value(); got != 0 {
+ t.Errorf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ handleIPv6Payload(hdr, test.hopLimit, test.atomicFragment, ep, &r)
+
+ // Rx count of the NDP packet should have increased.
+ if got := typStat.Value(); got != 1 {
+ t.Errorf("got %s = %d, want = 1", typ.name, got)
+ }
+
+ want := uint64(0)
+ if !test.valid {
+ // Invalid count should have increased.
+ want = 1
+ }
+ if got := invalid.Value(); got != want {
+ t.Errorf("got invalid = %d, want = %d", got, want)
+ }
+ })
}
})
}
@@ -592,21 +888,18 @@ func TestRouterAdvertValidation(t *testing.T) {
Data: hdr.View().ToVectorisedView(),
})
+ if got := rxRA.Value(); got != 1 {
+ t.Fatalf("got rxRA = %d, want = 1", got)
+ }
+
if test.expectedSuccess {
if got := invalid.Value(); got != 0 {
t.Fatalf("got invalid = %d, want = 0", got)
}
- if got := rxRA.Value(); got != 1 {
- t.Fatalf("got rxRA = %d, want = 1", got)
- }
-
} else {
if got := invalid.Value(); got != 1 {
t.Fatalf("got invalid = %d, want = 1", got)
}
- if got := rxRA.Value(); got != 0 {
- t.Fatalf("got rxRA = %d, want = 0", got)
- }
}
})
}
diff --git a/pkg/tcpip/seqnum/seqnum.go b/pkg/tcpip/seqnum/seqnum.go
index b40a3c212..d3bea7de4 100644
--- a/pkg/tcpip/seqnum/seqnum.go
+++ b/pkg/tcpip/seqnum/seqnum.go
@@ -46,11 +46,6 @@ func (v Value) InWindow(first Value, size Size) bool {
return v.InRange(first, first.Add(size))
}
-// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y).
-func Overlap(a Value, b Size, x Value, y Size) bool {
- return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b))
-}
-
// Add calculates the sequence number following the [v, v+s) window.
func (v Value) Add(s Size) Value {
return v + Value(s)
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 8d80e9cee..5e963a4af 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -15,6 +15,18 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "packet_buffer_list",
+ out = "packet_buffer_list.go",
+ package = "stack",
+ prefix = "PacketBuffer",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*PacketBuffer",
+ "Linker": "*PacketBuffer",
+ },
+)
+
go_library(
name = "stack",
srcs = [
@@ -29,7 +41,7 @@ go_library(
"ndp.go",
"nic.go",
"packet_buffer.go",
- "packet_buffer_state.go",
+ "packet_buffer_list.go",
"rand.go",
"registration.go",
"route.go",
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index c45c43d21..e9c652042 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -101,7 +101,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
}
// WritePackets implements LinkEndpoint.WritePackets.
-func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) {
+func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
panic("not implemented")
}
@@ -260,10 +260,10 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw
}
// WritePackets stores outbound packets into the channel.
-func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
n := 0
- for _, pkt := range pkts {
- e.WritePacket(r, gso, protocol, pkt)
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.WritePacket(r, gso, protocol, *pkt)
n++
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 37907ae24..6c0a4b24d 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -209,6 +209,23 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool {
return true
}
+// CheckPackets runs pkts through the rules for hook and returns a map of packets that
+// should not go forward.
+//
+// NOTE: unlike the Check API the returned map contains packets that should be
+// dropped.
+func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if ok := it.Check(hook, *pkt); !ok {
+ if drop == nil {
+ drop = make(map[*PacketBuffer]struct{})
+ }
+ drop[pkt] = struct{}{}
+ }
+ }
+ return drop
+}
+
// Precondition: pkt.NetworkHeader is set.
func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 7c9fc48d1..193a9dfde 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -241,6 +241,16 @@ type NDPDispatcher interface {
// call functions on the stack itself.
OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
+ // OnDNSSearchListOption will be called when an NDP option with a DNS
+ // search list has been received.
+ //
+ // It is up to the caller to use the domain names in the search list
+ // for only their valid lifetime. OnDNSSearchListOption may be called
+ // with new or already known domain names. If called with known domain
+ // names, their valid lifetimes must be refreshed to lifetime (it may
+ // be increased, decreased or completely invalidated when lifetime = 0.
+ OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration)
+
// OnDHCPv6Configuration will be called with an updated configuration that is
// available via DHCPv6 for a specified NIC.
//
@@ -305,6 +315,15 @@ type NDPConfigurations struct {
// lifetime(s) of the generated address changes; this option only
// affects the generation of new addresses as part of SLAAC.
AutoGenGlobalAddresses bool
+
+ // AutoGenAddressConflictRetries determines how many times to attempt to retry
+ // generation of a permanent auto-generated address in response to DAD
+ // conflicts.
+ //
+ // If the method used to generate the address does not support creating
+ // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's
+ // MAC address), then no attempt will be made to resolve the conflict.
+ AutoGenAddressConflictRetries uint8
}
// DefaultNDPConfigurations returns an NDPConfigurations populated with
@@ -411,8 +430,23 @@ type slaacPrefixState struct {
// Nonzero only when the address is not valid forever.
validUntil time.Time
+ // Nonzero only when the address is not preferred forever.
+ preferredUntil time.Time
+
// The prefix's permanent address endpoint.
+ //
+ // May only be nil when a SLAAC address is being (re-)generated. Otherwise,
+ // must not be nil as all SLAAC prefixes must have a SLAAC address.
ref *referencedNetworkEndpoint
+
+ // The number of times a permanent address has been generated for the prefix.
+ //
+ // Addresses may be regenerated in reseponse to a DAD conflicts.
+ generationAttempts uint8
+
+ // The maximum number of times to attempt regeneration of a permanent SLAAC
+ // address in response to DAD conflicts.
+ maxGenerationAttempts uint8
}
// startDuplicateAddressDetection performs Duplicate Address Detection.
@@ -687,7 +721,16 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
continue
}
- ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), opt.Addresses(), opt.Lifetime())
+ addrs, _ := opt.Addresses()
+ ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime())
+
+ case header.NDPDNSSearchList:
+ if ndp.nic.stack.ndpDisp == nil {
+ continue
+ }
+
+ domainNames, _ := opt.DomainNames()
+ ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime())
case header.NDPPrefixInformation:
prefix := opt.Subnet()
@@ -935,60 +978,83 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
return
}
- // If the preferred lifetime is zero, then the prefix should be considered
- // deprecated.
- deprecated := pl == 0
- ref := ndp.addSLAACAddr(prefix, deprecated)
- if ref == nil {
- // We were unable to generate a permanent address for prefix so do nothing
- // further as there is no reason to maintain state for a SLAAC prefix we
- // cannot generate a permanent address for.
- return
- }
-
state := slaacPrefixState{
deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
- prefixState, ok := ndp.slaacPrefixes[prefix]
+ state, ok := ndp.slaacPrefixes[prefix]
if !ok {
- panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the SLAAC prefix %s", prefix))
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
}
- ndp.deprecateSLAACAddress(prefixState.ref)
+ ndp.deprecateSLAACAddress(state.ref)
}),
invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
- ndp.invalidateSLAACPrefix(prefix, true)
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
+ }
+
+ ndp.invalidateSLAACPrefix(prefix, state)
}),
- ref: ref,
+ maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1,
+ }
+
+ now := time.Now()
+
+ // The time an address is preferred until is needed to properly generate the
+ // address.
+ if pl < header.NDPInfiniteLifetime {
+ state.preferredUntil = now.Add(pl)
+ }
+
+ if !ndp.generateSLAACAddr(prefix, &state) {
+ // We were unable to generate an address for the prefix, we do not nothing
+ // further as there is no reason to maintain state or timers for a prefix we
+ // do not have an address for.
+ return
}
// Setup the initial timers to deprecate and invalidate prefix.
- if !deprecated && pl < header.NDPInfiniteLifetime {
+ if pl < header.NDPInfiniteLifetime && pl != 0 {
state.deprecationTimer.Reset(pl)
}
if vl < header.NDPInfiniteLifetime {
state.invalidationTimer.Reset(vl)
- state.validUntil = time.Now().Add(vl)
+ state.validUntil = now.Add(vl)
}
ndp.slaacPrefixes[prefix] = state
}
-// addSLAACAddr adds a SLAAC address for prefix.
+// generateSLAACAddr generates a SLAAC address for prefix.
+//
+// Returns true if an address was successfully generated.
+//
+// Panics if the prefix is not a SLAAC prefix or it already has an address.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referencedNetworkEndpoint {
+func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool {
+ if r := state.ref; r != nil {
+ panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix()))
+ }
+
+ // If we have already reached the maximum address generation attempts for the
+ // prefix, do not generate another address.
+ if state.generationAttempts == state.maxGenerationAttempts {
+ return false
+ }
+
addrBytes := []byte(prefix.ID())
if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
addrBytes = header.AppendOpaqueInterfaceIdentifier(
addrBytes[:header.IIDOffsetInIPv6Address],
prefix,
oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name),
- 0, /* dadCounter */
+ state.generationAttempts,
oIID.SecretKey,
)
- } else {
+ } else if state.generationAttempts == 0 {
// Only attempt to generate an interface-specific IID if we have a valid
// link address.
//
@@ -996,12 +1062,16 @@ func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referen
// LinkEndpoint.LinkAddress) before reaching this point.
linkAddr := ndp.nic.linkEP.LinkAddress()
if !header.IsValidUnicastEthernetAddress(linkAddr) {
- return nil
+ return false
}
// Generate an address within prefix from the modified EUI-64 of ndp's NIC's
// Ethernet MAC address.
header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ } else {
+ // We have no way to regenerate an address when addresses are not generated
+ // with opaque IIDs.
+ return false
}
generatedAddr := tcpip.ProtocolAddress{
@@ -1014,26 +1084,52 @@ func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referen
// If the nic already has this address, do nothing further.
if ndp.nic.hasPermanentAddrLocked(generatedAddr.AddressWithPrefix.Address) {
- return nil
+ return false
}
// Inform the integrator that we have a new SLAAC address.
ndpDisp := ndp.nic.stack.ndpDisp
if ndpDisp == nil {
- return nil
+ return false
}
if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), generatedAddr.AddressWithPrefix) {
// Informed by the integrator not to add the address.
- return nil
+ return false
}
+ deprecated := time.Since(state.preferredUntil) >= 0
ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated)
if err != nil {
panic(fmt.Sprintf("ndp: error when adding address %+v: %s", generatedAddr, err))
}
- return ref
+ state.generationAttempts++
+ state.ref = ref
+ return true
+}
+
+// regenerateSLAACAddr regenerates an address for a SLAAC prefix.
+//
+// If generating a new address for the prefix fails, the prefix will be
+// invalidated.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate address for %s", prefix))
+ }
+
+ if ndp.generateSLAACAddr(prefix, &state) {
+ ndp.slaacPrefixes[prefix] = state
+ return
+ }
+
+ // We were unable to generate a permanent address for the SLAAC prefix so
+ // invalidate the prefix as there is no reason to maintain state for a
+ // SLAAC prefix we do not have an address for.
+ ndp.invalidateSLAACPrefix(prefix, state)
}
// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix.
@@ -1060,9 +1156,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl tim
// deprecation timer so it can be reset.
prefixState.deprecationTimer.StopLocked()
+ now := time.Now()
+
// Reset the deprecation timer if prefix has a finite preferred lifetime.
- if !deprecated && pl < header.NDPInfiniteLifetime {
- prefixState.deprecationTimer.Reset(pl)
+ if pl < header.NDPInfiniteLifetime {
+ if !deprecated {
+ prefixState.deprecationTimer.Reset(pl)
+ }
+ prefixState.preferredUntil = now.Add(pl)
+ } else {
+ prefixState.preferredUntil = time.Time{}
}
// As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix:
@@ -1105,7 +1208,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl tim
prefixState.invalidationTimer.StopLocked()
prefixState.invalidationTimer.Reset(effectiveVl)
- prefixState.validUntil = time.Now().Add(effectiveVl)
+ prefixState.validUntil = now.Add(effectiveVl)
}
// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP
@@ -1121,48 +1224,60 @@ func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) {
ref.deprecated = true
if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{
- Address: ref.ep.ID().LocalAddress,
- PrefixLen: ref.ep.PrefixLen(),
- })
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix())
}
}
// invalidateSLAACPrefix invalidates a SLAAC prefix.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, removeAddr bool) {
- state, ok := ndp.slaacPrefixes[prefix]
- if !ok {
- return
+func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) {
+ if r := state.ref; r != nil {
+ // Since we are already invalidating the prefix, do not invalidate the
+ // prefix when removing the address.
+ if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACPrefixInvalidation */); err != nil {
+ panic(fmt.Sprintf("ndp: removePermanentIPv6EndpointLocked(%s, false): %s", r.addrWithPrefix(), err))
+ }
}
- state.deprecationTimer.StopLocked()
- state.invalidationTimer.StopLocked()
- delete(ndp.slaacPrefixes, prefix)
+ ndp.cleanupSLAACPrefixResources(prefix, state)
+}
- addr := state.ref.ep.ID().LocalAddress
+// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's
+// resources.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr)
+ }
- if removeAddr {
- if err := ndp.nic.removePermanentAddressLocked(addr); err != nil {
- panic(fmt.Sprintf("ndp: removePermanentAddressLocked(%s): %s", addr, err))
- }
+ prefix := addr.Subnet()
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok || state.ref == nil || addr.Address != state.ref.ep.ID().LocalAddress {
+ return
}
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: state.ref.ep.PrefixLen(),
- })
+ if !invalidatePrefix {
+ // If the prefix is not being invalidated, disassociate the address from the
+ // prefix and do nothing further.
+ state.ref = nil
+ ndp.slaacPrefixes[prefix] = state
+ return
}
+
+ ndp.cleanupSLAACPrefixResources(prefix, state)
}
-// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC
-// address's resources from ndp.
+// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry.
+//
+// Panics if the SLAAC prefix is not known.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix) {
- ndp.invalidateSLAACPrefix(addr.Subnet(), false)
+func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) {
+ state.deprecationTimer.StopLocked()
+ state.invalidationTimer.StopLocked()
+ delete(ndp.slaacPrefixes, prefix)
}
// cleanupState cleans up ndp's state.
@@ -1181,7 +1296,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr
func (ndp *ndpState) cleanupState(hostOnly bool) {
linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
linkLocalPrefixes := 0
- for prefix := range ndp.slaacPrefixes {
+ for prefix, state := range ndp.slaacPrefixes {
// RFC 4862 section 5 states that routers are also expected to generate a
// link-local address so we do not invalidate them if we are cleaning up
// host-only state.
@@ -1190,7 +1305,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
continue
}
- ndp.invalidateSLAACPrefix(prefix, true)
+ ndp.invalidateSLAACPrefix(prefix, state)
}
if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 06edd05b6..6dd460984 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -133,6 +133,12 @@ type ndpRDNSSEvent struct {
rdnss ndpRDNSS
}
+type ndpDNSSLEvent struct {
+ nicID tcpip.NICID
+ domainNames []string
+ lifetime time.Duration
+}
+
type ndpDHCPv6Event struct {
nicID tcpip.NICID
configuration stack.DHCPv6ConfigurationFromNDPRA
@@ -150,6 +156,8 @@ type ndpDispatcher struct {
rememberPrefix bool
autoGenAddrC chan ndpAutoGenAddrEvent
rdnssC chan ndpRDNSSEvent
+ dnsslC chan ndpDNSSLEvent
+ routeTable []tcpip.Route
dhcpv6ConfigurationC chan ndpDHCPv6Event
}
@@ -257,6 +265,17 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc
}
}
+// Implements stack.NDPDispatcher.OnDNSSearchListOption.
+func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) {
+ if n.dnsslC != nil {
+ n.dnsslC <- ndpDNSSLEvent{
+ nicID,
+ domainNames,
+ lifetime,
+ }
+ }
+}
+
// Implements stack.NDPDispatcher.OnDHCPv6Configuration.
func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) {
if c := n.dhcpv6ConfigurationC; c != nil {
@@ -406,8 +425,7 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
}
- // Address should not be considered bound to the NIC yet
- // (DAD ongoing).
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
@@ -416,10 +434,9 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
- // Wait for the remaining time - some delta (500ms), to
- // make sure the address is still not resolved.
- const delta = 500 * time.Millisecond
- time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta)
+ // Make sure the address does not resolve before the resolution time has
+ // passed.
+ time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout)
addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
@@ -430,13 +447,7 @@ func TestDADResolve(t *testing.T) {
// Wait for DAD to resolve.
select {
- case <-time.After(2 * delta):
- // We should get a resolution event after 500ms
- // (delta) since we wait for 500ms less than the
- // expected resolution time above to make sure
- // that the address did not yet resolve. Waiting
- // for 1s (2x delta) without a resolution event
- // means something is wrong.
+ case <-time.After(2 * defaultAsyncEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
@@ -476,7 +487,7 @@ func TestDADResolve(t *testing.T) {
// As per RFC 4861 section 4.3, a possible option is the Source Link
// Layer option, but this option MUST NOT be included when the source
// address of the packet is the unspecified address.
- checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(),
+ checker.IPv6(t, p.Pkt.Header.View(),
checker.SrcAddr(header.IPv6Any),
checker.DstAddr(snmc),
checker.TTL(header.NDPHopLimit),
@@ -631,6 +642,12 @@ func TestDADFail(t *testing.T) {
if want := (tcpip.AddressWithPrefix{}); addr != want {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
+
+ // Attempting to add the address again should not fail if the address's
+ // state was cleaned up when DAD failed.
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
})
}
}
@@ -1034,8 +1051,6 @@ func TestNoRouterDiscovery(t *testing.T) {
forwarding := i&4 == 0
t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 1),
}
@@ -1074,8 +1089,6 @@ func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) str
// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
// remember a discovered router when the dispatcher asks it not to.
func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 1),
}
@@ -1116,8 +1129,6 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
}
func TestRouterDiscovery(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 1),
rememberRouter: true,
@@ -1219,8 +1230,6 @@ func TestRouterDiscovery(t *testing.T) {
// TestRouterDiscoveryMaxRouters tests that only
// stack.MaxDiscoveredDefaultRouters discovered routers are remembered.
func TestRouterDiscoveryMaxRouters(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 1),
rememberRouter: true,
@@ -1287,8 +1296,6 @@ func TestNoPrefixDiscovery(t *testing.T) {
forwarding := i&4 == 0
t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
prefixC: make(chan ndpPrefixEvent, 1),
}
@@ -1328,8 +1335,6 @@ func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) st
// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not
// remember a discovered on-link prefix when the dispatcher asks it not to.
func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
- t.Parallel()
-
prefix, subnet, _ := prefixSubnetAddr(0, "")
ndpDisp := ndpDispatcher{
@@ -1373,8 +1378,6 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
}
func TestPrefixDiscovery(t *testing.T) {
- t.Parallel()
-
prefix1, subnet1, _ := prefixSubnetAddr(0, "")
prefix2, subnet2, _ := prefixSubnetAddr(1, "")
prefix3, subnet3, _ := prefixSubnetAddr(2, "")
@@ -1563,8 +1566,6 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// TestPrefixDiscoveryMaxRouters tests that only
// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3),
rememberPrefix: true,
@@ -1659,8 +1660,6 @@ func TestNoAutoGenAddr(t *testing.T) {
forwarding := i&4 == 0
t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
@@ -1985,7 +1984,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// addr2 is deprecated but if explicitly requested, it should be used.
fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address)
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
}
// Another PI w/ 0 preferred lifetime should not result in a deprecation
@@ -1998,7 +1997,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
}
expectPrimaryAddr(addr1)
if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address)
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
}
// Refresh lifetimes of addr generated from prefix2.
@@ -2110,7 +2109,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// addr1 is deprecated but if explicitly requested, it should be used.
fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
}
// Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
@@ -2123,7 +2122,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
}
expectPrimaryAddr(addr2)
if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
}
// Refresh lifetimes for addr of prefix1.
@@ -2147,7 +2146,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// addr2 should be the primary endpoint now since it is not deprecated.
expectPrimaryAddr(addr2)
if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
}
// Wait for addr of prefix1 to be invalidated.
@@ -2410,8 +2409,6 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
},
}
- const delta = 500 * time.Millisecond
-
// This Run will not return until the parallel tests finish.
//
// We need this because we need to do some teardown work after the
@@ -2464,24 +2461,21 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
// to test.evl.
//
- // Make sure we do not get any invalidation
- // events until atleast 500ms (delta) before
- // test.evl.
+ // The address should not be invalidated until the effective valid
+ // lifetime has passed.
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(time.Duration(test.evl)*time.Second - delta):
+ case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncEventTimeout):
}
- // Wait for another second (2x delta), but now
- // we expect the invalidation event.
+ // Wait for the invalidation event.
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
-
- case <-time.After(2 * delta):
+ case <-time.After(2 * defaultAsyncEventTimeout):
t.Fatal("timeout waiting for addr auto gen event")
}
})
@@ -2493,8 +2487,6 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
// by the user, its resources will be cleaned up and an invalidation event will
// be sent to the integrator.
func TestAutoGenAddrRemoval(t *testing.T) {
- t.Parallel()
-
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
ndpDisp := ndpDispatcher{
@@ -2551,8 +2543,6 @@ func TestAutoGenAddrRemoval(t *testing.T) {
// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously
// assigned to the NIC but is in the permanentExpired state.
func TestAutoGenAddrAfterRemoval(t *testing.T) {
- t.Parallel()
-
const nicID = 1
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
@@ -2599,7 +2589,7 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
AddressWithPrefix: addr2,
}
if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d, %s) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
}
// addr2 should be more preferred now since it is at the front of the primary
// list.
@@ -2664,8 +2654,6 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
// is already assigned to the NIC, the static address remains.
func TestAutoGenAddrStaticConflict(t *testing.T) {
- t.Parallel()
-
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
ndpDisp := ndpDispatcher{
@@ -2721,8 +2709,6 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use
// opaque interface identifiers when configured to do so.
func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
- t.Parallel()
-
const nicID = 1
const nicName = "nic1"
var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
@@ -2822,12 +2808,465 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
}
}
+// TestAutoGenAddrWithOpaqueIIDDADRetries tests the regeneration of an
+// auto-generated IPv6 address in response to a DAD conflict.
+func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) {
+ const nicID = 1
+ const nicName = "nic"
+ const dadTransmits = 1
+ const retransmitTimer = time.Second
+ const maxMaxRetries = 3
+ const lifetimeSeconds = 10
+
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
+ secretKey := secretKeyBuf[:]
+ n, err := rand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ }
+
+ prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
+
+ for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ {
+ for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ {
+ addrTypes := []struct {
+ name string
+ ndpConfigs stack.NDPConfigurations
+ autoGenLinkLocal bool
+ subnet tcpip.Subnet
+ triggerSLAACFn func(e *channel.Endpoint)
+ }{
+ {
+ name: "Global address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ subnet: subnet,
+ triggerSLAACFn: func(e *channel.Endpoint) {
+ // Receive an RA with prefix1 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+
+ },
+ },
+ {
+ name: "LinkLocal address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ autoGenLinkLocal: true,
+ subnet: header.IPv6LinkLocalPrefix.Subnet(),
+ triggerSLAACFn: func(e *channel.Endpoint) {},
+ },
+ }
+
+ for _, addrType := range addrTypes {
+ maxRetries := maxRetries
+ numFailures := numFailures
+ addrType := addrType
+
+ t.Run(fmt.Sprintf("%s with %d max retries and %d failures", addrType.name, maxRetries, numFailures), func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })
+ opts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ addrType.triggerSLAACFn(e)
+
+ // Simulate DAD conflicts so the address is regenerated.
+ for i := uint8(0); i < numFailures; i++ {
+ addrBytes := []byte(addrType.subnet.ID())
+ addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, i, secretKey)),
+ PrefixLen: 64,
+ }
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ // Should not have any addresses assigned to the NIC.
+ mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); mainAddr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want)
+ }
+
+ // Simulate a DAD conflict.
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+
+ // Attempting to add the address manually should not fail if the
+ // address's state was cleaned up when DAD failed.
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err)
+ }
+ if err := s.RemoveAddress(nicID, addr.Address); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
+ }
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+ }
+
+ // Should not have any addresses assigned to the NIC.
+ mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); mainAddr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want)
+ }
+
+ // If we had less failures than generation attempts, we should have an
+ // address after DAD resolves.
+ if maxRetries+1 > numFailures {
+ addrBytes := []byte(addrType.subnet.ID())
+ addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, numFailures, secretKey)),
+ PrefixLen: 64,
+ }
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+
+ mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ }
+ if mainAddr != addr {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, addr)
+ }
+ }
+
+ // Should not attempt address regeneration again.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
+ case <-time.After(defaultAsyncEventTimeout):
+ }
+ })
+ }
+ }
+ }
+}
+
+// TestAutoGenAddrWithEUI64IIDNoDADRetries tests that a regeneration attempt is
+// not made for SLAAC addresses generated with an IID based on the NIC's link
+// address.
+func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
+ const nicID = 1
+ const dadTransmits = 1
+ const retransmitTimer = time.Second
+ const maxRetries = 3
+ const lifetimeSeconds = 10
+
+ prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
+
+ addrTypes := []struct {
+ name string
+ ndpConfigs stack.NDPConfigurations
+ autoGenLinkLocal bool
+ subnet tcpip.Subnet
+ triggerSLAACFn func(e *channel.Endpoint)
+ }{
+ {
+ name: "Global address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ subnet: subnet,
+ triggerSLAACFn: func(e *channel.Endpoint) {
+ // Receive an RA with prefix1 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+
+ },
+ },
+ {
+ name: "LinkLocal address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ autoGenLinkLocal: true,
+ subnet: header.IPv6LinkLocalPrefix.Subnet(),
+ triggerSLAACFn: func(e *channel.Endpoint) {},
+ },
+ }
+
+ for _, addrType := range addrTypes {
+ addrType := addrType
+
+ t.Run(addrType.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ addrType.triggerSLAACFn(e)
+
+ addrBytes := []byte(addrType.subnet.ID())
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr1, addrBytes[header.IIDOffsetInIPv6Address:])
+ addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: 64,
+ }
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ // Simulate a DAD conflict.
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+
+ // Should not attempt address regeneration.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
+ case <-time.After(defaultAsyncEventTimeout):
+ }
+ })
+ }
+}
+
+// TestAutoGenAddrContinuesLifetimesAfterRetry tests that retrying address
+// generation in response to DAD conflicts does not refresh the lifetimes.
+func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
+ const nicID = 1
+ const nicName = "nic"
+ const dadTransmits = 1
+ const retransmitTimer = 2 * time.Second
+ const failureTimer = time.Second
+ const maxRetries = 1
+ const lifetimeSeconds = 5
+
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
+ secretKey := secretKeyBuf[:]
+ n, err := rand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ }
+
+ prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })
+ opts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+
+ addrBytes := []byte(subnet.ID())
+ addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 0, secretKey)),
+ PrefixLen: 64,
+ }
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ // Simulate a DAD conflict after some time has passed.
+ time.Sleep(failureTimer)
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+
+ // Let the next address resolve.
+ addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey))
+ expectAutoGenAddrEvent(addr, newAddr)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+
+ // Address should be deprecated/invalidated after the lifetime expires.
+ //
+ // Note, the remaining lifetime is calculated from when the PI was first
+ // processed. Since we wait for some time before simulating a DAD conflict
+ // and more time for the new address to resolve, the new address is only
+ // expected to be valid for the remaining time. The DAD conflict should
+ // not have reset the lifetimes.
+ //
+ // We expect either just the invalidation event or the deprecation event
+ // followed by the invalidation event.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if e.eventType == deprecatedAddr {
+ if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation")
+ }
+ } else {
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for auto gen addr event")
+ }
+}
+
// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event
// to the integrator when an RA is received with the NDP Recursive DNS Server
// option with at least one valid address.
func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
- t.Parallel()
-
tests := []struct {
name string
opt header.NDPRecursiveDNSServer
@@ -2919,11 +3358,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
}
for _, test := range tests {
- test := test
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
// We do not expect more than a single RDNSS
// event at any time for this test.
@@ -2970,11 +3405,115 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
}
}
+// TestNDPDNSSearchListDispatch tests that the integrator is informed when an
+// NDP DNS Search List option is received with at least one domain name in the
+// search list.
+func TestNDPDNSSearchListDispatch(t *testing.T) {
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dnsslC: make(chan ndpDNSSLEvent, 3),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ optSer := header.NDPOptionsSerializer{
+ header.NDPDNSSearchList([]byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 2, 'h', 'i',
+ 0,
+ }),
+ header.NDPDNSSearchList([]byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 1, 'i',
+ 0,
+ 2, 'a', 'm',
+ 2, 'm', 'e',
+ 0,
+ }),
+ header.NDPDNSSearchList([]byte{
+ 0, 0,
+ 0, 0, 1, 0,
+ 3, 'x', 'y', 'z',
+ 0,
+ 5, 'h', 'e', 'l', 'l', 'o',
+ 5, 'w', 'o', 'r', 'l', 'd',
+ 0,
+ 4, 't', 'h', 'i', 's',
+ 2, 'i', 's',
+ 1, 'a',
+ 4, 't', 'e', 's', 't',
+ 0,
+ }),
+ }
+ expected := []struct {
+ domainNames []string
+ lifetime time.Duration
+ }{
+ {
+ domainNames: []string{
+ "hi",
+ },
+ lifetime: 0,
+ },
+ {
+ domainNames: []string{
+ "i",
+ "am.me",
+ },
+ lifetime: time.Second,
+ },
+ {
+ domainNames: []string{
+ "xyz",
+ "hello.world",
+ "this.is.a.test",
+ },
+ lifetime: 256 * time.Second,
+ },
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer))
+
+ for i, expected := range expected {
+ select {
+ case dnssl := <-ndpDisp.dnsslC:
+ if dnssl.nicID != nicID {
+ t.Errorf("got %d-th dnssl nicID = %d, want = %d", i, dnssl.nicID, nicID)
+ }
+ if diff := cmp.Diff(dnssl.domainNames, expected.domainNames); diff != "" {
+ t.Errorf("%d-th dnssl domain names mismatch (-want +got):\n%s", i, diff)
+ }
+ if dnssl.lifetime != expected.lifetime {
+ t.Errorf("got %d-th dnssl lifetime = %s, want = %s", i, dnssl.lifetime, expected.lifetime)
+ }
+ default:
+ t.Fatal("expected a DNSSL event")
+ }
+ }
+
+ // Should have no more DNSSL options.
+ select {
+ case <-ndpDisp.dnsslC:
+ t.Fatal("unexpectedly got a DNSSL event")
+ default:
+ }
+}
+
// TestCleanupNDPState tests that all discovered routers and prefixes, and
// auto-generated addresses are invalidated when a NIC becomes a router.
func TestCleanupNDPState(t *testing.T) {
- t.Parallel()
-
const (
lifetimeSeconds = 5
maxRouterAndPrefixEvents = 4
@@ -3417,8 +3956,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
// TestRouterSolicitation tests the initial Router Solicitations that are sent
// when a NIC newly becomes enabled.
func TestRouterSolicitation(t *testing.T) {
- t.Parallel()
-
const nicID = 1
tests := []struct {
@@ -3435,13 +3972,22 @@ func TestRouterSolicitation(t *testing.T) {
effectiveMaxRtrSolicitDelay time.Duration
}{
{
- name: "Single RS with delay",
+ name: "Single RS with 2s delay and interval",
expectedSrcAddr: header.IPv6Any,
maxRtrSolicit: 1,
- rtrSolicitInt: time.Second,
- effectiveRtrSolicitInt: time.Second,
- maxRtrSolicitDelay: time.Second,
- effectiveMaxRtrSolicitDelay: time.Second,
+ rtrSolicitInt: 2 * time.Second,
+ effectiveRtrSolicitInt: 2 * time.Second,
+ maxRtrSolicitDelay: 2 * time.Second,
+ effectiveMaxRtrSolicitDelay: 2 * time.Second,
+ },
+ {
+ name: "Single RS with 4s delay and interval",
+ expectedSrcAddr: header.IPv6Any,
+ maxRtrSolicit: 1,
+ rtrSolicitInt: 4 * time.Second,
+ effectiveRtrSolicitInt: 4 * time.Second,
+ maxRtrSolicitDelay: 4 * time.Second,
+ effectiveMaxRtrSolicitDelay: 4 * time.Second,
},
{
name: "Two RS with delay",
@@ -3449,8 +3995,8 @@ func TestRouterSolicitation(t *testing.T) {
nicAddr: llAddr1,
expectedSrcAddr: llAddr1,
maxRtrSolicit: 2,
- rtrSolicitInt: time.Second,
- effectiveRtrSolicitInt: time.Second,
+ rtrSolicitInt: 2 * time.Second,
+ effectiveRtrSolicitInt: 2 * time.Second,
maxRtrSolicitDelay: 500 * time.Millisecond,
effectiveMaxRtrSolicitDelay: 500 * time.Millisecond,
},
@@ -3464,8 +4010,8 @@ func TestRouterSolicitation(t *testing.T) {
header.NDPSourceLinkLayerAddressOption(linkAddr1),
},
maxRtrSolicit: 1,
- rtrSolicitInt: time.Second,
- effectiveRtrSolicitInt: time.Second,
+ rtrSolicitInt: 2 * time.Second,
+ effectiveRtrSolicitInt: 2 * time.Second,
maxRtrSolicitDelay: 0,
effectiveMaxRtrSolicitDelay: 0,
},
@@ -3515,6 +4061,7 @@ func TestRouterSolicitation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
+
e := channelLinkWithHeaderLength{
Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
headerLength: test.linkHeaderLen,
@@ -3522,7 +4069,8 @@ func TestRouterSolicitation(t *testing.T) {
e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
waitForPkt := func(timeout time.Duration) {
t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), timeout)
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
p, ok := e.ReadContext(ctx)
if !ok {
t.Fatal("timed out waiting for packet")
@@ -3552,7 +4100,8 @@ func TestRouterSolicitation(t *testing.T) {
}
waitForNothing := func(timeout time.Duration) {
t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), timeout)
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet")
}
@@ -3583,15 +4132,19 @@ func TestRouterSolicitation(t *testing.T) {
}
for ; remaining > 0; remaining-- {
- waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout)
- waitForPkt(defaultAsyncEventTimeout)
+ if test.effectiveRtrSolicitInt > defaultAsyncEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncEventTimeout)
+ waitForPkt(2 * defaultAsyncEventTimeout)
+ } else {
+ waitForPkt(test.effectiveRtrSolicitInt * defaultAsyncEventTimeout)
+ }
}
// Make sure no more RS.
if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt + defaultTimeout)
+ waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncEventTimeout)
} else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultTimeout)
+ waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
}
// Make sure the counter got properly
@@ -3605,11 +4158,9 @@ func TestRouterSolicitation(t *testing.T) {
}
func TestStopStartSolicitingRouters(t *testing.T) {
- t.Parallel()
-
const nicID = 1
+ const delay = 0
const interval = 500 * time.Millisecond
- const delay = time.Second
const maxRtrSolicitations = 3
tests := []struct {
@@ -3684,7 +4235,6 @@ func TestStopStartSolicitingRouters(t *testing.T) {
p, ok := e.ReadContext(ctx)
if !ok {
t.Fatal("timed out waiting for packet")
- return
}
if p.Proto != header.IPv6ProtocolNumber {
@@ -3710,11 +4260,11 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stop soliciting routers.
test.stopFn(t, s, true /* first */)
- ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
- // A single RS may have been sent before forwarding was enabled.
- ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout)
+ // A single RS may have been sent before solicitations were stopped.
+ ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout)
defer cancel()
if _, ok = e.ReadContext(ctx); ok {
t.Fatal("should not have sent more than one RS message")
@@ -3724,7 +4274,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stopping router solicitations after it has already been stopped should
// do nothing.
test.stopFn(t, s, false /* first */)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
@@ -3740,7 +4290,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
waitForPkt(delay + defaultAsyncEventTimeout)
waitForPkt(interval + defaultAsyncEventTimeout)
waitForPkt(interval + defaultAsyncEventTimeout)
- ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
@@ -3749,7 +4299,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Starting router solicitations after it has already completed should do
// nothing.
test.startFn(t, s)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after finishing router solicitations")
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 4835251bc..016dbe15e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1012,29 +1012,31 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadLocalAddress
}
- isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr)
+ switch r.protocol {
+ case header.IPv6ProtocolNumber:
+ return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAAPrefixInvalidation */)
+ default:
+ r.expireLocked()
+ return nil
+ }
+}
+
+func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACPrefixInvalidation bool) *tcpip.Error {
+ addr := r.addrWithPrefix()
+
+ isIPv6Unicast := header.IsV6UnicastAddress(addr.Address)
if isIPv6Unicast {
- // If we are removing a tentative IPv6 unicast address, stop DAD.
- if kind == permanentTentative {
- n.mu.ndp.stopDuplicateAddressDetection(addr)
- }
+ n.mu.ndp.stopDuplicateAddressDetection(addr.Address)
// If we are removing an address generated via SLAAC, cleanup
// its SLAAC resources and notify the integrator.
if r.configType == slaac {
- n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: r.ep.PrefixLen(),
- })
+ n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACPrefixInvalidation)
}
}
- r.setKind(permanentExpired)
- if !r.decRefLocked() {
- // The endpoint still has references to it.
- return nil
- }
+ r.expireLocked()
// At this point the endpoint is deleted.
@@ -1044,7 +1046,7 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
// We ignore the tcpip.ErrBadLocalAddress error because the solicited-node
// multicast group may be left by user action.
if isIPv6Unicast {
- snmc := header.SolicitedNodeAddr(addr)
+ snmc := header.SolicitedNodeAddr(addr.Address)
if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
@@ -1425,10 +1427,12 @@ func (n *NIC) isAddrTentative(addr tcpip.Address) bool {
return ref.getKind() == permanentTentative
}
-// dupTentativeAddrDetected attempts to inform n that a tentative addr
-// is a duplicate on a link.
+// dupTentativeAddrDetected attempts to inform n that a tentative addr is a
+// duplicate on a link.
//
-// dupTentativeAddrDetected will delete the tentative address if it exists.
+// dupTentativeAddrDetected will remove the tentative address if it exists. If
+// the address was generated via SLAAC, an attempt will be made to generate a
+// new address.
func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
@@ -1442,7 +1446,17 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- return n.removePermanentAddressLocked(addr)
+ // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a
+ // new address will be generated for it.
+ if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACPrefixInvalidation */); err != nil {
+ return err
+ }
+
+ if ref.configType == slaac {
+ n.mu.ndp.regenerateSLAACAddr(ref.addrWithPrefix().Subnet())
+ }
+
+ return nil
}
// setNDPConfigs sets the NDP configurations for n.
@@ -1570,6 +1584,13 @@ type referencedNetworkEndpoint struct {
deprecated bool
}
+func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix {
+ return tcpip.AddressWithPrefix{
+ Address: r.ep.ID().LocalAddress,
+ PrefixLen: r.ep.PrefixLen(),
+ }
+}
+
func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind)))
}
@@ -1597,6 +1618,13 @@ func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing)
}
+// expireLocked decrements the reference count and marks the permanent endpoint
+// as expired.
+func (r *referencedNetworkEndpoint) expireLocked() {
+ r.setKind(permanentExpired)
+ r.decRefLocked()
+}
+
// decRef decrements the ref count and cleans up the endpoint once it reaches
// zero.
func (r *referencedNetworkEndpoint) decRef() {
@@ -1606,14 +1634,11 @@ func (r *referencedNetworkEndpoint) decRef() {
}
// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
-// locked. Returns true if the endpoint was removed.
-func (r *referencedNetworkEndpoint) decRefLocked() bool {
+// locked.
+func (r *referencedNetworkEndpoint) decRefLocked() {
if atomic.AddInt32(&r.refs, -1) == 0 {
r.nic.removeEndpointLocked(r)
- return true
}
-
- return false
}
// incRef increments the ref count. It must only be called when the caller is
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 9367de180..dc125f25e 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -23,9 +23,11 @@ import (
// As a PacketBuffer traverses up the stack, it may be necessary to pass it to
// multiple endpoints. Clone() should be called in such cases so that
// modifications to the Data field do not affect other copies.
-//
-// +stateify savable
type PacketBuffer struct {
+ // PacketBufferEntry is used to build an intrusive list of
+ // PacketBuffers.
+ PacketBufferEntry
+
// Data holds the payload of the packet. For inbound packets, it also
// holds the headers, which are consumed as the packet moves up the
// stack. Headers are guaranteed not to be split across views.
@@ -34,14 +36,6 @@ type PacketBuffer struct {
// or otherwise modified.
Data buffer.VectorisedView
- // DataOffset is used for GSO output. It is the offset into the Data
- // field where the payload of this packet starts.
- DataOffset int
-
- // DataSize is used for GSO output. It is the size of this packet's
- // payload.
- DataSize int
-
// Header holds the headers of outbound packets. As a packet is passed
// down the stack, each layer adds to Header.
Header buffer.Prependable
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index ac043b722..23ca9ee03 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -246,7 +246,7 @@ type NetworkEndpoint interface {
// WritePackets writes packets to the given destination address and
// protocol. pkts must not be zero length.
- WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error)
+ WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address.
@@ -393,7 +393,7 @@ type LinkEndpoint interface {
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, it may
// require to change syscall filters.
- WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+ WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
// WriteRawPacket writes a packet directly to the link. The packet
// should already have an ethernet header.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 9fbe8a411..a0e5e0300 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -168,23 +168,26 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuff
return err
}
-// WritePackets writes the set of packets through the given route.
-func (r *Route) WritePackets(gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) {
+// WritePackets writes a list of n packets through the given route and returns
+// the number of packets written.
+func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
if !r.ref.isValidForOutgoing() {
return 0, tcpip.ErrInvalidEndpointState
}
n, err := r.ref.ep.WritePackets(r, gso, pkts, params)
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
}
r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n))
- payloadSize := 0
- for i := 0; i < n; i++ {
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkts[i].Header.UsedLength()))
- payloadSize += pkts[i].DataSize
+
+ writtenBytes := 0
+ for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() {
+ writtenBytes += pb.Header.UsedLength()
+ writtenBytes += pb.Data.Size()
}
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize))
+
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
return n, err
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 555fcd92f..c7634ceb1 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -153,7 +153,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
panic("not implemented")
}
@@ -1445,19 +1445,19 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}}
if err := s.AddProtocolAddress(1, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err)
+ t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err)
}
r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
+ t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
- t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
+ t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
// If the NIC doesn't exist, it won't work.
if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
}
}
@@ -1483,12 +1483,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
}
nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr}
if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err)
+ t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err)
}
nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr}
if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
- t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err)
+ t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err)
}
// Set the initial route table.
@@ -1503,10 +1503,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
// When an interface is given, the route for a broadcast goes through it.
r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
+ t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
- t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
+ t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
// When an interface is not given, it consults the route table.
@@ -2399,7 +2399,7 @@ func TestNICContextPreservation(t *testing.T) {
t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos)
}
if got, want := nicinfo.Context == test.want, true; got != want {
- t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want)
+ t.Fatalf("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want)
}
})
}
@@ -2768,7 +2768,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
if err != nil {
- t.Fatalf("NewSubnet failed:", err)
+ t.Fatalf("NewSubnet failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
@@ -2782,11 +2782,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// permanentExpired kind.
r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false)
if err != nil {
- t.Fatal("FindRoute failed:", err)
+ t.Fatalf("FindRoute failed: %v", err)
}
defer r.Release()
if err := s.RemoveAddress(1, "\x01"); err != nil {
- t.Fatalf("RemoveAddress failed:", err)
+ t.Fatalf("RemoveAddress failed: %v", err)
}
//
@@ -2798,7 +2798,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// Add some other address with peb set to
// FirstPrimaryEndpoint.
if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil {
- t.Fatal("AddAddressWithOptions failed:", err)
+ t.Fatalf("AddAddressWithOptions failed: %v", err)
}
@@ -2806,7 +2806,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// make sure the new peb was respected.
// (The address should just be promoted now).
if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil {
- t.Fatal("AddAddressWithOptions failed:", err)
+ t.Fatalf("AddAddressWithOptions failed: %v", err)
}
var primaryAddrs []tcpip.Address
for _, pa := range s.NICInfo()[1].ProtocolAddresses {
@@ -2839,11 +2839,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// GetMainNICAddress; else, our original address
// should be returned.
if err := s.RemoveAddress(1, "\x03"); err != nil {
- t.Fatalf("RemoveAddress failed:", err)
+ t.Fatalf("RemoveAddress failed: %v", err)
}
addr, err = s.GetMainNICAddress(1, fakeNetNumber)
if err != nil {
- t.Fatal("s.GetMainNICAddress failed:", err)
+ t.Fatalf("s.GetMainNICAddress failed: %v", err)
}
if ps == stack.NeverPrimaryEndpoint {
if want := (tcpip.AddressWithPrefix{}); addr != want {
@@ -3176,8 +3176,6 @@ func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC
// was disabled have DAD performed on them when the NIC is enabled.
func TestDoDADWhenNICEnabled(t *testing.T) {
- t.Parallel()
-
const dadTransmits = 1
const retransmitTimer = time.Second
const nicID = 1
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index c65b0c632..2474a7db3 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -206,7 +206,7 @@ func TestTransportDemuxerRegister(t *testing.T) {
// the distribution of packets received matches expectations.
func TestBindToDeviceDistribution(t *testing.T) {
type endpointSockopts struct {
- reuse int
+ reuse bool
bindToDevice tcpip.NICID
}
for _, test := range []struct {
@@ -221,11 +221,11 @@ func TestBindToDeviceDistribution(t *testing.T) {
"BindPortReuse",
// 5 endpoints that all have reuse set.
[]endpointSockopts{
- {reuse: 1, bindToDevice: 0},
- {reuse: 1, bindToDevice: 0},
- {reuse: 1, bindToDevice: 0},
- {reuse: 1, bindToDevice: 0},
- {reuse: 1, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed evenly.
@@ -236,9 +236,9 @@ func TestBindToDeviceDistribution(t *testing.T) {
"BindToDevice",
// 3 endpoints with various bindings.
[]endpointSockopts{
- {reuse: 0, bindToDevice: 1},
- {reuse: 0, bindToDevice: 2},
- {reuse: 0, bindToDevice: 3},
+ {reuse: false, bindToDevice: 1},
+ {reuse: false, bindToDevice: 2},
+ {reuse: false, bindToDevice: 3},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 go only to the endpoint bound to dev0.
@@ -253,12 +253,12 @@ func TestBindToDeviceDistribution(t *testing.T) {
"ReuseAndBindToDevice",
// 6 endpoints with various bindings.
[]endpointSockopts{
- {reuse: 1, bindToDevice: 1},
- {reuse: 1, bindToDevice: 1},
- {reuse: 1, bindToDevice: 2},
- {reuse: 1, bindToDevice: 2},
- {reuse: 1, bindToDevice: 2},
- {reuse: 1, bindToDevice: 0},
+ {reuse: true, bindToDevice: 1},
+ {reuse: true, bindToDevice: 1},
+ {reuse: true, bindToDevice: 2},
+ {reuse: true, bindToDevice: 2},
+ {reuse: true, bindToDevice: 2},
+ {reuse: true, bindToDevice: 0},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed among endpoints bound to
@@ -309,9 +309,8 @@ func TestBindToDeviceDistribution(t *testing.T) {
}(ep)
defer ep.Close()
- reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
- if err := ep.SetSockOpt(reusePortOption); err != nil {
- t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", reusePortOption, i, err)
+ if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil {
+ t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err)
}
bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 2ef3271f1..1ca4088c9 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -520,34 +520,90 @@ type WriteOptions struct {
type SockOptBool int
const (
+ // BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether
+ // datagram sockets are allowed to send packets to a broadcast address.
+ BroadcastOption SockOptBool = iota
+
+ // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be
+ // held until segments are full by the TCP transport protocol.
+ CorkOption
+
+ // DelayOption is used by SetSockOpt/GetSockOpt to specify if data
+ // should be sent out immediately by the transport protocol. For TCP,
+ // it determines if the Nagle algorithm is on or off.
+ DelayOption
+
+ // KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether
+ // TCP keepalive is enabled for this socket.
+ KeepaliveEnabledOption
+
+ // MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether
+ // multicast packets sent over a non-loopback interface will be looped back.
+ MulticastLoopOption
+
+ // PasscredOption is used by SetSockOpt/GetSockOpt to specify whether
+ // SCM_CREDENTIALS socket control messages are enabled.
+ //
+ // Only supported on Unix sockets.
+ PasscredOption
+
+ // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
+ QuickAckOption
+
// ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the
// IPV6_TCLASS ancillary message is passed with incoming packets.
- ReceiveTClassOption SockOptBool = iota
+ ReceiveTClassOption
// ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS
// ancillary message is passed with incoming packets.
ReceiveTOSOption
- // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
- // socket is to be restricted to sending and receiving IPv6 packets only.
- V6OnlyOption
-
// ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify
// if more inforamtion is provided with incoming packets such
// as interface index and address.
ReceiveIPPacketInfoOption
- // TODO(b/146901447): convert existing bool socket options to be handled via
- // Get/SetSockOptBool
+ // ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind()
+ // should allow reuse of local address.
+ ReuseAddressOption
+
+ // ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets
+ // to be bound to an identical socket address.
+ ReusePortOption
+
+ // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
+ // socket is to be restricted to sending and receiving IPv6 packets only.
+ V6OnlyOption
)
// SockOptInt represents socket options which values have the int type.
type SockOptInt int
const (
+ // KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number
+ // of un-ACKed TCP keepalives that will be sent before the connection is
+ // closed.
+ KeepaliveCountOption SockOptInt = iota
+
+ // IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS
+ // for all subsequent outgoing IPv4 packets from the endpoint.
+ IPv4TOSOption
+
+ // IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS
+ // for all subsequent outgoing IPv6 packets from the endpoint.
+ IPv6TrafficClassOption
+
+ // MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current
+ // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option.
+ MaxSegOption
+
+ // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
+ // TTL value for multicast messages. The default is 1.
+ MulticastTTLOption
+
// ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
// number of unread bytes in the input buffer should be returned.
- ReceiveQueueSizeOption SockOptInt = iota
+ ReceiveQueueSizeOption
// SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
// specify the send buffer size option.
@@ -561,44 +617,21 @@ const (
// number of unread bytes in the output buffer should be returned.
SendQueueSizeOption
- // DelayOption is used by SetSockOpt/GetSockOpt to specify if data
- // should be sent out immediately by the transport protocol. For TCP,
- // it determines if the Nagle algorithm is on or off.
- DelayOption
-
- // TODO(b/137664753): convert all int socket options to be handled via
- // GetSockOptInt.
+ // TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop
+ // limit value for unicast messages. The default is protocol specific.
+ //
+ // A zero value indicates the default.
+ TTLOption
)
// ErrorOption is used in GetSockOpt to specify that the last error reported by
// the endpoint should be cleared and returned.
type ErrorOption struct{}
-// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be
-// held until segments are full by the TCP transport protocol.
-type CorkOption int
-
-// ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind()
-// should allow reuse of local address.
-type ReuseAddressOption int
-
-// ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets
-// to be bound to an identical socket address.
-type ReusePortOption int
-
// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
// should bind only on a specific NIC.
type BindToDeviceOption NICID
-// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
-type QuickAckOption int
-
-// PasscredOption is used by SetSockOpt/GetSockOpt to specify whether
-// SCM_CREDENTIALS socket control messages are enabled.
-//
-// Only supported on Unix sockets.
-type PasscredOption int
-
// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
//
// TODO(b/64800844): Add and populate stat fields.
@@ -607,10 +640,6 @@ type TCPInfoOption struct {
RTTVar time.Duration
}
-// KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether
-// TCP keepalive is enabled for this socket.
-type KeepaliveEnabledOption int
-
// KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a
// connection must remain idle before the first TCP keepalive packet is sent.
// Once this time is reached, KeepaliveIntervalOption is used instead.
@@ -620,11 +649,6 @@ type KeepaliveIdleOption time.Duration
// interval between sending TCP keepalive packets.
type KeepaliveIntervalOption time.Duration
-// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number
-// of un-ACKed TCP keepalives that will be sent before the connection is
-// closed.
-type KeepaliveCountOption int
-
// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user
// specified timeout for a given TCP connection.
// See: RFC5482 for details.
@@ -638,20 +662,9 @@ type CongestionControlOption string
// control algorithms.
type AvailableCongestionControlOption string
-// ModerateReceiveBufferOption allows the caller to enable/disable TCP receive
// buffer moderation.
type ModerateReceiveBufferOption bool
-// MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current
-// Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option.
-type MaxSegOption int
-
-// TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop
-// limit value for unicast messages. The default is protocol specific.
-//
-// A zero value indicates the default.
-type TTLOption uint8
-
// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state
// before being marked closed.
@@ -668,9 +681,14 @@ type TCPTimeWaitTimeoutOption time.Duration
// for a handshake till the specified timeout until a segment with data arrives.
type TCPDeferAcceptOption time.Duration
-// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
-// TTL value for multicast messages. The default is 1.
-type MulticastTTLOption uint8
+// TCPMinRTOOption is use by SetSockOpt/GetSockOpt to allow overriding
+// default MinRTO used by the Stack.
+type TCPMinRTOOption time.Duration
+
+// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify
+// the number of endpoints that can be in SYN-RCVD state before the stack
+// switches to using SYN cookies.
+type TCPSynRcvdCountThresholdOption uint64
// MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a
// default interface for multicast.
@@ -679,10 +697,6 @@ type MulticastInterfaceOption struct {
InterfaceAddr Address
}
-// MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether
-// multicast packets sent over a non-loopback interface will be looped back.
-type MulticastLoopOption bool
-
// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to
// AddMembershipOption and RemoveMembershipOption.
type MembershipOption struct {
@@ -705,22 +719,10 @@ type RemoveMembershipOption MembershipOption
// TCP out-of-band data is delivered along with the normal in-band data.
type OutOfBandInlineOption int
-// BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether
-// datagram sockets are allowed to send packets to a broadcast address.
-type BroadcastOption int
-
// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify
// a default TTL.
type DefaultTTLOption uint8
-// IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS
-// for all subsequent outgoing IPv4 packets from the endpoint.
-type IPv4TOSOption uint8
-
-// IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS
-// for all subsequent outgoing IPv6 packets from the endpoint.
-type IPv6TrafficClassOption uint8
-
// IPPacketInfo is the message struture for IP_PKTINFO.
//
// +stateify savable
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index 8c0aacffa..1c8e2bc34 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -218,7 +218,7 @@ func TestAddressWithPrefixSubnet(t *testing.T) {
gotSubnet := ap.Subnet()
wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask)
if err != nil {
- t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err)
+ t.Errorf("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err)
continue
}
if gotSubnet != wantSubnet {
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index b007302fb..feef8dca0 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -348,13 +348,6 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(o)
- e.mu.Unlock()
- }
-
return nil
}
@@ -365,12 +358,25 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt sets a socket option. Currently not supported.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+
+ }
return nil
}
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -397,26 +403,23 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
+ case tcpip.TTLOption:
+ e.rcvMu.Lock()
+ v := int(e.ttl)
+ e.rcvMu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
}
- return -1, tcpip.ErrUnknownProtocolOption
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
+ switch opt.(type) {
case tcpip.ErrorOption:
return nil
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
- case *tcpip.TTLOption:
- e.rcvMu.Lock()
- *o = tcpip.TTLOption(e.ttl)
- e.rcvMu.Unlock()
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 337bc1c71..eee754a5a 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -533,14 +533,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
+ switch opt.(type) {
case tcpip.ErrorOption:
return nil
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -548,7 +544,13 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -576,9 +578,9 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
}
-
- return -1, tcpip.ErrUnknownProtocolOption
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 7f94f9646..61426623c 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -87,7 +87,9 @@ go_test(
"tcp_timestamp_test.go",
],
# FIXME(b/68809571)
- tags = ["flaky"],
+ tags = [
+ "flaky",
+ ],
deps = [
":tcp",
"//pkg/sync",
@@ -104,5 +106,16 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/tcp/testing/context",
"//pkg/waiter",
+ "//runsc/testutil",
+ ],
+)
+
+go_test(
+ name = "rcv_test",
+ size = "small",
+ srcs = ["rcv_test.go"],
+ deps = [
+ ":tcp",
+ "//pkg/tcpip/seqnum",
],
)
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 7a9dea4ac..e6a23c978 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -17,6 +17,7 @@ package tcp
import (
"crypto/sha1"
"encoding/binary"
+ "fmt"
"hash"
"io"
"time"
@@ -25,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -49,17 +49,14 @@ const (
// timestamp and the current timestamp. If the difference is greater
// than maxTSDiff, the cookie is expired.
maxTSDiff = 2
-)
-var (
- // SynRcvdCountThreshold is the global maximum number of connections
- // that are allowed to be in SYN-RCVD state before TCP starts using SYN
- // cookies to accept connections.
- //
- // It is an exported variable only for testing, and should not otherwise
- // be used by importers of this package.
+ // SynRcvdCountThreshold is the default global maximum number of
+ // connections that are allowed to be in SYN-RCVD state before TCP
+ // starts using SYN cookies to accept connections.
SynRcvdCountThreshold uint64 = 1000
+)
+var (
// mssTable is a slice containing the possible MSS values that we
// encode in the SYN cookie with two bits.
mssTable = []uint16{536, 1300, 1440, 1460}
@@ -74,29 +71,42 @@ func encodeMSS(mss uint16) uint32 {
return 0
}
-// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is
-// protected by a mutex so that we can increment only when it's guaranteed not
-// to go above a threshold.
-var synRcvdCount struct {
- sync.Mutex
- value uint64
- pending sync.WaitGroup
-}
-
// listenContext is used by a listening endpoint to store state used while
// listening for connections. This struct is allocated by the listen goroutine
// and must not be accessed or have its methods called concurrently as they
// may mutate the stored objects.
type listenContext struct {
- stack *stack.Stack
- rcvWnd seqnum.Size
- nonce [2][sha1.BlockSize]byte
+ stack *stack.Stack
+
+ // synRcvdCount is a reference to the stack level synRcvdCount.
+ synRcvdCount *synRcvdCounter
+
+ // rcvWnd is the receive window that is sent by this listening context
+ // in the initial SYN-ACK.
+ rcvWnd seqnum.Size
+
+ // nonce are random bytes that are initialized once when the context
+ // is created and used to seed the hash function when generating
+ // the SYN cookie.
+ nonce [2][sha1.BlockSize]byte
+
+ // listenEP is a reference to the listening endpoint associated with
+ // this context. Can be nil if the context is created by the forwarder.
listenEP *endpoint
+ // hasherMu protects hasher.
hasherMu sync.Mutex
- hasher hash.Hash
- v6only bool
+ // hasher is the hash function used to generate a SYN cookie.
+ hasher hash.Hash
+
+ // v6Only is true if listenEP is a dual stack socket and has the
+ // IPV6_V6ONLY option set.
+ v6Only bool
+
+ // netProto indicates the network protocol(IPv4/v6) for the listening
+ // endpoint.
netProto tcpip.NetworkProtocolNumber
+
// pendingMu protects pendingEndpoints. This should only be accessed
// by the listening endpoint's worker goroutine.
//
@@ -115,55 +125,22 @@ func timeStamp() uint32 {
return uint32(time.Now().Unix()>>6) & tsMask
}
-// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD
-// state. It succeeds if the increment doesn't make the count go beyond the
-// threshold, and fails otherwise.
-func incSynRcvdCount() bool {
- synRcvdCount.Lock()
-
- if synRcvdCount.value >= SynRcvdCountThreshold {
- synRcvdCount.Unlock()
- return false
- }
-
- synRcvdCount.pending.Add(1)
- synRcvdCount.value++
-
- synRcvdCount.Unlock()
- return true
-}
-
-// decSynRcvdCount atomically decrements the global number of endpoints in
-// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount
-// succeeded.
-func decSynRcvdCount() {
- synRcvdCount.Lock()
-
- synRcvdCount.value--
- synRcvdCount.pending.Done()
- synRcvdCount.Unlock()
-}
-
-// synCookiesInUse() returns true if the synRcvdCount is greater than
-// SynRcvdCountThreshold.
-func synCookiesInUse() bool {
- synRcvdCount.Lock()
- v := synRcvdCount.value
- synRcvdCount.Unlock()
- return v >= SynRcvdCountThreshold
-}
-
// newListenContext creates a new listen context.
-func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
stack: stk,
rcvWnd: rcvWnd,
hasher: sha1.New(),
- v6only: v6only,
+ v6Only: v6Only,
netProto: netProto,
listenEP: listenEP,
pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
}
+ p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol)
+ if !ok {
+ panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk))
+ }
+ l.synRcvdCount = p.SynRcvdCounter()
rand.Read(l.nonce[0][:])
rand.Read(l.nonce[1][:])
@@ -230,7 +207,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
netProto = s.route.NetProto
}
n := newEndpoint(l.stack, netProto, queue)
- n.v6only = l.v6only
+ n.v6only = l.v6Only
n.ID = s.id
n.boundNICID = s.route.NICID()
n.route = s.route.Clone()
@@ -316,7 +293,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
}
// Perform the 3-way handshake.
- h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
+ h := newPassiveHandshake(ep, ep.rcv.rcvWnd, isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.mu.Unlock()
ep.Close()
@@ -330,6 +307,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if l.listenEP != nil {
l.removePendingEndpoint(ep)
}
+
+ ep.drainClosingSegmentQueue()
+
return nil, err
}
ep.isConnectNotified = true
@@ -378,7 +358,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
for {
if e.acceptedChan == nil {
e.acceptMu.Unlock()
- n.Close()
+ n.notifyProtocolGoroutine(notifyReset)
return
}
select {
@@ -407,7 +387,7 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
// A limited number of these goroutines are allowed before TCP starts using SYN
// cookies to accept connections.
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
- defer decSynRcvdCount()
+ defer ctx.synRcvdCount.dec()
defer func() {
e.mu.Lock()
e.decSynRcvdCount()
@@ -452,19 +432,16 @@ func (e *endpoint) acceptQueueIsFull() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
- if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) {
+ e.rcvListMu.Lock()
+ rcvClosed := e.rcvClosed
+ e.rcvListMu.Unlock()
+ if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) {
+ // If the endpoint is shutdown, reply with reset.
+ //
// RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
// must be sent in response to a SYN-ACK while in the listen
// state to prevent completing a handshake from an old SYN.
- e.sendTCP(&s.route, tcpFields{
- id: s.id,
- ttl: e.ttl,
- tos: e.sendTOS,
- flags: header.TCPFlagRst,
- seq: s.ackNumber,
- ack: 0,
- rcvWnd: 0,
- }, buffer.VectorisedView{}, nil)
+ replyWithReset(s, e.sendTOS, e.ttl)
return
}
@@ -474,7 +451,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
switch {
case s.flags == header.TCPFlagSyn:
opts := parseSynSegmentOptions(s)
- if incSynRcvdCount() {
+ if ctx.synRcvdCount.inc() {
// Only handle the syn if the following conditions hold
// - accept queue is not full.
// - number of connections in synRcvd state is less than the
@@ -484,7 +461,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
return
}
- decSynRcvdCount()
+ ctx.synRcvdCount.dec()
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
@@ -537,7 +514,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
- if !synCookiesInUse() {
+ if !ctx.synRcvdCount.synCookiesInUse() {
// When not using SYN cookies, as per RFC 793, section 3.9, page 64:
// Any acknowledgment is bad if it arrives on a connection still in
// the LISTEN state. An acceptable reset segment should be formed
@@ -553,7 +530,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// The only time we should reach here when a connection
// was opened and closed really quickly and a delayed
// ACK was received from the sender.
- replyWithReset(s)
+ replyWithReset(s, e.sendTOS, e.ttl)
return
}
@@ -636,8 +613,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// its own goroutine and is responsible for handling connection requests.
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
- v6only := e.v6only
- ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto)
+ v6Only := e.v6only
+ ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
@@ -656,6 +633,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
}
e.mu.Unlock()
+ e.drainClosingSegmentQueue()
+
// Notify waiters that the endpoint is shutdown.
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 3239a5911..76e27bf26 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -105,24 +105,11 @@ type handshake struct {
}
func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
- rcvWndScale := ep.rcvWndScaleForHandshake()
-
- // Round-down the rcvWnd to a multiple of wndScale. This ensures that the
- // window offered in SYN won't be reduced due to the loss of precision if
- // window scaling is enabled after the handshake.
- rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale)
-
- // Ensure we can always accept at least 1 byte if the scale specified
- // was too high for the provided rcvWnd.
- if rcvWnd == 0 {
- rcvWnd = 1
- }
-
h := handshake{
ep: ep,
active: true,
rcvWnd: rcvWnd,
- rcvWndScale: int(rcvWndScale),
+ rcvWndScale: ep.rcvWndScaleForHandshake(),
}
h.resetState()
return h
@@ -756,8 +743,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV
func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) {
optLen := len(tf.opts)
hdr := &pkt.Header
- packetSize := pkt.DataSize
- off := pkt.DataOffset
+ packetSize := pkt.Data.Size()
// Initialize the header.
tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
pkt.TransportHeader = buffer.View(tcp)
@@ -782,12 +768,18 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta
// header and data and get the right sum of the TCP packet.
tcp.SetChecksum(xsum)
} else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
- xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize)
+ xsum = header.ChecksumVV(pkt.Data, xsum)
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
}
func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
+ // We need to shallow clone the VectorisedView here as ReadToView will
+ // split the VectorisedView and Trim underlying views as it splits. Not
+ // doing the clone here will cause the underlying views of data itself
+ // to be altered.
+ data = data.Clone(nil)
+
optLen := len(tf.opts)
if tf.rcvWnd > 0xffff {
tf.rcvWnd = 0xffff
@@ -796,31 +788,25 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
mss := int(gso.MSS)
n := (data.Size() + mss - 1) / mss
- // Allocate one big slice for all the headers.
- hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen
- buf := make([]byte, n*hdrSize)
- pkts := make([]stack.PacketBuffer, n)
- for i := range pkts {
- pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize])
- }
-
size := data.Size()
- off := 0
+ hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen
+ var pkts stack.PacketBufferList
for i := 0; i < n; i++ {
packetSize := mss
if packetSize > size {
packetSize = size
}
size -= packetSize
- pkts[i].DataOffset = off
- pkts[i].DataSize = packetSize
- pkts[i].Data = data
- pkts[i].Hash = tf.txHash
- pkts[i].Owner = owner
- buildTCPHdr(r, tf, &pkts[i], gso)
- off += packetSize
+ var pkt stack.PacketBuffer
+ pkt.Header = buffer.NewPrependable(hdrSize)
+ pkt.Hash = tf.txHash
+ pkt.Owner = owner
+ data.ReadToVV(&pkt.Data, packetSize)
+ buildTCPHdr(r, tf, &pkt, gso)
tf.seq = tf.seq.Add(seqnum.Size(packetSize))
+ pkts.PushBack(&pkt)
}
+
if tf.ttl == 0 {
tf.ttl = r.DefaultTTL()
}
@@ -845,12 +831,10 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac
}
pkt := stack.PacketBuffer{
- Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
- DataOffset: 0,
- DataSize: data.Size(),
- Data: data,
- Hash: tf.txHash,
- Owner: owner,
+ Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
+ Data: data,
+ Hash: tf.txHash,
+ Owner: owner,
}
buildTCPHdr(r, tf, &pkt, gso)
@@ -1056,15 +1040,34 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route)
}
if ep == nil {
- replyWithReset(s)
+ replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
s.decRef()
return
}
+
+ if e == ep {
+ panic("current endpoint not removed from demuxer, enqueing segments to itself")
+ }
+
if ep.(*endpoint).enqueueSegment(s) {
ep.(*endpoint).newSegmentWaker.Assert()
}
}
+// Drain segment queue from the endpoint and try to re-match the segment to a
+// different endpoint. This is used when the current endpoint is transitioned to
+// StateClose and has been unregistered from the transport demuxer.
+func (e *endpoint) drainClosingSegmentQueue() {
+ for {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ break
+ }
+
+ e.tryDeliverSegmentFromClosedEndpoint(s)
+ }
+}
+
func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
if e.rcv.acceptable(s.sequenceNumber, 0) {
// RFC 793, page 37 states that "in all states
@@ -1318,6 +1321,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
e.mu.Unlock()
+
+ e.drainClosingSegmentQueue()
+
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -1568,19 +1574,6 @@ loop:
// Lock released below.
epilogue()
- // epilogue removes the endpoint from the transport-demuxer and
- // unlocks e.mu. Now that no new segments can get enqueued to this
- // endpoint, try to re-match the segment to a different endpoint
- // as the current endpoint is closed.
- for {
- s := e.segmentQueue.dequeue()
- if s == nil {
- break
- }
-
- e.tryDeliverSegmentFromClosedEndpoint(s)
- }
-
// A new SYN was received during TIME_WAIT and we need to abort
// the timewait and redirect the segment to the listener queue
if reuseTW != nil {
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 4f361b226..804e95aea 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -568,11 +568,10 @@ func TestV4AcceptOnV4(t *testing.T) {
func testV4ListenClose(t *testing.T, c *context.Context) {
// Set the SynRcvd threshold to zero to force a syn cookie based accept
// to happen.
- saved := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = saved
- }()
- tcp.SynRcvdCountThreshold = 0
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err)
+ }
+
const n = uint16(32)
// Start listening.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 9b123e968..45f2aa78b 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -821,7 +821,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
var de DelayEnabled
if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
- e.SetSockOptInt(tcpip.DelayOption, 1)
+ e.SetSockOptBool(tcpip.DelayOption, true)
}
var tcpLT tcpip.TCPLingerTimeoutOption
@@ -980,25 +980,22 @@ func (e *endpoint) closeNoShutdownLocked() {
// Mark endpoint as closed.
e.closed = true
+
+ switch e.EndpointState() {
+ case StateClose, StateError:
+ return
+ }
+
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
- switch e.EndpointState() {
- // Sockets in StateSynRecv state(passive connections) are closed when
- // the handshake fails or if the listening socket is closed while
- // handshake was in progress. In such cases the handshake goroutine
- // is already gone by the time Close is called and we need to cleanup
- // here.
- case StateInitial, StateBound, StateSynRecv:
- e.cleanupLocked()
- e.setEndpointState(StateClose)
- case StateError, StateClose:
- // do nothing.
- default:
+ if e.workerRunning {
e.workerCleanup = true
tcpip.AddDanglingEndpoint(e)
// Worker will remove the dangling endpoint when the endpoint
// goroutine terminates.
e.notifyProtocolGoroutine(notifyClose)
+ } else {
+ e.transitionToStateCloseLocked()
}
}
@@ -1010,13 +1007,18 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
e.acceptMu.Unlock()
return
}
-
close(e.acceptedChan)
+ ch := e.acceptedChan
e.acceptedChan = nil
e.acceptCond.Broadcast()
e.acceptMu.Unlock()
- // Wait for all pending endpoints to close.
+ // Reset all connections that are waiting to be accepted.
+ for n := range ch {
+ n.notifyProtocolGoroutine(notifyReset)
+ }
+ // Wait for reset of all endpoints that are still waiting to be delivered to
+ // the now closed acceptedChan.
e.pendingAccepted.Wait()
}
@@ -1060,6 +1062,19 @@ func (e *endpoint) initialReceiveWindow() int {
if rcvWnd > routeWnd {
rcvWnd = routeWnd
}
+ rcvWndScale := e.rcvWndScaleForHandshake()
+
+ // Round-down the rcvWnd to a multiple of wndScale. This ensures that the
+ // window offered in SYN won't be reduced due to the loss of precision if
+ // window scaling is enabled after the handshake.
+ rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale)
+
+ // Ensure we can always accept at least 1 byte if the scale specified
+ // was too high for the provided rcvWnd.
+ if rcvWnd == 0 {
+ rcvWnd = 1
+ }
+
return rcvWnd
}
@@ -1409,10 +1424,58 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo
// SetSockOptBool sets a socket option.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- e.LockUser()
- defer e.UnlockUser()
-
switch opt {
+
+ case tcpip.BroadcastOption:
+ e.LockUser()
+ e.broadcast = v
+ e.UnlockUser()
+
+ case tcpip.CorkOption:
+ e.LockUser()
+ if !v {
+ atomic.StoreUint32(&e.cork, 0)
+
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.cork, 1)
+ }
+ e.UnlockUser()
+
+ case tcpip.DelayOption:
+ if v {
+ atomic.StoreUint32(&e.delay, 1)
+ } else {
+ atomic.StoreUint32(&e.delay, 0)
+
+ // Handle delayed data.
+ e.sndWaker.Assert()
+ }
+
+ case tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ e.keepalive.enabled = v
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+
+ case tcpip.QuickAckOption:
+ o := uint32(1)
+ if v {
+ o = 0
+ }
+ atomic.StoreUint32(&e.slowAck, o)
+
+ case tcpip.ReuseAddressOption:
+ e.LockUser()
+ e.reuseAddr = v
+ e.UnlockUser()
+
+ case tcpip.ReusePortOption:
+ e.LockUser()
+ e.reusePort = v
+ e.UnlockUser()
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -1424,7 +1487,9 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+ e.LockUser()
e.v6only = v
+ e.UnlockUser()
}
return nil
@@ -1432,18 +1497,50 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt sets a socket option.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ // Lower 2 bits represents ECN bits. RFC 3168, section 23.1
+ const inetECNMask = 3
+
switch opt {
+ case tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ e.keepalive.count = v
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+
+ case tcpip.IPv4TOSOption:
+ e.LockUser()
+ // TODO(gvisor.dev/issue/995): ECN is not currently supported,
+ // ignore the bits for now.
+ e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+ e.UnlockUser()
+
+ case tcpip.IPv6TrafficClassOption:
+ e.LockUser()
+ // TODO(gvisor.dev/issue/995): ECN is not currently supported,
+ // ignore the bits for now.
+ e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+ e.UnlockUser()
+
+ case tcpip.MaxSegOption:
+ userMSS := v
+ if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
+ return tcpip.ErrInvalidOptionValue
+ }
+ e.LockUser()
+ e.userMSS = uint16(userMSS)
+ e.UnlockUser()
+ e.notifyProtocolGoroutine(notifyMSSChanged)
+
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
var rs ReceiveBufferSizeOption
- size := int(v)
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
- if size < rs.Min {
- size = rs.Min
+ if v < rs.Min {
+ v = rs.Min
}
- if size > rs.Max {
- size = rs.Max
+ if v > rs.Max {
+ v = rs.Max
}
}
@@ -1458,17 +1555,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
if e.rcv != nil {
scale = e.rcv.rcvWndScale
}
- if size>>scale == 0 {
- size = 1 << scale
+ if v>>scale == 0 {
+ v = 1 << scale
}
// Make sure 2*size doesn't overflow.
- if size > math.MaxInt32/2 {
- size = math.MaxInt32 / 2
+ if v > math.MaxInt32/2 {
+ v = math.MaxInt32 / 2
}
availBefore := e.receiveBufferAvailableLocked()
- e.rcvBufSize = size
+ e.rcvBufSize = v
availAfter := e.receiveBufferAvailableLocked()
e.rcvAutoParams.disabled = true
@@ -1483,71 +1580,36 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.rcvListMu.Unlock()
e.UnlockUser()
e.notifyProtocolGoroutine(mask)
- return nil
case tcpip.SendBufferSizeOption:
// Make sure the send buffer size is within the min and max
// allowed.
- size := int(v)
var ss SendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
- if size < ss.Min {
- size = ss.Min
+ if v < ss.Min {
+ v = ss.Min
}
- if size > ss.Max {
- size = ss.Max
+ if v > ss.Max {
+ v = ss.Max
}
}
e.sndBufMu.Lock()
- e.sndBufSize = size
+ e.sndBufSize = v
e.sndBufMu.Unlock()
- return nil
-
- case tcpip.DelayOption:
- if v == 0 {
- atomic.StoreUint32(&e.delay, 0)
- // Handle delayed data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.delay, 1)
- }
- return nil
+ case tcpip.TTLOption:
+ e.LockUser()
+ e.ttl = uint8(v)
+ e.UnlockUser()
- default:
- return nil
}
+ return nil
}
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- // Lower 2 bits represents ECN bits. RFC 3168, section 23.1
- const inetECNMask = 3
switch v := opt.(type) {
- case tcpip.CorkOption:
- if v == 0 {
- atomic.StoreUint32(&e.cork, 0)
-
- // Handle the corked data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.cork, 1)
- }
- return nil
-
- case tcpip.ReuseAddressOption:
- e.LockUser()
- e.reuseAddr = v != 0
- e.UnlockUser()
- return nil
-
- case tcpip.ReusePortOption:
- e.LockUser()
- e.reusePort = v != 0
- e.UnlockUser()
- return nil
-
case tcpip.BindToDeviceOption:
id := tcpip.NICID(v)
if id != 0 && !e.stack.HasNIC(id) {
@@ -1556,72 +1618,26 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.LockUser()
e.bindToDevice = id
e.UnlockUser()
- return nil
-
- case tcpip.QuickAckOption:
- if v == 0 {
- atomic.StoreUint32(&e.slowAck, 1)
- } else {
- atomic.StoreUint32(&e.slowAck, 0)
- }
- return nil
-
- case tcpip.MaxSegOption:
- userMSS := v
- if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
- return tcpip.ErrInvalidOptionValue
- }
- e.LockUser()
- e.userMSS = uint16(userMSS)
- e.UnlockUser()
- e.notifyProtocolGoroutine(notifyMSSChanged)
- return nil
-
- case tcpip.TTLOption:
- e.LockUser()
- e.ttl = uint8(v)
- e.UnlockUser()
- return nil
-
- case tcpip.KeepaliveEnabledOption:
- e.keepalive.Lock()
- e.keepalive.enabled = v != 0
- e.keepalive.Unlock()
- e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- return nil
case tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
e.keepalive.idle = time.Duration(v)
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- return nil
case tcpip.KeepaliveIntervalOption:
e.keepalive.Lock()
e.keepalive.interval = time.Duration(v)
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- return nil
- case tcpip.KeepaliveCountOption:
- e.keepalive.Lock()
- e.keepalive.count = int(v)
- e.keepalive.Unlock()
- e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- return nil
+ case tcpip.OutOfBandInlineOption:
+ // We don't currently support disabling this option.
case tcpip.TCPUserTimeoutOption:
e.LockUser()
e.userTimeout = time.Duration(v)
e.UnlockUser()
- return nil
-
- case tcpip.BroadcastOption:
- e.LockUser()
- e.broadcast = v != 0
- e.UnlockUser()
- return nil
case tcpip.CongestionControlOption:
// Query the available cc algorithms in the stack and
@@ -1652,22 +1668,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// control algorithm is specified.
return tcpip.ErrNoSuchFile
- case tcpip.IPv4TOSOption:
- e.LockUser()
- // TODO(gvisor.dev/issue/995): ECN is not currently supported,
- // ignore the bits for now.
- e.sendTOS = uint8(v) & ^uint8(inetECNMask)
- e.UnlockUser()
- return nil
-
- case tcpip.IPv6TrafficClassOption:
- e.LockUser()
- // TODO(gvisor.dev/issue/995): ECN is not currently supported,
- // ignore the bits for now.
- e.sendTOS = uint8(v) & ^uint8(inetECNMask)
- e.UnlockUser()
- return nil
-
case tcpip.TCPLingerTimeoutOption:
e.LockUser()
if v < 0 {
@@ -1688,7 +1688,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
e.tcpLingerTimeout = time.Duration(v)
e.UnlockUser()
- return nil
case tcpip.TCPDeferAcceptOption:
e.LockUser()
@@ -1697,11 +1696,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
e.deferAccept = time.Duration(v)
e.UnlockUser()
- return nil
default:
return nil
}
+ return nil
}
// readyReceiveSize returns the number of bytes ready to be received.
@@ -1723,6 +1722,43 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
+ case tcpip.BroadcastOption:
+ e.LockUser()
+ v := e.broadcast
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.CorkOption:
+ return atomic.LoadUint32(&e.cork) != 0, nil
+
+ case tcpip.DelayOption:
+ return atomic.LoadUint32(&e.delay) != 0, nil
+
+ case tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ v := e.keepalive.enabled
+ e.keepalive.Unlock()
+
+ return v, nil
+
+ case tcpip.QuickAckOption:
+ v := atomic.LoadUint32(&e.slowAck) == 0
+ return v, nil
+
+ case tcpip.ReuseAddressOption:
+ e.LockUser()
+ v := e.reuseAddr
+ e.UnlockUser()
+
+ return v, nil
+
+ case tcpip.ReusePortOption:
+ e.LockUser()
+ v := e.reusePort
+ e.UnlockUser()
+
+ return v, nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -1734,14 +1770,41 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.UnlockUser()
return v, nil
- }
- return false, tcpip.ErrUnknownProtocolOption
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
+ case tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ v := e.keepalive.count
+ e.keepalive.Unlock()
+ return v, nil
+
+ case tcpip.IPv4TOSOption:
+ e.LockUser()
+ v := int(e.sendTOS)
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.LockUser()
+ v := int(e.sendTOS)
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.MaxSegOption:
+ // This is just stubbed out. Linux never returns the user_mss
+ // value as it either returns the defaultMSS or returns the
+ // actual current MSS. Netstack just returns the defaultMSS
+ // always for now.
+ v := header.TCPDefaultMSS
+ return v, nil
+
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
@@ -1757,12 +1820,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.rcvListMu.Unlock()
return v, nil
- case tcpip.DelayOption:
- var o int
- if v := atomic.LoadUint32(&e.delay); v != 0 {
- o = 1
- }
- return o, nil
+ case tcpip.TTLOption:
+ e.LockUser()
+ v := int(e.ttl)
+ e.UnlockUser()
+ return v, nil
default:
return -1, tcpip.ErrUnknownProtocolOption
@@ -1779,61 +1841,10 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.lastErrorMu.Unlock()
return err
- case *tcpip.MaxSegOption:
- // This is just stubbed out. Linux never returns the user_mss
- // value as it either returns the defaultMSS or returns the
- // actual current MSS. Netstack just returns the defaultMSS
- // always for now.
- *o = header.TCPDefaultMSS
- return nil
-
- case *tcpip.CorkOption:
- *o = 0
- if v := atomic.LoadUint32(&e.cork); v != 0 {
- *o = 1
- }
- return nil
-
- case *tcpip.ReuseAddressOption:
- e.LockUser()
- v := e.reuseAddr
- e.UnlockUser()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
- case *tcpip.ReusePortOption:
- e.LockUser()
- v := e.reusePort
- e.UnlockUser()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
case *tcpip.BindToDeviceOption:
e.LockUser()
*o = tcpip.BindToDeviceOption(e.bindToDevice)
e.UnlockUser()
- return nil
-
- case *tcpip.QuickAckOption:
- *o = 1
- if v := atomic.LoadUint32(&e.slowAck); v != 0 {
- *o = 0
- }
- return nil
-
- case *tcpip.TTLOption:
- e.LockUser()
- *o = tcpip.TTLOption(e.ttl)
- e.UnlockUser()
- return nil
case *tcpip.TCPInfoOption:
*o = tcpip.TCPInfoOption{}
@@ -1846,92 +1857,45 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
o.RTTVar = snd.rtt.rttvar
snd.rtt.Unlock()
}
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- e.keepalive.Lock()
- v := e.keepalive.enabled
- e.keepalive.Unlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
case *tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
*o = tcpip.KeepaliveIdleOption(e.keepalive.idle)
e.keepalive.Unlock()
- return nil
case *tcpip.KeepaliveIntervalOption:
e.keepalive.Lock()
*o = tcpip.KeepaliveIntervalOption(e.keepalive.interval)
e.keepalive.Unlock()
- return nil
-
- case *tcpip.KeepaliveCountOption:
- e.keepalive.Lock()
- *o = tcpip.KeepaliveCountOption(e.keepalive.count)
- e.keepalive.Unlock()
- return nil
case *tcpip.TCPUserTimeoutOption:
e.LockUser()
*o = tcpip.TCPUserTimeoutOption(e.userTimeout)
e.UnlockUser()
- return nil
case *tcpip.OutOfBandInlineOption:
// We don't currently support disabling this option.
*o = 1
- return nil
-
- case *tcpip.BroadcastOption:
- e.LockUser()
- v := e.broadcast
- e.UnlockUser()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
case *tcpip.CongestionControlOption:
e.LockUser()
*o = e.cc
e.UnlockUser()
- return nil
-
- case *tcpip.IPv4TOSOption:
- e.LockUser()
- *o = tcpip.IPv4TOSOption(e.sendTOS)
- e.UnlockUser()
- return nil
-
- case *tcpip.IPv6TrafficClassOption:
- e.LockUser()
- *o = tcpip.IPv6TrafficClassOption(e.sendTOS)
- e.UnlockUser()
- return nil
case *tcpip.TCPLingerTimeoutOption:
e.LockUser()
*o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout)
e.UnlockUser()
- return nil
case *tcpip.TCPDeferAcceptOption:
e.LockUser()
*o = tcpip.TCPDeferAcceptOption(e.deferAccept)
e.UnlockUser()
- return nil
default:
return tcpip.ErrUnknownProtocolOption
}
+ return nil
}
// checkV4MappedLocked determines the effective network protocol and converts
@@ -2146,7 +2110,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
switch {
case e.EndpointState().connected():
// Close for read.
- if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
+ if e.shutdownFlags&tcpip.ShutdownRead != 0 {
// Mark read side as closed.
e.rcvListMu.Lock()
e.rcvClosed = true
@@ -2155,7 +2119,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
// If we're fully closed and we have unread data we need to abort
// the connection with a RST.
- if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 {
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
// Wake up worker to terminate loop.
e.notifyProtocolGoroutine(notifyTickleWorker)
@@ -2164,7 +2128,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Close for write.
- if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 {
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
e.sndBufMu.Lock()
if e.sndClosed {
// Already closed.
@@ -2187,12 +2151,23 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
return nil
case e.EndpointState() == StateListen:
- // Tell protocolListenLoop to stop.
- if flags&tcpip.ShutdownRead != 0 {
- e.notifyProtocolGoroutine(notifyClose)
+ if e.shutdownFlags&tcpip.ShutdownRead != 0 {
+ // Reset all connections from the accept queue and keep the
+ // worker running so that it can continue handling incoming
+ // segments by replying with RST.
+ //
+ // By not removing this endpoint from the demuxer mapping, we
+ // ensure that any other bind to the same port fails, as on Linux.
+ // TODO(gvisor.dev/issue/2468): We need to enable applications to
+ // start listening on this endpoint again similar to Linux.
+ e.rcvListMu.Lock()
+ e.rcvClosed = true
+ e.rcvListMu.Unlock()
+ e.closePendingAcceptableConnectionsLocked()
+ // Notify waiters that the endpoint is shutdown.
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}
return nil
-
default:
return tcpip.ErrNotConnected
}
@@ -2296,8 +2271,11 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
+ e.rcvListMu.Lock()
+ rcvClosed := e.rcvClosed
+ e.rcvListMu.Unlock()
// Endpoint must be in listen state before it can accept connections.
- if e.EndpointState() != StateListen {
+ if rcvClosed || e.EndpointState() != StateListen {
return nil, nil, tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 808410c92..704d01c64 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -130,7 +130,7 @@ func (r *ForwarderRequest) Complete(sendReset bool) {
// If the caller requested, send a reset.
if sendReset {
- replyWithReset(r.segment)
+ replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL())
}
// Release all resources.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index dce9a1652..cfd9a4e8e 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -94,6 +94,63 @@ const (
ccCubic = "cubic"
)
+// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The
+// value is protected by a mutex so that we can increment only when it's
+// guaranteed not to go above a threshold.
+type synRcvdCounter struct {
+ sync.Mutex
+ value uint64
+ pending sync.WaitGroup
+ threshold uint64
+}
+
+// inc tries to increment the global number of endpoints in SYN-RCVD state. It
+// succeeds if the increment doesn't make the count go beyond the threshold, and
+// fails otherwise.
+func (s *synRcvdCounter) inc() bool {
+ s.Lock()
+ defer s.Unlock()
+ if s.value >= s.threshold {
+ return false
+ }
+
+ s.pending.Add(1)
+ s.value++
+
+ return true
+}
+
+// dec atomically decrements the global number of endpoints in SYN-RCVD
+// state. It must only be called if a previous call to inc succeeded.
+func (s *synRcvdCounter) dec() {
+ s.Lock()
+ defer s.Unlock()
+ s.value--
+ s.pending.Done()
+}
+
+// synCookiesInUse returns true if the synRcvdCount is greater than
+// SynRcvdCountThreshold.
+func (s *synRcvdCounter) synCookiesInUse() bool {
+ s.Lock()
+ defer s.Unlock()
+ return s.value >= s.threshold
+}
+
+// SetThreshold sets synRcvdCounter.Threshold to ths new threshold.
+func (s *synRcvdCounter) SetThreshold(threshold uint64) {
+ s.Lock()
+ defer s.Unlock()
+ s.threshold = threshold
+}
+
+// Threshold returns the current value of synRcvdCounter.Threhsold.
+func (s *synRcvdCounter) Threshold() uint64 {
+ s.Lock()
+ defer s.Unlock()
+ return s.threshold
+}
+
type protocol struct {
mu sync.RWMutex
sackEnabled bool
@@ -105,6 +162,8 @@ type protocol struct {
moderateReceiveBuffer bool
tcpLingerTimeout time.Duration
tcpTimeWaitTimeout time.Duration
+ minRTO time.Duration
+ synRcvdCount synRcvdCounter
dispatcher *dispatcher
}
@@ -164,12 +223,12 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo
return true
}
- replyWithReset(s)
+ replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
return true
}
// replyWithReset replies to the given segment with a reset segment.
-func replyWithReset(s *segment) {
+func replyWithReset(s *segment, tos, ttl uint8) {
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
ack := seqnum.Value(0)
@@ -193,8 +252,8 @@ func replyWithReset(s *segment) {
}
sendTCP(&s.route, tcpFields{
id: s.id,
- ttl: s.route.DefaultTTL(),
- tos: stack.DefaultTOS,
+ ttl: ttl,
+ tos: tos,
flags: flags,
seq: seq,
ack: ack,
@@ -272,6 +331,21 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case tcpip.TCPMinRTOOption:
+ if v < 0 {
+ v = tcpip.TCPMinRTOOption(MinRTO)
+ }
+ p.mu.Lock()
+ p.minRTO = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPSynRcvdCountThresholdOption:
+ p.mu.Lock()
+ p.synRcvdCount.SetThreshold(uint64(v))
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -334,6 +408,18 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.RUnlock()
return nil
+ case *tcpip.TCPMinRTOOption:
+ p.mu.RLock()
+ *v = tcpip.TCPMinRTOOption(p.minRTO)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPSynRcvdCountThresholdOption:
+ p.mu.RLock()
+ *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold())
+ p.mu.RUnlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -349,6 +435,12 @@ func (p *protocol) Wait() {
p.dispatcher.wait()
}
+// SynRcvdCounter returns a reference to the synRcvdCount for this protocol
+// instance.
+func (p *protocol) SynRcvdCounter() *synRcvdCounter {
+ return &p.synRcvdCount
+}
+
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{
@@ -358,6 +450,8 @@ func NewProtocol() stack.TransportProtocol {
availableCongestionControl: []string{ccReno, ccCubic},
tcpLingerTimeout: DefaultTCPLingerTimeout,
tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
dispatcher: newDispatcher(runtime.GOMAXPROCS(0)),
+ minRTO: MinRTO,
}
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index caf8977b3..a4b73b588 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -70,13 +70,24 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale
// acceptable checks if the segment sequence number range is acceptable
// according to the table on page 26 of RFC 793.
func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
- rcvWnd := r.rcvNxt.Size(r.rcvAcc)
- if rcvWnd == 0 {
- return segLen == 0 && segSeq == r.rcvNxt
- }
+ return Acceptable(segSeq, segLen, r.rcvNxt, r.rcvAcc)
+}
- return segSeq.InWindow(r.rcvNxt, rcvWnd) ||
- seqnum.Overlap(r.rcvNxt, rcvWnd, segSeq, segLen)
+// Acceptable checks if a segment that starts at segSeq and has length segLen is
+// "acceptable" for arriving in a receive window that starts at rcvNxt and ends
+// before rcvAcc, according to the table on page 26 and 69 of RFC 793.
+func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool {
+ if rcvNxt == rcvAcc {
+ return segLen == 0 && segSeq == rcvNxt
+ }
+ if segLen == 0 {
+ // rcvWnd is incremented by 1 because that is Linux's behavior despite the
+ // RFC.
+ return segSeq.InRange(rcvNxt, rcvAcc.Add(1))
+ }
+ // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming
+ // the payload, so we'll accept any payload that overlaps the receieve window.
+ return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThan(rcvAcc)
}
// getSendParams returns the parameters needed by the sender when building
diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go
new file mode 100644
index 000000000..dc02729ce
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv_test.go
@@ -0,0 +1,74 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rcv_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+)
+
+func TestAcceptable(t *testing.T) {
+ for _, tt := range []struct {
+ segSeq seqnum.Value
+ segLen seqnum.Size
+ rcvNxt, rcvAcc seqnum.Value
+ want bool
+ }{
+ // The segment is smaller than the window.
+ {105, 2, 100, 104, false},
+ {105, 2, 101, 105, false},
+ {105, 2, 102, 106, true},
+ {105, 2, 103, 107, true},
+ {105, 2, 104, 108, true},
+ {105, 2, 105, 109, true},
+ {105, 2, 106, 110, true},
+ {105, 2, 107, 111, false},
+
+ // The segment is larger than the window.
+ {105, 4, 103, 105, false},
+ {105, 4, 104, 106, true},
+ {105, 4, 105, 107, true},
+ {105, 4, 106, 108, true},
+ {105, 4, 107, 109, true},
+ {105, 4, 108, 110, true},
+ {105, 4, 109, 111, false},
+ {105, 4, 110, 112, false},
+
+ // The segment has no width.
+ {105, 0, 100, 102, false},
+ {105, 0, 101, 103, false},
+ {105, 0, 102, 104, false},
+ {105, 0, 103, 105, true},
+ {105, 0, 104, 106, true},
+ {105, 0, 105, 107, true},
+ {105, 0, 106, 108, false},
+ {105, 0, 107, 109, false},
+
+ // The receive window has no width.
+ {105, 2, 103, 103, false},
+ {105, 2, 104, 104, false},
+ {105, 2, 105, 105, false},
+ {105, 2, 106, 106, false},
+ {105, 2, 107, 107, false},
+ {105, 2, 108, 108, false},
+ {105, 2, 109, 109, false},
+ } {
+ if got := tcp.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want {
+ t.Errorf("tcp.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want)
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index e6fe7985d..40461fd31 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -77,9 +77,11 @@ func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.V
id: id,
route: r.Clone(),
}
- s.views[0] = v
- s.data = buffer.NewVectorisedView(len(v), s.views[:1])
s.rcvdTime = time.Now()
+ if len(v) != 0 {
+ s.views[0] = v
+ s.data = buffer.NewVectorisedView(len(v), s.views[:1])
+ }
return s
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 6b7bac37d..d8cfe3115 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "fmt"
"math"
"sync/atomic"
"time"
@@ -149,6 +150,9 @@ type sender struct {
rtt rtt
rto time.Duration
+ // minRTO is the minimum permitted value for sender.rto.
+ minRTO time.Duration
+
// maxPayloadSize is the maximum size of the payload of a given segment.
// It is initialized on demand.
maxPayloadSize int
@@ -260,6 +264,13 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
// etc.
s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss)
+ // Get Stack wide minRTO.
+ var v tcpip.TCPMinRTOOption
+ if err := ep.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
+ panic(fmt.Sprintf("unable to get minRTO from stack: %s", err))
+ }
+ s.minRTO = time.Duration(v)
+
return s
}
@@ -394,8 +405,8 @@ func (s *sender) updateRTO(rtt time.Duration) {
s.rto = s.rtt.srtt + 4*s.rtt.rttvar
s.rtt.Unlock()
- if s.rto < MinRTO {
- s.rto = MinRTO
+ if s.rto < s.minRTO {
+ s.rto = s.minRTO
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 782d7b42c..359a75e73 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/runsc/testutil"
)
func TestFastRecovery(t *testing.T) {
@@ -40,7 +41,7 @@ func TestFastRecovery(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
@@ -86,16 +87,23 @@ func TestFastRecovery(t *testing.T) {
// Receive the retransmitted packet.
c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
- }
+ // Wait before checking metrics.
+ metricPollFn := func() error {
+ if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ }
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ }
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want)
+ }
+ return nil
}
- if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want)
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
// Now send 7 mode duplicate acks. Each of these should cause a window
@@ -117,12 +125,18 @@ func TestFastRecovery(t *testing.T) {
// Receive the retransmit due to partial ack.
c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
- t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ // Wait before checking metrics.
+ metricPollFn = func() error {
+ if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ }
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ }
+ return nil
}
-
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
- t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
// Receive the 10 extra packets that should have been released due to
@@ -192,7 +206,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
@@ -234,7 +248,7 @@ func TestCongestionAvoidance(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
@@ -338,7 +352,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
@@ -447,7 +461,7 @@ func TestRetransmit(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
@@ -492,24 +506,33 @@ func TestRetransmit(t *testing.T) {
rtxOffset := bytesRead - maxPayload*expected
c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
- if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want)
- }
+ metricPollFn := func() error {
+ if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want)
+ }
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
- }
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
+ }
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
- t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want)
- }
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
+ return fmt.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want)
+ }
+
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want)
+ }
+
+ if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
+ }
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want)
+ return nil
}
- if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
+ // Poll when checking metrics.
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
// Acknowledge half of the pending data.
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index afea124ec..1dd63dd61 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -149,21 +149,22 @@ func TestSackPermittedAccept(t *testing.T) {
{true, false, -1, 0xffff}, // When cookie is used window scaling is disabled.
{false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
}
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
+
for _, tc := range testCases {
t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
- if tc.cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
- } else {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }
for _, sackEnabled := range []bool{false, true} {
t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+
+ if tc.cookieEnabled {
+ // Set the SynRcvd threshold to
+ // zero to force a syn cookie
+ // based accept to happen.
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
+ }
setStackSACKPermitted(t, c, sackEnabled)
rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
@@ -222,21 +223,23 @@ func TestSackDisabledAccept(t *testing.T) {
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
{false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
}
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
+
for _, tc := range testCases {
t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
- if tc.cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
- } else {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }
for _, sackEnabled := range []bool{false, true} {
t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+
+ if tc.cookieEnabled {
+ // Set the SynRcvd threshold to
+ // zero to force a syn cookie
+ // based accept to happen.
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
+ }
+
setStackSACKPermitted(t, c, sackEnabled)
rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
@@ -387,7 +390,7 @@ func TestSACKRecovery(t *testing.T) {
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index ce3df7478..ab1014c7f 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -284,7 +284,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
// are released instantly on Close.
tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil {
- t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err)
+ t.Fatalf("e.stack.SetTransportProtocolOption(%d, %v) = %v", tcp.ProtocolNumber, tcpTW, err)
}
c.EP.Close()
@@ -590,6 +590,10 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
),
)
+ // Give the stack a few ms to transition the endpoint out of ESTABLISHED
+ // state.
+ time.Sleep(10 * time.Millisecond)
+
if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
@@ -728,7 +732,7 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
tests := []struct {
name string
- setMSS uint16
+ setMSS int
expMSS uint16
}{
{
@@ -756,15 +760,14 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
c.Create(-1)
// Set the MSS socket option.
- opt := tcpip.MaxSegOption(test.setMSS)
- if err := c.EP.SetSockOpt(opt); err != nil {
- t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err)
+ if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, test.setMSS); err != nil {
+ t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err)
}
// Get expected window size.
rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
if err != nil {
- t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err)
+ t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
}
ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
@@ -818,15 +821,14 @@ func TestUserSuppliedMSSOnConnectV6(t *testing.T) {
c.CreateV6Endpoint(true)
// Set the MSS socket option.
- opt := tcpip.MaxSegOption(test.setMSS)
- if err := c.EP.SetSockOpt(opt); err != nil {
- t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err)
+ if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
+ t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err)
}
// Get expected window size.
rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
if err != nil {
- t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err)
+ t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
}
ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
@@ -1032,8 +1034,8 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) {
checker.SeqNum(200)))
}
-// TestListenShutdown tests for the listening endpoint not processing
-// any receive when it is on read shutdown.
+// TestListenShutdown tests for the listening endpoint replying with RST
+// on read shutdown.
func TestListenShutdown(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -1044,7 +1046,7 @@ func TestListenShutdown(t *testing.T) {
t.Fatal("Bind failed:", err)
}
- if err := c.EP.Listen(10 /* backlog */); err != nil {
+ if err := c.EP.Listen(1 /* backlog */); err != nil {
t.Fatal("Listen failed:", err)
}
@@ -1052,9 +1054,6 @@ func TestListenShutdown(t *testing.T) {
t.Fatal("Shutdown failed:", err)
}
- // Wait for the endpoint state to be propagated.
- time.Sleep(10 * time.Millisecond)
-
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -1063,7 +1062,49 @@ func TestListenShutdown(t *testing.T) {
AckNum: 200,
})
- c.CheckNoPacket("Packet received when listening socket was shutdown")
+ // Expect the listening endpoint to reset the connection.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ))
+}
+
+// TestListenCloseWhileConnect tests for the listening endpoint to
+// drain the accept-queue when closed. This should reset all of the
+// pending connections that are waiting to be accepted.
+func TestListenCloseWhileConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(1 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventIn)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+ // Wait for the new endpoint created because of handshake to be delivered
+ // to the listening endpoint's accept queue.
+ <-notifyCh
+
+ // Close the listening endpoint.
+ c.EP.Close()
+
+ // Expect the listening endpoint to reset the connection.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ))
}
func TestTOSV4(t *testing.T) {
@@ -1077,17 +1118,17 @@ func TestTOSV4(t *testing.T) {
c.EP = ep
const tos = 0xC0
- if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
- t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+ if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
+ t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err)
}
- var v tcpip.IPv4TOSOption
- if err := c.EP.GetSockOpt(&v); err != nil {
- t.Errorf("GetSockopt failed: %s", err)
+ v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
+ t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err)
}
- if want := tcpip.IPv4TOSOption(tos); v != want {
- t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ if v != tos {
+ t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos)
}
testV4Connect(t, c, checker.TOS(tos, 0))
@@ -1125,17 +1166,17 @@ func TestTrafficClassV6(t *testing.T) {
c.CreateV6Endpoint(false)
const tos = 0xC0
- if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
- t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err)
+ if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil {
+ t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err)
}
- var v tcpip.IPv6TrafficClassOption
- if err := c.EP.GetSockOpt(&v); err != nil {
- t.Fatalf("GetSockopt failed: %s", err)
+ v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
+ t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err)
}
- if want := tcpip.IPv6TrafficClassOption(tos); v != want {
- t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ if v != tos {
+ t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos)
}
// Test the connection request.
@@ -1711,7 +1752,7 @@ func TestNoWindowShrinking(t *testing.T) {
c.CreateConnected(789, 30000, 10)
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
- t.Fatalf("SetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %v", err)
}
we, ch := waiter.NewChannelEntry(nil)
@@ -1984,7 +2025,7 @@ func TestScaledWindowAccept(t *testing.T) {
// Set the window size greater than the maximum non-scaled window.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
@@ -2057,7 +2098,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
// Set the window size greater than the maximum non-scaled window.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
@@ -2221,10 +2262,10 @@ func TestSegmentMerging(t *testing.T) {
{
"cork",
func(ep tcpip.Endpoint) {
- ep.SetSockOpt(tcpip.CorkOption(1))
+ ep.SetSockOptBool(tcpip.CorkOption, true)
},
func(ep tcpip.Endpoint) {
- ep.SetSockOpt(tcpip.CorkOption(0))
+ ep.SetSockOptBool(tcpip.CorkOption, false)
},
},
}
@@ -2316,7 +2357,7 @@ func TestDelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOptInt(tcpip.DelayOption, 1)
+ c.EP.SetSockOptBool(tcpip.DelayOption, true)
var allData []byte
for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
@@ -2364,7 +2405,7 @@ func TestUndelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOptInt(tcpip.DelayOption, 1)
+ c.EP.SetSockOptBool(tcpip.DelayOption, true)
allData := [][]byte{{0}, {1, 2, 3}}
for i, data := range allData {
@@ -2397,7 +2438,7 @@ func TestUndelay(t *testing.T) {
// Check that we don't get the second packet yet.
c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
- c.EP.SetSockOptInt(tcpip.DelayOption, 0)
+ c.EP.SetSockOptBool(tcpip.DelayOption, false)
// Check that data is received.
second := c.GetPacket()
@@ -2434,8 +2475,8 @@ func TestMSSNotDelayed(t *testing.T) {
fn func(tcpip.Endpoint)
}{
{"no-op", func(tcpip.Endpoint) {}},
- {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }},
- {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }},
+ {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }},
+ {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }},
}
for _, test := range tests {
@@ -2576,12 +2617,12 @@ func TestSetTTL(t *testing.T) {
t.Fatalf("NewEndpoint failed: %v", err)
}
- if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
- t.Fatalf("SetSockOpt failed: %v", err)
+ if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
+ t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
}
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ t.Fatalf("Unexpected return value from Connect: %s", err)
}
// Receive SYN packet.
@@ -2621,7 +2662,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// window scaling option.
const rcvBufferSize = 0x20000
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
@@ -2667,26 +2708,24 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
// Set the SynRcvd threshold to zero to force a syn cookie based accept
// to happen.
- saved := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = saved
- }()
- tcp.SynRcvdCountThreshold = 0
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
// Create EP and start listening.
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake.
@@ -2704,7 +2743,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2765,7 +2804,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
const rcvBufferSize = 0x20000
const wndScale = 2
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
}
// Start connection attempt.
@@ -3882,26 +3921,26 @@ func TestMinMaxBufferSizes(t *testing.T) {
// Set values below the min.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err)
}
checkRecvBufferSize(t, ep, 200)
if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err)
}
checkSendBufferSize(t, ep, 300)
// Set values above the max.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
}
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
@@ -4147,11 +4186,11 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
case "ipv4":
case "ipv6":
if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err)
+ t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err)
}
case "dual":
if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil {
- t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err)
+ t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err)
}
default:
t.Fatalf("unknown network: '%s'", network)
@@ -4474,11 +4513,11 @@ func TestKeepalive(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const keepAliveInterval = 10 * time.Millisecond
- c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
+ const keepAliveInterval = 3 * time.Second
+ c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond))
c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
- c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5))
- c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1))
+ c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5)
+ c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true)
// 5 unacked keepalives are sent. ACK each one, and check that the
// connection stays alive after 5.
@@ -4569,7 +4608,7 @@ func TestKeepalive(t *testing.T) {
// Sleep for a litte over the KeepAlive interval to make sure
// the timer has time to fire after the last ACK and close the
// close the socket.
- time.Sleep(keepAliveInterval + 5*time.Millisecond)
+ time.Sleep(keepAliveInterval + keepAliveInterval/2)
// The connection should be terminated after 5 unacked keepalives.
// Send an ACK to trigger a RST from the stack as the endpoint should
@@ -5104,25 +5143,23 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
}
func TestListenBacklogFullSynCookieInUse(t *testing.T) {
- saved := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = saved
- }()
- tcp.SynRcvdCountThreshold = 1
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err)
+ }
+
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -5130,7 +5167,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
listenBacklog := 1
portOffset := uint16(0)
if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
executeHandshake(t, c, context.TestPort+portOffset, false)
@@ -5609,7 +5646,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
return
}
if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) {
- t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w, wantRcvWnd)
+ t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w)
}
},
))
@@ -5770,14 +5807,14 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
func TestDelayEnabled(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- checkDelayOption(t, c, false, 0) // Delay is disabled by default.
+ checkDelayOption(t, c, false, false) // Delay is disabled by default.
for _, v := range []struct {
delayEnabled tcp.DelayEnabled
- wantDelayOption int
+ wantDelayOption bool
}{
- {delayEnabled: false, wantDelayOption: 0},
- {delayEnabled: true, wantDelayOption: 1},
+ {delayEnabled: false, wantDelayOption: false},
+ {delayEnabled: true, wantDelayOption: true},
} {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -5788,7 +5825,7 @@ func TestDelayEnabled(t *testing.T) {
}
}
-func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) {
+func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) {
t.Helper()
var gotDelayEnabled tcp.DelayEnabled
@@ -5803,12 +5840,12 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del
if err != nil {
t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err)
}
- gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption)
+ gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption)
if err != nil {
- t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err)
+ t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err)
}
if gotDelayOption != wantDelayOption {
- t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption)
+ t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption)
}
}
@@ -6617,14 +6654,17 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
- const keepAliveInterval = 10 * time.Millisecond
- c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
+ const keepAliveInterval = 3 * time.Second
+ c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond))
c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
- c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10))
- c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1))
-
- // Set userTimeout to be the duration for 3 keepalive probes.
- userTimeout := 30 * time.Millisecond
+ c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10)
+ c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true)
+
+ // Set userTimeout to be the duration to be 1 keepalive
+ // probes. Which means that after the first probe is sent
+ // the second one should cause the connection to be
+ // closed due to userTimeout being hit.
+ userTimeout := 1 * keepAliveInterval
c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
// Check that the connection is still alive.
@@ -6632,28 +6672,23 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
}
- // Now receive 2 keepalives, but don't ACK them. The connection should
- // be reset when the 3rd one should be sent due to userTimeout being
- // 30ms and each keepalive probe should be sent 10ms apart as set above after
- // the connection has been idle for 10ms.
- for i := 0; i < 2; i++ {
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(uint32(790)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
+ // Now receive 1 keepalives, but don't ACK it.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)),
+ checker.AckNum(uint32(790)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
// Sleep for a litte over the KeepAlive interval to make sure
// the timer has time to fire after the last ACK and close the
// close the socket.
- time.Sleep(keepAliveInterval + 5*time.Millisecond)
+ time.Sleep(keepAliveInterval + keepAliveInterval/2)
- // The connection should be terminated after 30ms.
+ // The connection should be closed with a timeout.
// Send an ACK to trigger a RST from the stack as the endpoint should
// be dead.
c.SendPacket(nil, &context.Headers{
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index a641e953d..8edbff964 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -127,16 +127,14 @@ func TestTimeStampDisabledConnect(t *testing.T) {
}
func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
if cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
}
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
tsVal := rand.Uint32()
@@ -148,7 +146,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ t.Fatalf("Unexpected error from Write: %s", err)
}
// Check that data is received and that the timestamp option TSEcr field
@@ -190,17 +188,15 @@ func TestTimeStampEnabledAccept(t *testing.T) {
}
func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
- if cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
- }
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ if cookieEnabled {
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
+ }
+
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
@@ -211,7 +207,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ t.Fatalf("Unexpected error from Write: %s", err)
}
// Check that data is received and that the timestamp option is disabled
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index d4f6bc635..7b1d72cf4 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -152,6 +152,13 @@ func New(t *testing.T, mtu uint32) *Context {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
+ // Increase minimum RTO in tests to avoid test flakes due to early
+ // retransmit in case the test executors are overloaded and cause timers
+ // to fire earlier than expected.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil {
+ t.Fatalf("failed to set stack-wide minRTO: %s", err)
+ }
+
// Some of the congestion control tests send up to 640 packets, we so
// set the channel size to 1000.
ep := channel.New(1000, mtu, "")
@@ -217,7 +224,8 @@ func (c *Context) Stack() *stack.Stack {
func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
c.t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), wait)
+ ctx, cancel := context.WithTimeout(context.Background(), wait)
+ defer cancel()
if _, ok := c.linkEP.ReadContext(ctx); ok {
c.t.Fatal(errMsg)
}
@@ -235,7 +243,8 @@ func (c *Context) CheckNoPacket(errMsg string) {
func (c *Context) GetPacket() []byte {
c.t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
p, ok := c.linkEP.ReadContext(ctx)
if !ok {
c.t.Fatalf("Packet wasn't written out")
@@ -415,6 +424,8 @@ func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlock
// verifies that the packet packet payload of packet matches the slice
// of data indicated by offset & size.
func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
+ c.t.Helper()
+
c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0)
}
@@ -423,6 +434,8 @@ func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
// data indicated by offset & size and skips optlen bytes in addition to the IP
// TCP headers when comparing the data.
func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) {
+ c.t.Helper()
+
b := c.GetPacket()
checker.IPv4(c.t, b,
checker.PayloadLen(size+header.TCPMinimumSize+optlen),
@@ -445,6 +458,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op
// data indicated by offset & size. It returns true if a packet was received and
// processed.
func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool {
+ c.t.Helper()
+
b := c.GetPacketNonBlocking()
if b == nil {
return false
@@ -486,7 +501,8 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
func (c *Context) GetV6Packet() []byte {
c.t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
p, ok := c.linkEP.ReadContext(ctx)
if !ok {
c.t.Fatalf("Packet wasn't written out")
@@ -567,6 +583,8 @@ func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf
//
// PreCondition: c.EP must already be created.
func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) {
+ c.t.Helper()
+
// Start connection attempt.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
index 3ad6994a7..2025ff757 100644
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -9,6 +9,7 @@ go_library(
deps = [
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
+ "//pkg/tcpip/transport/tcp",
],
)
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
index 93712cd45..30d05200f 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -20,6 +20,7 @@ package tcpconntrack
import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
)
// Result is returned when the state of a TCB is updated in response to an
@@ -311,17 +312,7 @@ type stream struct {
// the window is zero, if it's a packet with no payload and sequence number
// equal to una.
func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
- wnd := s.una.Size(s.end)
- if wnd == 0 {
- return segLen == 0 && segSeq == s.una
- }
-
- // Make sure [segSeq, seqSeq+segLen) is non-empty.
- if segLen == 0 {
- segLen = 1
- }
-
- return seqnum.Overlap(s.una, wnd, segSeq, segLen)
+ return tcp.Acceptable(segSeq, segLen, s.una, s.end)
}
// closed determines if the stream has already been closed. This happens when
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 120d3baa3..edb54f0be 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -501,11 +501,20 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
switch opt {
+ case tcpip.BroadcastOption:
+ e.mu.Lock()
+ e.broadcast = v
+ e.mu.Unlock()
+
+ case tcpip.MulticastLoopOption:
+ e.mu.Lock()
+ e.multicastLoop = v
+ e.mu.Unlock()
+
case tcpip.ReceiveTOSOption:
e.mu.Lock()
e.receiveTOS = v
e.mu.Unlock()
- return nil
case tcpip.ReceiveTClassOption:
// We only support this option on v6 endpoints.
@@ -516,7 +525,18 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.mu.Lock()
e.receiveTClass = v
e.mu.Unlock()
- return nil
+
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.Lock()
+ e.receiveIPPacketInfo = v
+ e.mu.Unlock()
+
+ case tcpip.ReuseAddressOption:
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v
+ e.mu.Unlock()
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
@@ -533,13 +553,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
}
e.v6only = v
- return nil
-
- case tcpip.ReceiveIPPacketInfoOption:
- e.mu.Lock()
- e.receiveIPPacketInfo = v
- e.mu.Unlock()
- return nil
}
return nil
@@ -547,22 +560,38 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
- return nil
-}
+ switch opt {
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ e.multicastTTL = uint8(v)
+ e.mu.Unlock()
-// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
case tcpip.TTLOption:
e.mu.Lock()
e.ttl = uint8(v)
e.mu.Unlock()
- case tcpip.MulticastTTLOption:
+ case tcpip.IPv4TOSOption:
e.mu.Lock()
- e.multicastTTL = uint8(v)
+ e.sendTOS = uint8(v)
e.mu.Unlock()
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.ReceiveBufferSizeOption:
+ case tcpip.SendBufferSizeOption:
+
+ }
+
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
case tcpip.MulticastInterfaceOption:
e.mu.Lock()
defer e.mu.Unlock()
@@ -686,16 +715,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
- case tcpip.MulticastLoopOption:
- e.mu.Lock()
- e.multicastLoop = bool(v)
- e.mu.Unlock()
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.reusePort = v != 0
- e.mu.Unlock()
-
case tcpip.BindToDeviceOption:
id := tcpip.NICID(v)
if id != 0 && !e.stack.HasNIC(id) {
@@ -704,26 +723,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Lock()
e.bindToDevice = id
e.mu.Unlock()
- return nil
-
- case tcpip.BroadcastOption:
- e.mu.Lock()
- e.broadcast = v != 0
- e.mu.Unlock()
-
- return nil
-
- case tcpip.IPv4TOSOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
- return nil
-
- case tcpip.IPv6TrafficClassOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
- return nil
}
return nil
}
@@ -731,6 +730,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
+ case tcpip.BroadcastOption:
+ e.mu.RLock()
+ v := e.broadcast
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ case tcpip.MulticastLoopOption:
+ e.mu.RLock()
+ v := e.multicastLoop
+ e.mu.RUnlock()
+ return v, nil
+
case tcpip.ReceiveTOSOption:
e.mu.RLock()
v := e.receiveTOS
@@ -748,6 +762,22 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.RLock()
+ v := e.receiveIPPacketInfo
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.ReuseAddressOption:
+ return false, nil
+
+ case tcpip.ReusePortOption:
+ e.mu.RLock()
+ v := e.reusePort
+ e.mu.RUnlock()
+
+ return v, nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -760,19 +790,32 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return v, nil
- case tcpip.ReceiveIPPacketInfoOption:
- e.mu.RLock()
- v := e.receiveIPPacketInfo
- e.mu.RUnlock()
- return v, nil
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
}
-
- return false, tcpip.ErrUnknownProtocolOption
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
+ case tcpip.IPv4TOSOption:
+ e.mu.RLock()
+ v := int(e.sendTOS)
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ v := int(e.sendTOS)
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ v := int(e.multicastTTL)
+ e.mu.Unlock()
+ return v, nil
+
case tcpip.ReceiveQueueSizeOption:
v := 0
e.rcvMu.Lock()
@@ -794,29 +837,22 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
v := e.rcvBufSizeMax
e.rcvMu.Unlock()
return v, nil
- }
- return -1, tcpip.ErrUnknownProtocolOption
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ v := int(e.ttl)
+ e.mu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
- return nil
-
- case *tcpip.TTLOption:
- e.mu.Lock()
- *o = tcpip.TTLOption(e.ttl)
- e.mu.Unlock()
- return nil
-
- case *tcpip.MulticastTTLOption:
- e.mu.Lock()
- *o = tcpip.MulticastTTLOption(e.multicastTTL)
- e.mu.Unlock()
- return nil
-
case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
*o = tcpip.MulticastInterfaceOption{
@@ -824,67 +860,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.multicastAddr,
}
e.mu.Unlock()
- return nil
-
- case *tcpip.MulticastLoopOption:
- e.mu.RLock()
- v := e.multicastLoop
- e.mu.RUnlock()
-
- *o = tcpip.MulticastLoopOption(v)
- return nil
-
- case *tcpip.ReuseAddressOption:
- *o = 0
- return nil
-
- case *tcpip.ReusePortOption:
- e.mu.RLock()
- v := e.reusePort
- e.mu.RUnlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
case *tcpip.BindToDeviceOption:
e.mu.RLock()
*o = tcpip.BindToDeviceOption(e.bindToDevice)
e.mu.RUnlock()
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
- case *tcpip.BroadcastOption:
- e.mu.RLock()
- v := e.broadcast
- e.mu.RUnlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
- case *tcpip.IPv4TOSOption:
- e.mu.RLock()
- *o = tcpip.IPv4TOSOption(e.sendTOS)
- e.mu.RUnlock()
- return nil
-
- case *tcpip.IPv6TrafficClassOption:
- e.mu.RLock()
- *o = tcpip.IPv6TrafficClassOption(e.sendTOS)
- e.mu.RUnlock()
- return nil
default:
return tcpip.ErrUnknownProtocolOption
}
+ return nil
}
// sendUDP sends a UDP segment via the provided network endpoint and under the
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 0905726c1..8acaa607a 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -343,11 +343,11 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
c.createEndpoint(flow.sockProto())
if flow.isV6Only() {
if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ c.t.Fatalf("SetSockOptBool failed: %s", err)
}
} else if flow.isBroadcast() {
- if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil {
- c.t.Fatal("SetSockOpt failed:", err)
+ if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ c.t.Fatalf("SetSockOptBool failed: %s", err)
}
}
}
@@ -358,7 +358,8 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
c.t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
p, ok := c.linkEP.ReadContext(ctx)
if !ok {
c.t.Fatalf("Packet wasn't written out")
@@ -607,7 +608,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
// Check the peer address.
h := flow.header4Tuple(incoming)
if addr.Addr != h.srcAddr.Addr {
- c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr)
+ c.t.Fatalf("unexpected remote address: got %s, want %v", addr.Addr, h.srcAddr)
}
// Check the payload.
@@ -1271,8 +1272,8 @@ func TestTTL(t *testing.T) {
c.createEndpointForFlow(flow)
const multicastTTL = 42
- if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil {
+ c.t.Fatalf("SetSockOptInt failed: %s", err)
}
var wantTTL uint8
@@ -1311,8 +1312,8 @@ func TestSetTTL(t *testing.T) {
c.createEndpointForFlow(flow)
- if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
+ c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
}
var p stack.NetworkProtocol
@@ -1346,25 +1347,26 @@ func TestSetTOS(t *testing.T) {
c.createEndpointForFlow(flow)
const tos = testTOS
- var v tcpip.IPv4TOSOption
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
+ v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0)
+ c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
}
- if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
- c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err)
+ if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
+ c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err)
}
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
+ v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
}
- if want := tcpip.IPv4TOSOption(tos); v != want {
- c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want)
+ if v != tos {
+ c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos)
}
testWrite(c, flow, checker.TOS(tos, 0))
@@ -1381,25 +1383,26 @@ func TestSetTClass(t *testing.T) {
c.createEndpointForFlow(flow)
const tClass = testTOS
- var v tcpip.IPv6TrafficClassOption
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
+ v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0)
+ c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0)
}
- if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil {
- c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err)
+ if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil {
+ c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err)
}
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
+ v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
}
- if want := tcpip.IPv6TrafficClassOption(tClass); v != want {
- c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want)
+ if v != tClass {
+ c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass)
}
// The header getter for TClass is called TOS, so use that checker.
@@ -1430,7 +1433,7 @@ func TestReceiveTosTClass(t *testing.T) {
// Verify that setting and reading the option works.
v, err := c.ep.GetSockOptBool(option)
if err != nil {
- c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err)
+ c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
}
// Test for expected default value.
if v != false {
@@ -1444,7 +1447,7 @@ func TestReceiveTosTClass(t *testing.T) {
got, err := c.ep.GetSockOptBool(option)
if err != nil {
- c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err)
+ c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
}
if got != want {
@@ -1563,7 +1566,8 @@ func TestV4UnknownDestination(t *testing.T) {
}
c.injectPacket(tc.flow, payload)
if !tc.icmpRequired {
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
if p, ok := c.linkEP.ReadContext(ctx); ok {
t.Fatalf("unexpected packet received: %+v", p)
}
@@ -1571,7 +1575,8 @@ func TestV4UnknownDestination(t *testing.T) {
}
// ICMP required.
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
p, ok := c.linkEP.ReadContext(ctx)
if !ok {
t.Fatalf("packet wasn't written out")
@@ -1639,7 +1644,8 @@ func TestV6UnknownDestination(t *testing.T) {
}
c.injectPacket(tc.flow, payload)
if !tc.icmpRequired {
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
if p, ok := c.linkEP.ReadContext(ctx); ok {
t.Fatalf("unexpected packet received: %+v", p)
}
@@ -1647,7 +1653,8 @@ func TestV6UnknownDestination(t *testing.T) {
}
// ICMP required.
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
p, ok := c.linkEP.ReadContext(ctx)
if !ok {
t.Fatalf("packet wasn't written out")
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index d2f4403b0..cd6a0ea6b 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -29,9 +29,6 @@ import (
)
// IO provides access to the contents of a virtual memory space.
-//
-// FIXME(b/38173783): Implementations of IO cannot expect ctx to contain any
-// meaningful data.
type IO interface {
// CopyOut copies len(src) bytes from src to the memory mapped at addr. It
// returns the number of bytes copied. If the number of bytes copied is <
diff --git a/pkg/usermem/usermem_x86.go b/pkg/usermem/usermem_x86.go
index 8059b72d2..d96f829fb 100644
--- a/pkg/usermem/usermem_x86.go
+++ b/pkg/usermem/usermem_x86.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 i386
+// +build amd64 386
package usermem
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 26f68fe3d..5451f1eba 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -21,6 +21,7 @@ go_library(
"network.go",
"strace.go",
"user.go",
+ "vfs.go",
],
visibility = [
"//runsc:__subpackages__",
@@ -33,6 +34,7 @@ go_library(
"//pkg/control/server",
"//pkg/cpuid",
"//pkg/eventchannel",
+ "//pkg/fspath",
"//pkg/log",
"//pkg/memutil",
"//pkg/rand",
@@ -40,6 +42,7 @@ go_library(
"//pkg/sentry/arch",
"//pkg/sentry/arch:registers_go_proto",
"//pkg/sentry/control",
+ "//pkg/sentry/devices/memdev",
"//pkg/sentry/fs",
"//pkg/sentry/fs/dev",
"//pkg/sentry/fs/gofer",
@@ -49,6 +52,12 @@ go_library(
"//pkg/sentry/fs/sys",
"//pkg/sentry/fs/tmpfs",
"//pkg/sentry/fs/tty",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/fsimpl/gofer",
+ "//pkg/sentry/fsimpl/host",
+ "//pkg/sentry/fsimpl/proc",
+ "//pkg/sentry/fsimpl/sys",
+ "//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel:uncaught_signal_go_proto",
@@ -71,6 +80,7 @@ go_library(
"//pkg/sentry/time",
"//pkg/sentry/unimpl:unimplemented_syscall_go_proto",
"//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
"//pkg/sentry/watchdog",
"//pkg/sync",
"//pkg/syserror",
@@ -114,10 +124,12 @@ go_test(
"//pkg/p9",
"//pkg/sentry/contexttest",
"//pkg/sentry/fs",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sync",
"//pkg/unet",
"//runsc/fsgofer",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go
index 8995d678e..b7cfb35bf 100644
--- a/runsc/boot/compat.go
+++ b/runsc/boot/compat.go
@@ -65,7 +65,7 @@ func newCompatEmitter(logFD int) (*compatEmitter, error) {
if logFD > 0 {
f := os.NewFile(uintptr(logFD), "user log file")
- target := &log.MultiEmitter{c.sink, &log.K8sJSONEmitter{log.Writer{Next: f}}}
+ target := &log.MultiEmitter{c.sink, log.K8sJSONEmitter{&log.Writer{Next: f}}}
c.sink = &log.BasicLogger{Level: log.Info, Emitter: target}
}
return c, nil
diff --git a/runsc/boot/config.go b/runsc/boot/config.go
index 7ea5bfade..715a19112 100644
--- a/runsc/boot/config.go
+++ b/runsc/boot/config.go
@@ -305,5 +305,10 @@ func (c *Config) ToFlags() []string {
if len(c.TestOnlyTestNameEnv) != 0 {
f = append(f, "--TESTONLY-test-name-env="+c.TestOnlyTestNameEnv)
}
+
+ if c.VFS2 {
+ f = append(f, "--vfs2=true")
+ }
+
return f
}
diff --git a/runsc/boot/fds.go b/runsc/boot/fds.go
index 5314b0f2a..7e49f6f9f 100644
--- a/runsc/boot/fds.go
+++ b/runsc/boot/fds.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/host"
+ vfshost "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
"gvisor.dev/gvisor/pkg/sentry/kernel"
)
@@ -31,6 +32,10 @@ func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.F
return nil, fmt.Errorf("stdioFDs should contain exactly 3 FDs (stdin, stdout, and stderr), but %d FDs received", len(stdioFDs))
}
+ if kernel.VFS2Enabled {
+ return createFDTableVFS2(ctx, console, stdioFDs)
+ }
+
k := kernel.KernelFromContext(ctx)
fdTable := k.NewFDTable()
defer fdTable.DecRef()
@@ -78,3 +83,31 @@ func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.F
fdTable.IncRef()
return fdTable, nil
}
+
+func createFDTableVFS2(ctx context.Context, console bool, stdioFDs []int) (*kernel.FDTable, error) {
+ k := kernel.KernelFromContext(ctx)
+ fdTable := k.NewFDTable()
+ defer fdTable.DecRef()
+
+ hostMount, err := vfshost.NewMount(k.VFS())
+ if err != nil {
+ return nil, fmt.Errorf("creating host mount: %w", err)
+ }
+
+ for appFD, hostFD := range stdioFDs {
+ // TODO(gvisor.dev/issue/1482): Add TTY support.
+ appFile, err := vfshost.ImportFD(hostMount, hostFD, false)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := fdTable.NewFDAtVFS2(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil {
+ appFile.DecRef()
+ return nil, err
+ }
+ appFile.DecRef()
+ }
+
+ fdTable.IncRef()
+ return fdTable, nil
+}
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index 06b9f888a..1828d116a 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -44,7 +44,7 @@ var allowedSyscalls = seccomp.SyscallRules{
{
seccomp.AllowAny{},
seccomp.AllowAny{},
- seccomp.AllowValue(0),
+ seccomp.AllowValue(syscall.O_CLOEXEC),
},
},
syscall.SYS_EPOLL_CREATE1: {},
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index 0f62842ea..98cce60af 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -278,6 +278,9 @@ func subtargets(root string, mnts []specs.Mount) []string {
}
func setupContainerFS(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error {
+ if conf.VFS2 {
+ return setupContainerVFS2(ctx, conf, mntr, procArgs)
+ }
mns, err := mntr.setupFS(conf, procArgs)
if err != nil {
return err
@@ -573,6 +576,9 @@ func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hin
// should be mounted (e.g. a volume shared between containers). It must be
// called for the root container only.
func (c *containerMounter) processHints(conf *Config) error {
+ if conf.VFS2 {
+ return nil
+ }
ctx := c.k.SupervisorContext()
for _, hint := range c.hints.mounts {
// TODO(b/142076984): Only support tmpfs for now. Bind mounts require a
@@ -781,9 +787,6 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) (
useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
default:
- // TODO(nlacasse): Support all the mount types and make this a fatal error.
- // Most applications will "just work" without them, so this is a warning
- // for now.
log.Warningf("ignoring unknown filesystem type %q", m.Type)
}
return fsName, opts, useOverlay, nil
@@ -824,7 +827,20 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns
inode, err := filesystem.Mount(ctx, mountDevice(m), mf, strings.Join(opts, ","), nil)
if err != nil {
- return fmt.Errorf("creating mount with source %q: %v", m.Source, err)
+ err := fmt.Errorf("creating mount with source %q: %v", m.Source, err)
+ // Check to see if this is a common error due to a Linux bug.
+ // This error is generated here in order to cause it to be
+ // printed to the user using Docker via 'runsc create' etc. rather
+ // than simply printed to the logs for the 'runsc boot' command.
+ //
+ // We check the error message string rather than type because the
+ // actual error types (syscall.EIO, syscall.EPIPE) are lost by file system
+ // implementation (e.g. p9).
+ // TODO(gvisor.dev/issue/1765): Remove message when bug is resolved.
+ if strings.Contains(err.Error(), syscall.EIO.Error()) || strings.Contains(err.Error(), syscall.EPIPE.Error()) {
+ return fmt.Errorf("%v: %s", err, specutils.FaqErrorMsg("memlock", "you may be encountering a Linux kernel bug"))
+ }
+ return err
}
// If there are submounts, we need to overlay the mount on top of a ramfs
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index e7ca98134..cf1f47bc7 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -26,7 +26,6 @@ import (
specs "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
@@ -73,6 +72,8 @@ import (
_ "gvisor.dev/gvisor/pkg/sentry/socket/unix"
)
+var syscallTable *kernel.SyscallTable
+
// Loader keeps state needed to start the kernel and run the container..
type Loader struct {
// k is the kernel.
@@ -156,13 +157,17 @@ type Args struct {
Spec *specs.Spec
// Conf is the system configuration.
Conf *Config
- // ControllerFD is the FD to the URPC controller.
+ // ControllerFD is the FD to the URPC controller. The Loader takes ownership
+ // of this FD and may close it at any time.
ControllerFD int
- // Device is an optional argument that is passed to the platform.
+ // Device is an optional argument that is passed to the platform. The Loader
+ // takes ownership of this file and may close it at any time.
Device *os.File
- // GoferFDs is an array of FDs used to connect with the Gofer.
+ // GoferFDs is an array of FDs used to connect with the Gofer. The Loader
+ // takes ownership of these FDs and may close them at any time.
GoferFDs []int
- // StdioFDs is the stdio for the application.
+ // StdioFDs is the stdio for the application. The Loader takes ownership of
+ // these FDs and may close them at any time.
StdioFDs []int
// Console is set to true if using TTY.
Console bool
@@ -175,6 +180,9 @@ type Args struct {
UserLogFD int
}
+// make sure stdioFDs are always the same on initial start and on restore
+const startingStdioFD = 64
+
// New initializes a new kernel loader configured by spec.
// New also handles setting up a kernel for restoring a container.
func New(args Args) (*Loader, error) {
@@ -188,13 +196,14 @@ func New(args Args) (*Loader, error) {
return nil, fmt.Errorf("setting up memory usage: %v", err)
}
- if args.Conf.VFS2 {
- st, ok := kernel.LookupSyscallTable(abi.Linux, arch.Host)
- if ok {
- vfs2.Override(st.Table)
- }
+ // Patch the syscall table.
+ kernel.VFS2Enabled = args.Conf.VFS2
+ if kernel.VFS2Enabled {
+ vfs2.Override(syscallTable.Table)
}
+ kernel.RegisterSyscallTable(syscallTable)
+
// Create kernel and platform.
p, err := createPlatform(args.Conf, args.Device)
if err != nil {
@@ -319,6 +328,24 @@ func New(args Args) (*Loader, error) {
return nil, fmt.Errorf("creating pod mount hints: %v", err)
}
+ // Make host FDs stable between invocations. Host FDs must map to the exact
+ // same number when the sandbox is restored. Otherwise the wrong FD will be
+ // used.
+ var stdioFDs []int
+ newfd := startingStdioFD
+ for _, fd := range args.StdioFDs {
+ err := syscall.Dup3(fd, newfd, syscall.O_CLOEXEC)
+ if err != nil {
+ return nil, fmt.Errorf("dup3 of stdioFDs failed: %v", err)
+ }
+ stdioFDs = append(stdioFDs, newfd)
+ err = syscall.Close(fd)
+ if err != nil {
+ return nil, fmt.Errorf("close original stdioFDs failed: %v", err)
+ }
+ newfd++
+ }
+
eid := execID{cid: args.ID}
l := &Loader{
k: k,
@@ -327,7 +354,7 @@ func New(args Args) (*Loader, error) {
watchdog: dog,
spec: args.Spec,
goferFDs: args.GoferFDs,
- stdioFDs: args.StdioFDs,
+ stdioFDs: stdioFDs,
rootProcArgs: procArgs,
sandboxID: args.ID,
processes: map[execID]*execProcess{eid: {}},
@@ -367,11 +394,16 @@ func newProcess(id string, spec *specs.Spec, creds *auth.Credentials, k *kernel.
return kernel.CreateProcessArgs{}, fmt.Errorf("creating limits: %v", err)
}
+ wd := spec.Process.Cwd
+ if wd == "" {
+ wd = "/"
+ }
+
// Create the process arguments.
procArgs := kernel.CreateProcessArgs{
Argv: spec.Process.Args,
Envv: spec.Process.Env,
- WorkingDirectory: spec.Process.Cwd, // Defaults to '/' if empty.
+ WorkingDirectory: wd,
Credentials: creds,
Umask: 0022,
Limits: ls,
@@ -516,7 +548,15 @@ func (l *Loader) run() error {
}
// Add the HOME enviroment variable if it is not already set.
- envv, err := maybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace, l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv)
+ var envv []string
+ if kernel.VFS2Enabled {
+ envv, err = maybeAddExecUserHomeVFS2(ctx, l.rootProcArgs.MountNamespaceVFS2,
+ l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv)
+
+ } else {
+ envv, err = maybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace,
+ l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv)
+ }
if err != nil {
return err
}
@@ -569,6 +609,16 @@ func (l *Loader) run() error {
}
})
+ // l.stdioFDs are derived from dup() in boot.New() and they are now dup()ed again
+ // either in createFDTable() during initial start or in descriptor.initAfterLoad()
+ // during restore, we can release l.stdioFDs now.
+ for _, fd := range l.stdioFDs {
+ err := syscall.Close(fd)
+ if err != nil {
+ return fmt.Errorf("close dup()ed stdioFDs: %v", err)
+ }
+ }
+
log.Infof("Process should have started...")
l.watchdog.Start()
return l.k.Start()
diff --git a/runsc/boot/loader_amd64.go b/runsc/boot/loader_amd64.go
index b9669f2ac..78df86611 100644
--- a/runsc/boot/loader_amd64.go
+++ b/runsc/boot/loader_amd64.go
@@ -17,11 +17,10 @@
package boot
import (
- "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
)
func init() {
- // Register the global syscall table.
- kernel.RegisterSyscallTable(linux.AMD64)
+ // Set the global syscall table.
+ syscallTable = linux.AMD64
}
diff --git a/runsc/boot/loader_arm64.go b/runsc/boot/loader_arm64.go
index cf64d28c8..250785010 100644
--- a/runsc/boot/loader_arm64.go
+++ b/runsc/boot/loader_arm64.go
@@ -17,11 +17,10 @@
package boot
import (
- "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
)
func init() {
- // Register the global syscall table.
- kernel.RegisterSyscallTable(linux.ARM64)
+ // Set the global syscall table.
+ syscallTable = linux.ARM64
}
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index 44aa63196..e7c71734f 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -24,11 +24,13 @@ import (
"time"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/control/server"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/runsc/fsgofer"
@@ -65,6 +67,11 @@ func testSpec() *specs.Spec {
}
}
+func resetSyscallTable() {
+ kernel.VFS2Enabled = false
+ kernel.FlushSyscallTablesTestOnly()
+}
+
// startGofer starts a new gofer routine serving 'root' path. It returns the
// sandbox side of the connection, and a function that when called will stop the
// gofer.
@@ -100,7 +107,7 @@ func startGofer(root string) (int, func(), error) {
return sandboxEnd, cleanup, nil
}
-func createLoader() (*Loader, func(), error) {
+func createLoader(vfsEnabled bool) (*Loader, func(), error) {
fd, err := server.CreateSocket(ControlSocketAddr(fmt.Sprintf("%010d", rand.Int())[:10]))
if err != nil {
return nil, nil, err
@@ -108,12 +115,23 @@ func createLoader() (*Loader, func(), error) {
conf := testConfig()
spec := testSpec()
+ conf.VFS2 = vfsEnabled
+
sandEnd, cleanup, err := startGofer(spec.Root.Path)
if err != nil {
return nil, nil, err
}
- stdio := []int{int(os.Stdin.Fd()), int(os.Stdout.Fd()), int(os.Stderr.Fd())}
+ // Loader takes ownership of stdio.
+ var stdio []int
+ for _, f := range []*os.File{os.Stdin, os.Stdout, os.Stderr} {
+ newFd, err := unix.Dup(int(f.Fd()))
+ if err != nil {
+ return nil, nil, err
+ }
+ stdio = append(stdio, newFd)
+ }
+
args := Args{
ID: "foo",
Spec: spec,
@@ -132,10 +150,22 @@ func createLoader() (*Loader, func(), error) {
// TestRun runs a simple application in a sandbox and checks that it succeeds.
func TestRun(t *testing.T) {
- l, cleanup, err := createLoader()
+ defer resetSyscallTable()
+ doRun(t, false)
+}
+
+// TestRunVFS2 runs TestRun in VFSv2.
+func TestRunVFS2(t *testing.T) {
+ defer resetSyscallTable()
+ doRun(t, true)
+}
+
+func doRun(t *testing.T, vfsEnabled bool) {
+ l, cleanup, err := createLoader(vfsEnabled)
if err != nil {
t.Fatalf("error creating loader: %v", err)
}
+
defer l.Destroy()
defer cleanup()
@@ -169,7 +199,18 @@ func TestRun(t *testing.T) {
// TestStartSignal tests that the controller Start message will cause
// WaitForStartSignal to return.
func TestStartSignal(t *testing.T) {
- l, cleanup, err := createLoader()
+ defer resetSyscallTable()
+ doStartSignal(t, false)
+}
+
+// TestStartSignalVFS2 does TestStartSignal with VFS2.
+func TestStartSignalVFS2(t *testing.T) {
+ defer resetSyscallTable()
+ doStartSignal(t, true)
+}
+
+func doStartSignal(t *testing.T, vfsEnabled bool) {
+ l, cleanup, err := createLoader(vfsEnabled)
if err != nil {
t.Fatalf("error creating loader: %v", err)
}
diff --git a/runsc/boot/user.go b/runsc/boot/user.go
index f0aa52135..332e4fce5 100644
--- a/runsc/boot/user.go
+++ b/runsc/boot/user.go
@@ -23,8 +23,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -84,6 +86,48 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.K
File: f,
}
+ return findHomeInPasswd(uint32(uid), r, defaultHome)
+}
+
+type fileReaderVFS2 struct {
+ ctx context.Context
+ fd *vfs.FileDescription
+}
+
+func (r *fileReaderVFS2) Read(buf []byte) (int, error) {
+ n, err := r.fd.Read(r.ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ return int(n), err
+}
+
+func getExecUserHomeVFS2(ctx context.Context, mns *vfs.MountNamespace, uid auth.KUID) (string, error) {
+ const defaultHome = "/"
+
+ root := mns.Root()
+ defer root.DecRef()
+
+ creds := auth.CredentialsFromContext(ctx)
+
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse("/etc/passwd"),
+ }
+
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }
+
+ fd, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, target, opts)
+ if err != nil {
+ return defaultHome, nil
+ }
+ defer fd.DecRef()
+
+ r := &fileReaderVFS2{
+ ctx: ctx,
+ fd: fd,
+ }
+
homeDir, err := findHomeInPasswd(uint32(uid), r, defaultHome)
if err != nil {
return "", err
@@ -111,6 +155,26 @@ func maybeAddExecUserHome(ctx context.Context, mns *fs.MountNamespace, uid auth.
if err != nil {
return nil, fmt.Errorf("error reading exec user: %v", err)
}
+
+ return append(envv, "HOME="+homeDir), nil
+}
+
+func maybeAddExecUserHomeVFS2(ctx context.Context, vmns *vfs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) {
+ // Check if the envv already contains HOME.
+ for _, env := range envv {
+ if strings.HasPrefix(env, "HOME=") {
+ // We have it. Return the original slice unmodified.
+ return envv, nil
+ }
+ }
+
+ // Read /etc/passwd for the user's HOME directory and set the HOME
+ // environment variable as required by POSIX if it is not overridden by
+ // the user.
+ homeDir, err := getExecUserHomeVFS2(ctx, vmns, uid)
+ if err != nil {
+ return nil, fmt.Errorf("error reading exec user: %v", err)
+ }
return append(envv, "HOME="+homeDir), nil
}
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
new file mode 100644
index 000000000..82083c57d
--- /dev/null
+++ b/runsc/boot/vfs.go
@@ -0,0 +1,310 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "fmt"
+ "path"
+ "strconv"
+ "strings"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/devices/memdev"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ devtmpfsimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ goferimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/gofer"
+ procimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc"
+ sysimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys"
+ tmpfsimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func registerFilesystems(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) error {
+
+ vfsObj.MustRegisterFilesystemType(rootFsName, &goferimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserList: true,
+ })
+
+ vfsObj.MustRegisterFilesystemType(bind, &goferimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserList: true,
+ })
+
+ vfsObj.MustRegisterFilesystemType(devpts, &devtmpfsimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+
+ vfsObj.MustRegisterFilesystemType(devtmpfs, &devtmpfsimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(proc, &procimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(sysfs, &sysimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(tmpfs, &tmpfsimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(nonefs, &sysimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+
+ // Setup files in devtmpfs.
+ if err := memdev.Register(vfsObj); err != nil {
+ return fmt.Errorf("registering memdev: %w", err)
+ }
+ a, err := devtmpfsimpl.NewAccessor(ctx, vfsObj, creds, devtmpfsimpl.Name)
+ if err != nil {
+ return fmt.Errorf("creating devtmpfs accessor: %w", err)
+ }
+ defer a.Release()
+
+ if err := a.UserspaceInit(ctx); err != nil {
+ return fmt.Errorf("initializing userspace: %w", err)
+ }
+ if err := memdev.CreateDevtmpfsFiles(ctx, a); err != nil {
+ return fmt.Errorf("creating devtmpfs files: %w", err)
+ }
+ return nil
+}
+
+func setupContainerVFS2(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error {
+ if err := mntr.k.VFS().Init(); err != nil {
+ return fmt.Errorf("failed to initialize VFS: %w", err)
+ }
+ mns, err := mntr.setupVFS2(ctx, conf, procArgs)
+ if err != nil {
+ return fmt.Errorf("failed to setupFS: %w", err)
+ }
+ procArgs.MountNamespaceVFS2 = mns
+ return setExecutablePathVFS2(ctx, procArgs)
+}
+
+func setExecutablePathVFS2(ctx context.Context, procArgs *kernel.CreateProcessArgs) error {
+
+ exe := procArgs.Argv[0]
+
+ // Absolute paths can be used directly.
+ if path.IsAbs(exe) {
+ procArgs.Filename = exe
+ return nil
+ }
+
+ // Paths with '/' in them should be joined to the working directory, or
+ // to the root if working directory is not set.
+ if strings.IndexByte(exe, '/') > 0 {
+
+ if !path.IsAbs(procArgs.WorkingDirectory) {
+ return fmt.Errorf("working directory %q must be absolute", procArgs.WorkingDirectory)
+ }
+
+ procArgs.Filename = path.Join(procArgs.WorkingDirectory, exe)
+ return nil
+ }
+
+ // Paths with a '/' are relative to the CWD.
+ if strings.IndexByte(exe, '/') > 0 {
+ procArgs.Filename = path.Join(procArgs.WorkingDirectory, exe)
+ return nil
+ }
+
+ // Otherwise, We must lookup the name in the paths, starting from the
+ // root directory.
+ root := procArgs.MountNamespaceVFS2.Root()
+ defer root.DecRef()
+
+ paths := fs.GetPath(procArgs.Envv)
+ creds := procArgs.Credentials
+
+ for _, p := range paths {
+
+ binPath := path.Join(p, exe)
+
+ pop := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(binPath),
+ FollowFinalSymlink: true,
+ }
+
+ opts := &vfs.OpenOptions{
+ FileExec: true,
+ Flags: linux.O_RDONLY,
+ }
+
+ dentry, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, pop, opts)
+ if err == syserror.ENOENT || err == syserror.EACCES {
+ // Didn't find it here.
+ continue
+ }
+ if err != nil {
+ return err
+ }
+ dentry.DecRef()
+
+ procArgs.Filename = binPath
+ return nil
+ }
+
+ return fmt.Errorf("executable %q not found in $PATH=%q", exe, strings.Join(paths, ":"))
+}
+
+func (c *containerMounter) setupVFS2(ctx context.Context, conf *Config, procArgs *kernel.CreateProcessArgs) (*vfs.MountNamespace, error) {
+ log.Infof("Configuring container's file system with VFS2")
+
+ // Create context with root credentials to mount the filesystem (the current
+ // user may not be privileged enough).
+ rootProcArgs := *procArgs
+ rootProcArgs.WorkingDirectory = "/"
+ rootProcArgs.Credentials = auth.NewRootCredentials(procArgs.Credentials.UserNamespace)
+ rootProcArgs.Umask = 0022
+ rootProcArgs.MaxSymlinkTraversals = linux.MaxSymlinkTraversals
+ rootCtx := procArgs.NewContext(c.k)
+
+ creds := procArgs.Credentials
+ if err := registerFilesystems(rootCtx, c.k.VFS(), creds); err != nil {
+ return nil, fmt.Errorf("register filesystems: %w", err)
+ }
+
+ fd := c.fds.remove()
+
+ opts := strings.Join(p9MountOptionsVFS2(fd, conf.FileAccess), ",")
+
+ log.Infof("Mounting root over 9P, ioFD: %d", fd)
+ mns, err := c.k.VFS().NewMountNamespace(ctx, creds, "", rootFsName, &vfs.GetFilesystemOptions{Data: opts})
+ if err != nil {
+ return nil, fmt.Errorf("setting up mountnamespace: %w", err)
+ }
+
+ rootProcArgs.MountNamespaceVFS2 = mns
+
+ // Mount submounts.
+ if err := c.mountSubmountsVFS2(rootCtx, conf, mns, creds); err != nil {
+ return nil, fmt.Errorf("mounting submounts vfs2: %w", err)
+ }
+
+ return mns, nil
+}
+
+func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials) error {
+
+ for _, submount := range c.mounts {
+ log.Debugf("Mounting %q to %q, type: %s, options: %s", submount.Source, submount.Destination, submount.Type, submount.Options)
+ if err := c.mountSubmountVFS2(ctx, conf, mns, creds, &submount); err != nil {
+ return err
+ }
+ }
+
+ // TODO(gvisor.dev/issue/1487): implement mountTmp from fs.go.
+
+ return c.checkDispenser()
+}
+
+// TODO(gvisor.dev/issue/1487): Implement submount options similar to the VFS1 version.
+func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials, submount *specs.Mount) error {
+ root := mns.Root()
+ defer root.DecRef()
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(submount.Destination),
+ }
+
+ _, options, useOverlay, err := c.getMountNameAndOptionsVFS2(conf, *submount)
+ if err != nil {
+ return fmt.Errorf("mountOptions failed: %w", err)
+ }
+
+ opts := &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ Data: strings.Join(options, ","),
+ },
+ InternalMount: true,
+ }
+
+ // All writes go to upper, be paranoid and make lower readonly.
+ opts.ReadOnly = useOverlay
+
+ if err := c.k.VFS().MountAt(ctx, creds, "", target, submount.Type, opts); err != nil {
+ return fmt.Errorf("failed to mount %q (type: %s): %w, opts: %v", submount.Destination, submount.Type, err, opts)
+ }
+ log.Infof("Mounted %q to %q type: %s, internal-options: %q", submount.Source, submount.Destination, submount.Type, opts)
+ return nil
+}
+
+// getMountNameAndOptionsVFS2 retrieves the fsName, opts, and useOverlay values
+// used for mounts.
+func (c *containerMounter) getMountNameAndOptionsVFS2(conf *Config, m specs.Mount) (string, []string, bool, error) {
+ var (
+ fsName string
+ opts []string
+ useOverlay bool
+ )
+
+ switch m.Type {
+ case devpts, devtmpfs, proc, sysfs:
+ fsName = m.Type
+ case nonefs:
+ fsName = sysfs
+ case tmpfs:
+ fsName = m.Type
+
+ var err error
+ opts, err = parseAndFilterOptions(m.Options, tmpfsAllowedOptions...)
+ if err != nil {
+ return "", nil, false, err
+ }
+
+ case bind:
+ fd := c.fds.remove()
+ fsName = "9p"
+ opts = p9MountOptionsVFS2(fd, c.getMountAccessType(m))
+ // If configured, add overlay to all writable mounts.
+ useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
+
+ default:
+ log.Warningf("ignoring unknown filesystem type %q", m.Type)
+ }
+ return fsName, opts, useOverlay, nil
+}
+
+// p9MountOptions creates a slice of options for a p9 mount.
+// TODO(gvisor.dev/issue/1200): Remove this version in favor of the one in
+// fs.go when privateunixsocket lands.
+func p9MountOptionsVFS2(fd int, fa FileAccessType) []string {
+ opts := []string{
+ "trans=fd",
+ "rfdno=" + strconv.Itoa(fd),
+ "wfdno=" + strconv.Itoa(fd),
+ }
+ if fa == FileAccessShared {
+ opts = append(opts, "cache=remote_revalidating")
+ }
+ return opts
+}
diff --git a/runsc/cmd/capability_test.go b/runsc/cmd/capability_test.go
index 0c27f7313..9360d7442 100644
--- a/runsc/cmd/capability_test.go
+++ b/runsc/cmd/capability_test.go
@@ -85,7 +85,7 @@ func TestCapabilities(t *testing.T) {
Inheritable: caps,
}
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
// Use --network=host to make sandbox use spec's capabilities.
conf.Network = boot.NetworkHost
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 02e5af3d3..28f0d54b9 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -272,9 +272,8 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error {
root := spec.Root.Path
if !conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
- // FIXME: runsc can't be re-executed without
- // /proc, so we create a tmpfs mount, mount ./proc and ./root
- // there, then move this mount to the root and after
+ // runsc can't be re-executed without /proc, so we create a tmpfs mount,
+ // mount ./proc and ./root there, then move this mount to the root and after
// setCapsAndCallSelf, runsc will chroot into /root.
//
// We need a directory to construct a new root and we know that
diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go
index 651615d4c..af245b6d8 100644
--- a/runsc/container/console_test.go
+++ b/runsc/container/console_test.go
@@ -118,7 +118,7 @@ func receiveConsolePTY(srv *unet.ServerSocket) (*os.File, error) {
// Test that an pty FD is sent over the console socket if one is provided.
func TestConsoleSocket(t *testing.T) {
- for _, conf := range configs(all...) {
+ for _, conf := range configs(t, all...) {
t.Logf("Running test with conf: %+v", conf)
spec := testutil.NewSpecWithArgs("true")
rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
@@ -163,7 +163,7 @@ func TestConsoleSocket(t *testing.T) {
// Test that job control signals work on a console created with "exec -ti".
func TestJobControlSignalExec(t *testing.T) {
spec := testutil.NewSpecWithArgs("/bin/sleep", "10000")
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
if err != nil {
@@ -286,7 +286,7 @@ func TestJobControlSignalExec(t *testing.T) {
// Test that job control signals work on a console created with "run -ti".
func TestJobControlSignalRootContainer(t *testing.T) {
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
// Don't let bash execute from profile or rc files, otherwise our PID
// counts get messed up.
spec := testutil.NewSpecWithArgs("/bin/bash", "--noprofile", "--norc")
diff --git a/runsc/container/container.go b/runsc/container/container.go
index c9839044c..7233659b1 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -1077,9 +1077,9 @@ func (c *Container) adjustGoferOOMScoreAdj() error {
// oom_score_adj is set to the lowest oom_score_adj among the containers
// running in the sandbox.
//
-// TODO(gvisor.dev/issue/512): This call could race with other containers being
+// TODO(gvisor.dev/issue/238): This call could race with other containers being
// created at the same time and end up setting the wrong oom_score_adj to the
-// sandbox.
+// sandbox. Use rpc client to synchronize.
func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) error {
containers, err := loadSandbox(rootDir, s.ID)
if err != nil {
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index 442e80ac0..5db6d64aa 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -251,12 +251,12 @@ var noOverlay = []configOption{kvm, nonExclusiveFS}
var all = append(noOverlay, overlay)
// configs generates different configurations to run tests.
-func configs(opts ...configOption) []*boot.Config {
+func configs(t *testing.T, opts ...configOption) []*boot.Config {
// Always load the default config.
- cs := []*boot.Config{testutil.TestConfig()}
+ cs := []*boot.Config{testutil.TestConfig(t)}
for _, o := range opts {
- c := testutil.TestConfig()
+ c := testutil.TestConfig(t)
switch o {
case overlay:
c.Overlay = true
@@ -285,7 +285,7 @@ func TestLifecycle(t *testing.T) {
childReaper.Start()
defer childReaper.Stop()
- for _, conf := range configs(all...) {
+ for _, conf := range configs(t, all...) {
t.Logf("Running test with conf: %+v", conf)
// The container will just sleep for a long time. We will kill it before
// it finishes sleeping.
@@ -457,7 +457,7 @@ func TestExePath(t *testing.T) {
t.Fatal(err)
}
- for _, conf := range configs(overlay) {
+ for _, conf := range configs(t, overlay) {
t.Logf("Running test with conf: %+v", conf)
for _, test := range []struct {
path string
@@ -521,9 +521,19 @@ func TestExePath(t *testing.T) {
// Test the we can retrieve the application exit status from the container.
func TestAppExitStatus(t *testing.T) {
+ doAppExitStatus(t, false)
+}
+
+// This is TestAppExitStatus for VFSv2.
+func TestAppExitStatusVFS2(t *testing.T) {
+ doAppExitStatus(t, true)
+}
+
+func doAppExitStatus(t *testing.T, vfs2 bool) {
// First container will succeed.
succSpec := testutil.NewSpecWithArgs("true")
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
+ conf.VFS2 = vfs2
rootDir, bundleDir, err := testutil.SetupContainer(succSpec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -573,7 +583,7 @@ func TestAppExitStatus(t *testing.T) {
// TestExec verifies that a container can exec a new program.
func TestExec(t *testing.T) {
- for _, conf := range configs(overlay) {
+ for _, conf := range configs(t, overlay) {
t.Logf("Running test with conf: %+v", conf)
const uid = 343
@@ -667,7 +677,7 @@ func TestExec(t *testing.T) {
// TestKillPid verifies that we can signal individual exec'd processes.
func TestKillPid(t *testing.T) {
- for _, conf := range configs(overlay) {
+ for _, conf := range configs(t, overlay) {
t.Logf("Running test with conf: %+v", conf)
app, err := testutil.FindFile("runsc/container/test_app/test_app")
@@ -743,7 +753,7 @@ func TestKillPid(t *testing.T) {
// be the next consecutive number after the last number from the checkpointed container.
func TestCheckpointRestore(t *testing.T) {
// Skip overlay because test requires writing to host file.
- for _, conf := range configs(noOverlay...) {
+ for _, conf := range configs(t, noOverlay...) {
t.Logf("Running test with conf: %+v", conf)
dir, err := ioutil.TempDir(testutil.TmpDir(), "checkpoint-test")
@@ -904,7 +914,7 @@ func TestCheckpointRestore(t *testing.T) {
// with filesystem Unix Domain Socket use.
func TestUnixDomainSockets(t *testing.T) {
// Skip overlay because test requires writing to host file.
- for _, conf := range configs(noOverlay...) {
+ for _, conf := range configs(t, noOverlay...) {
t.Logf("Running test with conf: %+v", conf)
// UDS path is limited to 108 chars for compatibility with older systems.
@@ -1042,7 +1052,7 @@ func TestUnixDomainSockets(t *testing.T) {
// recreated. Then it resumes the container, verify that the file gets created
// again.
func TestPauseResume(t *testing.T) {
- for _, conf := range configs(noOverlay...) {
+ for _, conf := range configs(t, noOverlay...) {
t.Run(fmt.Sprintf("conf: %+v", conf), func(t *testing.T) {
t.Logf("Running test with conf: %+v", conf)
@@ -1123,7 +1133,7 @@ func TestPauseResume(t *testing.T) {
// occurs given the correct state.
func TestPauseResumeStatus(t *testing.T) {
spec := testutil.NewSpecWithArgs("sleep", "20")
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -1189,7 +1199,7 @@ func TestCapabilities(t *testing.T) {
uid := auth.KUID(os.Getuid() + 1)
gid := auth.KGID(os.Getgid() + 1)
- for _, conf := range configs(all...) {
+ for _, conf := range configs(t, all...) {
t.Logf("Running test with conf: %+v", conf)
spec := testutil.NewSpecWithArgs("sleep", "100")
@@ -1278,7 +1288,7 @@ func TestCapabilities(t *testing.T) {
// TestRunNonRoot checks that sandbox can be configured when running as
// non-privileged user.
func TestRunNonRoot(t *testing.T) {
- for _, conf := range configs(noOverlay...) {
+ for _, conf := range configs(t, noOverlay...) {
t.Logf("Running test with conf: %+v", conf)
spec := testutil.NewSpecWithArgs("/bin/true")
@@ -1322,7 +1332,7 @@ func TestRunNonRoot(t *testing.T) {
// TestMountNewDir checks that runsc will create destination directory if it
// doesn't exit.
func TestMountNewDir(t *testing.T) {
- for _, conf := range configs(overlay) {
+ for _, conf := range configs(t, overlay) {
t.Logf("Running test with conf: %+v", conf)
root, err := ioutil.TempDir(testutil.TmpDir(), "root")
@@ -1351,7 +1361,7 @@ func TestMountNewDir(t *testing.T) {
}
func TestReadonlyRoot(t *testing.T) {
- for _, conf := range configs(overlay) {
+ for _, conf := range configs(t, overlay) {
t.Logf("Running test with conf: %+v", conf)
spec := testutil.NewSpecWithArgs("/bin/touch", "/foo")
@@ -1389,7 +1399,7 @@ func TestReadonlyRoot(t *testing.T) {
}
func TestUIDMap(t *testing.T) {
- for _, conf := range configs(noOverlay...) {
+ for _, conf := range configs(t, noOverlay...) {
t.Logf("Running test with conf: %+v", conf)
testDir, err := ioutil.TempDir(testutil.TmpDir(), "test-mount")
if err != nil {
@@ -1470,7 +1480,7 @@ func TestUIDMap(t *testing.T) {
}
func TestReadonlyMount(t *testing.T) {
- for _, conf := range configs(overlay) {
+ for _, conf := range configs(t, overlay) {
t.Logf("Running test with conf: %+v", conf)
dir, err := ioutil.TempDir(testutil.TmpDir(), "ro-mount")
@@ -1527,7 +1537,7 @@ func TestAbbreviatedIDs(t *testing.T) {
}
defer os.RemoveAll(rootDir)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
cids := []string{
@@ -1585,7 +1595,7 @@ func TestAbbreviatedIDs(t *testing.T) {
func TestGoferExits(t *testing.T) {
spec := testutil.NewSpecWithArgs("/bin/sleep", "10000")
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -1654,7 +1664,7 @@ func TestRootNotMount(t *testing.T) {
spec.Root.Readonly = true
spec.Mounts = nil
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
if err := run(spec, conf); err != nil {
t.Fatalf("error running sandbox: %v", err)
}
@@ -1668,7 +1678,7 @@ func TestUserLog(t *testing.T) {
// sched_rr_get_interval = 148 - not implemented in gvisor.
spec := testutil.NewSpecWithArgs(app, "syscall", "--syscall=148")
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -1708,7 +1718,7 @@ func TestUserLog(t *testing.T) {
}
func TestWaitOnExitedSandbox(t *testing.T) {
- for _, conf := range configs(all...) {
+ for _, conf := range configs(t, all...) {
t.Logf("Running test with conf: %+v", conf)
// Run a shell that sleeps for 1 second and then exits with a
@@ -1763,7 +1773,7 @@ func TestWaitOnExitedSandbox(t *testing.T) {
func TestDestroyNotStarted(t *testing.T) {
spec := testutil.NewSpecWithArgs("/bin/sleep", "100")
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -1790,7 +1800,7 @@ func TestDestroyNotStarted(t *testing.T) {
func TestDestroyStarting(t *testing.T) {
for i := 0; i < 10; i++ {
spec := testutil.NewSpecWithArgs("/bin/sleep", "100")
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -1835,7 +1845,7 @@ func TestDestroyStarting(t *testing.T) {
}
func TestCreateWorkingDir(t *testing.T) {
- for _, conf := range configs(overlay) {
+ for _, conf := range configs(t, overlay) {
t.Logf("Running test with conf: %+v", conf)
tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "cwd-create")
@@ -1908,7 +1918,7 @@ func TestMountPropagation(t *testing.T) {
},
}
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -1959,7 +1969,7 @@ func TestMountPropagation(t *testing.T) {
}
func TestMountSymlink(t *testing.T) {
- for _, conf := range configs(overlay) {
+ for _, conf := range configs(t, overlay) {
t.Logf("Running test with conf: %+v", conf)
dir, err := ioutil.TempDir(testutil.TmpDir(), "mount-symlink")
@@ -2039,7 +2049,7 @@ func TestNetRaw(t *testing.T) {
}
for _, enableRaw := range []bool{true, false} {
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.EnableRaw = enableRaw
test := "--enabled"
@@ -2056,7 +2066,7 @@ func TestNetRaw(t *testing.T) {
// TestOverlayfsStaleRead most basic test that '--overlayfs-stale-read' works.
func TestOverlayfsStaleRead(t *testing.T) {
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.OverlayfsStaleRead = true
in, err := ioutil.TempFile(testutil.TmpDir(), "stale-read.in")
@@ -2120,7 +2130,7 @@ func TestTTYField(t *testing.T) {
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
// We will run /bin/sleep, possibly with an open TTY.
cmd := []string{"/bin/sleep", "10000"}
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
index 2da93ec5b..dc2fb42ce 100644
--- a/runsc/container/multi_container_test.go
+++ b/runsc/container/multi_container_test.go
@@ -135,7 +135,7 @@ func createSharedMount(mount specs.Mount, name string, pod ...*specs.Spec) {
// TestMultiContainerSanity checks that it is possible to run 2 dead-simple
// containers in the same sandbox.
func TestMultiContainerSanity(t *testing.T) {
- for _, conf := range configs(all...) {
+ for _, conf := range configs(t, all...) {
t.Logf("Running test with conf: %+v", conf)
rootDir, err := testutil.SetupRootDir()
@@ -173,7 +173,7 @@ func TestMultiContainerSanity(t *testing.T) {
// TestMultiPIDNS checks that it is possible to run 2 dead-simple
// containers in the same sandbox with different pidns.
func TestMultiPIDNS(t *testing.T) {
- for _, conf := range configs(all...) {
+ for _, conf := range configs(t, all...) {
t.Logf("Running test with conf: %+v", conf)
rootDir, err := testutil.SetupRootDir()
@@ -218,7 +218,7 @@ func TestMultiPIDNS(t *testing.T) {
// TestMultiPIDNSPath checks the pidns path.
func TestMultiPIDNSPath(t *testing.T) {
- for _, conf := range configs(all...) {
+ for _, conf := range configs(t, all...) {
t.Logf("Running test with conf: %+v", conf)
rootDir, err := testutil.SetupRootDir()
@@ -289,7 +289,7 @@ func TestMultiContainerWait(t *testing.T) {
}
defer os.RemoveAll(rootDir)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
// The first container should run the entire duration of the test.
@@ -367,7 +367,7 @@ func TestExecWait(t *testing.T) {
}
defer os.RemoveAll(rootDir)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
// The first container should run the entire duration of the test.
@@ -463,7 +463,7 @@ func TestMultiContainerMount(t *testing.T) {
}
defer os.RemoveAll(rootDir)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
containers, cleanup, err := startContainers(conf, sps, ids)
@@ -484,7 +484,7 @@ func TestMultiContainerMount(t *testing.T) {
// TestMultiContainerSignal checks that it is possible to signal individual
// containers without killing the entire sandbox.
func TestMultiContainerSignal(t *testing.T) {
- for _, conf := range configs(all...) {
+ for _, conf := range configs(t, all...) {
t.Logf("Running test with conf: %+v", conf)
rootDir, err := testutil.SetupRootDir()
@@ -585,7 +585,7 @@ func TestMultiContainerDestroy(t *testing.T) {
t.Fatal("error finding test_app:", err)
}
- for _, conf := range configs(all...) {
+ for _, conf := range configs(t, all...) {
t.Logf("Running test with conf: %+v", conf)
rootDir, err := testutil.SetupRootDir()
@@ -653,7 +653,7 @@ func TestMultiContainerProcesses(t *testing.T) {
}
defer os.RemoveAll(rootDir)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
// Note: use curly braces to keep 'sh' process around. Otherwise, shell
@@ -712,7 +712,7 @@ func TestMultiContainerKillAll(t *testing.T) {
}
defer os.RemoveAll(rootDir)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
for _, tc := range []struct {
@@ -804,7 +804,7 @@ func TestMultiContainerDestroyNotStarted(t *testing.T) {
[]string{"/bin/sleep", "100"},
[]string{"/bin/sleep", "100"})
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
rootDir, rootBundleDir, err := testutil.SetupContainer(specs[0], conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -858,7 +858,7 @@ func TestMultiContainerDestroyStarting(t *testing.T) {
}
specs, ids := createSpecs(cmds...)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
rootDir, rootBundleDir, err := testutil.SetupContainer(specs[0], conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -943,7 +943,7 @@ func TestMultiContainerDifferentFilesystems(t *testing.T) {
}
defer os.RemoveAll(rootDir)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
// Make sure overlay is enabled, and none of the root filesystems are
@@ -1006,7 +1006,7 @@ func TestMultiContainerContainerDestroyStress(t *testing.T) {
childrenSpecs := allSpecs[1:]
childrenIDs := allIDs[1:]
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
rootDir, bundleDir, err := testutil.SetupContainer(rootSpec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -1080,7 +1080,7 @@ func TestMultiContainerContainerDestroyStress(t *testing.T) {
// Test that pod shared mounts are properly mounted in 2 containers and that
// changes from one container is reflected in the other.
func TestMultiContainerSharedMount(t *testing.T) {
- for _, conf := range configs(all...) {
+ for _, conf := range configs(t, all...) {
t.Logf("Running test with conf: %+v", conf)
rootDir, err := testutil.SetupRootDir()
@@ -1195,7 +1195,7 @@ func TestMultiContainerSharedMount(t *testing.T) {
// Test that pod mounts are mounted as readonly when requested.
func TestMultiContainerSharedMountReadonly(t *testing.T) {
- for _, conf := range configs(all...) {
+ for _, conf := range configs(t, all...) {
t.Logf("Running test with conf: %+v", conf)
rootDir, err := testutil.SetupRootDir()
@@ -1262,7 +1262,7 @@ func TestMultiContainerSharedMountReadonly(t *testing.T) {
// Test that shared pod mounts continue to work after container is restarted.
func TestMultiContainerSharedMountRestart(t *testing.T) {
- for _, conf := range configs(all...) {
+ for _, conf := range configs(t, all...) {
t.Logf("Running test with conf: %+v", conf)
rootDir, err := testutil.SetupRootDir()
@@ -1381,7 +1381,7 @@ func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) {
}
defer os.RemoveAll(rootDir)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
// Setup the containers.
@@ -1463,7 +1463,7 @@ func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) {
}
defer os.RemoveAll(rootDir)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
// Create the specs.
@@ -1500,7 +1500,7 @@ func TestMultiContainerGoferKilled(t *testing.T) {
}
defer os.RemoveAll(rootDir)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
sleep := []string{"sleep", "100"}
@@ -1587,7 +1587,7 @@ func TestMultiContainerLoadSandbox(t *testing.T) {
}
defer os.RemoveAll(rootDir)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
// Create containers for the sandbox.
@@ -1687,7 +1687,7 @@ func TestMultiContainerRunNonRoot(t *testing.T) {
}
defer os.RemoveAll(rootDir)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
pod, cleanup, err := startContainers(conf, podSpecs, ids)
diff --git a/runsc/container/shared_volume_test.go b/runsc/container/shared_volume_test.go
index dc4194134..f80852414 100644
--- a/runsc/container/shared_volume_test.go
+++ b/runsc/container/shared_volume_test.go
@@ -31,7 +31,7 @@ import (
// TestSharedVolume checks that modifications to a volume mount are propagated
// into and out of the sandbox.
func TestSharedVolume(t *testing.T) {
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.FileAccess = boot.FileAccessShared
t.Logf("Running test with conf: %+v", conf)
@@ -190,7 +190,7 @@ func checkFile(c *Container, filename string, want []byte) error {
// TestSharedVolumeFile tests that changes to file content outside the sandbox
// is reflected inside.
func TestSharedVolumeFile(t *testing.T) {
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.FileAccess = boot.FileAccessShared
t.Logf("Running test with conf: %+v", conf)
diff --git a/runsc/container/test_app/test_app.go b/runsc/container/test_app/test_app.go
index 01c47c79f..5f1c4b7d6 100644
--- a/runsc/container/test_app/test_app.go
+++ b/runsc/container/test_app/test_app.go
@@ -96,7 +96,7 @@ func (c *uds) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{})
listener, err := net.Listen("unix", c.socketPath)
if err != nil {
- log.Fatal("error listening on socket %q:", c.socketPath, err)
+ log.Fatalf("error listening on socket %q: %v", c.socketPath, err)
}
go server(listener, outputFile)
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index cadd83273..1942f50d7 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -767,22 +767,18 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
return err
}
-// TODO(b/127675828): support getxattr.
func (*localFile) GetXattr(string, uint64) (string, error) {
return "", syscall.EOPNOTSUPP
}
-// TODO(b/127675828): support setxattr.
func (*localFile) SetXattr(string, string, uint32) error {
return syscall.EOPNOTSUPP
}
-// TODO(b/148303075): support listxattr.
func (*localFile) ListXattr(uint64) (map[string]struct{}, error) {
return nil, syscall.EOPNOTSUPP
}
-// TODO(b/148303075): support removexattr.
func (*localFile) RemoveXattr(string) error {
return syscall.EOPNOTSUPP
}
diff --git a/runsc/main.go b/runsc/main.go
index 62e184ec9..2baba90f8 100644
--- a/runsc/main.go
+++ b/runsc/main.go
@@ -84,6 +84,7 @@ var (
rootless = flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.")
referenceLeakMode = flag.String("ref-leak-mode", "disabled", "sets reference leak check mode: disabled (default), log-names, log-traces.")
cpuNumFromQuota = flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)")
+ vfs2Enabled = flag.Bool("vfs2", false, "TEST ONLY; use while VFSv2 is landing. This uses the new experimental VFS layer.")
// Test flags, not to be used outside tests, ever.
testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.")
@@ -230,6 +231,7 @@ func main() {
ReferenceLeakMode: refsLeakMode,
OverlayfsStaleRead: *overlayfsStaleRead,
CPUNumFromQuota: *cpuNumFromQuota,
+ VFS2: *vfs2Enabled,
TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot,
TestOnlyTestNameEnv: *testOnlyTestNameEnv,
@@ -294,9 +296,7 @@ func main() {
if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil {
cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err)
}
- }
-
- if *alsoLogToStderr {
+ } else if *alsoLogToStderr {
e = &log.MultiEmitter{e, newEmitter(*debugLogFormat, os.Stderr)}
}
@@ -313,6 +313,7 @@ func main() {
log.Infof("\t\tFileAccess: %v, overlay: %t", conf.FileAccess, conf.Overlay)
log.Infof("\t\tNetwork: %v, logging: %t", conf.Network, conf.LogPackets)
log.Infof("\t\tStrace: %t, max size: %d, syscalls: %s", conf.Strace, conf.StraceLogSize, conf.StraceSyscalls)
+ log.Infof("\t\tVFS2 enabled: %v", conf.VFS2)
log.Infof("***************************")
if *testOnlyAllowRunAsCurrentUserWithoutChroot {
@@ -342,11 +343,11 @@ func main() {
func newEmitter(format string, logFile io.Writer) log.Emitter {
switch format {
case "text":
- return &log.GoogleEmitter{log.Writer{Next: logFile}}
+ return log.GoogleEmitter{&log.Writer{Next: logFile}}
case "json":
- return &log.JSONEmitter{log.Writer{Next: logFile}}
+ return log.JSONEmitter{&log.Writer{Next: logFile}}
case "json-k8s":
- return &log.K8sJSONEmitter{log.Writer{Next: logFile}}
+ return log.K8sJSONEmitter{&log.Writer{Next: logFile}}
}
cmd.Fatalf("invalid log format %q, must be 'text', 'json', or 'json-k8s'", format)
panic("unreachable")
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index 6c15727fa..e82bcef6f 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -18,10 +18,12 @@ package sandbox
import (
"context"
"fmt"
+ "io"
"math"
"os"
"os/exec"
"strconv"
+ "strings"
"syscall"
"time"
@@ -142,7 +144,19 @@ func New(conf *boot.Config, args *Args) (*Sandbox, error) {
// Wait until the sandbox has booted.
b := make([]byte, 1)
if l, err := clientSyncFile.Read(b); err != nil || l != 1 {
- return nil, fmt.Errorf("waiting for sandbox to start: %v", err)
+ err := fmt.Errorf("waiting for sandbox to start: %v", err)
+ // If the sandbox failed to start, it may be because the binary
+ // permissions were incorrect. Check the bits and return a more helpful
+ // error message.
+ //
+ // NOTE: The error message is checked because error types are lost over
+ // rpc calls.
+ if strings.Contains(err.Error(), io.EOF.Error()) {
+ if permsErr := checkBinaryPermissions(conf); permsErr != nil {
+ return nil, fmt.Errorf("%v: %v", err, permsErr)
+ }
+ }
+ return nil, err
}
c.Release()
@@ -388,8 +402,6 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nextFD++
}
- cmd.Args = append(cmd.Args, "--panic-signal="+strconv.Itoa(int(syscall.SIGTERM)))
-
// Add the "boot" command to the args.
//
// All flags after this must be for the boot command
@@ -444,6 +456,12 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nextFD++
}
+ // TODO(b/151157106): syscall tests fail by timeout if asyncpreemptoff
+ // isn't set.
+ if conf.Platform == "kvm" {
+ cmd.Env = append(cmd.Env, "GODEBUG=asyncpreemptoff=1")
+ }
+
// The current process' stdio must be passed to the application via the
// --stdio-fds flag. The stdio of the sandbox process itself must not
// be connected to the same FDs, otherwise we risk leaking sandbox
@@ -582,45 +600,32 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nss = append(nss, specs.LinuxNamespace{Type: specs.UserNamespace})
cmd.Args = append(cmd.Args, "--setup-root")
+ const nobody = 65534
if conf.Rootless {
- log.Infof("Rootless mode: sandbox will run as root inside user namespace, mapped to the current user, uid: %d, gid: %d", os.Getuid(), os.Getgid())
+ log.Infof("Rootless mode: sandbox will run as nobody inside user namespace, mapped to the current user, uid: %d, gid: %d", os.Getuid(), os.Getgid())
cmd.SysProcAttr.UidMappings = []syscall.SysProcIDMap{
{
- ContainerID: 0,
+ ContainerID: nobody,
HostID: os.Getuid(),
Size: 1,
},
}
cmd.SysProcAttr.GidMappings = []syscall.SysProcIDMap{
{
- ContainerID: 0,
+ ContainerID: nobody,
HostID: os.Getgid(),
Size: 1,
},
}
- cmd.SysProcAttr.Credential = &syscall.Credential{Uid: 0, Gid: 0}
} else {
// Map nobody in the new namespace to nobody in the parent namespace.
//
// A sandbox process will construct an empty
- // root for itself, so it has to have the CAP_SYS_ADMIN
- // capability.
- //
- // FIXME(b/122554829): The current implementations of
- // os/exec doesn't allow to set ambient capabilities if
- // a process is started in a new user namespace. As a
- // workaround, we start the sandbox process with the 0
- // UID and then it constructs a chroot and sets UID to
- // nobody. https://github.com/golang/go/issues/2315
- const nobody = 65534
+ // root for itself, so it has to have
+ // CAP_SYS_ADMIN and CAP_SYS_CHROOT capabilities.
cmd.SysProcAttr.UidMappings = []syscall.SysProcIDMap{
{
- ContainerID: 0,
- HostID: nobody - 1,
- Size: 1,
- },
- {
ContainerID: nobody,
HostID: nobody,
Size: 1,
@@ -633,11 +638,11 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
Size: 1,
},
}
-
- // Set credentials to run as user and group nobody.
- cmd.SysProcAttr.Credential = &syscall.Credential{Uid: 0, Gid: nobody}
}
+ // Set credentials to run as user and group nobody.
+ cmd.SysProcAttr.Credential = &syscall.Credential{Uid: nobody, Gid: nobody}
+ cmd.SysProcAttr.AmbientCaps = append(cmd.SysProcAttr.AmbientCaps, uintptr(capability.CAP_SYS_ADMIN), uintptr(capability.CAP_SYS_CHROOT))
} else {
return fmt.Errorf("can't run sandbox process as user nobody since we don't have CAP_SETUID or CAP_SETGID")
}
@@ -713,7 +718,19 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
log.Debugf("Starting sandbox: %s %v", binPath, cmd.Args)
log.Debugf("SysProcAttr: %+v", cmd.SysProcAttr)
if err := specutils.StartInNS(cmd, nss); err != nil {
- return fmt.Errorf("Sandbox: %v", err)
+ err := fmt.Errorf("starting sandbox: %v", err)
+ // If the sandbox failed to start, it may be because the binary
+ // permissions were incorrect. Check the bits and return a more helpful
+ // error message.
+ //
+ // NOTE: The error message is checked because error types are lost over
+ // rpc calls.
+ if strings.Contains(err.Error(), syscall.EACCES.Error()) {
+ if permsErr := checkBinaryPermissions(conf); permsErr != nil {
+ return fmt.Errorf("%v: %v", err, permsErr)
+ }
+ }
+ return err
}
s.child = true
s.Pid = cmd.Process.Pid
@@ -1176,3 +1193,31 @@ func deviceFileForPlatform(name string) (*os.File, error) {
}
return f, nil
}
+
+// checkBinaryPermissions verifies that the required binary bits are set on
+// the runsc executable.
+func checkBinaryPermissions(conf *boot.Config) error {
+ // All platforms need the other exe bit
+ neededBits := os.FileMode(0001)
+ if conf.Platform == platforms.Ptrace {
+ // Ptrace needs the other read bit
+ neededBits |= os.FileMode(0004)
+ }
+
+ exePath, err := os.Executable()
+ if err != nil {
+ return fmt.Errorf("getting exe path: %v", err)
+ }
+
+ // Check the permissions of the runsc binary and print an error if it
+ // doesn't match expectations.
+ info, err := os.Stat(exePath)
+ if err != nil {
+ return fmt.Errorf("stat file: %v", err)
+ }
+
+ if info.Mode().Perm()&neededBits != neededBits {
+ return fmt.Errorf(specutils.FaqErrorMsg("runsc-perms", fmt.Sprintf("%s does not have the correct permissions", exePath)))
+ }
+ return nil
+}
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index d3c2e4e78..837d5e238 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -92,6 +92,12 @@ func ValidateSpec(spec *specs.Spec) error {
log.Warningf("AppArmor profile %q is being ignored", spec.Process.ApparmorProfile)
}
+ // PR_SET_NO_NEW_PRIVS is assumed to always be set.
+ // See kernel.Task.updateCredsForExecLocked.
+ if !spec.Process.NoNewPrivileges {
+ log.Warningf("noNewPrivileges ignored. PR_SET_NO_NEW_PRIVS is assumed to always be set.")
+ }
+
// TODO(gvisor.dev/issue/510): Apply seccomp to application inside sandbox.
if spec.Linux != nil && spec.Linux.Seccomp != nil {
log.Warningf("Seccomp spec is being ignored")
@@ -528,3 +534,8 @@ func EnvVar(env []string, name string) (string, bool) {
}
return "", false
}
+
+// FaqErrorMsg returns an error message pointing to the FAQ.
+func FaqErrorMsg(anchor, msg string) string {
+ return fmt.Sprintf("%s; see https://gvisor.dev/faq#%s for more details", msg, anchor)
+}
diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go
index 51e487715..5e09f8f16 100644
--- a/runsc/testutil/testutil.go
+++ b/runsc/testutil/testutil.go
@@ -31,11 +31,13 @@ import (
"os"
"os/exec"
"os/signal"
+ "path"
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"syscall"
+ "testing"
"time"
"github.com/cenkalti/backoff"
@@ -81,17 +83,16 @@ func ConfigureExePath() error {
// TestConfig returns the default configuration to use in tests. Note that
// 'RootDir' must be set by caller if required.
-func TestConfig() *boot.Config {
+func TestConfig(t *testing.T) *boot.Config {
logDir := ""
if dir, ok := os.LookupEnv("TEST_UNDECLARED_OUTPUTS_DIR"); ok {
logDir = dir + "/"
}
return &boot.Config{
Debug: true,
- DebugLog: logDir,
+ DebugLog: path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%"),
LogFormat: "text",
DebugLogFormat: "text",
- AlsoLogToStderr: true,
LogPackets: true,
Network: boot.NetworkNone,
Strace: true,
diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh
index 3fd80fc2e..e0f6df438 100755
--- a/scripts/benchmark.sh
+++ b/scripts/benchmark.sh
@@ -16,12 +16,19 @@
source $(dirname $0)/common.sh
+# gcloud may be installed as a "snap". If it is, include it in PATH.
+declare -r snap="/snap/bin"
+if [[ -d "${snap}" ]]; then
+ export PATH="${PATH}:${snap}"
+fi
+
+# Make sure we can find gcloud and exit if not.
+which gcloud
+
# Exporting for subprocesses as GCP APIs and tools check this environmental
# variable for authentication.
export GOOGLE_APPLICATION_CREDENTIALS="${KOKORO_KEYSTORE_DIR}/${GCLOUD_CREDENTIALS}"
-which gcloud
-
gcloud auth activate-service-account \
--key-file "${GOOGLE_APPLICATION_CREDENTIALS}"
diff --git a/scripts/common.sh b/scripts/common.sh
index 735a383de..bc6ba71e8 100755
--- a/scripts/common.sh
+++ b/scripts/common.sh
@@ -89,12 +89,20 @@ function install_runsc() {
# be correct, otherwise this may result in a loop that spins until time out.
function apt_install() {
while true; do
- if (sudo apt-get update && sudo apt-get install -y "$@"); then
- break
- fi
- result=$?
- if [[ $result -ne 100 ]]; then
- return $result
- fi
+ sudo apt-get update &&
+ sudo apt-get install -y "$@" &&
+ true
+ result="${?}"
+ case $result in
+ 0)
+ break
+ ;;
+ 100)
+ # 100 is the error code that apt-get returns.
+ ;;
+ *)
+ exit $result
+ ;;
+ esac
done
}
diff --git a/scripts/runtime_tests.sh b/scripts/runtime_tests.sh
new file mode 100755
index 000000000..350a59f7c
--- /dev/null
+++ b/scripts/runtime_tests.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Check that a runtime is provided.
+if [ ! -v RUNTIME_TEST_NAME ]; then
+ echo "Must set $RUNTIME_TEST_NAME" >&2
+ exit 1
+fi
+
+install_runsc_for_test runtimes
+test_runsc "//test/runtimes:${RUNTIME_TEST_NAME}_test"
diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD
index fb0b2db41..dfcd55f60 100644
--- a/test/packetdrill/BUILD
+++ b/test/packetdrill/BUILD
@@ -1,4 +1,4 @@
-load("defs.bzl", "packetdrill_linux_test", "packetdrill_netstack_test", "packetdrill_test")
+load("defs.bzl", "packetdrill_test")
package(licenses = ["notice"])
@@ -17,16 +17,6 @@ packetdrill_test(
scripts = ["fin_wait2_timeout.pkt"],
)
-packetdrill_linux_test(
- name = "tcp_user_timeout_test_linux_test",
- scripts = ["linux/tcp_user_timeout.pkt"],
-)
-
-packetdrill_netstack_test(
- name = "tcp_user_timeout_test_netstack_test",
- scripts = ["netstack/tcp_user_timeout.pkt"],
-)
-
packetdrill_test(
name = "listen_close_before_handshake_complete_test",
scripts = ["listen_close_before_handshake_complete.pkt"],
diff --git a/test/packetdrill/linux/tcp_user_timeout.pkt b/test/packetdrill/linux/tcp_user_timeout.pkt
deleted file mode 100644
index 38018cb42..000000000
--- a/test/packetdrill/linux/tcp_user_timeout.pkt
+++ /dev/null
@@ -1,39 +0,0 @@
-// Test that a socket w/ TCP_USER_TIMEOUT set aborts the connection
-// if there is pending unacked data after the user specified timeout.
-
-0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
-+0 bind(3, ..., ...) = 0
-
-+0 listen(3, 1) = 0
-
-// Establish a connection without timestamps.
-+0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
-+0 > S. 0:0(0) ack 1 <...>
-+0.1 < . 1:1(0) ack 1 win 32792
-
-+0.100 accept(3, ..., ...) = 4
-
-// Okay, we received nothing, and decide to close this idle socket.
-// We set TCP_USER_TIMEOUT to 3 seconds because really it is not worth
-// trying hard to cleanly close this flow, at the price of keeping
-// a TCP structure in kernel for about 1 minute!
-+2 setsockopt(4, SOL_TCP, TCP_USER_TIMEOUT, [3000], 4) = 0
-
-// The write/ack is required mainly for netstack as netstack does
-// not update its RTO during the handshake.
-+0 write(4, ..., 100) = 100
-+0 > P. 1:101(100) ack 1 <...>
-+0 < . 1:1(0) ack 101 win 32792
-
-+0 close(4) = 0
-
-+0 > F. 101:101(0) ack 1 <...>
-+.3~+.400 > F. 101:101(0) ack 1 <...>
-+.3~+.400 > F. 101:101(0) ack 1 <...>
-+.6~+.800 > F. 101:101(0) ack 1 <...>
-+1.2~+1.300 > F. 101:101(0) ack 1 <...>
-
-// We finally receive something from the peer, but it is way too late
-// Our socket vanished because TCP_USER_TIMEOUT was really small.
-+.1 < . 1:2(1) ack 102 win 32792
-+0 > R 102:102(0) win 0
diff --git a/test/packetdrill/netstack/tcp_user_timeout.pkt b/test/packetdrill/netstack/tcp_user_timeout.pkt
deleted file mode 100644
index 60103adba..000000000
--- a/test/packetdrill/netstack/tcp_user_timeout.pkt
+++ /dev/null
@@ -1,38 +0,0 @@
-// Test that a socket w/ TCP_USER_TIMEOUT set aborts the connection
-// if there is pending unacked data after the user specified timeout.
-
-0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
-+0 bind(3, ..., ...) = 0
-
-+0 listen(3, 1) = 0
-
-// Establish a connection without timestamps.
-+0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
-+0 > S. 0:0(0) ack 1 <...>
-+0.1 < . 1:1(0) ack 1 win 32792
-
-+0.100 accept(3, ..., ...) = 4
-
-// Okay, we received nothing, and decide to close this idle socket.
-// We set TCP_USER_TIMEOUT to 3 seconds because really it is not worth
-// trying hard to cleanly close this flow, at the price of keeping
-// a TCP structure in kernel for about 1 minute!
-+2 setsockopt(4, SOL_TCP, TCP_USER_TIMEOUT, [3000], 4) = 0
-
-// The write/ack is required mainly for netstack as netstack does
-// not update its RTO during the handshake.
-+0 write(4, ..., 100) = 100
-+0 > P. 1:101(100) ack 1 <...>
-+0 < . 1:1(0) ack 101 win 32792
-
-+0 close(4) = 0
-
-+0 > F. 101:101(0) ack 1 <...>
-+.2~+.300 > F. 101:101(0) ack 1 <...>
-+.4~+.500 > F. 101:101(0) ack 1 <...>
-+.8~+.900 > F. 101:101(0) ack 1 <...>
-
-// We finally receive something from the peer, but it is way too late
-// Our socket vanished because TCP_USER_TIMEOUT was really small.
-+1.61 < . 1:2(1) ack 102 win 32792
-+0 > R 102:102(0) win 0
diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc
index 2f10dda40..86e580c6f 100644
--- a/test/packetimpact/dut/posix_server.cc
+++ b/test/packetimpact/dut/posix_server.cc
@@ -61,13 +61,15 @@
}
class PosixImpl final : public posix_server::Posix::Service {
- ::grpc::Status Socket(grpc_impl::ServerContext *context,
- const ::posix_server::SocketRequest *request,
- ::posix_server::SocketResponse *response) override {
- response->set_fd(
- socket(request->domain(), request->type(), request->protocol()));
+ ::grpc::Status Accept(grpc_impl::ServerContext *context,
+ const ::posix_server::AcceptRequest *request,
+ ::posix_server::AcceptResponse *response) override {
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ response->set_fd(accept(request->sockfd(),
+ reinterpret_cast<sockaddr *>(&addr), &addrlen));
response->set_errno_(errno);
- return ::grpc::Status::OK;
+ return sockaddr_to_proto(addr, addrlen, response->mutable_addr());
}
::grpc::Status Bind(grpc_impl::ServerContext *context,
@@ -119,6 +121,14 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status::OK;
}
+ ::grpc::Status Close(grpc_impl::ServerContext *context,
+ const ::posix_server::CloseRequest *request,
+ ::posix_server::CloseResponse *response) override {
+ response->set_ret(close(request->fd()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
::grpc::Status GetSockName(
grpc_impl::ServerContext *context,
const ::posix_server::GetSockNameRequest *request,
@@ -139,15 +149,13 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status::OK;
}
- ::grpc::Status Accept(grpc_impl::ServerContext *context,
- const ::posix_server::AcceptRequest *request,
- ::posix_server::AcceptResponse *response) override {
- sockaddr_storage addr;
- socklen_t addrlen = sizeof(addr);
- response->set_fd(accept(request->sockfd(),
- reinterpret_cast<sockaddr *>(&addr), &addrlen));
+ ::grpc::Status Send(::grpc::ServerContext *context,
+ const ::posix_server::SendRequest *request,
+ ::posix_server::SendResponse *response) override {
+ response->set_ret(::send(request->sockfd(), request->buf().data(),
+ request->buf().size(), request->flags()));
response->set_errno_(errno);
- return sockaddr_to_proto(addr, addrlen, response->mutable_addr());
+ return ::grpc::Status::OK;
}
::grpc::Status SetSockOpt(
@@ -161,6 +169,17 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status::OK;
}
+ ::grpc::Status SetSockOptInt(
+ ::grpc::ServerContext *context,
+ const ::posix_server::SetSockOptIntRequest *request,
+ ::posix_server::SetSockOptIntResponse *response) override {
+ int opt = request->intval();
+ response->set_ret(::setsockopt(request->sockfd(), request->level(),
+ request->optname(), &opt, sizeof(opt)));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
::grpc::Status SetSockOptTimeval(
::grpc::ServerContext *context,
const ::posix_server::SetSockOptTimevalRequest *request,
@@ -174,11 +193,23 @@ class PosixImpl final : public posix_server::Posix::Service {
return ::grpc::Status::OK;
}
- ::grpc::Status Close(grpc_impl::ServerContext *context,
- const ::posix_server::CloseRequest *request,
- ::posix_server::CloseResponse *response) override {
- response->set_ret(close(request->fd()));
+ ::grpc::Status Socket(grpc_impl::ServerContext *context,
+ const ::posix_server::SocketRequest *request,
+ ::posix_server::SocketResponse *response) override {
+ response->set_fd(
+ socket(request->domain(), request->type(), request->protocol()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Recv(::grpc::ServerContext *context,
+ const ::posix_server::RecvRequest *request,
+ ::posix_server::RecvResponse *response) override {
+ std::vector<char> buf(request->len());
+ response->set_ret(
+ recv(request->sockfd(), buf.data(), buf.size(), request->flags()));
response->set_errno_(errno);
+ response->set_buf(buf.data(), response->ret());
return ::grpc::Status::OK;
}
};
diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto
index 026876fc2..4035e1ee6 100644
--- a/test/packetimpact/proto/posix_server.proto
+++ b/test/packetimpact/proto/posix_server.proto
@@ -16,17 +16,6 @@ syntax = "proto3";
package posix_server;
-message SocketRequest {
- int32 domain = 1;
- int32 type = 2;
- int32 protocol = 3;
-}
-
-message SocketResponse {
- int32 fd = 1;
- int32 errno_ = 2;
-}
-
message SockaddrIn {
int32 family = 1;
uint32 port = 2;
@@ -48,6 +37,23 @@ message Sockaddr {
}
}
+message Timeval {
+ int64 seconds = 1;
+ int64 microseconds = 2;
+}
+
+// Request and Response pairs for each Posix service RPC call, sorted.
+
+message AcceptRequest {
+ int32 sockfd = 1;
+}
+
+message AcceptResponse {
+ int32 fd = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ Sockaddr addr = 3;
+}
+
message BindRequest {
int32 sockfd = 1;
Sockaddr addr = 2;
@@ -55,7 +61,16 @@ message BindRequest {
message BindResponse {
int32 ret = 1;
- int32 errno_ = 2;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message CloseRequest {
+ int32 fd = 1;
+}
+
+message CloseResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
}
message GetSockNameRequest {
@@ -64,7 +79,7 @@ message GetSockNameRequest {
message GetSockNameResponse {
int32 ret = 1;
- int32 errno_ = 2;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
Sockaddr addr = 3;
}
@@ -75,17 +90,18 @@ message ListenRequest {
message ListenResponse {
int32 ret = 1;
- int32 errno_ = 2;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
}
-message AcceptRequest {
+message SendRequest {
int32 sockfd = 1;
+ bytes buf = 2;
+ int32 flags = 3;
}
-message AcceptResponse {
- int32 fd = 1;
+message SendResponse {
+ int32 ret = 1;
int32 errno_ = 2;
- Sockaddr addr = 3;
}
message SetSockOptRequest {
@@ -97,12 +113,19 @@ message SetSockOptRequest {
message SetSockOptResponse {
int32 ret = 1;
- int32 errno_ = 2;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
}
-message Timeval {
- int64 seconds = 1;
- int64 microseconds = 2;
+message SetSockOptIntRequest {
+ int32 sockfd = 1;
+ int32 level = 2;
+ int32 optname = 3;
+ int32 intval = 4;
+}
+
+message SetSockOptIntResponse {
+ int32 ret = 1;
+ int32 errno_ = 2;
}
message SetSockOptTimevalRequest {
@@ -114,37 +137,57 @@ message SetSockOptTimevalRequest {
message SetSockOptTimevalResponse {
int32 ret = 1;
- int32 errno_ = 2;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
}
-message CloseRequest {
+message SocketRequest {
+ int32 domain = 1;
+ int32 type = 2;
+ int32 protocol = 3;
+}
+
+message SocketResponse {
int32 fd = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
}
-message CloseResponse {
+message RecvRequest {
+ int32 sockfd = 1;
+ int32 len = 2;
+ int32 flags = 3;
+}
+
+message RecvResponse {
int32 ret = 1;
- int32 errno_ = 2;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ bytes buf = 3;
}
service Posix {
- // Call socket() on the DUT.
- rpc Socket(SocketRequest) returns (SocketResponse);
+ // Call accept() on the DUT.
+ rpc Accept(AcceptRequest) returns (AcceptResponse);
// Call bind() on the DUT.
rpc Bind(BindRequest) returns (BindResponse);
+ // Call close() on the DUT.
+ rpc Close(CloseRequest) returns (CloseResponse);
// Call getsockname() on the DUT.
rpc GetSockName(GetSockNameRequest) returns (GetSockNameResponse);
// Call listen() on the DUT.
rpc Listen(ListenRequest) returns (ListenResponse);
- // Call accept() on the DUT.
- rpc Accept(AcceptRequest) returns (AcceptResponse);
+ // Call send() on the DUT.
+ rpc Send(SendRequest) returns (SendResponse);
// Call setsockopt() on the DUT. You should prefer one of the other
// SetSockOpt* functions with a more structured optval or else you may get the
// encoding wrong, such as making a bad assumption about the server's word
// sizes or endianness.
rpc SetSockOpt(SetSockOptRequest) returns (SetSockOptResponse);
+ // Call setsockopt() on the DUT with an int optval.
+ rpc SetSockOptInt(SetSockOptIntRequest) returns (SetSockOptIntResponse);
// Call setsockopt() on the DUT with a Timeval optval.
rpc SetSockOptTimeval(SetSockOptTimevalRequest)
returns (SetSockOptTimevalResponse);
- // Call close() on the DUT.
- rpc Close(CloseRequest) returns (CloseResponse);
+ // Call socket() on the DUT.
+ rpc Socket(SocketRequest) returns (SocketResponse);
+ // Call recv() on the DUT.
+ rpc Recv(RecvRequest) returns (RecvResponse);
}
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
index a34c81fcc..b6a254882 100644
--- a/test/packetimpact/testbench/BUILD
+++ b/test/packetimpact/testbench/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(
default_visibility = ["//test/packetimpact:__subpackages__"],
@@ -16,6 +16,7 @@ go_library(
],
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
"//pkg/usermem",
@@ -27,5 +28,14 @@ go_library(
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//keepalive:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
+ "@org_uber_go_multierr//:go_default_library",
],
)
+
+go_test(
+ name = "testbench_test",
+ size = "small",
+ srcs = ["layers_test.go"],
+ library = ":testbench",
+ deps = ["//pkg/tcpip"],
+)
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index b7aa63934..f84fd8ba7 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -21,10 +21,12 @@ import (
"fmt"
"math/rand"
"net"
+ "strings"
"testing"
"time"
"github.com/mohae/deepcopy"
+ "go.uber.org/multierr"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -36,19 +38,6 @@ var remoteIPv4 = flag.String("remote_ipv4", "", "remote IPv4 address for test pa
var localMAC = flag.String("local_mac", "", "local mac address for test packets")
var remoteMAC = flag.String("remote_mac", "", "remote mac address for test packets")
-// TCPIPv4 maintains state about a TCP/IPv4 connection.
-type TCPIPv4 struct {
- outgoing Layers
- incoming Layers
- LocalSeqNum seqnum.Value
- RemoteSeqNum seqnum.Value
- SynAck *TCP
- sniffer Sniffer
- injector Injector
- portPickerFD int
- t *testing.T
-}
-
// pickPort makes a new socket and returns the socket FD and port. The caller
// must close the FD when done with the port if there is no error.
func pickPort() (int, uint16, error) {
@@ -75,171 +64,607 @@ func pickPort() (int, uint16, error) {
return fd, uint16(newSockAddrInet4.Port), nil
}
-// tcpLayerIndex is the position of the TCP layer in the TCPIPv4 connection. It
-// is the third, after Ethernet and IPv4.
-const tcpLayerIndex int = 2
+// layerState stores the state of a layer of a connection.
+type layerState interface {
+ // outgoing returns an outgoing layer to be sent in a frame.
+ outgoing() Layer
-// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
-func NewTCPIPv4(t *testing.T, dut DUT, outgoingTCP, incomingTCP TCP) TCPIPv4 {
+ // incoming creates an expected Layer for comparing against a received Layer.
+ // Because the expectation can depend on values in the received Layer, it is
+ // an input to incoming. For example, the ACK number needs to be checked in a
+ // TCP packet but only if the ACK flag is set in the received packet.
+ incoming(received Layer) Layer
+
+ // sent updates the layerState based on the Layer that was sent. The input is
+ // a Layer with all prev and next pointers populated so that the entire frame
+ // as it was sent is available.
+ sent(sent Layer) error
+
+ // received updates the layerState based on a Layer that is receieved. The
+ // input is a Layer with all prev and next pointers populated so that the
+ // entire frame as it was receieved is available.
+ received(received Layer) error
+
+ // close frees associated resources held by the LayerState.
+ close() error
+}
+
+// etherState maintains state about an Ethernet connection.
+type etherState struct {
+ out, in Ether
+}
+
+var _ layerState = (*etherState)(nil)
+
+// newEtherState creates a new etherState.
+func newEtherState(out, in Ether) (*etherState, error) {
lMAC, err := tcpip.ParseMACAddress(*localMAC)
if err != nil {
- t.Fatalf("can't parse localMAC %q: %s", *localMAC, err)
+ return nil, err
}
rMAC, err := tcpip.ParseMACAddress(*remoteMAC)
if err != nil {
- t.Fatalf("can't parse remoteMAC %q: %s", *remoteMAC, err)
+ return nil, err
}
-
- portPickerFD, localPort, err := pickPort()
- if err != nil {
- t.Fatalf("can't pick a port: %s", err)
+ s := etherState{
+ out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
+ in: Ether{SrcAddr: &rMAC, DstAddr: &lMAC},
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
}
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *etherState) outgoing() Layer {
+ return &s.out
+}
+
+func (s *etherState) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*etherState) sent(Layer) error {
+ return nil
+}
+
+func (*etherState) received(Layer) error {
+ return nil
+}
+
+func (*etherState) close() error {
+ return nil
+}
+
+// ipv4State maintains state about an IPv4 connection.
+type ipv4State struct {
+ out, in IPv4
+}
+
+var _ layerState = (*ipv4State)(nil)
+
+// newIPv4State creates a new ipv4State.
+func newIPv4State(out, in IPv4) (*ipv4State, error) {
lIP := tcpip.Address(net.ParseIP(*localIPv4).To4())
rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4())
+ s := ipv4State{
+ out: IPv4{SrcAddr: &lIP, DstAddr: &rIP},
+ in: IPv4{SrcAddr: &rIP, DstAddr: &lIP},
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
- sniffer, err := NewSniffer(t)
+func (s *ipv4State) outgoing() Layer {
+ return &s.out
+}
+
+func (s *ipv4State) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*ipv4State) sent(Layer) error {
+ return nil
+}
+
+func (*ipv4State) received(Layer) error {
+ return nil
+}
+
+func (*ipv4State) close() error {
+ return nil
+}
+
+// tcpState maintains state about a TCP connection.
+type tcpState struct {
+ out, in TCP
+ localSeqNum, remoteSeqNum *seqnum.Value
+ synAck *TCP
+ portPickerFD int
+ finSent bool
+}
+
+var _ layerState = (*tcpState)(nil)
+
+// SeqNumValue is a helper routine that allocates a new seqnum.Value value to
+// store v and returns a pointer to it.
+func SeqNumValue(v seqnum.Value) *seqnum.Value {
+ return &v
+}
+
+// newTCPState creates a new TCPState.
+func newTCPState(out, in TCP) (*tcpState, error) {
+ portPickerFD, localPort, err := pickPort()
if err != nil {
- t.Fatalf("can't make new sniffer: %s", err)
+ return nil, err
+ }
+ s := tcpState{
+ out: TCP{SrcPort: &localPort},
+ in: TCP{DstPort: &localPort},
+ localSeqNum: SeqNumValue(seqnum.Value(rand.Uint32())),
+ portPickerFD: portPickerFD,
+ finSent: false,
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
}
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
- injector, err := NewInjector(t)
- if err != nil {
- t.Fatalf("can't make new injector: %s", err)
+func (s *tcpState) outgoing() Layer {
+ newOutgoing := deepcopy.Copy(s.out).(TCP)
+ if s.localSeqNum != nil {
+ newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum))
+ }
+ if s.remoteSeqNum != nil {
+ newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum))
+ }
+ return &newOutgoing
+}
+
+func (s *tcpState) incoming(received Layer) Layer {
+ tcpReceived, ok := received.(*TCP)
+ if !ok {
+ return nil
+ }
+ newIn := deepcopy.Copy(s.in).(TCP)
+ if s.remoteSeqNum != nil {
+ newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum))
+ }
+ if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 {
+ // The caller didn't specify an AckNum so we'll expect the calculated one,
+ // but only if the ACK flag is set because the AckNum is not valid in a
+ // header if ACK is not set.
+ newIn.AckNum = Uint32(uint32(*s.localSeqNum))
}
+ return &newIn
+}
- newOutgoingTCP := &TCP{
- DataOffset: Uint8(header.TCPMinimumSize),
- WindowSize: Uint16(32768),
- SrcPort: &localPort,
+func (s *tcpState) sent(sent Layer) error {
+ tcp, ok := sent.(*TCP)
+ if !ok {
+ return fmt.Errorf("can't update tcpState with %T Layer", sent)
+ }
+ if !s.finSent {
+ // update localSeqNum by the payload only when FIN is not yet sent by us
+ for current := tcp.next(); current != nil; current = current.next() {
+ s.localSeqNum.UpdateForward(seqnum.Size(current.length()))
+ }
}
- if err := newOutgoingTCP.merge(outgoingTCP); err != nil {
- t.Fatalf("can't merge %v into %v: %s", outgoingTCP, newOutgoingTCP, err)
+ if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
+ s.localSeqNum.UpdateForward(1)
}
- newIncomingTCP := &TCP{
- DstPort: &localPort,
+ if *tcp.Flags&(header.TCPFlagFin) != 0 {
+ s.finSent = true
}
- if err := newIncomingTCP.merge(incomingTCP); err != nil {
- t.Fatalf("can't merge %v into %v: %s", incomingTCP, newIncomingTCP, err)
+ return nil
+}
+
+func (s *tcpState) received(l Layer) error {
+ tcp, ok := l.(*TCP)
+ if !ok {
+ return fmt.Errorf("can't update tcpState with %T Layer", l)
}
- return TCPIPv4{
- outgoing: Layers{
- &Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
- &IPv4{SrcAddr: &lIP, DstAddr: &rIP},
- newOutgoingTCP},
- incoming: Layers{
- &Ether{SrcAddr: &rMAC, DstAddr: &lMAC},
- &IPv4{SrcAddr: &rIP, DstAddr: &lIP},
- newIncomingTCP},
- sniffer: sniffer,
- injector: injector,
+ s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum))
+ if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
+ s.remoteSeqNum.UpdateForward(1)
+ }
+ for current := tcp.next(); current != nil; current = current.next() {
+ s.remoteSeqNum.UpdateForward(seqnum.Size(current.length()))
+ }
+ return nil
+}
+
+// close frees the port associated with this connection.
+func (s *tcpState) close() error {
+ if err := unix.Close(s.portPickerFD); err != nil {
+ return err
+ }
+ s.portPickerFD = -1
+ return nil
+}
+
+// udpState maintains state about a UDP connection.
+type udpState struct {
+ out, in UDP
+ portPickerFD int
+}
+
+var _ layerState = (*udpState)(nil)
+
+// newUDPState creates a new udpState.
+func newUDPState(out, in UDP) (*udpState, error) {
+ portPickerFD, localPort, err := pickPort()
+ if err != nil {
+ return nil, err
+ }
+ s := udpState{
+ out: UDP{SrcPort: &localPort},
+ in: UDP{DstPort: &localPort},
portPickerFD: portPickerFD,
- t: t,
- LocalSeqNum: seqnum.Value(rand.Uint32()),
}
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
}
-// Close the injector and sniffer associated with this connection.
-func (conn *TCPIPv4) Close() {
- conn.sniffer.Close()
- conn.injector.Close()
- if err := unix.Close(conn.portPickerFD); err != nil {
- conn.t.Fatalf("can't close portPickerFD: %s", err)
+func (s *udpState) outgoing() Layer {
+ return &s.out
+}
+
+func (s *udpState) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*udpState) sent(l Layer) error {
+ return nil
+}
+
+func (*udpState) received(l Layer) error {
+ return nil
+}
+
+// close frees the port associated with this connection.
+func (s *udpState) close() error {
+ if err := unix.Close(s.portPickerFD); err != nil {
+ return err
}
- conn.portPickerFD = -1
+ s.portPickerFD = -1
+ return nil
}
-// Send a packet with reasonable defaults and override some fields by tcp.
-func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
- if tcp.SeqNum == nil {
- tcp.SeqNum = Uint32(uint32(conn.LocalSeqNum))
+// Connection holds a collection of layer states for maintaining a connection
+// along with sockets for sniffer and injecting packets.
+type Connection struct {
+ layerStates []layerState
+ injector Injector
+ sniffer Sniffer
+ t *testing.T
+}
+
+// match tries to match each Layer in received against the incoming filter. If
+// received is longer than layerStates then that may still count as a match. The
+// reverse is never a match. override overrides the default matchers for each
+// Layer.
+func (conn *Connection) match(override, received Layers) bool {
+ if len(received) < len(conn.layerStates) {
+ return false
+ }
+ for i, s := range conn.layerStates {
+ toMatch := s.incoming(received[i])
+ if toMatch == nil {
+ return false
+ }
+ if i < len(override) {
+ toMatch.merge(override[i])
+ }
+ if !toMatch.match(received[i]) {
+ return false
+ }
+ }
+ return true
+}
+
+// Close frees associated resources held by the Connection.
+func (conn *Connection) Close() {
+ errs := multierr.Combine(conn.sniffer.close(), conn.injector.close())
+ for _, s := range conn.layerStates {
+ if err := s.close(); err != nil {
+ errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err))
+ }
}
- if tcp.AckNum == nil {
- tcp.AckNum = Uint32(uint32(conn.RemoteSeqNum))
+ if errs != nil {
+ conn.t.Fatalf("unable to close %+v: %s", conn, errs)
}
- layersToSend := deepcopy.Copy(conn.outgoing).(Layers)
- if err := layersToSend[tcpLayerIndex].(*TCP).merge(tcp); err != nil {
- conn.t.Fatalf("can't merge %v into %v: %s", tcp, layersToSend[tcpLayerIndex], err)
+}
+
+// CreateFrame builds a frame for the connection with layer overriding defaults
+// of the innermost layer and additionalLayers added after it.
+func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
+ var layersToSend Layers
+ for _, s := range conn.layerStates {
+ layersToSend = append(layersToSend, s.outgoing())
+ }
+ if err := layersToSend[len(layersToSend)-1].merge(layer); err != nil {
+ conn.t.Fatalf("can't merge %+v into %+v: %s", layer, layersToSend[len(layersToSend)-1], err)
}
layersToSend = append(layersToSend, additionalLayers...)
- outBytes, err := layersToSend.toBytes()
+ return layersToSend
+}
+
+// SendFrame sends a frame on the wire and updates the state of all layers.
+func (conn *Connection) SendFrame(frame Layers) {
+ outBytes, err := frame.toBytes()
if err != nil {
conn.t.Fatalf("can't build outgoing TCP packet: %s", err)
}
conn.injector.Send(outBytes)
- // Compute the next TCP sequence number.
- for i := tcpLayerIndex + 1; i < len(layersToSend); i++ {
- conn.LocalSeqNum.UpdateForward(seqnum.Size(layersToSend[i].length()))
+ // frame might have nil values where the caller wanted to use default values.
+ // sentFrame will have no nil values in it because it comes from parsing the
+ // bytes that were actually sent.
+ sentFrame := parse(parseEther, outBytes)
+ // Update the state of each layer based on what was sent.
+ for i, s := range conn.layerStates {
+ if err := s.sent(sentFrame[i]); err != nil {
+ conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err)
+ }
}
- if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
- conn.LocalSeqNum.UpdateForward(1)
+}
+
+// Send a packet with reasonable defaults. Potentially override the final layer
+// in the connection with the provided layer and add additionLayers.
+func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) {
+ conn.SendFrame(conn.CreateFrame(layer, additionalLayers...))
+}
+
+// recvFrame gets the next successfully parsed frame (of type Layers) within the
+// timeout provided. If no parsable frame arrives before the timeout, it returns
+// nil.
+func (conn *Connection) recvFrame(timeout time.Duration) Layers {
+ if timeout <= 0 {
+ return nil
}
+ b := conn.sniffer.Recv(timeout)
+ if b == nil {
+ return nil
+ }
+ return parse(parseEther, b)
}
-// Recv gets a packet from the sniffer within the timeout provided. If no packet
-// arrives before the timeout, it returns nil.
-func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP {
- deadline := time.Now().Add(timeout)
- for {
- timeout = deadline.Sub(time.Now())
- if timeout <= 0 {
- break
- }
- b := conn.sniffer.Recv(timeout)
- if b == nil {
- break
- }
- layers, err := ParseEther(b)
- if err != nil {
- continue // Ignore packets that can't be parsed.
- }
- if !conn.incoming.match(layers) {
- continue // Ignore packets that don't match the expected incoming.
- }
- tcpHeader := (layers[tcpLayerIndex]).(*TCP)
- conn.RemoteSeqNum = seqnum.Value(*tcpHeader.SeqNum)
- if *tcpHeader.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
- conn.RemoteSeqNum.UpdateForward(1)
- }
- for i := tcpLayerIndex + 1; i < len(layers); i++ {
- conn.RemoteSeqNum.UpdateForward(seqnum.Size(layers[i].length()))
- }
- return tcpHeader
+// Expect a frame with the final layerStates layer matching the provided Layer
+// within the timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) {
+ // Make a frame that will ignore all but the final layer.
+ layers := make([]Layer, len(conn.layerStates))
+ layers[len(layers)-1] = layer
+
+ gotFrame, err := conn.ExpectFrame(layers, timeout)
+ if err != nil {
+ return nil, err
}
- return nil
+ if len(conn.layerStates)-1 < len(gotFrame) {
+ return gotFrame[len(conn.layerStates)-1], nil
+ }
+ conn.t.Fatal("the received frame should be at least as long as the expected layers")
+ return nil, fmt.Errorf("the received frame should be at least as long as the expected layers")
}
-// Expect a packet that matches the provided tcp within the timeout specified.
-// If it doesn't arrive in time, the test fails.
-func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) *TCP {
+// ExpectFrame expects a frame that matches the provided Layers within the
+// timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) {
deadline := time.Now().Add(timeout)
+ var allLayers []string
for {
- timeout = deadline.Sub(time.Now())
- if timeout <= 0 {
- return nil
+ var gotLayers Layers
+ if timeout = time.Until(deadline); timeout > 0 {
+ gotLayers = conn.recvFrame(timeout)
}
- gotTCP := conn.Recv(timeout)
- if gotTCP == nil {
- return nil
+ if gotLayers == nil {
+ return nil, fmt.Errorf("got %d packets:\n%s", len(allLayers), strings.Join(allLayers, "\n"))
}
- if tcp.match(gotTCP) {
- return gotTCP
+ if conn.match(layers, gotLayers) {
+ for i, s := range conn.layerStates {
+ if err := s.received(gotLayers[i]); err != nil {
+ conn.t.Fatal(err)
+ }
+ }
+ return gotLayers, nil
}
+ allLayers = append(allLayers, fmt.Sprintf("%s", gotLayers))
+ }
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *Connection) Drain() {
+ conn.sniffer.Drain()
+}
+
+// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection.
+type TCPIPv4 Connection
+
+// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
+func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv4State, err := newIPv4State(IPv4{}, IPv4{})
+ if err != nil {
+ t.Fatalf("can't make ipv4State: %s", err)
+ }
+ tcpState, err := newTCPState(outgoingTCP, incomingTCP)
+ if err != nil {
+ t.Fatalf("can't make tcpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return TCPIPv4{
+ layerStates: []layerState{etherState, ipv4State, tcpState},
+ injector: injector,
+ sniffer: sniffer,
+ t: t,
}
}
-// Handshake performs a TCP 3-way handshake.
+// Handshake performs a TCP 3-way handshake. The input Connection should have a
+// final TCP Layer.
func (conn *TCPIPv4) Handshake() {
// Send the SYN.
conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)})
// Wait for the SYN-ACK.
- conn.SynAck = conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
- if conn.SynAck == nil {
- conn.t.Fatalf("didn't get synack during handshake")
+ synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if synAck == nil {
+ conn.t.Fatalf("didn't get synack during handshake: %s", err)
}
+ conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
// Send an ACK.
conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)})
}
+
+// ExpectData is a convenient method that expects a Layer and the Layer after
+// it. If it doens't arrive in time, it returns nil.
+func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = tcp
+ if payload != nil {
+ expected = append(expected, payload)
+ }
+ return (*Connection)(conn).ExpectFrame(expected, timeout)
+}
+
+// Send a packet with reasonable defaults. Potentially override the TCP layer in
+// the connection with the provided layer and add additionLayers.
+func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
+ (*Connection)(conn).Send(&tcp, additionalLayers...)
+}
+
+// Close frees associated resources held by the TCPIPv4 connection.
+func (conn *TCPIPv4) Close() {
+ (*Connection)(conn).Close()
+}
+
+// Expect a frame with the TCP layer matching the provided TCP within the
+// timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) {
+ layer, err := (*Connection)(conn).Expect(&tcp, timeout)
+ if layer == nil {
+ return nil, err
+ }
+ gotTCP, ok := layer.(*TCP)
+ if !ok {
+ conn.t.Fatalf("expected %s to be TCP", layer)
+ }
+ return gotTCP, err
+}
+
+func (conn *TCPIPv4) state() *tcpState {
+ state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState)
+ if !ok {
+ conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates)
+ }
+ return state
+}
+
+// RemoteSeqNum returns the next expected sequence number from the DUT.
+func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value {
+ return conn.state().remoteSeqNum
+}
+
+// LocalSeqNum returns the next sequence number to send from the testbench.
+func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value {
+ return conn.state().localSeqNum
+}
+
+// SynAck returns the SynAck that was part of the handshake.
+func (conn *TCPIPv4) SynAck() *TCP {
+ return conn.state().synAck
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *TCPIPv4) Drain() {
+ conn.sniffer.Drain()
+}
+
+// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection.
+type UDPIPv4 Connection
+
+// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults.
+func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv4State, err := newIPv4State(IPv4{}, IPv4{})
+ if err != nil {
+ t.Fatalf("can't make ipv4State: %s", err)
+ }
+ tcpState, err := newUDPState(outgoingUDP, incomingUDP)
+ if err != nil {
+ t.Fatalf("can't make udpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return UDPIPv4{
+ layerStates: []layerState{etherState, ipv4State, tcpState},
+ injector: injector,
+ sniffer: sniffer,
+ t: t,
+ }
+}
+
+// CreateFrame builds a frame for the connection with layer overriding defaults
+// of the innermost layer and additionalLayers added after it.
+func (conn *UDPIPv4) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {
+ return (*Connection)(conn).CreateFrame(layer, additionalLayers...)
+}
+
+// SendFrame sends a frame on the wire and updates the state of all layers.
+func (conn *UDPIPv4) SendFrame(frame Layers) {
+ (*Connection)(conn).SendFrame(frame)
+}
+
+// Close frees associated resources held by the UDPIPv4 connection.
+func (conn *UDPIPv4) Close() {
+ (*Connection)(conn).Close()
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *UDPIPv4) Drain() {
+ conn.sniffer.Drain()
+}
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
index 8ea1706d3..9335909c0 100644
--- a/test/packetimpact/testbench/dut.go
+++ b/test/packetimpact/testbench/dut.go
@@ -65,33 +65,6 @@ func (dut *DUT) TearDown() {
dut.conn.Close()
}
-// SocketWithErrno calls socket on the DUT and returns the fd and errno.
-func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
- dut.t.Helper()
- req := pb.SocketRequest{
- Domain: domain,
- Type: typ,
- Protocol: proto,
- }
- ctx := context.Background()
- resp, err := dut.posixServer.Socket(ctx, &req)
- if err != nil {
- dut.t.Fatalf("failed to call Socket: %s", err)
- }
- return resp.GetFd(), syscall.Errno(resp.GetErrno_())
-}
-
-// Socket calls socket on the DUT and returns the file descriptor. If socket
-// fails on the DUT, the test ends.
-func (dut *DUT) Socket(domain, typ, proto int32) int32 {
- dut.t.Helper()
- fd, err := dut.SocketWithErrno(domain, typ, proto)
- if fd < 0 {
- dut.t.Fatalf("failed to create socket: %s", err)
- }
- return fd
-}
-
func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr {
dut.t.Helper()
switch s := sa.(type) {
@@ -142,14 +115,95 @@ func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr {
return nil
}
+// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol
+// proto, and bound to the IP address addr. Returns the new file descriptor and
+// the port that was selected on the DUT.
+func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) {
+ dut.t.Helper()
+ var fd int32
+ if addr.To4() != nil {
+ fd = dut.Socket(unix.AF_INET, typ, proto)
+ sa := unix.SockaddrInet4{}
+ copy(sa.Addr[:], addr.To4())
+ dut.Bind(fd, &sa)
+ } else if addr.To16() != nil {
+ fd = dut.Socket(unix.AF_INET6, typ, proto)
+ sa := unix.SockaddrInet6{}
+ copy(sa.Addr[:], addr.To16())
+ dut.Bind(fd, &sa)
+ } else {
+ dut.t.Fatal("unknown ip addr type for remoteIP")
+ }
+ sa := dut.GetSockName(fd)
+ var port int
+ switch s := sa.(type) {
+ case *unix.SockaddrInet4:
+ port = s.Port
+ case *unix.SockaddrInet6:
+ port = s.Port
+ default:
+ dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa)
+ }
+ return fd, uint16(port)
+}
+
+// CreateListener makes a new TCP connection. If it fails, the test ends.
+func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
+ fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(*remoteIPv4))
+ dut.Listen(fd, backlog)
+ return fd, remotePort
+}
+
+// All the functions that make gRPC calls to the Posix service are below, sorted
+// alphabetically.
+
+// Accept calls accept on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// AcceptWithErrno.
+func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ fd, sa, err := dut.AcceptWithErrno(ctx, sockfd)
+ if fd < 0 {
+ dut.t.Fatalf("failed to accept: %s", err)
+ }
+ return fd, sa
+}
+
+// AcceptWithErrno calls accept on the DUT.
+func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
+ dut.t.Helper()
+ req := pb.AcceptRequest{
+ Sockfd: sockfd,
+ }
+ resp, err := dut.posixServer.Accept(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Accept: %s", err)
+ }
+ return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+}
+
+// Bind calls bind on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is
+// needed, use BindWithErrno.
+func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.BindWithErrno(ctx, fd, sa)
+ if ret != 0 {
+ dut.t.Fatalf("failed to bind socket: %s", err)
+ }
+}
+
// BindWithErrno calls bind on the DUT.
-func (dut *DUT) BindWithErrno(fd int32, sa unix.Sockaddr) (int32, error) {
+func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) {
dut.t.Helper()
req := pb.BindRequest{
Sockfd: fd,
Addr: dut.sockaddrToProto(sa),
}
- ctx := context.Background()
resp, err := dut.posixServer.Bind(ctx, &req)
if err != nil {
dut.t.Fatalf("failed to call Bind: %s", err)
@@ -157,23 +211,52 @@ func (dut *DUT) BindWithErrno(fd int32, sa unix.Sockaddr) (int32, error) {
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// Bind calls bind on the DUT and causes a fatal test failure if it doesn't
-// succeed.
-func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) {
+// Close calls close on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// CloseWithErrno.
+func (dut *DUT) Close(fd int32) {
dut.t.Helper()
- ret, err := dut.BindWithErrno(fd, sa)
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.CloseWithErrno(ctx, fd)
if ret != 0 {
- dut.t.Fatalf("failed to bind socket: %s", err)
+ dut.t.Fatalf("failed to close: %s", err)
}
}
+// CloseWithErrno calls close on the DUT.
+func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.CloseRequest{
+ Fd: fd,
+ }
+ resp, err := dut.posixServer.Close(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Close: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// GetSockName calls getsockname on the DUT and causes a fatal test failure if
+// it doesn't succeed. If more control over the timeout or error handling is
+// needed, use GetSockNameWithErrno.
+func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, sa, err := dut.GetSockNameWithErrno(ctx, sockfd)
+ if ret != 0 {
+ dut.t.Fatalf("failed to getsockname: %s", err)
+ }
+ return sa
+}
+
// GetSockNameWithErrno calls getsockname on the DUT.
-func (dut *DUT) GetSockNameWithErrno(sockfd int32) (int32, unix.Sockaddr, error) {
+func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
dut.t.Helper()
req := pb.GetSockNameRequest{
Sockfd: sockfd,
}
- ctx := context.Background()
resp, err := dut.posixServer.GetSockName(ctx, &req)
if err != nil {
dut.t.Fatalf("failed to call Bind: %s", err)
@@ -181,26 +264,26 @@ func (dut *DUT) GetSockNameWithErrno(sockfd int32) (int32, unix.Sockaddr, error)
return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
}
-// GetSockName calls getsockname on the DUT and causes a fatal test failure if
-// it doens't succeed.
-func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr {
+// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// ListenWithErrno.
+func (dut *DUT) Listen(sockfd, backlog int32) {
dut.t.Helper()
- ret, sa, err := dut.GetSockNameWithErrno(sockfd)
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.ListenWithErrno(ctx, sockfd, backlog)
if ret != 0 {
- dut.t.Fatalf("failed to getsockname: %s", err)
+ dut.t.Fatalf("failed to listen: %s", err)
}
- return sa
}
// ListenWithErrno calls listen on the DUT.
-func (dut *DUT) ListenWithErrno(sockfd, backlog int32) (int32, error) {
+func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int32, error) {
dut.t.Helper()
req := pb.ListenRequest{
Sockfd: sockfd,
Backlog: backlog,
}
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
- defer cancel()
resp, err := dut.posixServer.Listen(ctx, &req)
if err != nil {
dut.t.Fatalf("failed to call Listen: %s", err)
@@ -208,44 +291,54 @@ func (dut *DUT) ListenWithErrno(sockfd, backlog int32) (int32, error) {
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
-// succeed.
-func (dut *DUT) Listen(sockfd, backlog int32) {
+// Send calls send on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// SendWithErrno.
+func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 {
dut.t.Helper()
- ret, err := dut.ListenWithErrno(sockfd, backlog)
- if ret != 0 {
- dut.t.Fatalf("failed to listen: %s", err)
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags)
+ if ret == -1 {
+ dut.t.Fatalf("failed to send: %s", err)
}
+ return ret
}
-// AcceptWithErrno calls accept on the DUT.
-func (dut *DUT) AcceptWithErrno(sockfd int32) (int32, unix.Sockaddr, error) {
+// SendWithErrno calls send on the DUT.
+func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32) (int32, error) {
dut.t.Helper()
- req := pb.AcceptRequest{
+ req := pb.SendRequest{
Sockfd: sockfd,
+ Buf: buf,
+ Flags: flags,
}
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
- defer cancel()
- resp, err := dut.posixServer.Accept(ctx, &req)
+ resp, err := dut.posixServer.Send(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Accept: %s", err)
+ dut.t.Fatalf("failed to call Send: %s", err)
}
- return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// Accept calls accept on the DUT and causes a fatal test failure if it doesn't
-// succeed.
-func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) {
+// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use SetSockOptWithErrno. Because endianess and the width of values
+// might differ between the testbench and DUT architectures, prefer to use a
+// more specific SetSockOptXxx function.
+func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) {
dut.t.Helper()
- fd, sa, err := dut.AcceptWithErrno(sockfd)
- if fd < 0 {
- dut.t.Fatalf("failed to accept: %s", err)
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval)
+ if ret != 0 {
+ dut.t.Fatalf("failed to SetSockOpt: %s", err)
}
- return fd, sa
}
-// SetSockOptWithErrno calls setsockopt on the DUT.
-func (dut *DUT) SetSockOptWithErrno(sockfd, level, optname int32, optval []byte) (int32, error) {
+// SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the
+// width of values might differ between the testbench and DUT architectures,
+// prefer to use a more specific SetSockOptXxxWithErrno function.
+func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname int32, optval []byte) (int32, error) {
dut.t.Helper()
req := pb.SetSockOptRequest{
Sockfd: sockfd,
@@ -253,8 +346,6 @@ func (dut *DUT) SetSockOptWithErrno(sockfd, level, optname int32, optval []byte)
Optname: optname,
Optval: optval,
}
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
- defer cancel()
resp, err := dut.posixServer.SetSockOpt(ctx, &req)
if err != nil {
dut.t.Fatalf("failed to call SetSockOpt: %s", err)
@@ -262,19 +353,51 @@ func (dut *DUT) SetSockOptWithErrno(sockfd, level, optname int32, optval []byte)
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it
-// doesn't succeed.
-func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) {
+// SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the int optval or error handling
+// is needed, use SetSockOptIntWithErrno.
+func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) {
dut.t.Helper()
- ret, err := dut.SetSockOptWithErrno(sockfd, level, optname, optval)
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval)
if ret != 0 {
- dut.t.Fatalf("failed to SetSockOpt: %s", err)
+ dut.t.Fatalf("failed to SetSockOptInt: %s", err)
+ }
+}
+
+// SetSockOptIntWithErrno calls setsockopt with an integer optval.
+func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname, optval int32) (int32, error) {
+ dut.t.Helper()
+ req := pb.SetSockOptIntRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ Intval: optval,
+ }
+ resp, err := dut.posixServer.SetSockOptInt(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call SetSockOptInt: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the timeout or error handling is
+// needed, use SetSockOptTimevalWithErrno.
+func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) {
+ dut.t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptTimevalWithErrno(ctx, sockfd, level, optname, tv)
+ if ret != 0 {
+ dut.t.Fatalf("failed to SetSockOptTimeval: %s", err)
}
}
// SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to
// bytes.
-func (dut *DUT) SetSockOptTimevalWithErrno(sockfd, level, optname int32, tv *unix.Timeval) (int32, error) {
+func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) {
dut.t.Helper()
timeval := pb.Timeval{
Seconds: int64(tv.Sec),
@@ -286,8 +409,6 @@ func (dut *DUT) SetSockOptTimevalWithErrno(sockfd, level, optname int32, tv *uni
Optname: optname,
Timeval: &timeval,
}
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
- defer cancel()
resp, err := dut.posixServer.SetSockOptTimeval(ctx, &req)
if err != nil {
dut.t.Fatalf("failed to call SetSockOptTimeval: %s", err)
@@ -295,69 +416,58 @@ func (dut *DUT) SetSockOptTimevalWithErrno(sockfd, level, optname int32, tv *uni
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
-// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure
-// if it doesn't succeed.
-func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) {
+// Socket calls socket on the DUT and returns the file descriptor. If socket
+// fails on the DUT, the test ends.
+func (dut *DUT) Socket(domain, typ, proto int32) int32 {
dut.t.Helper()
- ret, err := dut.SetSockOptTimevalWithErrno(sockfd, level, optname, tv)
- if ret != 0 {
- dut.t.Fatalf("failed to SetSockOptTimeval: %s", err)
+ fd, err := dut.SocketWithErrno(domain, typ, proto)
+ if fd < 0 {
+ dut.t.Fatalf("failed to create socket: %s", err)
}
+ return fd
}
-// CloseWithErrno calls close on the DUT.
-func (dut *DUT) CloseWithErrno(fd int32) (int32, error) {
+// SocketWithErrno calls socket on the DUT and returns the fd and errno.
+func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
dut.t.Helper()
- req := pb.CloseRequest{
- Fd: fd,
+ req := pb.SocketRequest{
+ Domain: domain,
+ Type: typ,
+ Protocol: proto,
}
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
- defer cancel()
- resp, err := dut.posixServer.Close(ctx, &req)
+ ctx := context.Background()
+ resp, err := dut.posixServer.Socket(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Close: %s", err)
+ dut.t.Fatalf("failed to call Socket: %s", err)
}
- return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+ return resp.GetFd(), syscall.Errno(resp.GetErrno_())
}
-// Close calls close on the DUT and causes a fatal test failure if it doesn't
-// succeed.
-func (dut *DUT) Close(fd int32) {
+// Recv calls recv on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// RecvWithErrno.
+func (dut *DUT) Recv(sockfd, len, flags int32) []byte {
dut.t.Helper()
- ret, err := dut.CloseWithErrno(fd)
- if ret != 0 {
- dut.t.Fatalf("failed to close: %s", err)
+ ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ defer cancel()
+ ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags)
+ if ret == -1 {
+ dut.t.Fatalf("failed to recv: %s", err)
}
+ return buf
}
-// CreateListener makes a new TCP connection. If it fails, the test ends.
-func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
+// RecvWithErrno calls recv on the DUT.
+func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) {
dut.t.Helper()
- addr := net.ParseIP(*remoteIPv4)
- var fd int32
- if addr.To4() != nil {
- fd = dut.Socket(unix.AF_INET, typ, proto)
- sa := unix.SockaddrInet4{}
- copy(sa.Addr[:], addr.To4())
- dut.Bind(fd, &sa)
- } else if addr.To16() != nil {
- fd = dut.Socket(unix.AF_INET6, typ, proto)
- sa := unix.SockaddrInet6{}
- copy(sa.Addr[:], addr.To16())
- dut.Bind(fd, &sa)
- } else {
- dut.t.Fatal("unknown ip addr type for remoteIP")
+ req := pb.RecvRequest{
+ Sockfd: sockfd,
+ Len: len,
+ Flags: flags,
}
- sa := dut.GetSockName(fd)
- var port int
- switch s := sa.(type) {
- case *unix.SockaddrInet4:
- port = s.Port
- case *unix.SockaddrInet6:
- port = s.Port
- default:
- dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa)
+ resp, err := dut.posixServer.Recv(ctx, &req)
+ if err != nil {
+ dut.t.Fatalf("failed to call Recv: %s", err)
}
- dut.Listen(fd, backlog)
- return fd, uint16(port)
+ return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_())
}
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index 35fa4dcb6..5ce324f0d 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -15,13 +15,16 @@
package testbench
import (
+ "encoding/hex"
"fmt"
"reflect"
+ "strings"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/imdario/mergo"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -31,6 +34,8 @@ import (
// Layer contains all the fields of the encapsulation. Each field is a pointer
// and may be nil.
type Layer interface {
+ fmt.Stringer
+
// toBytes converts the Layer into bytes. In places where the Layer's field
// isn't nil, the value that is pointed to is used. When the field is nil, a
// reasonable default for the Layer is used. For example, "64" for IPv4 TTL
@@ -42,7 +47,8 @@ type Layer interface {
// match checks if the current Layer matches the provided Layer. If either
// Layer has a nil in a given field, that field is considered matching.
- // Otherwise, the values pointed to by the fields must match.
+ // Otherwise, the values pointed to by the fields must match. The LayerBase is
+ // ignored.
match(Layer) bool
// length in bytes of the current encapsulation
@@ -59,6 +65,9 @@ type Layer interface {
// setPrev sets the pointer to the Layer encapsulating this one.
setPrev(Layer)
+
+ // merge overrides the values in the interface with the provided values.
+ merge(Layer) error
}
// LayerBase is the common elements of all layers.
@@ -83,21 +92,59 @@ func (lb *LayerBase) setPrev(l Layer) {
lb.prevLayer = l
}
+// equalLayer compares that two Layer structs match while ignoring field in
+// which either input has a nil and also ignoring the LayerBase of the inputs.
func equalLayer(x, y Layer) bool {
+ if x == nil || y == nil {
+ return true
+ }
+ // opt ignores comparison pairs where either of the inputs is a nil.
opt := cmp.FilterValues(func(x, y interface{}) bool {
- if reflect.ValueOf(x).Kind() == reflect.Ptr && reflect.ValueOf(x).IsNil() {
- return true
- }
- if reflect.ValueOf(y).Kind() == reflect.Ptr && reflect.ValueOf(y).IsNil() {
- return true
+ for _, l := range []interface{}{x, y} {
+ v := reflect.ValueOf(l)
+ if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() {
+ return true
+ }
}
return false
-
}, cmp.Ignore())
- return cmp.Equal(x, y, opt, cmpopts.IgnoreUnexported(LayerBase{}))
+ return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{}))
+}
+
+// mergeLayer merges other in layer. Any non-nil value in other overrides the
+// corresponding value in layer. If other is nil, no action is performed.
+func mergeLayer(layer, other Layer) error {
+ if other == nil {
+ return nil
+ }
+ return mergo.Merge(layer, other, mergo.WithOverride)
+}
+
+func stringLayer(l Layer) string {
+ v := reflect.ValueOf(l).Elem()
+ t := v.Type()
+ var ret []string
+ for i := 0; i < v.NumField(); i++ {
+ t := t.Field(i)
+ if t.Anonymous {
+ // Ignore the LayerBase in the Layer struct.
+ continue
+ }
+ v := v.Field(i)
+ if v.IsNil() {
+ continue
+ }
+ v = reflect.Indirect(v)
+ if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
+ ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes())))
+ } else {
+ ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v))
+ }
+ }
+ return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " "))
}
-// Ether can construct and match the ethernet encapsulation.
+// Ether can construct and match an ethernet encapsulation.
type Ether struct {
LayerBase
SrcAddr *tcpip.LinkAddress
@@ -105,6 +152,10 @@ type Ether struct {
Type *tcpip.NetworkProtocolNumber
}
+func (l *Ether) String() string {
+ return stringLayer(l)
+}
+
func (l *Ether) toBytes() ([]byte, error) {
b := make([]byte, header.EthernetMinimumSize)
h := header.Ethernet(b)
@@ -123,7 +174,7 @@ func (l *Ether) toBytes() ([]byte, error) {
fields.Type = header.IPv4ProtocolNumber
default:
// TODO(b/150301488): Support more protocols, like IPv6.
- return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", n)
+ return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n)
}
}
h.Encode(fields)
@@ -142,27 +193,46 @@ func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocol
return &v
}
-// ParseEther parses the bytes assuming that they start with an ethernet header
+// layerParser parses the input bytes and returns a Layer along with the next
+// layerParser to run. If there is no more parsing to do, the returned
+// layerParser is nil.
+type layerParser func([]byte) (Layer, layerParser)
+
+// parse parses bytes starting with the first layerParser and using successive
+// layerParsers until all the bytes are parsed.
+func parse(parser layerParser, b []byte) Layers {
+ var layers Layers
+ for {
+ var layer Layer
+ layer, parser = parser(b)
+ layers = append(layers, layer)
+ if parser == nil {
+ break
+ }
+ b = b[layer.length():]
+ }
+ layers.linkLayers()
+ return layers
+}
+
+// parseEther parses the bytes assuming that they start with an ethernet header
// and continues parsing further encapsulations.
-func ParseEther(b []byte) (Layers, error) {
+func parseEther(b []byte) (Layer, layerParser) {
h := header.Ethernet(b)
ether := Ether{
SrcAddr: LinkAddress(h.SourceAddress()),
DstAddr: LinkAddress(h.DestinationAddress()),
Type: NetworkProtocolNumber(h.Type()),
}
- layers := Layers{&ether}
+ var nextParser layerParser
switch h.Type() {
case header.IPv4ProtocolNumber:
- moreLayers, err := ParseIPv4(b[ether.length():])
- if err != nil {
- return nil, err
- }
- return append(layers, moreLayers...), nil
+ nextParser = parseIPv4
default:
- // TODO(b/150301488): Support more protocols, like IPv6.
- return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %v", b)
+ // Assume that the rest is a payload.
+ nextParser = parsePayload
}
+ return &ether, nextParser
}
func (l *Ether) match(other Layer) bool {
@@ -173,7 +243,13 @@ func (l *Ether) length() int {
return header.EthernetMinimumSize
}
-// IPv4 can construct and match the ethernet excapulation.
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *Ether) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// IPv4 can construct and match an IPv4 encapsulation.
type IPv4 struct {
LayerBase
IHL *uint8
@@ -189,6 +265,10 @@ type IPv4 struct {
DstAddr *tcpip.Address
}
+func (l *IPv4) String() string {
+ return stringLayer(l)
+}
+
func (l *IPv4) toBytes() ([]byte, error) {
b := make([]byte, header.IPv4MinimumSize)
h := header.IPv4(b)
@@ -236,9 +316,11 @@ func (l *IPv4) toBytes() ([]byte, error) {
switch n := l.next().(type) {
case *TCP:
fields.Protocol = uint8(header.TCPProtocolNumber)
+ case *UDP:
+ fields.Protocol = uint8(header.UDPProtocolNumber)
default:
- // TODO(b/150301488): Support more protocols, like UDP.
- return nil, fmt.Errorf("can't deduce the ip header's next protocol: %+v", n)
+ // TODO(b/150301488): Support more protocols as needed.
+ return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n)
}
}
if l.SrcAddr != nil {
@@ -275,9 +357,9 @@ func Address(v tcpip.Address) *tcpip.Address {
return &v
}
-// ParseIPv4 parses the bytes assuming that they start with an ipv4 header and
+// parseIPv4 parses the bytes assuming that they start with an ipv4 header and
// continues parsing further encapsulations.
-func ParseIPv4(b []byte) (Layers, error) {
+func parseIPv4(b []byte) (Layer, layerParser) {
h := header.IPv4(b)
tos, _ := h.TOS()
ipv4 := IPv4{
@@ -293,16 +375,17 @@ func ParseIPv4(b []byte) (Layers, error) {
SrcAddr: Address(h.SourceAddress()),
DstAddr: Address(h.DestinationAddress()),
}
- layers := Layers{&ipv4}
- switch h.Protocol() {
- case uint8(header.TCPProtocolNumber):
- moreLayers, err := ParseTCP(b[ipv4.length():])
- if err != nil {
- return nil, err
- }
- return append(layers, moreLayers...), nil
+ var nextParser layerParser
+ switch h.TransportProtocol() {
+ case header.TCPProtocolNumber:
+ nextParser = parseTCP
+ case header.UDPProtocolNumber:
+ nextParser = parseUDP
+ default:
+ // Assume that the rest is a payload.
+ nextParser = parsePayload
}
- return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", h.Protocol())
+ return &ipv4, nextParser
}
func (l *IPv4) match(other Layer) bool {
@@ -316,7 +399,13 @@ func (l *IPv4) length() int {
return int(*l.IHL)
}
-// TCP can construct and match the TCP excapulation.
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *IPv4) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// TCP can construct and match a TCP encapsulation.
type TCP struct {
LayerBase
SrcPort *uint16
@@ -330,6 +419,10 @@ type TCP struct {
UrgentPointer *uint16
}
+func (l *TCP) String() string {
+ return stringLayer(l)
+}
+
func (l *TCP) toBytes() ([]byte, error) {
b := make([]byte, header.TCPMinimumSize)
h := header.TCP(b)
@@ -347,12 +440,16 @@ func (l *TCP) toBytes() ([]byte, error) {
}
if l.DataOffset != nil {
h.SetDataOffset(*l.DataOffset)
+ } else {
+ h.SetDataOffset(uint8(l.length()))
}
if l.Flags != nil {
h.SetFlags(*l.Flags)
}
if l.WindowSize != nil {
h.SetWindowSize(*l.WindowSize)
+ } else {
+ h.SetWindowSize(32768)
}
if l.UrgentPointer != nil {
h.SetUrgentPoiner(*l.UrgentPointer)
@@ -361,38 +458,52 @@ func (l *TCP) toBytes() ([]byte, error) {
h.SetChecksum(*l.Checksum)
return h, nil
}
- if err := setChecksum(&h, l); err != nil {
+ if err := setTCPChecksum(&h, l); err != nil {
return nil, err
}
return h, nil
}
-// setChecksum calculates the checksum of the TCP header and sets it in h.
-func setChecksum(h *header.TCP, tcp *TCP) error {
- h.SetChecksum(0)
- tcpLength := uint16(tcp.length())
- current := tcp.next()
- for current != nil {
- tcpLength += uint16(current.length())
- current = current.next()
+// totalLength returns the length of the provided layer and all following
+// layers.
+func totalLength(l Layer) int {
+ var totalLength int
+ for ; l != nil; l = l.next() {
+ totalLength += l.length()
}
+ return totalLength
+}
+// layerChecksum calculates the checksum of the Layer header, including the
+// peusdeochecksum of the layer before it and all the bytes after it..
+func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) {
+ totalLength := uint16(totalLength(l))
var xsum uint16
- switch s := tcp.prev().(type) {
+ switch s := l.prev().(type) {
case *IPv4:
- xsum = header.PseudoHeaderChecksum(header.TCPProtocolNumber, *s.SrcAddr, *s.DstAddr, tcpLength)
+ xsum = header.PseudoHeaderChecksum(protoNumber, *s.SrcAddr, *s.DstAddr, totalLength)
default:
// TODO(b/150301488): Support more protocols, like IPv6.
- return fmt.Errorf("can't get src and dst addr from previous layer")
+ return 0, fmt.Errorf("can't get src and dst addr from previous layer: %#v", s)
}
- current = tcp.next()
- for current != nil {
+ var payloadBytes buffer.VectorisedView
+ for current := l.next(); current != nil; current = current.next() {
payload, err := current.toBytes()
if err != nil {
- return fmt.Errorf("can't get bytes for next header: %s", payload)
+ return 0, fmt.Errorf("can't get bytes for next header: %s", payload)
}
- xsum = header.Checksum(payload, xsum)
- current = current.next()
+ payloadBytes.AppendView(payload)
+ }
+ xsum = header.ChecksumVV(payloadBytes, xsum)
+ return xsum, nil
+}
+
+// setTCPChecksum calculates the checksum of the TCP header and sets it in h.
+func setTCPChecksum(h *header.TCP, tcp *TCP) error {
+ h.SetChecksum(0)
+ xsum, err := layerChecksum(tcp, header.TCPProtocolNumber)
+ if err != nil {
+ return err
}
h.SetChecksum(^h.CalculateChecksum(xsum))
return nil
@@ -404,9 +515,9 @@ func Uint32(v uint32) *uint32 {
return &v
}
-// ParseTCP parses the bytes assuming that they start with a tcp header and
+// parseTCP parses the bytes assuming that they start with a tcp header and
// continues parsing further encapsulations.
-func ParseTCP(b []byte) (Layers, error) {
+func parseTCP(b []byte) (Layer, layerParser) {
h := header.TCP(b)
tcp := TCP{
SrcPort: Uint16(h.SourcePort()),
@@ -419,12 +530,7 @@ func ParseTCP(b []byte) (Layers, error) {
Checksum: Uint16(h.Checksum()),
UrgentPointer: Uint16(h.UrgentPointer()),
}
- layers := Layers{&tcp}
- moreLayers, err := ParsePayload(b[tcp.length():])
- if err != nil {
- return nil, err
- }
- return append(layers, moreLayers...), nil
+ return &tcp, parsePayload
}
func (l *TCP) match(other Layer) bool {
@@ -440,8 +546,86 @@ func (l *TCP) length() int {
// merge overrides the values in l with the values from other but only in fields
// where the value is not nil.
-func (l *TCP) merge(other TCP) error {
- return mergo.Merge(l, other, mergo.WithOverride)
+func (l *TCP) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// UDP can construct and match a UDP encapsulation.
+type UDP struct {
+ LayerBase
+ SrcPort *uint16
+ DstPort *uint16
+ Length *uint16
+ Checksum *uint16
+}
+
+func (l *UDP) String() string {
+ return stringLayer(l)
+}
+
+func (l *UDP) toBytes() ([]byte, error) {
+ b := make([]byte, header.UDPMinimumSize)
+ h := header.UDP(b)
+ if l.SrcPort != nil {
+ h.SetSourcePort(*l.SrcPort)
+ }
+ if l.DstPort != nil {
+ h.SetDestinationPort(*l.DstPort)
+ }
+ if l.Length != nil {
+ h.SetLength(*l.Length)
+ } else {
+ h.SetLength(uint16(totalLength(l)))
+ }
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ return h, nil
+ }
+ if err := setUDPChecksum(&h, l); err != nil {
+ return nil, err
+ }
+ return h, nil
+}
+
+// setUDPChecksum calculates the checksum of the UDP header and sets it in h.
+func setUDPChecksum(h *header.UDP, udp *UDP) error {
+ h.SetChecksum(0)
+ xsum, err := layerChecksum(udp, header.UDPProtocolNumber)
+ if err != nil {
+ return err
+ }
+ h.SetChecksum(^h.CalculateChecksum(xsum))
+ return nil
+}
+
+// parseUDP parses the bytes assuming that they start with a udp header and
+// returns the parsed layer and the next parser to use.
+func parseUDP(b []byte) (Layer, layerParser) {
+ h := header.UDP(b)
+ udp := UDP{
+ SrcPort: Uint16(h.SourcePort()),
+ DstPort: Uint16(h.DestinationPort()),
+ Length: Uint16(h.Length()),
+ Checksum: Uint16(h.Checksum()),
+ }
+ return &udp, parsePayload
+}
+
+func (l *UDP) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *UDP) length() int {
+ if l.Length == nil {
+ return header.UDPMinimumSize
+ }
+ return int(*l.Length)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *UDP) merge(other Layer) error {
+ return mergeLayer(l, other)
}
// Payload has bytes beyond OSI layer 4.
@@ -450,13 +634,17 @@ type Payload struct {
Bytes []byte
}
-// ParsePayload parses the bytes assuming that they start with a payload and
+func (l *Payload) String() string {
+ return stringLayer(l)
+}
+
+// parsePayload parses the bytes assuming that they start with a payload and
// continue to the end. There can be no further encapsulations.
-func ParsePayload(b []byte) (Layers, error) {
+func parsePayload(b []byte) (Layer, layerParser) {
payload := Payload{
Bytes: b,
}
- return Layers{&payload}, nil
+ return &payload, nil
}
func (l *Payload) toBytes() ([]byte, error) {
@@ -471,18 +659,33 @@ func (l *Payload) length() int {
return len(l.Bytes)
}
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *Payload) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
// Layers is an array of Layer and supports similar functions to Layer.
type Layers []Layer
-func (ls *Layers) toBytes() ([]byte, error) {
+// linkLayers sets the linked-list ponters in ls.
+func (ls *Layers) linkLayers() {
for i, l := range *ls {
if i > 0 {
l.setPrev((*ls)[i-1])
+ } else {
+ l.setPrev(nil)
}
if i+1 < len(*ls) {
l.setNext((*ls)[i+1])
+ } else {
+ l.setNext(nil)
}
}
+}
+
+func (ls *Layers) toBytes() ([]byte, error) {
+ ls.linkLayers()
outBytes := []byte{}
for _, l := range *ls {
layerBytes, err := l.toBytes()
@@ -498,8 +701,8 @@ func (ls *Layers) match(other Layers) bool {
if len(*ls) > len(other) {
return false
}
- for i := 0; i < len(*ls); i++ {
- if !equalLayer((*ls)[i], other[i]) {
+ for i, l := range *ls {
+ if !equalLayer(l, other[i]) {
return false
}
}
diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go
new file mode 100644
index 000000000..b32efda93
--- /dev/null
+++ b/test/packetimpact/testbench/layers_test.go
@@ -0,0 +1,156 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestLayerMatch(t *testing.T) {
+ var nilPayload *Payload
+ noPayload := &Payload{}
+ emptyPayload := &Payload{Bytes: []byte{}}
+ fullPayload := &Payload{Bytes: []byte{1, 2, 3}}
+ emptyTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: emptyPayload}}
+ fullTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: fullPayload}}
+ for _, tt := range []struct {
+ a, b Layer
+ want bool
+ }{
+ {nilPayload, nilPayload, true},
+ {nilPayload, noPayload, true},
+ {nilPayload, emptyPayload, true},
+ {nilPayload, fullPayload, true},
+ {noPayload, noPayload, true},
+ {noPayload, emptyPayload, true},
+ {noPayload, fullPayload, true},
+ {emptyPayload, emptyPayload, true},
+ {emptyPayload, fullPayload, false},
+ {fullPayload, fullPayload, true},
+ {emptyTCP, fullTCP, true},
+ } {
+ if got := tt.a.match(tt.b); got != tt.want {
+ t.Errorf("%s.match(%s) = %t, want %t", tt.a, tt.b, got, tt.want)
+ }
+ if got := tt.b.match(tt.a); got != tt.want {
+ t.Errorf("%s.match(%s) = %t, want %t", tt.b, tt.a, got, tt.want)
+ }
+ }
+}
+
+func TestLayerStringFormat(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ l Layer
+ want string
+ }{
+ {
+ name: "TCP",
+ l: &TCP{
+ SrcPort: Uint16(34785),
+ DstPort: Uint16(47767),
+ SeqNum: Uint32(3452155723),
+ AckNum: Uint32(2596996163),
+ DataOffset: Uint8(5),
+ Flags: Uint8(20),
+ WindowSize: Uint16(64240),
+ Checksum: Uint16(0x2e2b),
+ },
+ want: "&testbench.TCP{" +
+ "SrcPort:34785 " +
+ "DstPort:47767 " +
+ "SeqNum:3452155723 " +
+ "AckNum:2596996163 " +
+ "DataOffset:5 " +
+ "Flags:20 " +
+ "WindowSize:64240 " +
+ "Checksum:11819" +
+ "}",
+ },
+ {
+ name: "UDP",
+ l: &UDP{
+ SrcPort: Uint16(34785),
+ DstPort: Uint16(47767),
+ Length: Uint16(12),
+ },
+ want: "&testbench.UDP{" +
+ "SrcPort:34785 " +
+ "DstPort:47767 " +
+ "Length:12" +
+ "}",
+ },
+ {
+ name: "IPv4",
+ l: &IPv4{
+ IHL: Uint8(5),
+ TOS: Uint8(0),
+ TotalLength: Uint16(44),
+ ID: Uint16(0),
+ Flags: Uint8(2),
+ FragmentOffset: Uint16(0),
+ TTL: Uint8(64),
+ Protocol: Uint8(6),
+ Checksum: Uint16(0x2e2b),
+ SrcAddr: Address(tcpip.Address([]byte{197, 34, 63, 10})),
+ DstAddr: Address(tcpip.Address([]byte{197, 34, 63, 20})),
+ },
+ want: "&testbench.IPv4{" +
+ "IHL:5 " +
+ "TOS:0 " +
+ "TotalLength:44 " +
+ "ID:0 " +
+ "Flags:2 " +
+ "FragmentOffset:0 " +
+ "TTL:64 " +
+ "Protocol:6 " +
+ "Checksum:11819 " +
+ "SrcAddr:197.34.63.10 " +
+ "DstAddr:197.34.63.20" +
+ "}",
+ },
+ {
+ name: "Ether",
+ l: &Ether{
+ SrcAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x0a})),
+ DstAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x14})),
+ Type: NetworkProtocolNumber(4),
+ },
+ want: "&testbench.Ether{" +
+ "SrcAddr:02:42:c5:22:3f:0a " +
+ "DstAddr:02:42:c5:22:3f:14 " +
+ "Type:4" +
+ "}",
+ },
+ {
+ name: "Payload",
+ l: &Payload{
+ Bytes: []byte("Hooray for packetimpact."),
+ },
+ want: "&testbench.Payload{Bytes:\n" +
+ "00000000 48 6f 6f 72 61 79 20 66 6f 72 20 70 61 63 6b 65 |Hooray for packe|\n" +
+ "00000010 74 69 6d 70 61 63 74 2e |timpact.|\n" +
+ "}",
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := tt.l.String(); got != tt.want {
+ t.Errorf("%s.String() = %s, want: %s", tt.name, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
index 0c7d0f979..ff722d4a6 100644
--- a/test/packetimpact/testbench/rawsockets.go
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -17,6 +17,7 @@ package testbench
import (
"encoding/binary"
"flag"
+ "fmt"
"math"
"net"
"testing"
@@ -47,6 +48,12 @@ func NewSniffer(t *testing.T) (Sniffer, error) {
if err != nil {
return Sniffer{}, err
}
+ if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, 1); err != nil {
+ t.Fatalf("can't set sockopt SO_RCVBUFFORCE to 1: %s", err)
+ }
+ if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1e7); err != nil {
+ t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err)
+ }
return Sniffer{
t: t,
fd: snifferFd,
@@ -91,12 +98,36 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte {
}
}
-// Close the socket that Sniffer is using.
-func (s *Sniffer) Close() {
+// Drain drains the Sniffer's socket receive buffer by receiving until there's
+// nothing else to receive.
+func (s *Sniffer) Drain() {
+ s.t.Helper()
+ flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0)
+ if err != nil {
+ s.t.Fatalf("failed to get sniffer socket fd flags: %s", err)
+ }
+ if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil {
+ s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err)
+ }
+ for {
+ buf := make([]byte, maxReadSize)
+ _, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC)
+ if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK {
+ break
+ }
+ }
+ if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil {
+ s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err)
+ }
+}
+
+// close the socket that Sniffer is using.
+func (s *Sniffer) close() error {
if err := unix.Close(s.fd); err != nil {
- s.t.Fatalf("can't close sniffer socket: %s", err)
+ return fmt.Errorf("can't close sniffer socket: %w", err)
}
s.fd = -1
+ return nil
}
// Injector can inject raw frames.
@@ -142,10 +173,11 @@ func (i *Injector) Send(b []byte) {
}
}
-// Close the underlying socket.
-func (i *Injector) Close() {
+// close the underlying socket.
+func (i *Injector) close() error {
if err := unix.Close(i.fd); err != nil {
- i.t.Fatalf("can't close sniffer socket: %s", err)
+ return fmt.Errorf("can't close sniffer socket: %w", err)
}
i.fd = -1
+ return nil
}
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index 1dff2a4d5..47c722ccd 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -15,6 +15,87 @@ packetimpact_go_test(
],
)
+packetimpact_go_test(
+ name = "udp_recv_multicast",
+ srcs = ["udp_recv_multicast_test.go"],
+ # TODO(b/152813495): Fix netstack then remove the line below.
+ netstack = False,
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_window_shrink",
+ srcs = ["tcp_window_shrink_test.go"],
+ # TODO(b/153202472): Fix netstack then remove the line below.
+ netstack = False,
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_outside_the_window",
+ srcs = ["tcp_outside_the_window_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_noaccept_close_rst",
+ srcs = ["tcp_noaccept_close_rst_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_should_piggyback",
+ srcs = ["tcp_should_piggyback_test.go"],
+ # TODO(b/153680566): Fix netstack then remove the line below.
+ netstack = False,
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_close_wait_ack",
+ srcs = ["tcp_close_wait_ack_test.go"],
+ # TODO(b/153574037): Fix netstack then remove the line below.
+ netstack = False,
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_user_timeout",
+ srcs = ["tcp_user_timeout_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
sh_binary(
name = "test_runner",
srcs = ["test_runner.sh"],
diff --git a/test/packetimpact/tests/Dockerfile b/test/packetimpact/tests/Dockerfile
index 507030cc7..9075bc555 100644
--- a/test/packetimpact/tests/Dockerfile
+++ b/test/packetimpact/tests/Dockerfile
@@ -1,5 +1,17 @@
FROM ubuntu:bionic
-RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y iptables netcat tcpdump iproute2 tshark
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
+ # iptables to disable OS native packet processing.
+ iptables \
+ # nc to check that the posix_server is running.
+ netcat \
+ # tcpdump to log brief packet sniffing.
+ tcpdump \
+ # ip link show to display MAC addresses.
+ iproute2 \
+ # tshark to log verbose packet sniffing.
+ tshark \
+ # killall for cleanup.
+ psmisc
RUN hash -r
CMD /bin/bash
diff --git a/test/packetimpact/tests/defs.bzl b/test/packetimpact/tests/defs.bzl
index 1b4213d9b..8c0d058b2 100644
--- a/test/packetimpact/tests/defs.bzl
+++ b/test/packetimpact/tests/defs.bzl
@@ -93,7 +93,17 @@ def packetimpact_netstack_test(name, testbench_binary, **kwargs):
**kwargs
)
-def packetimpact_go_test(name, size = "small", pure = True, **kwargs):
+def packetimpact_go_test(name, size = "small", pure = True, linux = True, netstack = True, **kwargs):
+ """Add packetimpact tests written in go.
+
+ Args:
+ name: name of the test
+ size: size of the test
+ pure: make a static go binary
+ linux: generate a linux test
+ netstack: generate a netstack test
+ **kwargs: all the other args, forwarded to go_test
+ """
testbench_binary = name + "_test"
go_test(
name = testbench_binary,
@@ -102,5 +112,7 @@ def packetimpact_go_test(name, size = "small", pure = True, **kwargs):
tags = PACKETIMPACT_TAGS,
**kwargs
)
- packetimpact_linux_test(name = name, testbench_binary = testbench_binary)
- packetimpact_netstack_test(name = name, testbench_binary = testbench_binary)
+ if linux:
+ packetimpact_linux_test(name = name, testbench_binary = testbench_binary)
+ if netstack:
+ packetimpact_netstack_test(name = name, testbench_binary = testbench_binary)
diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go
index 5f54e67ed..b98594f94 100644
--- a/test/packetimpact/tests/fin_wait2_timeout_test.go
+++ b/test/packetimpact/tests/fin_wait2_timeout_test.go
@@ -36,7 +36,7 @@ func TestFinWait2Timeout(t *testing.T) {
defer dut.TearDown()
listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
defer dut.Close(listenFd)
- conn := tb.NewTCPIPv4(t, dut, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
defer conn.Close()
conn.Handshake()
@@ -47,20 +47,22 @@ func TestFinWait2Timeout(t *testing.T) {
}
dut.Close(acceptFd)
- if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); gotOne == nil {
- t.Fatal("expected a FIN-ACK within 1 second but got none")
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err)
}
conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
time.Sleep(5 * time.Second)
+ conn.Drain()
+
conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
if tt.linger2 {
- if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); gotOne == nil {
- t.Fatal("expected a RST packet within a second but got none")
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ t.Fatalf("expected a RST packet within a second but got none: %s", err)
}
} else {
- if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, 10*time.Second); gotOne != nil {
- t.Fatal("expected no RST packets within ten seconds but got one")
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, 10*time.Second); err == nil {
+ t.Fatalf("expected no RST packets within ten seconds but got one: %s", err)
}
}
})
diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go
new file mode 100644
index 000000000..eb4cc7a65
--- /dev/null
+++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go
@@ -0,0 +1,102 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_close_wait_ack_test
+
+import (
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func TestCloseWaitAck(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ makeTestingTCP func(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP
+ seqNumOffset seqnum.Size
+ expectAck bool
+ }{
+ {"OTW", GenerateOTWSeqSegment, 0, false},
+ {"OTW", GenerateOTWSeqSegment, 1, true},
+ {"OTW", GenerateOTWSeqSegment, 2, true},
+ {"ACK", GenerateUnaccACKSegment, 0, false},
+ {"ACK", GenerateUnaccACKSegment, 1, true},
+ {"ACK", GenerateUnaccACKSegment, 2, true},
+ } {
+ t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Handshake()
+ acceptFd, _ := dut.Accept(listenFd)
+
+ // Send a FIN to DUT to intiate the active close
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagFin)})
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err)
+ }
+
+ // Send a segment with OTW Seq / unacc ACK and expect an ACK back
+ conn.Send(tt.makeTestingTCP(&conn, tt.seqNumOffset), &tb.Payload{Bytes: []byte("Sample Data")})
+ gotAck, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second)
+ if tt.expectAck && err != nil {
+ t.Fatalf("expected an ack but got none: %s", err)
+ }
+ if !tt.expectAck && gotAck != nil {
+ t.Fatalf("expected no ack but got one: %s", gotAck)
+ }
+
+ // Now let's verify DUT is indeed in CLOSE_WAIT
+ dut.Close(acceptFd)
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil {
+ t.Fatalf("expected DUT to send a FIN: %s", err)
+ }
+ // Ack the FIN from DUT
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ // Send some extra data to DUT
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, &tb.Payload{Bytes: []byte("Sample Data")})
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ t.Fatalf("expected DUT to send an RST: %s", err)
+ }
+ })
+ }
+}
+
+// This generates an segment with seqnum = RCV.NXT + RCV.WND + seqNumOffset, the
+// generated segment is only acceptable when seqNumOffset is 0, otherwise an ACK
+// is expected from the receiver.
+func GenerateOTWSeqSegment(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP {
+ windowSize := seqnum.Size(*conn.SynAck().WindowSize)
+ lastAcceptable := conn.LocalSeqNum().Add(windowSize - 1)
+ otwSeq := uint32(lastAcceptable.Add(seqNumOffset))
+ return tb.TCP{SeqNum: tb.Uint32(otwSeq), Flags: tb.Uint8(header.TCPFlagAck)}
+}
+
+// This generates an segment with acknum = SND.NXT + seqNumOffset, the generated
+// segment is only acceptable when seqNumOffset is 0, otherwise an ACK is
+// expected from the receiver.
+func GenerateUnaccACKSegment(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP {
+ lastAcceptable := conn.RemoteSeqNum()
+ unaccAck := uint32(lastAcceptable.Add(seqNumOffset))
+ return tb.TCP{AckNum: tb.Uint32(unaccAck), Flags: tb.Uint8(header.TCPFlagAck)}
+}
diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
new file mode 100644
index 000000000..7ebdd1950
--- /dev/null
+++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_noaccept_close_rst_test
+
+import (
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func TestTcpNoAcceptCloseReset(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ conn.Handshake()
+ defer conn.Close()
+ dut.Close(listenFd)
+ if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil {
+ t.Fatalf("expected a RST-ACK packet but got none: %s", err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go
new file mode 100644
index 000000000..db3d3273b
--- /dev/null
+++ b/test/packetimpact/tests/tcp_outside_the_window_test.go
@@ -0,0 +1,88 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_outside_the_window_test
+
+import (
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+// TestTCPOutsideTheWindows tests the behavior of the DUT when packets arrive
+// that are inside or outside the TCP window. Packets that are outside the
+// window should force an extra ACK, as described in RFC793 page 69:
+// https://tools.ietf.org/html/rfc793#page-69
+func TestTCPOutsideTheWindow(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ tcpFlags uint8
+ payload []tb.Layer
+ seqNumOffset seqnum.Size
+ expectACK bool
+ }{
+ {"SYN", header.TCPFlagSyn, nil, 0, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 0, true},
+ {"ACK", header.TCPFlagAck, nil, 0, false},
+ {"FIN", header.TCPFlagFin, nil, 0, false},
+ {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 0, true},
+
+ {"SYN", header.TCPFlagSyn, nil, 1, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 1, true},
+ {"ACK", header.TCPFlagAck, nil, 1, true},
+ {"FIN", header.TCPFlagFin, nil, 1, false},
+ {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 1, true},
+
+ {"SYN", header.TCPFlagSyn, nil, 2, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 2, true},
+ {"ACK", header.TCPFlagAck, nil, 2, true},
+ {"FIN", header.TCPFlagFin, nil, 2, false},
+ {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 2, true},
+ } {
+ t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+ conn.Handshake()
+ acceptFD, _ := dut.Accept(listenFD)
+ defer dut.Close(acceptFD)
+
+ windowSize := seqnum.Size(*conn.SynAck().WindowSize) + tt.seqNumOffset
+ conn.Drain()
+ // Ignore whatever incrementing that this out-of-order packet might cause
+ // to the AckNum.
+ localSeqNum := tb.Uint32(uint32(*conn.LocalSeqNum()))
+ conn.Send(tb.TCP{
+ Flags: tb.Uint8(tt.tcpFlags),
+ SeqNum: tb.Uint32(uint32(conn.LocalSeqNum().Add(windowSize))),
+ }, tt.payload...)
+ timeout := 3 * time.Second
+ gotACK, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout)
+ if tt.expectACK && err != nil {
+ t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err)
+ }
+ if !tt.expectACK && gotACK != nil {
+ t.Fatalf("expected no ACK packet within %s but got one: %s", timeout, gotACK)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_should_piggyback_test.go b/test/packetimpact/tests/tcp_should_piggyback_test.go
new file mode 100644
index 000000000..b0be6ba23
--- /dev/null
+++ b/test/packetimpact/tests/tcp_should_piggyback_test.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_should_piggyback_test
+
+import (
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func TestPiggyback(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort, WindowSize: tb.Uint16(12)}, tb.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Handshake()
+ acceptFd, _ := dut.Accept(listenFd)
+ defer dut.Close(acceptFd)
+
+ dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+
+ dut.Send(acceptFd, sampleData, 0)
+ expectedTCP := tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}
+ expectedPayload := tb.Payload{Bytes: sampleData}
+ if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("Expected %v but didn't get one: %s", tb.Layers{&expectedTCP, &expectedPayload}, err)
+ }
+
+ // Cause DUT to send us more data as soon as we ACK their first data segment because we have
+ // a small window.
+ dut.Send(acceptFd, sampleData, 0)
+
+ // DUT should ACK our segment by piggybacking ACK to their outstanding data segment instead of
+ // sending a separate ACK packet.
+ conn.Send(expectedTCP, &expectedPayload)
+ if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("Expected %v but didn't get one: %s", tb.Layers{&expectedTCP, &expectedPayload}, err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go
new file mode 100644
index 000000000..3cf82badb
--- /dev/null
+++ b/test/packetimpact/tests/tcp_user_timeout_test.go
@@ -0,0 +1,100 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_user_timeout_test
+
+import (
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func sendPayload(conn *tb.TCPIPv4, dut *tb.DUT, fd int32) error {
+ sampleData := make([]byte, 100)
+ for i := range sampleData {
+ sampleData[i] = uint8(i)
+ }
+ conn.Drain()
+ dut.Send(fd, sampleData, 0)
+ if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &tb.Payload{Bytes: sampleData}, time.Second); err != nil {
+ return fmt.Errorf("expected data but got none: %w", err)
+ }
+ return nil
+}
+
+func sendFIN(conn *tb.TCPIPv4, dut *tb.DUT, fd int32) error {
+ dut.Close(fd)
+ return nil
+}
+
+func TestTCPUserTimeout(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ userTimeout time.Duration
+ sendDelay time.Duration
+ }{
+ {"NoUserTimeout", 0, 3 * time.Second},
+ {"ACKBeforeUserTimeout", 5 * time.Second, 4 * time.Second},
+ {"ACKAfterUserTimeout", 5 * time.Second, 7 * time.Second},
+ } {
+ for _, ttf := range []struct {
+ description string
+ f func(conn *tb.TCPIPv4, dut *tb.DUT, fd int32) error
+ }{
+ {"AfterPayload", sendPayload},
+ {"AfterFIN", sendFIN},
+ } {
+ t.Run(tt.description+ttf.description, func(t *testing.T) {
+ // Create a socket, listen, TCP handshake, and accept.
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+ conn.Handshake()
+ acceptFD, _ := dut.Accept(listenFD)
+
+ if tt.userTimeout != 0 {
+ dut.SetSockOptInt(acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds()))
+ }
+
+ if err := ttf.f(&conn, &dut, acceptFD); err != nil {
+ t.Fatal(err)
+ }
+
+ time.Sleep(tt.sendDelay)
+ conn.Drain()
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+
+ // If TCP_USER_TIMEOUT was set and the above delay was longer than the
+ // TCP_USER_TIMEOUT then the DUT should send a RST in response to the
+ // testbench's packet.
+ expectRST := tt.userTimeout != 0 && tt.sendDelay > tt.userTimeout
+ expectTimeout := 5 * time.Second
+ got, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, expectTimeout)
+ if expectRST && err != nil {
+ t.Errorf("expected RST packet within %s but got none: %s", expectTimeout, err)
+ }
+ if !expectRST && got != nil {
+ t.Errorf("expected no RST packet within %s but got one: %s", expectTimeout, got)
+ }
+ })
+ }
+ }
+}
diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go
new file mode 100644
index 000000000..c9354074e
--- /dev/null
+++ b/test/packetimpact/tests/tcp_window_shrink_test.go
@@ -0,0 +1,68 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_window_shrink_test
+
+import (
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func TestWindowShrink(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFd)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ conn.Handshake()
+ acceptFd, _ := dut.Accept(listenFd)
+ defer dut.Close(acceptFd)
+
+ dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &tb.Payload{Bytes: sampleData}
+
+ dut.Send(acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
+ }
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+
+ dut.Send(acceptFd, sampleData, 0)
+ dut.Send(acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
+ }
+ if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
+ }
+ // We close our receiving window here
+ conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), WindowSize: tb.Uint16(0)})
+
+ dut.Send(acceptFd, []byte("Sample Data"), 0)
+ // Note: There is another kind of zero-window probing which Windows uses (by sending one
+ // new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change
+ // the following lines.
+ expectedRemoteSeqNum := *conn.RemoteSeqNum() - 1
+ if _, err := conn.ExpectData(&tb.TCP{SeqNum: tb.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil {
+ t.Fatalf("expected a packet with sequence number %v: %s", expectedRemoteSeqNum, err)
+ }
+}
diff --git a/test/packetimpact/tests/test_runner.sh b/test/packetimpact/tests/test_runner.sh
index 5281cb53d..e99fc7d09 100755
--- a/test/packetimpact/tests/test_runner.sh
+++ b/test/packetimpact/tests/test_runner.sh
@@ -29,13 +29,15 @@ function failure() {
}
trap 'failure ${LINENO} "$BASH_COMMAND"' ERR
-declare -r LONGOPTS="dut_platform:,posix_server_binary:,testbench_binary:,runtime:,tshark"
+declare -r LONGOPTS="dut_platform:,posix_server_binary:,testbench_binary:,runtime:,tshark,extra_test_arg:"
# Don't use declare below so that the error from getopt will end the script.
PARSED=$(getopt --options "" --longoptions=$LONGOPTS --name "$0" -- "$@")
eval set -- "$PARSED"
+declare -a EXTRA_TEST_ARGS
+
while true; do
case "$1" in
--dut_platform)
@@ -62,6 +64,10 @@ while true; do
declare -r TSHARK="1"
shift 1
;;
+ --extra_test_arg)
+ EXTRA_TEST_ARGS+="$2"
+ shift 2
+ ;;
--)
shift
break
@@ -125,6 +131,19 @@ docker --version
function finish {
local cleanup_success=1
+
+ if [[ -z "${TSHARK-}" ]]; then
+ # Kill tcpdump so that it will flush output.
+ docker exec -t "${TESTBENCH}" \
+ killall tcpdump || \
+ cleanup_success=0
+ else
+ # Kill tshark so that it will flush output.
+ docker exec -t "${TESTBENCH}" \
+ killall tshark || \
+ cleanup_success=0
+ fi
+
for net in "${CTRL_NET}" "${TEST_NET}"; do
# Kill all processes attached to ${net}.
for docker_command in "kill" "rm"; do
@@ -224,6 +243,8 @@ else
# interface with the test packets.
docker exec -t "${TESTBENCH}" \
tshark -V -l -n -i "${TEST_DEVICE}" \
+ -o tcp.check_checksum:TRUE \
+ -o udp.check_checksum:TRUE \
host "${TEST_NET_PREFIX}${TESTBENCH_NET_SUFFIX}" &
fi
@@ -235,6 +256,7 @@ sleep 3
# be executed on the DUT.
docker exec -t "${TESTBENCH}" \
/bin/bash -c "${DOCKER_TESTBENCH_BINARY} \
+ ${EXTRA_TEST_ARGS[@]-} \
--posix_server_ip=${CTRL_NET_PREFIX}${DUT_NET_SUFFIX} \
--posix_server_port=${CTRL_PORT} \
--remote_ipv4=${TEST_NET_PREFIX}${DUT_NET_SUFFIX} \
diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_recv_multicast_test.go
new file mode 100644
index 000000000..61fd17050
--- /dev/null
+++ b/test/packetimpact/tests/udp_recv_multicast_test.go
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp_recv_multicast_test
+
+import (
+ "net"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func TestUDPRecvMulticast(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(boundFD)
+ conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort})
+ defer conn.Close()
+ frame := conn.CreateFrame(&tb.UDP{}, &tb.Payload{Bytes: []byte("hello world")})
+ frame[1].(*tb.IPv4).DstAddr = tb.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4()))
+ conn.SendFrame(frame)
+ dut.Recv(boundFD, 100, 0)
+}
diff --git a/test/perf/BUILD b/test/perf/BUILD
index 0a0def6a3..471d8c2ab 100644
--- a/test/perf/BUILD
+++ b/test/perf/BUILD
@@ -30,6 +30,7 @@ syscall_test(
syscall_test(
size = "enormous",
+ shard_count = 10,
tags = ["nogotsan"],
test = "//test/perf/linux:getdents_benchmark",
)
diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc
index afc599ad2..d8e81fa8c 100644
--- a/test/perf/linux/getdents_benchmark.cc
+++ b/test/perf/linux/getdents_benchmark.cc
@@ -38,7 +38,7 @@ namespace testing {
namespace {
-constexpr int kBufferSize = 16384;
+constexpr int kBufferSize = 65536;
PosixErrorOr<TempPath> CreateDirectory(int count,
std::vector<std::string>* files) {
diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go
index 4038661cb..679342def 100644
--- a/test/root/cgroup_test.go
+++ b/test/root/cgroup_test.go
@@ -53,7 +53,7 @@ func verifyPid(pid int, path string) error {
if scanner.Err() != nil {
return scanner.Err()
}
- return fmt.Errorf("got: %s, want: %d", gots, pid)
+ return fmt.Errorf("got: %v, want: %d", gots, pid)
}
// TestCgroup sets cgroup options and checks that cgroup was properly configured.
@@ -106,7 +106,7 @@ func TestMemCGroup(t *testing.T) {
time.Sleep(100 * time.Millisecond)
}
- t.Fatalf("%vMB is less than %vMB: %v", memUsage>>20, allocMemSize>>20)
+ t.Fatalf("%vMB is less than %vMB", memUsage>>20, allocMemSize>>20)
}
// TestCgroup sets cgroup options and checks that cgroup was properly configured.
diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go
index 126f0975a..22488b05d 100644
--- a/test/root/oom_score_adj_test.go
+++ b/test/root/oom_score_adj_test.go
@@ -46,7 +46,7 @@ func TestOOMScoreAdjSingle(t *testing.T) {
}
defer os.RemoveAll(rootDir)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
ppid, err := specutils.GetParentPid(os.Getpid())
@@ -137,7 +137,7 @@ func TestOOMScoreAdjMulti(t *testing.T) {
}
defer os.RemoveAll(rootDir)
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
ppid, err := specutils.GetParentPid(os.Getpid())
diff --git a/test/runtimes/blacklist_test.go b/test/runtimes/blacklist_test.go
index 52f49b984..0ff69ab18 100644
--- a/test/runtimes/blacklist_test.go
+++ b/test/runtimes/blacklist_test.go
@@ -32,6 +32,6 @@ func TestBlacklists(t *testing.T) {
t.Fatalf("error parsing blacklist: %v", err)
}
if *blacklistFile != "" && len(bl) == 0 {
- t.Errorf("got empty blacklist for file %q", blacklistFile)
+ t.Errorf("got empty blacklist for file %q", *blacklistFile)
}
}
diff --git a/test/runtimes/runner.go b/test/runtimes/runner.go
index ddb890dbc..3c98f4570 100644
--- a/test/runtimes/runner.go
+++ b/test/runtimes/runner.go
@@ -114,7 +114,7 @@ func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.Int
F: func(t *testing.T) {
// Is the test blacklisted?
if _, ok := blacklist[tc]; ok {
- t.Skip("SKIP: blacklisted test %q", tc)
+ t.Skipf("SKIP: blacklisted test %q", tc)
}
var (
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index d0c431234..d9095c95f 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -138,7 +138,6 @@ cc_library(
hdrs = ["socket_netlink_route_util.h"],
deps = [
":socket_netlink_util",
- "@com_google_absl//absl/types:optional",
],
)
@@ -663,10 +662,7 @@ cc_binary(
cc_binary(
name = "exec_binary_test",
testonly = 1,
- srcs = select_arch(
- amd64 = ["exec_binary.cc"],
- arm64 = [],
- ),
+ srcs = ["exec_binary.cc"],
linkstatic = 1,
deps = [
"//test/util:cleanup",
@@ -2026,6 +2022,8 @@ cc_binary(
"//test/util:file_descriptor",
"@com_google_absl//absl/strings",
gtest,
+ ":ip_socket_test_util",
+ ":unix_domain_socket_test_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
@@ -2802,13 +2800,13 @@ cc_binary(
srcs = ["socket_netlink_route.cc"],
linkstatic = 1,
deps = [
+ ":socket_netlink_route_util",
":socket_netlink_util",
":socket_test_util",
"//test/util:capability_util",
"//test/util:cleanup",
"//test/util:file_descriptor",
"@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/types:optional",
gtest,
"//test/util:test_main",
"//test/util:test_util",
diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc
index a33daff17..806d5729e 100644
--- a/test/syscalls/linux/aio.cc
+++ b/test/syscalls/linux/aio.cc
@@ -89,6 +89,7 @@ class AIOTest : public FileTest {
FileTest::TearDown();
if (ctx_ != 0) {
ASSERT_THAT(DestroyContext(), SyscallSucceeds());
+ ctx_ = 0;
}
}
@@ -188,14 +189,19 @@ TEST_F(AIOTest, BadWrite) {
}
TEST_F(AIOTest, ExitWithPendingIo) {
- // Setup a context that is 5 entries deep.
- ASSERT_THAT(SetupContext(5), SyscallSucceeds());
+ // Setup a context that is 100 entries deep.
+ ASSERT_THAT(SetupContext(100), SyscallSucceeds());
struct iocb cb = CreateCallback();
struct iocb* cbs[] = {&cb};
// Submit a request but don't complete it to make it pending.
- EXPECT_THAT(Submit(1, cbs), SyscallSucceeds());
+ for (int i = 0; i < 100; ++i) {
+ EXPECT_THAT(Submit(1, cbs), SyscallSucceeds());
+ }
+
+ ASSERT_THAT(DestroyContext(), SyscallSucceeds());
+ ctx_ = 0;
}
int Submitter(void* arg) {
diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc
index a4f8f3cec..f57d38dc7 100644
--- a/test/syscalls/linux/epoll.cc
+++ b/test/syscalls/linux/epoll.cc
@@ -56,10 +56,6 @@ TEST(EpollTest, AllWritable) {
struct epoll_event result[kFDsPerEpoll];
ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1),
SyscallSucceedsWithValue(kFDsPerEpoll));
- // TODO(edahlgren): Why do some tests check epoll_event::data, and others
- // don't? Does Linux actually guarantee that, in any of these test cases,
- // epoll_wait will necessarily write out the epoll_events in the order that
- // they were registered?
for (int i = 0; i < kFDsPerEpoll; i++) {
ASSERT_EQ(result[i].events, EPOLLOUT);
}
diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc
index 07bd527e6..12c9b05ca 100644
--- a/test/syscalls/linux/exec.cc
+++ b/test/syscalls/linux/exec.cc
@@ -812,26 +812,28 @@ void ExecFromThread() {
bool ValidateProcCmdlineVsArgv(const int argc, const char* const* argv) {
auto contents_or = GetContents("/proc/self/cmdline");
if (!contents_or.ok()) {
- std::cerr << "Unable to get /proc/self/cmdline: " << contents_or.error();
+ std::cerr << "Unable to get /proc/self/cmdline: " << contents_or.error()
+ << std::endl;
return false;
}
auto contents = contents_or.ValueOrDie();
if (contents.back() != '\0') {
- std::cerr << "Non-null terminated /proc/self/cmdline!";
+ std::cerr << "Non-null terminated /proc/self/cmdline!" << std::endl;
return false;
}
contents.pop_back();
std::vector<std::string> procfs_cmdline = absl::StrSplit(contents, '\0');
if (static_cast<int>(procfs_cmdline.size()) != argc) {
- std::cerr << "argc = " << argc << " != " << procfs_cmdline.size();
+ std::cerr << "argc = " << argc << " != " << procfs_cmdline.size()
+ << std::endl;
return false;
}
for (int i = 0; i < argc; ++i) {
if (procfs_cmdline[i] != argv[i]) {
std::cerr << "Procfs command line argument " << i << " mismatch "
- << procfs_cmdline[i] << " != " << argv[i];
+ << procfs_cmdline[i] << " != " << argv[i] << std::endl;
return false;
}
}
diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc
index 736452b0c..1a9f203b9 100644
--- a/test/syscalls/linux/exec_binary.cc
+++ b/test/syscalls/linux/exec_binary.cc
@@ -48,10 +48,17 @@ namespace {
using ::testing::AnyOf;
using ::testing::Eq;
-#ifndef __x86_64__
+#if !defined(__x86_64__) && !defined(__aarch64__)
// The assembly stub and ELF internal details must be ported to other arches.
-#error "Test only supported on x86-64"
-#endif // __x86_64__
+#error "Test only supported on x86-64/arm64"
+#endif // __x86_64__ || __aarch64__
+
+#if defined(__x86_64__)
+#define EM_TYPE EM_X86_64
+#define IP_REG(p) ((p).rip)
+#define RAX_REG(p) ((p).rax)
+#define RDI_REG(p) ((p).rdi)
+#define RETURN_REG(p) ((p).rax)
// amd64 stub that calls PTRACE_TRACEME and sends itself SIGSTOP.
const char kPtraceCode[] = {
@@ -139,6 +146,76 @@ const char kPtraceCode[] = {
// Size of a syscall instruction.
constexpr int kSyscallSize = 2;
+#elif defined(__aarch64__)
+#define EM_TYPE EM_AARCH64
+#define IP_REG(p) ((p).pc)
+#define RAX_REG(p) ((p).regs[8])
+#define RDI_REG(p) ((p).regs[0])
+#define RETURN_REG(p) ((p).regs[0])
+
+const char kPtraceCode[] = {
+ // MOVD $117, R8 /* ptrace */
+ '\xa8',
+ '\x0e',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R0 /* PTRACE_TRACEME */
+ '\x00',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R1 /* pid */
+ '\x01',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R2 /* addr */
+ '\x02',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R3 /* data */
+ '\x03',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // SVC
+ '\x01',
+ '\x00',
+ '\x00',
+ '\xd4',
+ // MOVD $172, R8 /* getpid */
+ '\x88',
+ '\x15',
+ '\x80',
+ '\xd2',
+ // SVC
+ '\x01',
+ '\x00',
+ '\x00',
+ '\xd4',
+ // MOVD $129, R8 /* kill, R0=pid */
+ '\x28',
+ '\x10',
+ '\x80',
+ '\xd2',
+ // MOVD $19, R1 /* SIGSTOP */
+ '\x61',
+ '\x02',
+ '\x80',
+ '\xd2',
+ // SVC
+ '\x01',
+ '\x00',
+ '\x00',
+ '\xd4',
+};
+// Size of a syscall instruction.
+constexpr int kSyscallSize = 4;
+#else
+#error "Unknown architecture"
+#endif
+
// This test suite tests executable loading in the kernel (ELF and interpreter
// scripts).
@@ -281,7 +358,7 @@ ElfBinary<64> StandardElf() {
elf.header.e_ident[EI_DATA] = ELFDATA2LSB;
elf.header.e_ident[EI_VERSION] = EV_CURRENT;
elf.header.e_type = ET_EXEC;
- elf.header.e_machine = EM_X86_64;
+ elf.header.e_machine = EM_TYPE;
elf.header.e_version = EV_CURRENT;
elf.header.e_phoff = sizeof(elf.header);
elf.header.e_phentsize = sizeof(decltype(elf)::ElfPhdr);
@@ -327,9 +404,15 @@ TEST(ElfTest, Execute) {
ASSERT_NO_ERRNO(WaitStopped(child));
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
- // RIP is just beyond the final syscall instruction.
- EXPECT_EQ(regs.rip, elf.header.e_entry + sizeof(kPtraceCode));
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
+ // RIP/PC is just beyond the final syscall instruction.
+ EXPECT_EQ(IP_REG(regs), elf.header.e_entry + sizeof(kPtraceCode));
EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
{0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
@@ -718,9 +801,16 @@ TEST(ElfTest, PIE) {
// RIP tells us which page the first segment was loaded into.
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
- const uint64_t load_addr = regs.rip & ~(kPageSize - 1);
+ const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1);
EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
// text page.
@@ -787,9 +877,15 @@ TEST(ElfTest, PIENonZeroStart) {
// RIP tells us which page the first segment was loaded into.
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
- const uint64_t load_addr = regs.rip & ~(kPageSize - 1);
+ const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1);
// The ELF is loaded at an arbitrary address, not the first PT_LOAD vaddr.
//
@@ -910,9 +1006,15 @@ TEST(ElfTest, ELFInterpreter) {
// RIP tells us which page the first segment of the interpreter was loaded
// into.
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
- const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1);
+ const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1);
EXPECT_THAT(
child, ContainsMappings(std::vector<ProcMapsEntry>({
@@ -1084,9 +1186,15 @@ TEST(ElfTest, ELFInterpreterRelative) {
// RIP tells us which page the first segment of the interpreter was loaded
// into.
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
- const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1);
+ const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1);
EXPECT_THAT(
child, ContainsMappings(std::vector<ProcMapsEntry>({
@@ -1480,14 +1588,21 @@ TEST(ExecveTest, BrkAfterBinary) {
ASSERT_NO_ERRNO(WaitStopped(child));
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
// RIP is just beyond the final syscall instruction. Rewind to execute a brk
// syscall.
- regs.rip -= kSyscallSize;
- regs.rax = __NR_brk;
- regs.rdi = 0;
- ASSERT_THAT(ptrace(PTRACE_SETREGS, child, 0, &regs), SyscallSucceeds());
+ IP_REG(regs) -= kSyscallSize;
+ RAX_REG(regs) = __NR_brk;
+ RDI_REG(regs) = 0;
+ ASSERT_THAT(ptrace(PTRACE_SETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
// Resume the child, waiting for syscall entry.
ASSERT_THAT(ptrace(PTRACE_SYSCALL, child, 0, 0), SyscallSucceeds());
@@ -1504,7 +1619,12 @@ TEST(ExecveTest, BrkAfterBinary) {
ASSERT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP)
<< "status = " << status;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
// brk is after the text page.
//
@@ -1512,7 +1632,7 @@ TEST(ExecveTest, BrkAfterBinary) {
// address will be, but it is always beyond the final page in the binary.
// i.e., it does not start immediately after memsz in the middle of a page.
// Userspace may expect to use that space.
- EXPECT_GE(regs.rax, 0x41000);
+ EXPECT_GE(RETURN_REG(regs), 0x41000);
}
} // namespace
diff --git a/test/syscalls/linux/file_base.h b/test/syscalls/linux/file_base.h
index 6f80bc97c..fb418e052 100644
--- a/test/syscalls/linux/file_base.h
+++ b/test/syscalls/linux/file_base.h
@@ -52,17 +52,6 @@ class FileTest : public ::testing::Test {
test_file_fd_ = ASSERT_NO_ERRNO_AND_VALUE(
Open(test_file_name_, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR));
- // FIXME(edahlgren): enable when mknod syscall is supported.
- // test_fifo_name_ = NewTempAbsPath();
- // ASSERT_THAT(mknod(test_fifo_name_.c_str()), S_IFIFO|0644, 0,
- // SyscallSucceeds());
- // ASSERT_THAT(test_fifo_[1] = open(test_fifo_name_.c_str(),
- // O_WRONLY),
- // SyscallSucceeds());
- // ASSERT_THAT(test_fifo_[0] = open(test_fifo_name_.c_str(),
- // O_RDONLY),
- // SyscallSucceeds());
-
ASSERT_THAT(pipe(test_pipe_), SyscallSucceeds());
ASSERT_THAT(fcntl(test_pipe_[0], F_SETFL, O_NONBLOCK), SyscallSucceeds());
}
@@ -96,18 +85,12 @@ class FileTest : public ::testing::Test {
CloseFile();
UnlinkFile();
ClosePipes();
-
- // FIXME(edahlgren): enable when mknod syscall is supported.
- // close(test_fifo_[0]);
- // close(test_fifo_[1]);
- // unlink(test_fifo_name_.c_str());
}
+ protected:
std::string test_file_name_;
- std::string test_fifo_name_;
FileDescriptor test_file_fd_;
- int test_fifo_[2];
int test_pipe_[2];
};
diff --git a/test/syscalls/linux/fork.cc b/test/syscalls/linux/fork.cc
index ff8bdfeb0..853f6231a 100644
--- a/test/syscalls/linux/fork.cc
+++ b/test/syscalls/linux/fork.cc
@@ -431,7 +431,6 @@ TEST(CloneTest, NewUserNamespacePermitsAllOtherNamespaces) {
<< "status = " << status;
}
-#ifdef __x86_64__
// Clone with CLONE_SETTLS and a non-canonical TLS address is rejected.
TEST(CloneTest, NonCanonicalTLS) {
constexpr uintptr_t kNonCanonical = 1ull << 48;
@@ -440,11 +439,25 @@ TEST(CloneTest, NonCanonicalTLS) {
// on this.
char stack;
+ // The raw system call interface on x86-64 is:
+ // long clone(unsigned long flags, void *stack,
+ // int *parent_tid, int *child_tid,
+ // unsigned long tls);
+ //
+ // While on arm64, the order of the last two arguments is reversed:
+ // long clone(unsigned long flags, void *stack,
+ // int *parent_tid, unsigned long tls,
+ // int *child_tid);
+#if defined(__x86_64__)
EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr,
nullptr, kNonCanonical),
SyscallFailsWithErrno(EPERM));
-}
+#elif defined(__aarch64__)
+ EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr,
+ kNonCanonical, nullptr),
+ SyscallFailsWithErrno(EPERM));
#endif
+}
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/getrandom.cc b/test/syscalls/linux/getrandom.cc
index f97f60029..f87cdd7a1 100644
--- a/test/syscalls/linux/getrandom.cc
+++ b/test/syscalls/linux/getrandom.cc
@@ -29,6 +29,8 @@ namespace {
#define SYS_getrandom 318
#elif defined(__i386__)
#define SYS_getrandom 355
+#elif defined(__aarch64__)
+#define SYS_getrandom 278
#else
#error "Unknown architecture"
#endif
diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc
index bba022a41..98d07ae85 100644
--- a/test/syscalls/linux/ip_socket_test_util.cc
+++ b/test/syscalls/linux/ip_socket_test_util.cc
@@ -16,7 +16,6 @@
#include <net/if.h>
#include <netinet/in.h>
-#include <sys/ioctl.h>
#include <sys/socket.h>
#include <cstring>
@@ -35,12 +34,11 @@ uint16_t PortFromInetSockaddr(const struct sockaddr* addr) {
}
PosixErrorOr<int> InterfaceIndex(std::string name) {
- // TODO(igudger): Consider using netlink.
- ifreq req = {};
- memcpy(req.ifr_name, name.c_str(), name.size());
- ASSIGN_OR_RETURN_ERRNO(auto sock, Socket(AF_INET, SOCK_DGRAM, 0));
- RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(sock.get(), SIOCGIFINDEX, &req));
- return req.ifr_ifindex;
+ int index = if_nametoindex(name.c_str());
+ if (index) {
+ return index;
+ }
+ return PosixError(errno);
}
namespace {
@@ -177,17 +175,17 @@ SocketKind IPv6TCPUnboundSocket(int type) {
PosixError IfAddrHelper::Load() {
Release();
RETURN_ERROR_IF_SYSCALL_FAIL(getifaddrs(&ifaddr_));
- return PosixError(0);
+ return NoError();
}
void IfAddrHelper::Release() {
if (ifaddr_) {
freeifaddrs(ifaddr_);
+ ifaddr_ = nullptr;
}
- ifaddr_ = nullptr;
}
-std::vector<std::string> IfAddrHelper::InterfaceList(int family) {
+std::vector<std::string> IfAddrHelper::InterfaceList(int family) const {
std::vector<std::string> names;
for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) {
if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) {
@@ -198,7 +196,7 @@ std::vector<std::string> IfAddrHelper::InterfaceList(int family) {
return names;
}
-sockaddr* IfAddrHelper::GetAddr(int family, std::string name) {
+const sockaddr* IfAddrHelper::GetAddr(int family, std::string name) const {
for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) {
if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) {
continue;
@@ -210,7 +208,7 @@ sockaddr* IfAddrHelper::GetAddr(int family, std::string name) {
return nullptr;
}
-PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) {
+PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) const {
return InterfaceIndex(name);
}
diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h
index 39fd6709d..9c3859fcd 100644
--- a/test/syscalls/linux/ip_socket_test_util.h
+++ b/test/syscalls/linux/ip_socket_test_util.h
@@ -110,10 +110,10 @@ class IfAddrHelper {
PosixError Load();
void Release();
- std::vector<std::string> InterfaceList(int family);
+ std::vector<std::string> InterfaceList(int family) const;
- struct sockaddr* GetAddr(int family, std::string name);
- PosixErrorOr<int> GetIndex(std::string name);
+ const sockaddr* GetAddr(int family, std::string name) const;
+ PosixErrorOr<int> GetIndex(std::string name) const;
private:
struct ifaddrs* ifaddr_;
diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc
index 8b48f0804..dd981a278 100644
--- a/test/syscalls/linux/itimer.cc
+++ b/test/syscalls/linux/itimer.cc
@@ -246,7 +246,7 @@ int TestSIGPROFFairness(absl::Duration sleep) {
// The number of samples on the main thread should be very low as it did
// nothing.
- TEST_CHECK(result.main_thread_samples < 60);
+ TEST_CHECK(result.main_thread_samples < 80);
// Both workers should get roughly equal number of samples.
TEST_CHECK(result.worker_samples.size() == 2);
diff --git a/test/syscalls/linux/lseek.cc b/test/syscalls/linux/lseek.cc
index a8af8e545..6ce1e6cc3 100644
--- a/test/syscalls/linux/lseek.cc
+++ b/test/syscalls/linux/lseek.cc
@@ -53,7 +53,7 @@ TEST(LseekTest, NegativeOffset) {
// A 32-bit off_t is not large enough to represent an offset larger than
// maximum file size on standard file systems, so it isn't possible to cause
// overflow.
-#ifdef __x86_64__
+#if defined(__x86_64__) || defined(__aarch64__)
TEST(LseekTest, Overflow) {
// HA! Classic Linux. We really should have an EOVERFLOW
// here, since we're seeking to something that cannot be
diff --git a/test/syscalls/linux/memfd.cc b/test/syscalls/linux/memfd.cc
index e57b49a4a..f8b7f7938 100644
--- a/test/syscalls/linux/memfd.cc
+++ b/test/syscalls/linux/memfd.cc
@@ -16,6 +16,7 @@
#include <fcntl.h>
#include <linux/magic.h>
#include <linux/memfd.h>
+#include <linux/unistd.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/statfs.h>
diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc
index def4c50a4..4036a9275 100644
--- a/test/syscalls/linux/mkdir.cc
+++ b/test/syscalls/linux/mkdir.cc
@@ -36,21 +36,12 @@ class MkdirTest : public ::testing::Test {
// TearDown unlinks created files.
void TearDown() override {
- // FIXME(edahlgren): We don't currently implement rmdir.
- // We do this unconditionally because there's no harm in trying.
- rmdir(dirname_.c_str());
+ EXPECT_THAT(rmdir(dirname_.c_str()), SyscallSucceeds());
}
std::string dirname_;
};
-TEST_F(MkdirTest, DISABLED_CanCreateReadbleDir) {
- ASSERT_THAT(mkdir(dirname_.c_str(), 0444), SyscallSucceeds());
- ASSERT_THAT(
- open(JoinPath(dirname_, "anything").c_str(), O_RDWR | O_CREAT, 0666),
- SyscallFailsWithErrno(EACCES));
-}
-
TEST_F(MkdirTest, CanCreateWritableDir) {
ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds());
std::string filename = JoinPath(dirname_, "anything");
@@ -84,10 +75,11 @@ TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) {
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
- auto parent = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0555));
- auto dir = JoinPath(parent.path(), "foo");
- ASSERT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(EACCES));
+ ASSERT_THAT(mkdir(dirname_.c_str(), 0555), SyscallSucceeds());
+ auto dir = JoinPath(dirname_.c_str(), "foo");
+ EXPECT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(open(JoinPath(dirname_, "file").c_str(), O_RDWR | O_CREAT, 0666),
+ SyscallFailsWithErrno(EACCES));
}
} // namespace
diff --git a/test/syscalls/linux/mlock.cc b/test/syscalls/linux/mlock.cc
index 367a90fe1..78ac96bed 100644
--- a/test/syscalls/linux/mlock.cc
+++ b/test/syscalls/linux/mlock.cc
@@ -199,8 +199,10 @@ TEST(MunlockallTest, Basic) {
}
#ifndef SYS_mlock2
-#ifdef __x86_64__
+#if defined(__x86_64__)
#define SYS_mlock2 325
+#elif defined(__aarch64__)
+#define SYS_mlock2 284
#endif
#endif
diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc
index 11fb1b457..6d3227ab6 100644
--- a/test/syscalls/linux/mmap.cc
+++ b/test/syscalls/linux/mmap.cc
@@ -361,7 +361,7 @@ TEST_F(MMapTest, MapFixed) {
}
// 64-bit addresses work too
-#ifdef __x86_64__
+#if defined(__x86_64__) || defined(__aarch64__)
TEST_F(MMapTest, MapFixed64) {
EXPECT_THAT(Map(0x300000000000, kPageSize, PROT_NONE,
MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0),
@@ -571,6 +571,12 @@ const uint8_t machine_code[] = {
0xb8, 0x2a, 0x00, 0x00, 0x00, // movl $42, %eax
0xc3, // retq
};
+#elif defined(__aarch64__)
+const uint8_t machine_code[] = {
+ 0x40, 0x05, 0x80, 0x52, // mov w0, #42
+ 0xc0, 0x03, 0x5f, 0xd6, // ret
+};
+#endif
// PROT_EXEC allows code execution
TEST_F(MMapTest, ProtExec) {
@@ -605,7 +611,6 @@ TEST_F(MMapTest, NoProtExecDeath) {
EXPECT_EXIT(func(), ::testing::KilledBySignal(SIGSEGV), "");
}
-#endif
TEST_F(MMapTest, NoExceedLimitData) {
void* prevbrk;
@@ -1644,6 +1649,7 @@ TEST(MMapNoFixtureTest, MapReadOnlyAfterCreateWriteOnly) {
}
// Conditional on MAP_32BIT.
+// This flag is supported only on x86-64, for 64-bit programs.
#ifdef __x86_64__
TEST(MMapNoFixtureTest, Map32Bit) {
diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc
index 267ae19f6..640fe6bfc 100644
--- a/test/syscalls/linux/open.cc
+++ b/test/syscalls/linux/open.cc
@@ -186,6 +186,28 @@ TEST_F(OpenTest, OpenNoFollowStillFollowsLinksInPath) {
ASSERT_NO_ERRNO_AND_VALUE(Open(path_via_symlink, O_RDONLY | O_NOFOLLOW));
}
+// Test that open(2) can follow symlinks that point back to the same tree.
+// Test sets up files as follows:
+// root/child/symlink => redirects to ../..
+// root/child/target => regular file
+//
+// open("root/child/symlink/root/child/file")
+TEST_F(OpenTest, SymlinkRecurse) {
+ auto root =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(GetAbsoluteTestTmpdir()));
+ auto child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path()));
+ auto symlink = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(child.path(), "../.."));
+ auto target = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(child.path(), "abc", 0644));
+ auto path_via_symlink =
+ JoinPath(symlink.path(), Basename(root.path()), Basename(child.path()),
+ Basename(target.path()));
+ const auto contents =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents(path_via_symlink));
+ ASSERT_EQ(contents, "abc");
+}
+
TEST_F(OpenTest, Fault) {
char* totally_not_null = nullptr;
ASSERT_THAT(open(totally_not_null, O_RDONLY), SyscallFailsWithErrno(EFAULT));
diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc
index d8e19e910..67228b66b 100644
--- a/test/syscalls/linux/pipe.cc
+++ b/test/syscalls/linux/pipe.cc
@@ -265,6 +265,8 @@ TEST_P(PipeTest, OffsetCalls) {
SyscallFailsWithErrno(ESPIPE));
struct iovec iov;
+ iov.iov_base = &buf;
+ iov.iov_len = sizeof(buf);
EXPECT_THAT(preadv(wfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE));
EXPECT_THAT(pwritev(rfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE));
}
diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc
index c42472474..1e35a4a8b 100644
--- a/test/syscalls/linux/poll.cc
+++ b/test/syscalls/linux/poll.cc
@@ -266,7 +266,7 @@ TEST_F(PollTest, Nfds) {
}
rlim_t max_fds = rlim.rlim_cur;
- std::cout << "Using limit: " << max_fds;
+ std::cout << "Using limit: " << max_fds << std::endl;
// Create an eventfd. Since its value is initially zero, it is writable.
FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD());
diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc
index 2cecf2e5f..bcdbbb044 100644
--- a/test/syscalls/linux/pread64.cc
+++ b/test/syscalls/linux/pread64.cc
@@ -14,6 +14,7 @@
#include <errno.h>
#include <fcntl.h>
+#include <linux/unistd.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/types.h>
@@ -118,6 +119,21 @@ TEST_F(Pread64Test, EndOfFile) {
EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallSucceedsWithValue(0));
}
+int memfd_create(const std::string& name, unsigned int flags) {
+ return syscall(__NR_memfd_create, name.c_str(), flags);
+}
+
+TEST_F(Pread64Test, Overflow) {
+ int f = memfd_create("negative", 0);
+ const FileDescriptor fd(f);
+
+ EXPECT_THAT(ftruncate(fd.get(), 0x7fffffffffffffffull), SyscallSucceeds());
+
+ char buf[10];
+ EXPECT_THAT(pread64(fd.get(), buf, sizeof(buf), 0x7fffffffffffffffull),
+ SyscallFailsWithErrno(EINVAL));
+}
+
TEST(Pread64TestNoTempFile, CantReadSocketPair_NoRandomSave) {
int sock_fds[2];
EXPECT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_fds), SyscallSucceeds());
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index 5a70f6c3b..79a625ebc 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -994,7 +994,7 @@ constexpr uint64_t kMappingSize = 100 << 20;
// Tolerance on RSS comparisons to account for background thread mappings,
// reclaimed pages, newly faulted pages, etc.
-constexpr uint64_t kRSSTolerance = 5 << 20;
+constexpr uint64_t kRSSTolerance = 10 << 20;
// Capture RSS before and after an anonymous mapping with passed prot.
void MapPopulateRSS(int prot, uint64_t* before, uint64_t* after) {
@@ -1326,8 +1326,6 @@ TEST(ProcPidSymlink, SubprocessRunning) {
SyscallSucceedsWithValue(sizeof(buf)));
}
-// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
-// on proc files.
TEST(ProcPidSymlink, SubprocessZombied) {
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
@@ -1337,7 +1335,7 @@ TEST(ProcPidSymlink, SubprocessZombied) {
int want = EACCES;
if (!IsRunningOnGvisor()) {
auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion());
- if (version.major == 4 && version.minor > 3) {
+ if (version.major > 4 || (version.major == 4 && version.minor > 3)) {
want = ENOENT;
}
}
@@ -1350,30 +1348,25 @@ TEST(ProcPidSymlink, SubprocessZombied) {
SyscallFailsWithErrno(want));
}
- // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
- // on proc files.
+ // FIXME(gvisor.dev/issue/164): Inconsistent behavior between linux on proc
+ // files.
//
// ~4.3: Syscall fails with EACCES.
- // 4.17 & gVisor: Syscall succeeds and returns 1.
+ // 4.17: Syscall succeeds and returns 1.
//
- // EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)),
- // SyscallFailsWithErrno(EACCES));
+ if (!IsRunningOnGvisor()) {
+ return;
+ }
- // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
- // on proc files.
- //
- // ~4.3: Syscall fails with EACCES.
- // 4.17 & gVisor: Syscall succeeds and returns 1.
- //
- // EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)),
- // SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)),
+ SyscallFailsWithErrno(want));
+
+ EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)),
+ SyscallFailsWithErrno(want));
}
// Test whether /proc/PID/ symlinks can be read for an exited process.
TEST(ProcPidSymlink, SubprocessExited) {
- // FIXME(gvisor.dev/issue/164): These all succeed on gVisor.
- SKIP_IF(IsRunningOnGvisor());
-
char buf[1];
EXPECT_THAT(ReadlinkWhileExited("exe", buf, sizeof(buf)),
diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc
index 4e23d1e78..cac394910 100644
--- a/test/syscalls/linux/proc_net.cc
+++ b/test/syscalls/linux/proc_net.cc
@@ -353,7 +353,7 @@ TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) {
EXPECT_EQ(oldNoPorts, newNoPorts - 1);
}
-TEST(ProcNetSnmp, UdpIn) {
+TEST(ProcNetSnmp, UdpIn_NoRandomSave) {
// TODO(gvisor.dev/issue/866): epsocket metrics are not savable.
const DisableSave ds;
diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc
index 66db0acaa..a63067586 100644
--- a/test/syscalls/linux/proc_net_unix.cc
+++ b/test/syscalls/linux/proc_net_unix.cc
@@ -106,7 +106,7 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() {
std::vector<UnixEntry> entries;
std::vector<std::string> lines = absl::StrSplit(content, '\n');
std::cerr << "<contents of /proc/net/unix>" << std::endl;
- for (std::string line : lines) {
+ for (const std::string& line : lines) {
// Emit the proc entry to the test output to provide context for the test
// results.
std::cerr << line << std::endl;
@@ -374,7 +374,7 @@ TEST(ProcNetUnix, DgramSocketStateDisconnectingOnBind) {
// corresponding entries, as they don't have an address yet.
if (IsRunningOnGvisor()) {
ASSERT_EQ(entries.size(), 2);
- for (auto e : entries) {
+ for (const auto& e : entries) {
ASSERT_EQ(e.state, SS_DISCONNECTING);
}
}
@@ -403,7 +403,7 @@ TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) {
// corresponding entries, as they don't have an address yet.
if (IsRunningOnGvisor()) {
ASSERT_EQ(entries.size(), 2);
- for (auto e : entries) {
+ for (const auto& e : entries) {
ASSERT_EQ(e.state, SS_DISCONNECTING);
}
}
diff --git a/test/syscalls/linux/proc_pid_smaps.cc b/test/syscalls/linux/proc_pid_smaps.cc
index 7f2e8f203..9fb1b3a2c 100644
--- a/test/syscalls/linux/proc_pid_smaps.cc
+++ b/test/syscalls/linux/proc_pid_smaps.cc
@@ -173,7 +173,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps(
return;
}
unknown_fields.insert(std::string(key));
- std::cerr << "skipping unknown smaps field " << key;
+ std::cerr << "skipping unknown smaps field " << key << std::endl;
};
auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty());
@@ -191,7 +191,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps(
// amount of whitespace).
if (!entry) {
std::cerr << "smaps line not considered a maps line: "
- << maybe_maps_entry.error_message();
+ << maybe_maps_entry.error_message() << std::endl;
return PosixError(
EINVAL,
absl::StrCat("smaps field line without preceding maps line: ", l));
diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc
index bfe3e2603..926690eb8 100644
--- a/test/syscalls/linux/ptrace.cc
+++ b/test/syscalls/linux/ptrace.cc
@@ -400,9 +400,11 @@ TEST(PtraceTest, GetRegSet) {
// Read exactly the full register set.
EXPECT_EQ(iov.iov_len, sizeof(regs));
-#ifdef __x86_64__
+#if defined(__x86_64__)
// Child called kill(2), with SIGSTOP as arg 2.
EXPECT_EQ(regs.rsi, SIGSTOP);
+#elif defined(__aarch64__)
+ EXPECT_EQ(regs.regs[1], SIGSTOP);
#endif
// Suppress SIGSTOP and resume the child.
@@ -752,15 +754,23 @@ TEST(PtraceTest,
SyscallSucceeds());
EXPECT_TRUE(siginfo.si_code == SIGTRAP || siginfo.si_code == (SIGTRAP | 0x80))
<< "si_code = " << siginfo.si_code;
-#ifdef __x86_64__
+
{
struct user_regs_struct regs = {};
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+#if defined(__x86_64__)
EXPECT_TRUE(regs.orig_rax == SYS_vfork || regs.orig_rax == SYS_clone)
<< "orig_rax = " << regs.orig_rax;
EXPECT_EQ(grandchild_pid, regs.rax);
- }
+#elif defined(__aarch64__)
+ EXPECT_TRUE(regs.regs[8] == SYS_clone) << "regs[8] = " << regs.regs[8];
+ EXPECT_EQ(grandchild_pid, regs.regs[0]);
#endif // defined(__x86_64__)
+ }
// After this point, the child will be making wait4 syscalls that will be
// interrupted by saving, so saving is not permitted. Note that this is
@@ -805,14 +815,21 @@ TEST(PtraceTest,
SyscallSucceedsWithValue(child_pid));
EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80))
<< " status " << status;
-#ifdef __x86_64__
{
struct user_regs_struct regs = {};
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+#if defined(__x86_64__)
EXPECT_EQ(SYS_wait4, regs.orig_rax);
EXPECT_EQ(grandchild_pid, regs.rax);
- }
+#elif defined(__aarch64__)
+ EXPECT_EQ(SYS_wait4, regs.regs[8]);
+ EXPECT_EQ(grandchild_pid, regs.regs[0]);
#endif // defined(__x86_64__)
+ }
// Detach from the child and wait for it to exit.
ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds());
@@ -1188,7 +1205,7 @@ TEST(PtraceTest, SeizeSetOptions) {
// gVisor is not susceptible to this race because
// kernel.Task.waitCollectTraceeStopLocked() checks specifically for an
// active ptraceStop, which is not initiated if SIGKILL is pending.
- std::cout << "Observed syscall-exit after SIGKILL";
+ std::cout << "Observed syscall-exit after SIGKILL" << std::endl;
ASSERT_THAT(waitpid(child_pid, &status, 0),
SyscallSucceedsWithValue(child_pid));
}
diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc
index dafe64d20..b8a0159ba 100644
--- a/test/syscalls/linux/pty.cc
+++ b/test/syscalls/linux/pty.cc
@@ -1126,7 +1126,7 @@ TEST_F(PtyTest, SwitchTwiceMultiline) {
std::string kExpected = "GO\nBLUE\n!";
// Write each line.
- for (std::string input : kInputs) {
+ for (const std::string& input : kInputs) {
ASSERT_THAT(WriteFd(master_.get(), input.c_str(), input.size()),
SyscallSucceedsWithValue(input.size()));
}
diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc
index b48fe540d..e69794910 100644
--- a/test/syscalls/linux/pwrite64.cc
+++ b/test/syscalls/linux/pwrite64.cc
@@ -14,6 +14,7 @@
#include <errno.h>
#include <fcntl.h>
+#include <linux/unistd.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
@@ -27,14 +28,7 @@ namespace testing {
namespace {
-// This test is currently very rudimentary.
-//
-// TODO(edahlgren):
-// * bad buffer states (EFAULT).
-// * bad fds (wrong permission, wrong type of file, EBADF).
-// * check offset is not incremented.
-// * check for EOF.
-// * writing to pipes, symlinks, special files.
+// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary.
class Pwrite64 : public ::testing::Test {
void SetUp() override {
name_ = NewTempAbsPath();
@@ -72,6 +66,17 @@ TEST_F(Pwrite64, InvalidArgs) {
EXPECT_THAT(close(fd), SyscallSucceeds());
}
+TEST_F(Pwrite64, Overflow) {
+ int fd;
+ ASSERT_THAT(fd = open(name_.c_str(), O_APPEND | O_RDWR), SyscallSucceeds());
+ constexpr int64_t kBufSize = 1024;
+ std::vector<char> buf(kBufSize);
+ std::fill(buf.begin(), buf.end(), 'a');
+ EXPECT_THAT(PwriteFd(fd, buf.data(), buf.size(), 0x7fffffffffffffffull),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/rseq/BUILD b/test/syscalls/linux/rseq/BUILD
index ee5b0a11b..853258b04 100644
--- a/test/syscalls/linux/rseq/BUILD
+++ b/test/syscalls/linux/rseq/BUILD
@@ -21,23 +21,23 @@ genrule(
],
outs = ["rseq"],
cmd = "$(CC) " +
- "$(CC_FLAGS) " +
- "-I. " +
- "-Wall " +
- "-Werror " +
- "-O2 " +
- "-std=c++17 " +
- "-static " +
- "-nostdlib " +
- "-ffreestanding " +
- "-o " +
- "$(location rseq) " +
- select_arch(
- amd64 = "$(location critical_amd64.S) $(location start_amd64.S) ",
- arm64 = "$(location critical_arm64.S) $(location start_arm64.S) ",
- no_match_error = "unsupported architecture",
- ) +
- "$(location rseq.cc)",
+ "$(CC_FLAGS) " +
+ "-I. " +
+ "-Wall " +
+ "-Werror " +
+ "-O2 " +
+ "-std=c++17 " +
+ "-static " +
+ "-nostdlib " +
+ "-ffreestanding " +
+ "-o " +
+ "$(location rseq) " +
+ select_arch(
+ amd64 = "$(location critical_amd64.S) $(location start_amd64.S) ",
+ arm64 = "$(location critical_arm64.S) $(location start_arm64.S) ",
+ no_match_error = "unsupported architecture",
+ ) +
+ "$(location rseq.cc)",
toolchains = [
cc_toolchain,
":no_pie_cc_flags",
diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc
index 580ab5193..64123e904 100644
--- a/test/syscalls/linux/sendfile.cc
+++ b/test/syscalls/linux/sendfile.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <fcntl.h>
+#include <linux/unistd.h>
#include <sys/eventfd.h>
#include <sys/sendfile.h>
#include <unistd.h>
@@ -70,6 +71,28 @@ TEST(SendFileTest, InvalidOffset) {
SyscallFailsWithErrno(EINVAL));
}
+int memfd_create(const std::string& name, unsigned int flags) {
+ return syscall(__NR_memfd_create, name.c_str(), flags);
+}
+
+TEST(SendFileTest, Overflow) {
+ // Create input file.
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file.
+ int fd;
+ EXPECT_THAT(fd = memfd_create("overflow", 0), SyscallSucceeds());
+ const FileDescriptor outf(fd);
+
+ // out_offset + kSize overflows INT64_MAX.
+ loff_t out_offset = 0x7ffffffffffffffeull;
+ constexpr int kSize = 3;
+ EXPECT_THAT(sendfile(outf.get(), inf.get(), &out_offset, kSize),
+ SyscallFailsWithErrno(EINVAL));
+}
+
TEST(SendFileTest, SendTrivially) {
// Create temp files.
constexpr char kData[] = "To be, or not to be, that is the question:";
@@ -530,6 +553,34 @@ TEST(SendFileTest, SendToSpecialFile) {
SyscallSucceedsWithValue(kSize & (~7)));
}
+TEST(SendFileTest, SendFileToPipe) {
+ // Create temp file.
+ constexpr char kData[] = "<insert-quote-here>";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Create a pipe for sending to a pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Expect to read up to the given size.
+ std::vector<char> buf(kDataSize);
+ ScopedThread t([&]() {
+ absl::SleepFor(absl::Milliseconds(100));
+ ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kDataSize));
+ });
+
+ // Send with twice the size of the file, which should hit EOF.
+ EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize * 2),
+ SyscallSucceedsWithValue(kDataSize));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/sendfile_socket.cc b/test/syscalls/linux/sendfile_socket.cc
index 8f7ee4163..c101fe9d2 100644
--- a/test/syscalls/linux/sendfile_socket.cc
+++ b/test/syscalls/linux/sendfile_socket.cc
@@ -23,6 +23,7 @@
#include "gtest/gtest.h"
#include "absl/strings/string_view.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/temp_path.h"
@@ -35,61 +36,39 @@ namespace {
class SendFileTest : public ::testing::TestWithParam<int> {
protected:
- PosixErrorOr<std::tuple<int, int>> Sockets() {
+ PosixErrorOr<std::unique_ptr<SocketPair>> Sockets(int type) {
// Bind a server socket.
int family = GetParam();
- struct sockaddr server_addr = {};
switch (family) {
case AF_INET: {
- struct sockaddr_in* server_addr_in =
- reinterpret_cast<struct sockaddr_in*>(&server_addr);
- server_addr_in->sin_family = family;
- server_addr_in->sin_addr.s_addr = INADDR_ANY;
- break;
+ if (type == SOCK_STREAM) {
+ return SocketPairKind{
+ "TCP", AF_INET, type, 0,
+ TCPAcceptBindSocketPairCreator(AF_INET, type, 0, false)}
+ .Create();
+ } else {
+ return SocketPairKind{
+ "UDP", AF_INET, type, 0,
+ UDPBidirectionalBindSocketPairCreator(AF_INET, type, 0, false)}
+ .Create();
+ }
}
case AF_UNIX: {
- struct sockaddr_un* server_addr_un =
- reinterpret_cast<struct sockaddr_un*>(&server_addr);
- server_addr_un->sun_family = family;
- server_addr_un->sun_path[0] = '\0';
- break;
+ if (type == SOCK_STREAM) {
+ return SocketPairKind{
+ "UNIX", AF_UNIX, type, 0,
+ FilesystemAcceptBindSocketPairCreator(AF_UNIX, type, 0)}
+ .Create();
+ } else {
+ return SocketPairKind{
+ "UNIX", AF_UNIX, type, 0,
+ FilesystemBidirectionalBindSocketPairCreator(AF_UNIX, type, 0)}
+ .Create();
+ }
}
default:
return PosixError(EINVAL);
}
- int server = socket(family, SOCK_STREAM, 0);
- if (bind(server, &server_addr, sizeof(server_addr)) < 0) {
- return PosixError(errno);
- }
- if (listen(server, 1) < 0) {
- close(server);
- return PosixError(errno);
- }
-
- // Fetch the address; both are anonymous.
- socklen_t length = sizeof(server_addr);
- if (getsockname(server, &server_addr, &length) < 0) {
- close(server);
- return PosixError(errno);
- }
-
- // Connect the client.
- int client = socket(family, SOCK_STREAM, 0);
- if (connect(client, &server_addr, length) < 0) {
- close(server);
- close(client);
- return PosixError(errno);
- }
-
- // Accept on the server.
- int server_client = accept(server, nullptr, 0);
- if (server_client < 0) {
- close(server);
- close(client);
- return PosixError(errno);
- }
- close(server);
- return std::make_tuple(client, server_client);
}
};
@@ -106,9 +85,7 @@ TEST_P(SendFileTest, SendMultiple) {
const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
// Create sockets.
- std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets());
- const FileDescriptor server(std::get<0>(fds));
- FileDescriptor client(std::get<1>(fds)); // non-const, reset is used.
+ auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM));
// Thread that reads data from socket and dumps to a file.
ScopedThread th([&] {
@@ -118,7 +95,7 @@ TEST_P(SendFileTest, SendMultiple) {
// Read until socket is closed.
char buf[10240];
for (int cnt = 0;; cnt++) {
- int r = RetryEINTR(read)(server.get(), buf, sizeof(buf));
+ int r = RetryEINTR(read)(socks->first_fd(), buf, sizeof(buf));
// We cannot afford to save on every read() call.
if (cnt % 1000 == 0) {
ASSERT_THAT(r, SyscallSucceeds());
@@ -149,10 +126,10 @@ TEST_P(SendFileTest, SendMultiple) {
for (size_t sent = 0; sent < data.size(); cnt++) {
const size_t remain = data.size() - sent;
std::cout << "sendfile, size=" << data.size() << ", sent=" << sent
- << ", remain=" << remain;
+ << ", remain=" << remain << std::endl;
// Send data and verify that sendfile returns the correct value.
- int res = sendfile(client.get(), inf.get(), nullptr, remain);
+ int res = sendfile(socks->second_fd(), inf.get(), nullptr, remain);
// We cannot afford to save on every sendfile() call.
if (cnt % 120 == 0) {
MaybeSave();
@@ -169,7 +146,7 @@ TEST_P(SendFileTest, SendMultiple) {
}
// Close socket to stop thread.
- client.reset();
+ close(socks->release_second_fd());
th.Join();
// Verify that the output file has the correct data.
@@ -183,9 +160,7 @@ TEST_P(SendFileTest, SendMultiple) {
TEST_P(SendFileTest, Shutdown) {
// Create a socket.
- std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets());
- const FileDescriptor client(std::get<0>(fds));
- FileDescriptor server(std::get<1>(fds)); // non-const, reset below.
+ auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM));
// If this is a TCP socket, then turn off linger.
if (GetParam() == AF_INET) {
@@ -193,7 +168,7 @@ TEST_P(SendFileTest, Shutdown) {
sl.l_onoff = 1;
sl.l_linger = 0;
ASSERT_THAT(
- setsockopt(server.get(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)),
+ setsockopt(socks->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)),
SyscallSucceeds());
}
@@ -212,12 +187,12 @@ TEST_P(SendFileTest, Shutdown) {
ScopedThread t([&]() {
size_t done = 0;
while (done < data.size()) {
- int n = RetryEINTR(read)(server.get(), data.data(), data.size());
+ int n = RetryEINTR(read)(socks->first_fd(), data.data(), data.size());
ASSERT_THAT(n, SyscallSucceeds());
done += n;
}
// Close the server side socket.
- server.reset();
+ close(socks->release_first_fd());
});
// Continuously stream from the file to the socket. Note we do not assert
@@ -225,7 +200,7 @@ TEST_P(SendFileTest, Shutdown) {
// data is written. Eventually, we should get a connection reset error.
while (1) {
off_t offset = 0; // Always read from the start.
- int n = sendfile(client.get(), inf.get(), &offset, data.size());
+ int n = sendfile(socks->second_fd(), inf.get(), &offset, data.size());
EXPECT_THAT(n, AnyOf(SyscallFailsWithErrno(ECONNRESET),
SyscallFailsWithErrno(EPIPE), SyscallSucceeds()));
if (n <= 0) {
@@ -234,6 +209,20 @@ TEST_P(SendFileTest, Shutdown) {
}
}
+TEST_P(SendFileTest, SendpageFromEmptyFileToUDP) {
+ auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_DGRAM));
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ // The value to the count argument has to be so that it is impossible to
+ // allocate a buffer of this size. In Linux, sendfile transfer at most
+ // 0x7ffff000 (MAX_RW_COUNT) bytes.
+ EXPECT_THAT(sendfile(socks->first_fd(), fd.get(), 0x0, 0x8000000000004),
+ SyscallSucceedsWithValue(0));
+}
+
INSTANTIATE_TEST_SUITE_P(AddressFamily, SendFileTest,
::testing::Values(AF_UNIX, AF_INET));
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index b24618a88..d3000dbc6 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -234,7 +234,7 @@ TEST_P(DualStackSocketTest, AddressOperations) {
}
}
-// TODO(gvisor.dev/issues/1556): uncomment V4MappedAny.
+// TODO(gvisor.dev/issue/1556): uncomment V4MappedAny.
INSTANTIATE_TEST_SUITE_P(
All, DualStackSocketTest,
::testing::Combine(
@@ -319,17 +319,14 @@ TEST_P(SocketInetLoopbackTest, TCPListenUnbound) {
tcpSimpleConnectTest(listener, connector, false);
}
-TEST_P(SocketInetLoopbackTest, TCPListenClose) {
+TEST_P(SocketInetLoopbackTest, TCPListenShutdown) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
- constexpr int kAcceptCount = 32;
- constexpr int kBacklog = kAcceptCount * 2;
- constexpr int kFDs = 128;
- constexpr int kThreadCount = 4;
- constexpr int kFDsPerThread = kFDs / kThreadCount;
+ constexpr int kBacklog = 2;
+ constexpr int kFDs = kBacklog + 1;
// Create the listening socket.
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
@@ -348,39 +345,167 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
- DisableSave ds; // Too many system calls.
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- FileDescriptor clients[kFDs];
- std::unique_ptr<ScopedThread> threads[kThreadCount];
+
+ // Shutdown the write of the listener, expect to not have any effect.
+ ASSERT_THAT(shutdown(listen_fd.get(), SHUT_WR), SyscallSucceeds());
+
for (int i = 0; i < kFDs; i++) {
- clients[i] = ASSERT_NO_ERRNO_AND_VALUE(
- Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds());
}
- for (int i = 0; i < kThreadCount; i++) {
- threads[i] = absl::make_unique<ScopedThread>([&connector, &conn_addr,
- &clients, i]() {
- for (int j = 0; j < kFDsPerThread; j++) {
- int k = i * kFDsPerThread + j;
- int ret =
- connect(clients[k].get(), reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len);
- if (ret != 0) {
- EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
- }
- }
- });
+
+ // Shutdown the read of the listener, expect to fail subsequent
+ // server accepts, binds and client connects.
+ ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds());
+
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Check that shutdown did not release the port.
+ FileDescriptor new_listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ bind(new_listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Check that subsequent connection attempts receive a RST.
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ for (int i = 0; i < kFDs; i++) {
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallFailsWithErrno(ECONNREFUSED));
}
- for (int i = 0; i < kThreadCount; i++) {
- threads[i]->Join();
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenClose) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kAcceptCount = 2;
+ constexpr int kBacklog = kAcceptCount + 2;
+ constexpr int kFDs = kBacklog * 3;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ std::vector<FileDescriptor> clients;
+ for (int i = 0; i < kFDs; i++) {
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ }
+ clients.push_back(std::move(client));
}
for (int i = 0; i < kAcceptCount; i++) {
auto accepted =
ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
}
- // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked
- // before function end.
- // ds.reset();
+}
+
+void TestListenWhileConnect(const TestParam& param,
+ void (*stopListen)(FileDescriptor&)) {
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kBacklog = 2;
+ constexpr int kClients = kBacklog + 1;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ std::vector<FileDescriptor> clients;
+ for (int i = 0; i < kClients; i++) {
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ clients.push_back(std::move(client));
+ }
+ }
+
+ stopListen(listen_fd);
+
+ for (auto& client : clients) {
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = client.get(),
+ .events = POLLIN,
+ };
+ // When the listening socket is closed, then we expect the remote to reset
+ // the connection.
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR);
+ char c;
+ // Subsequent read can fail with:
+ // ECONNRESET: If the client connection was established and was reset by the
+ // remote.
+ // ECONNREFUSED: If the client connection failed to be established.
+ ASSERT_THAT(read(client.get(), &c, sizeof(c)),
+ AnyOf(SyscallFailsWithErrno(ECONNRESET),
+ SyscallFailsWithErrno(ECONNREFUSED)));
+ }
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) {
+ TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(close(f.release()), SyscallSucceeds());
+ });
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) {
+ TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds());
+ });
}
TEST_P(SocketInetLoopbackTest, TCPbacklog) {
@@ -605,15 +730,23 @@ TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) {
&conn_addrlen),
SyscallSucceeds());
- constexpr int kTCPLingerTimeout = 5;
- EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2,
- &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)),
- SyscallSucceedsWithValue(0));
+ // Disable cooperative saves after this point as TCP timers are not restored
+ // across a S/R.
+ {
+ DisableSave ds;
+ constexpr int kTCPLingerTimeout = 5;
+ EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2,
+ &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)),
+ SyscallSucceedsWithValue(0));
- // close the connecting FD to trigger FIN_WAIT2 on the connected fd.
- conn_fd.reset();
+ // close the connecting FD to trigger FIN_WAIT2 on the connected fd.
+ conn_fd.reset();
+
+ absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1));
- absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1));
+ // ds going out of scope will Re-enable S/R's since at this point the timer
+ // must have fired and cleaned up the endpoint.
+ }
// Now bind and connect a new socket and verify that we can immediately
// rebind the address bound by the conn_fd as it never entered TIME_WAIT.
@@ -1082,6 +1215,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
if (connects_received >= kConnectAttempts) {
// Another thread have shutdown our read side causing the
// accept to fail.
+ ASSERT_EQ(errno, EINVAL);
break;
}
ASSERT_NO_ERRNO(fd);
@@ -1149,7 +1283,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
}
-TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) {
+TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
@@ -1262,7 +1396,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) {
EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
}
-TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) {
+TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
@@ -2138,8 +2272,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) {
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(connect(connected_fd.get(),
- reinterpret_cast<sockaddr*>(&bound_addr), bound_addr_len),
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&bound_addr),
+ bound_addr_len),
SyscallSucceeds());
// Get the ephemeral port.
@@ -2204,7 +2339,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) {
setsockopt(fd2, SOL_SOCKET, SO_REUSEPORT, &portreuse2, sizeof(int)),
SyscallSucceeds());
- std::cout << portreuse1 << " " << portreuse2;
+ std::cout << portreuse1 << " " << portreuse2 << std::endl;
int ret = bind(fd2, reinterpret_cast<sockaddr*>(&addr), addrlen);
// Verify that two sockets can be bound to the same port only if
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
index 40e673625..d690d9564 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
@@ -45,37 +45,31 @@ void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() {
got_if_infos_ = false;
// Get interface list.
- std::vector<std::string> if_names;
ASSERT_NO_ERRNO(if_helper_.Load());
- if_names = if_helper_.InterfaceList(AF_INET);
+ std::vector<std::string> if_names = if_helper_.InterfaceList(AF_INET);
if (if_names.size() != 2) {
return;
}
// Figure out which interface is where.
- int lo = 0, eth = 1;
- if (if_names[lo] != "lo") {
- lo = 1;
- eth = 0;
- }
-
- if (if_names[lo] != "lo") {
- return;
- }
-
- lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[lo]));
- lo_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[lo]);
- if (lo_if_addr_ == nullptr) {
+ std::string lo = if_names[0];
+ std::string eth = if_names[1];
+ if (lo != "lo") std::swap(lo, eth);
+ if (lo != "lo") return;
+
+ lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(lo));
+ auto lo_if_addr = if_helper_.GetAddr(AF_INET, lo);
+ if (lo_if_addr == nullptr) {
return;
}
- lo_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(lo_if_addr_)->sin_addr;
+ lo_if_addr_ = *reinterpret_cast<const sockaddr_in*>(lo_if_addr);
- eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[eth]));
- eth_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[eth]);
- if (eth_if_addr_ == nullptr) {
+ eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(eth));
+ auto eth_if_addr = if_helper_.GetAddr(AF_INET, eth);
+ if (eth_if_addr == nullptr) {
return;
}
- eth_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(eth_if_addr_)->sin_addr;
+ eth_if_addr_ = *reinterpret_cast<const sockaddr_in*>(eth_if_addr);
got_if_infos_ = true;
}
@@ -242,7 +236,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Bind the non-receiving socket to the unicast ethernet address.
auto norecv_addr = rcv1_addr;
reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr =
- eth_if_sin_addr_;
+ eth_if_addr_.sin_addr;
ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr),
norecv_addr.addr_len),
SyscallSucceedsWithValue(0));
@@ -1028,7 +1022,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
ip_mreqn iface = {};
iface.imr_ifindex = lo_if_idx_;
- iface.imr_address = eth_if_sin_addr_;
+ iface.imr_address = eth_if_addr_.sin_addr;
ASSERT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
sizeof(iface)),
SyscallSucceeds());
@@ -1058,7 +1052,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
SKIP_IF(IsRunningOnGvisor());
// Verify the received source address.
- EXPECT_EQ(eth_if_sin_addr_.s_addr, src_addr_in->sin_addr.s_addr);
+ EXPECT_EQ(eth_if_addr_.sin_addr.s_addr, src_addr_in->sin_addr.s_addr);
}
// Check that when we are bound to one interface we can set IP_MULTICAST_IF to
@@ -1075,7 +1069,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Create sender and bind to eth interface.
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- ASSERT_THAT(bind(sender->get(), eth_if_addr_, sizeof(sockaddr_in)),
+ ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&eth_if_addr_),
+ sizeof(eth_if_addr_)),
SyscallSucceeds());
// Run through all possible combinations of index and address for
@@ -1085,9 +1080,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
struct in_addr imr_address;
} test_data[] = {
{lo_if_idx_, {}},
- {0, lo_if_sin_addr_},
- {lo_if_idx_, lo_if_sin_addr_},
- {lo_if_idx_, eth_if_sin_addr_},
+ {0, lo_if_addr_.sin_addr},
+ {lo_if_idx_, lo_if_addr_.sin_addr},
+ {lo_if_idx_, eth_if_addr_.sin_addr},
};
for (auto t : test_data) {
ip_mreqn iface = {};
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
index bec2e96ee..10b90b1e0 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
@@ -36,10 +36,8 @@ class IPv4UDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest {
// Interface infos.
int lo_if_idx_;
int eth_if_idx_;
- sockaddr* lo_if_addr_;
- sockaddr* eth_if_addr_;
- in_addr lo_if_sin_addr_;
- in_addr eth_if_sin_addr_;
+ sockaddr_in lo_if_addr_;
+ sockaddr_in eth_if_addr_;
};
} // namespace testing
diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc
index e5aed1eec..fbe61c5a0 100644
--- a/test/syscalls/linux/socket_netlink_route.cc
+++ b/test/syscalls/linux/socket_netlink_route.cc
@@ -26,7 +26,7 @@
#include "gtest/gtest.h"
#include "absl/strings/str_format.h"
-#include "absl/types/optional.h"
+#include "test/syscalls/linux/socket_netlink_route_util.h"
#include "test/syscalls/linux/socket_netlink_util.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/capability_util.h"
@@ -118,24 +118,6 @@ void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) {
// TODO(mpratt): Check ifinfomsg contents and following attrs.
}
-PosixError DumpLinks(
- const FileDescriptor& fd, uint32_t seq,
- const std::function<void(const struct nlmsghdr* hdr)>& fn) {
- struct request {
- struct nlmsghdr hdr;
- struct ifinfomsg ifm;
- };
-
- struct request req = {};
- req.hdr.nlmsg_len = sizeof(req);
- req.hdr.nlmsg_type = RTM_GETLINK;
- req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
- req.hdr.nlmsg_seq = seq;
- req.ifm.ifi_family = AF_UNSPEC;
-
- return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false);
-}
-
TEST(NetlinkRouteTest, GetLinkDump) {
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
@@ -152,7 +134,7 @@ TEST(NetlinkRouteTest, GetLinkDump) {
const struct ifinfomsg* msg =
reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
std::cout << "Found interface idx=" << msg->ifi_index
- << ", type=" << std::hex << msg->ifi_type;
+ << ", type=" << std::hex << msg->ifi_type << std::endl;
if (msg->ifi_type == ARPHRD_LOOPBACK) {
loopbackFound = true;
EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0);
@@ -161,37 +143,6 @@ TEST(NetlinkRouteTest, GetLinkDump) {
EXPECT_TRUE(loopbackFound);
}
-struct Link {
- int index;
- std::string name;
-};
-
-PosixErrorOr<absl::optional<Link>> FindLoopbackLink() {
- ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
-
- absl::optional<Link> link;
- RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) {
- if (hdr->nlmsg_type != RTM_NEWLINK ||
- hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) {
- return;
- }
- const struct ifinfomsg* msg =
- reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
- if (msg->ifi_type == ARPHRD_LOOPBACK) {
- const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME);
- if (rta == nullptr) {
- // Ignore links that do not have a name.
- return;
- }
-
- link = Link();
- link->index = msg->ifi_index;
- link->name = std::string(reinterpret_cast<const char*>(RTA_DATA(rta)));
- }
- }));
- return link;
-}
-
// CheckLinkMsg checks a netlink message against an expected link.
void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) {
ASSERT_THAT(hdr->nlmsg_type, Eq(RTM_NEWLINK));
@@ -209,9 +160,7 @@ void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) {
}
TEST(NetlinkRouteTest, GetLinkByIndex) {
- absl::optional<Link> loopback_link =
- ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink());
- ASSERT_TRUE(loopback_link.has_value());
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
@@ -227,13 +176,13 @@ TEST(NetlinkRouteTest, GetLinkByIndex) {
req.hdr.nlmsg_flags = NLM_F_REQUEST;
req.hdr.nlmsg_seq = kSeq;
req.ifm.ifi_family = AF_UNSPEC;
- req.ifm.ifi_index = loopback_link->index;
+ req.ifm.ifi_index = loopback_link.index;
bool found = false;
ASSERT_NO_ERRNO(NetlinkRequestResponse(
fd, &req, sizeof(req),
[&](const struct nlmsghdr* hdr) {
- CheckLinkMsg(hdr, *loopback_link);
+ CheckLinkMsg(hdr, loopback_link);
found = true;
},
false));
@@ -241,9 +190,7 @@ TEST(NetlinkRouteTest, GetLinkByIndex) {
}
TEST(NetlinkRouteTest, GetLinkByName) {
- absl::optional<Link> loopback_link =
- ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink());
- ASSERT_TRUE(loopback_link.has_value());
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
@@ -262,8 +209,8 @@ TEST(NetlinkRouteTest, GetLinkByName) {
req.hdr.nlmsg_seq = kSeq;
req.ifm.ifi_family = AF_UNSPEC;
req.rtattr.rta_type = IFLA_IFNAME;
- req.rtattr.rta_len = RTA_LENGTH(loopback_link->name.size() + 1);
- strncpy(req.ifname, loopback_link->name.c_str(), sizeof(req.ifname));
+ req.rtattr.rta_len = RTA_LENGTH(loopback_link.name.size() + 1);
+ strncpy(req.ifname, loopback_link.name.c_str(), sizeof(req.ifname));
req.hdr.nlmsg_len =
NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len);
@@ -271,7 +218,7 @@ TEST(NetlinkRouteTest, GetLinkByName) {
ASSERT_NO_ERRNO(NetlinkRequestResponse(
fd, &req, sizeof(req),
[&](const struct nlmsghdr* hdr) {
- CheckLinkMsg(hdr, *loopback_link);
+ CheckLinkMsg(hdr, loopback_link);
found = true;
},
false));
@@ -523,9 +470,7 @@ TEST(NetlinkRouteTest, LookupAll) {
TEST(NetlinkRouteTest, AddAddr) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
- absl::optional<Link> loopback_link =
- ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink());
- ASSERT_TRUE(loopback_link.has_value());
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
@@ -545,7 +490,7 @@ TEST(NetlinkRouteTest, AddAddr) {
req.ifa.ifa_prefixlen = 24;
req.ifa.ifa_flags = 0;
req.ifa.ifa_scope = 0;
- req.ifa.ifa_index = loopback_link->index;
+ req.ifa.ifa_index = loopback_link.index;
req.rtattr.rta_type = IFA_LOCAL;
req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr));
inet_pton(AF_INET, "10.0.0.1", &req.addr);
diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc
index 53eb3b6b2..bde1dbb4d 100644
--- a/test/syscalls/linux/socket_netlink_route_util.cc
+++ b/test/syscalls/linux/socket_netlink_route_util.cc
@@ -18,7 +18,6 @@
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
-#include "absl/types/optional.h"
#include "test/syscalls/linux/socket_netlink_util.h"
namespace gvisor {
@@ -73,14 +72,14 @@ PosixErrorOr<std::vector<Link>> DumpLinks() {
return links;
}
-PosixErrorOr<absl::optional<Link>> FindLoopbackLink() {
+PosixErrorOr<Link> LoopbackLink() {
ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
for (const auto& link : links) {
if (link.type == ARPHRD_LOOPBACK) {
- return absl::optional<Link>(link);
+ return link;
}
}
- return absl::optional<Link>();
+ return PosixError(ENOENT, "loopback link not found");
}
PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h
index 2c018e487..149c4a7f6 100644
--- a/test/syscalls/linux/socket_netlink_route_util.h
+++ b/test/syscalls/linux/socket_netlink_route_util.h
@@ -20,7 +20,6 @@
#include <vector>
-#include "absl/types/optional.h"
#include "test/syscalls/linux/socket_netlink_util.h"
namespace gvisor {
@@ -37,7 +36,8 @@ PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq,
PosixErrorOr<std::vector<Link>> DumpLinks();
-PosixErrorOr<absl::optional<Link>> FindLoopbackLink();
+// Returns the loopback link on the system. ENOENT if not found.
+PosixErrorOr<Link> LoopbackLink();
// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface.
PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
index 5d3a39868..53b678e94 100644
--- a/test/syscalls/linux/socket_test_util.cc
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -364,11 +364,6 @@ CreateTCPConnectAcceptSocketPair(int bound, int connected, int type,
}
MaybeSave(); // Successful accept.
- // FIXME(b/110484944)
- if (connect_result == -1) {
- absl::SleepFor(absl::Seconds(1));
- }
-
T extra_addr = {};
LocalhostAddr(&extra_addr, dual_stack);
return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr,
diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc
index 4cf1f76f1..8bf663e8b 100644
--- a/test/syscalls/linux/socket_unix.cc
+++ b/test/syscalls/linux/socket_unix.cc
@@ -257,6 +257,8 @@ TEST_P(UnixSocketPairTest, ShutdownWrite) {
TEST_P(UnixSocketPairTest, SocketReopenFromProcfs) {
// TODO(b/122310852): We should be returning ENXIO and NOT EIO.
+ // TODO(github.dev/issue/1624): This should be resolved in VFS2. Verify
+ // that this is the case and delete the SKIP_IF once we delete VFS1.
SKIP_IF(IsRunningOnGvisor());
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc
index faa1247f6..f103e2e56 100644
--- a/test/syscalls/linux/splice.cc
+++ b/test/syscalls/linux/splice.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <fcntl.h>
+#include <linux/unistd.h>
#include <sys/eventfd.h>
#include <sys/resource.h>
#include <sys/sendfile.h>
diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc
index 53ad2dda3..6195b11e1 100644
--- a/test/syscalls/linux/tuntap.cc
+++ b/test/syscalls/linux/tuntap.cc
@@ -56,14 +56,14 @@ PosixErrorOr<std::set<std::string>> DumpLinkNames() {
return names;
}
-PosixErrorOr<absl::optional<Link>> GetLinkByName(const std::string& name) {
+PosixErrorOr<Link> GetLinkByName(const std::string& name) {
ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
for (const auto& link : links) {
if (link.name == name) {
- return absl::optional<Link>(link);
+ return link;
}
}
- return absl::optional<Link>();
+ return PosixError(ENOENT, "interface not found");
}
struct pihdr {
@@ -242,7 +242,7 @@ TEST_F(TuntapTest, InvalidReadWrite) {
TEST_F(TuntapTest, WriteToDownDevice) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
- // FIXME: gVisor always creates enabled/up'd interfaces.
+ // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces.
SKIP_IF(IsRunningOnGvisor());
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
@@ -268,23 +268,21 @@ PosixErrorOr<FileDescriptor> OpenAndAttachTap(
return PosixError(errno);
}
- ASSIGN_OR_RETURN_ERRNO(absl::optional<Link> link, GetLinkByName(dev_name));
- if (!link.has_value()) {
- return PosixError(ENOENT, "no link");
- }
+ ASSIGN_OR_RETURN_ERRNO(auto link, GetLinkByName(dev_name));
// Interface setup.
struct in_addr addr;
inet_pton(AF_INET, dev_ipv4_addr.c_str(), &addr);
- EXPECT_NO_ERRNO(LinkAddLocalAddr(link->index, AF_INET, /*prefixlen=*/24,
- &addr, sizeof(addr)));
+ EXPECT_NO_ERRNO(LinkAddLocalAddr(link.index, AF_INET, /*prefixlen=*/24, &addr,
+ sizeof(addr)));
if (!IsRunningOnGvisor()) {
- // FIXME: gVisor doesn't support setting MAC address on interfaces yet.
- RETURN_IF_ERRNO(LinkSetMacAddr(link->index, kMacA, sizeof(kMacA)));
+ // FIXME(b/110961832): gVisor doesn't support setting MAC address on
+ // interfaces yet.
+ RETURN_IF_ERRNO(LinkSetMacAddr(link.index, kMacA, sizeof(kMacA)));
- // FIXME: gVisor always creates enabled/up'd interfaces.
- RETURN_IF_ERRNO(LinkChangeFlags(link->index, IFF_UP, IFF_UP));
+ // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces.
+ RETURN_IF_ERRNO(LinkChangeFlags(link.index, IFF_UP, IFF_UP));
}
return fd;
diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc
index 6218fbce1..ff66a79f4 100644
--- a/test/syscalls/linux/uidgid.cc
+++ b/test/syscalls/linux/uidgid.cc
@@ -14,6 +14,7 @@
#include <errno.h>
#include <grp.h>
+#include <sys/resource.h>
#include <sys/types.h>
#include <unistd.h>
@@ -249,6 +250,17 @@ TEST(UidGidRootTest, Setgroups) {
SyscallFailsWithErrno(EFAULT));
}
+TEST(UidGidRootTest, Setuid_prlimit) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot()));
+
+ // Change our UID.
+ EXPECT_THAT(seteuid(65534), SyscallSucceeds());
+
+ // Despite the UID change, we should be able to get our own limits.
+ struct rlimit rl = {};
+ ASSERT_THAT(prlimit(0, RLIMIT_NOFILE, NULL, &rl), SyscallSucceeds());
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/utimes.cc b/test/syscalls/linux/utimes.cc
index 3a927a430..22e6d1a85 100644
--- a/test/syscalls/linux/utimes.cc
+++ b/test/syscalls/linux/utimes.cc
@@ -34,17 +34,10 @@ namespace testing {
namespace {
-// TODO(b/36516566): utimes(nullptr) does not pick the "now" time in the
-// application's time domain, so when asserting that times are within a window,
-// we expand the window to allow for differences between the time domains.
-constexpr absl::Duration kClockSlack = absl::Milliseconds(100);
-
// TimeBoxed runs fn, setting before and after to (coarse realtime) times
// guaranteed* to come before and after fn started and completed, respectively.
//
// fn may be called more than once if the clock is adjusted.
-//
-// * See the comment on kClockSlack. gVisor breaks this guarantee.
void TimeBoxed(absl::Time* before, absl::Time* after,
std::function<void()> const& fn) {
do {
@@ -69,12 +62,6 @@ void TimeBoxed(absl::Time* before, absl::Time* after,
// which could lead to test failures, but that is very unlikely to happen.
continue;
}
-
- if (IsRunningOnGvisor()) {
- // See comment on kClockSlack.
- *before -= kClockSlack;
- *after += kClockSlack;
- }
} while (*after < *before);
}
@@ -235,10 +222,7 @@ void TestUtimensat(int dirFd, std::string const& path) {
EXPECT_GE(mtime3, before);
EXPECT_LE(mtime3, after);
- if (!IsRunningOnGvisor()) {
- // FIXME(b/36516566): Gofers set atime and mtime to different "now" times.
- EXPECT_EQ(atime3, mtime3);
- }
+ EXPECT_EQ(atime3, mtime3);
}
TEST(UtimensatTest, OnAbsPath) {
diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc
index 9b219cfd6..39b5b2f56 100644
--- a/test/syscalls/linux/write.cc
+++ b/test/syscalls/linux/write.cc
@@ -31,14 +31,8 @@ namespace gvisor {
namespace testing {
namespace {
-// This test is currently very rudimentary.
-//
-// TODO(edahlgren):
-// * bad buffer states (EFAULT).
-// * bad fds (wrong permission, wrong type of file, EBADF).
-// * check offset is incremented.
-// * check for EOF.
-// * writing to pipes, symlinks, special files.
+
+// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary.
class WriteTest : public ::testing::Test {
public:
ssize_t WriteBytes(int fd, int bytes) {
diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc
index 8b00ef44c..3231732ec 100644
--- a/test/syscalls/linux/xattr.cc
+++ b/test/syscalls/linux/xattr.cc
@@ -41,12 +41,12 @@ class XattrTest : public FileTest {};
TEST_F(XattrTest, XattrNonexistentFile) {
const char* path = "/does/not/exist";
- EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, /*flags=*/0),
- SyscallFailsWithErrno(ENOENT));
- EXPECT_THAT(getxattr(path, nullptr, nullptr, 0),
+ const char* name = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0),
SyscallFailsWithErrno(ENOENT));
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENOENT));
EXPECT_THAT(listxattr(path, nullptr, 0), SyscallFailsWithErrno(ENOENT));
- EXPECT_THAT(removexattr(path, nullptr), SyscallFailsWithErrno(ENOENT));
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(ENOENT));
}
TEST_F(XattrTest, XattrNullName) {
diff --git a/test/util/capability_util.cc b/test/util/capability_util.cc
index 9fee52fbb..a1b994c45 100644
--- a/test/util/capability_util.cc
+++ b/test/util/capability_util.cc
@@ -63,13 +63,13 @@ PosixErrorOr<bool> CanCreateUserNamespace() {
// is in a chroot environment (i.e., the caller's root directory does
// not match the root directory of the mount namespace in which it
// resides)."
- std::cerr << "clone(CLONE_NEWUSER) failed with EPERM";
+ std::cerr << "clone(CLONE_NEWUSER) failed with EPERM" << std::endl;
return false;
} else if (errno == EUSERS) {
// "(since Linux 3.11) CLONE_NEWUSER was specified in flags, and the call
// would cause the limit on the number of nested user namespaces to be
// exceeded. See user_namespaces(7)."
- std::cerr << "clone(CLONE_NEWUSER) failed with EUSERS";
+ std::cerr << "clone(CLONE_NEWUSER) failed with EUSERS" << std::endl;
return false;
} else {
// Unexpected error code; indicate an actual error.
diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl
index 92b0b5fc0..132040c20 100644
--- a/tools/bazeldefs/platforms.bzl
+++ b/tools/bazeldefs/platforms.bzl
@@ -2,15 +2,10 @@
# Platform to associated tags.
platforms = {
- "ptrace": [
- # TODO(b/120560048): Make the tests run without this tag.
- "no-sandbox",
- ],
+ "ptrace": [],
"kvm": [
"manual",
"local",
- # TODO(b/120560048): Make the tests run without this tag.
- "no-sandbox",
],
}
diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl
index c5be52ecd..8c9995fd4 100644
--- a/tools/go_generics/defs.bzl
+++ b/tools/go_generics/defs.bzl
@@ -105,7 +105,6 @@ def _go_template_instance_impl(ctx):
executable = ctx.executable._tool,
)
- # TODO: How can we get the dependencies out?
return struct(
files = depset([output]),
)
diff --git a/tools/go_marshal/analysis/analysis_unsafe.go b/tools/go_marshal/analysis/analysis_unsafe.go
index 9a9a4f298..cd55cf5cb 100644
--- a/tools/go_marshal/analysis/analysis_unsafe.go
+++ b/tools/go_marshal/analysis/analysis_unsafe.go
@@ -161,6 +161,10 @@ func AlignmentCheck(t *testing.T, typ reflect.Type) (ok bool, delta uint64) {
if typ.NumField() > 0 && nextXOff != int(typ.Size()) {
implicitPad := int(typ.Size()) - nextXOff
f := typ.Field(typ.NumField() - 1) // Final field
+ if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" {
+ // Final field explicitly marked unaligned.
+ break
+ }
t.Fatalf("Suspect offset for field %s.%s at the end of %s, detected an implicit %d byte padding from offset %d to %d at the end of the struct; either add %d bytes of explict padding at end of the struct or tag the final field %s as `marshal:\"unaligned\"`.",
typ.Name(), f.Name, typ.Name(), implicitPad, nextXOff, typ.Size(), implicitPad, f.Name)
}
diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl
index d79786a68..323e33882 100644
--- a/tools/go_marshal/defs.bzl
+++ b/tools/go_marshal/defs.bzl
@@ -53,9 +53,10 @@ go_marshal = rule(
# marshal_deps are the dependencies requied by generated code.
marshal_deps = [
- "//tools/go_marshal/marshal",
+ "//pkg/gohacks",
"//pkg/safecopy",
"//pkg/usermem",
+ "//tools/go_marshal/marshal",
]
# marshal_test_deps are required by test targets.
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index 82983804c..177013dbb 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -28,12 +28,6 @@ import (
"gvisor.dev/gvisor/tools/tags"
)
-const (
- marshalImport = "gvisor.dev/gvisor/tools/go_marshal/marshal"
- safecopyImport = "gvisor.dev/gvisor/pkg/safecopy"
- usermemImport = "gvisor.dev/gvisor/pkg/usermem"
-)
-
// List of identifiers we use in generated code that may conflict with a
// similarly-named source identifier. Abort gracefully when we see these to
// avoid potentially confusing compilation failures in generated code.
@@ -44,8 +38,8 @@ const (
// All recievers are single letters, so we don't allow import aliases to be a
// single letter.
var badIdents = []string{
- "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "idx", "inner", "len",
- "ptr", "src", "srcs", "task", "val",
+ "addr", "blk", "buf", "dst", "dsts", "count", "err", "hdr", "idx", "inner",
+ "length", "limit", "ptr", "size", "src", "srcs", "task", "val",
// All single-letter identifiers.
}
@@ -110,9 +104,10 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G
g.imports.add("reflect")
g.imports.add("runtime")
g.imports.add("unsafe")
- g.imports.add(marshalImport)
- g.imports.add(safecopyImport)
- g.imports.add(usermemImport)
+ g.imports.add("gvisor.dev/gvisor/pkg/gohacks")
+ g.imports.add("gvisor.dev/gvisor/pkg/safecopy")
+ g.imports.add("gvisor.dev/gvisor/pkg/usermem")
+ g.imports.add("gvisor.dev/gvisor/tools/go_marshal/marshal")
return &g, nil
}
@@ -194,10 +189,73 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
return files, fsets, nil
}
+// sliceAPI carries information about the '+marshal slice' directive.
+type sliceAPI struct {
+ // Comment node in the AST containing the +marshal tag.
+ comment *ast.Comment
+ // Identifier fragment to use when naming generated functions for the slice
+ // API.
+ ident string
+ // Whether the generated functions should reference the newtype name, or the
+ // inner type name. Only meaningful on newtype declarations on primitives.
+ inner bool
+}
+
+// marshallableType carries information about a type marked with the '+marshal'
+// directive.
+type marshallableType struct {
+ spec *ast.TypeSpec
+ slice *sliceAPI
+}
+
+func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) marshallableType {
+ mt := marshallableType{
+ spec: spec,
+ slice: nil,
+ }
+
+ var unhandledTags []string
+
+ for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) {
+ if strings.HasPrefix(tag, "slice:") {
+ tokens := strings.Split(tag, ":")
+ if len(tokens) < 2 || len(tokens) > 3 {
+ abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag))
+ }
+ if len(tokens[1]) == 0 {
+ abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'")
+ }
+
+ sa := &sliceAPI{
+ comment: tagLine,
+ ident: tokens[1],
+ }
+ mt.slice = sa
+
+ if len(tokens) == 3 {
+ if tokens[2] != "inner" {
+ abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'")
+ }
+ sa.inner = true
+ }
+
+ continue
+ }
+
+ unhandledTags = append(unhandledTags, tag)
+ }
+
+ if len(unhandledTags) > 0 {
+ abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " ")))
+ }
+
+ return mt
+}
+
// collectMarshallableTypes walks the parsed AST and collects a list of type
// declarations for which we need to generate the Marshallable interface.
-func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec {
- var types []*ast.TypeSpec
+func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []marshallableType {
+ var types []marshallableType
for _, decl := range a.Decls {
gdecl, ok := decl.(*ast.GenDecl)
// Type declaration?
@@ -212,9 +270,11 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*a
}
// Does the comment contain a "+marshal" line?
marked := false
+ var tagLine *ast.Comment
for _, c := range gdecl.Doc.List {
- if c.Text == "// +marshal" {
+ if strings.HasPrefix(c.Text, "// +marshal") {
marked = true
+ tagLine = c
break
}
}
@@ -229,20 +289,17 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*a
switch t.Type.(type) {
case *ast.StructType:
debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name)
- types = append(types, t)
- continue
case *ast.Ident: // Newtype on primitive.
debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name)
- types = append(types, t)
- continue
case *ast.ArrayType: // Newtype on array.
debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name)
- types = append(types, t)
- continue
+ default:
+ // A user specifically requested marshalling on this type, but we
+ // don't support it.
+ abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
}
- // A user specifically requested marshalling on this type, but we
- // don't support it.
- abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
+ types = append(types, newMarshallableType(f, tagLine, t))
+
}
}
return types
@@ -269,7 +326,7 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
// Make sure we have an import that doesn't use any local names that
// would conflict with identifiers in the generated code.
- if len(i.name) == 1 {
+ if len(i.name) == 1 && i.name != "_" {
abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name))
}
if _, ok := badIdentsMap[i.name]; ok {
@@ -281,19 +338,28 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
}
-func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
- i := newInterfaceGenerator(t, fset)
- switch ty := t.Type.(type) {
+func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interfaceGenerator {
+ i := newInterfaceGenerator(t.spec, fset)
+ switch ty := t.spec.Type.(type) {
case *ast.StructType:
- i.validateStruct(t, ty)
+ i.validateStruct(t.spec, ty)
i.emitMarshallableForStruct(ty)
+ if t.slice != nil {
+ i.emitMarshallableSliceForStruct(ty, t.slice)
+ }
case *ast.Ident:
i.validatePrimitiveNewtype(ty)
i.emitMarshallableForPrimitiveNewtype(ty)
+ if t.slice != nil {
+ i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice)
+ }
case *ast.ArrayType:
- i.validateArrayNewtype(t.Name, ty)
+ i.validateArrayNewtype(t.spec.Name, ty)
// After validate, we can safely call arrayLen.
- i.emitMarshallableForArrayNewtype(t.Name, ty.Elt.(*ast.Ident), arrayLen(ty))
+ i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident))
+ if t.slice != nil {
+ abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?"))
+ }
default:
// This should've been filtered out by collectMarshallabeTypes.
panic(fmt.Sprintf("Unexpected type %+v", ty))
@@ -303,9 +369,9 @@ func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interface
// generateOneTestSuite generates a test suite for the automatically generated
// implementations type t.
-func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator {
- i := newTestGenerator(t)
- i.emitTests()
+func (g *Generator) generateOneTestSuite(t marshallableType) *testGenerator {
+ i := newTestGenerator(t.spec)
+ i.emitTests(t.slice)
return i
}
@@ -355,7 +421,7 @@ func (g *Generator) Run() error {
// the list of imports we need to copy to the generated code.
for name, _ := range impl.is {
if !g.imports.markUsed(name) {
- panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'", impl.typeName(), name))
+ panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name))
}
}
ts = append(ts, g.generateOneTestSuite(t))
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index 8babf61d2..e3c3dac63 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -15,8 +15,10 @@
package gomarshal
import (
+ "fmt"
"go/ast"
"go/token"
+ "strings"
)
// interfaceGenerator generates marshalling interfaces for a single type.
@@ -72,7 +74,6 @@ func (g *interfaceGenerator) recordUsedMarshallable(m string) {
func (g *interfaceGenerator) recordUsedImport(i string) {
g.is[i] = struct{}{}
-
}
func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
@@ -163,3 +164,113 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) {
g.recordPotentiallyNonPackedField(accessor)
}
}
+
+// emitCastToByteSlice unsafely casts an arbitrary type's underlying memory to a
+// byte slice, bypassing escape analysis. The caller is responsible for ensuring
+// srcPtr lives until they're done with dstVar, the runtime does not consider
+// dstVar dependent on srcPtr due to the escape analysis bypass.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifier "hdr", and cannot be used
+// in a context where it is already bound.
+func (g *interfaceGenerator) emitCastToByteSlice(srcPtr, dstVar, lenExpr string) {
+ g.recordUsedImport("gohacks")
+ g.emit("// Construct a slice backed by dst's underlying memory.\n")
+ g.emit("var %s []byte\n", dstVar)
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
+ g.emit("hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(%s)))\n", srcPtr)
+ g.emit("hdr.Len = %s\n", lenExpr)
+ g.emit("hdr.Cap = %s\n\n", lenExpr)
+}
+
+// emitCastToByteSlice unsafely casts a slice with elements of an abitrary type
+// to a byte slice. As part of the cast, the byte slice is made to look
+// independent of the src slice by bypassing escape analysis. This means the
+// byte slice can be used without causing the source to escape. The caller is
+// responsible for ensuring srcPtr lives until they're done with dstVar, as the
+// runtime no longer considers dstVar dependent on srcPtr and is free to GC it.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifiers "ptr", "val" and "hdr",
+// and cannot be used in a context where these identifiers are already bound.
+func (g *interfaceGenerator) emitCastSliceToByteSlice(srcPtr, dstVar, lenExpr string) {
+ g.emitNoEscapeSliceDataPointer(srcPtr, "val")
+
+ g.emit("// Construct a slice backed by dst's underlying memory.\n")
+ g.emit("var %s []byte\n", dstVar)
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
+ g.emit("hdr.Data = uintptr(val)\n")
+ g.emit("hdr.Len = %s\n", lenExpr)
+ g.emit("hdr.Cap = %s\n\n", lenExpr)
+}
+
+// emitNoEscapeSliceDataPointer unsafely casts a slice's data pointer to an
+// unsafe.Pointer, bypassing escape analysis. The caller is responsible for
+// ensuring srcPtr lives until they're done with dstVar, as the runtime no
+// longer considers dstVar dependent on srcPtr and is free to GC it.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifier "ptr" cannot be used in a
+// context where this identifier is already bound.
+func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) {
+ g.recordUsedImport("gohacks")
+ g.emit("ptr := unsafe.Pointer(%s)\n", srcPtr)
+ g.emit("%s := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data))\n\n", dstVar)
+}
+
+func (g *interfaceGenerator) emitKeepAlive(ptrVar string) {
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar)
+ g.emit("// must live until the use above.\n")
+ g.emit("runtime.KeepAlive(%s)\n", ptrVar)
+}
+
+func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) {
+ switch x := e.X.(type) {
+ case *ast.BinaryExpr:
+ // Recursively expand sub-expression.
+ g.expandBinaryExpr(b, x)
+ case *ast.Ident:
+ fmt.Fprintf(b, "%s", x.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(b, "%s", x.Value)
+ default:
+ g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+
+ fmt.Fprintf(b, "%s", e.Op)
+
+ switch y := e.Y.(type) {
+ case *ast.BinaryExpr:
+ // Recursively expand sub-expression.
+ g.expandBinaryExpr(b, y)
+ case *ast.Ident:
+ fmt.Fprintf(b, "%s", y.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(b, "%s", y.Value)
+ default:
+ g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+}
+
+// arrayLenExpr returns a string containing a valid golang expression
+// representing the length of array a. The returned expression should be treated
+// as a single value, and will be already parenthesized as required.
+func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string {
+ var b strings.Builder
+
+ switch l := a.Len.(type) {
+ case *ast.Ident:
+ fmt.Fprintf(&b, "%s", l.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(&b, "%s", l.Value)
+ case *ast.BinaryExpr:
+ g.expandBinaryExpr(&b, l)
+ return fmt.Sprintf("(%s)", b.String())
+ default:
+ g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+ return b.String()
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
index da36d9305..8d6f102d5 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
@@ -27,20 +27,12 @@ func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType
g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
}
- if _, ok := a.Len.(*ast.BasicLit); !ok {
- g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don't use consts or expressions"))
- }
-
if _, ok := a.Elt.(*ast.Ident); !ok {
g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
}
-
- if arrayLen(a) <= 0 {
- g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
- }
}
-func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, len int) {
+func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) {
g.recordUsedImport("io")
g.recordUsedImport("marshal")
g.recordUsedImport("reflect")
@@ -49,13 +41,15 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident,
g.recordUsedImport("unsafe")
g.recordUsedImport("usermem")
+ lenExpr := g.arrayLenExpr(a)
+
g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
g.inIndent(func() {
if size, dynamic := g.scalarSize(elt); !dynamic {
- g.emit("return %d\n", size*len)
+ g.emit("return %d * %s\n", size, lenExpr)
} else {
- g.emit("return (*%s)(nil).SizeBytes() * %d\n", n.Name, len)
+ g.emit("return (*%s)(nil).SizeBytes() * %s\n", n.Name, lenExpr)
}
})
g.emit("}\n\n")
@@ -63,7 +57,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident,
g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("for idx := 0; idx < %d; idx++ {\n", len)
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
g.inIndent(func() {
g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst")
})
@@ -74,7 +68,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident,
g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("for idx := 0; idx < %d; idx++ {\n", len)
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
g.inIndent(func() {
g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src")
})
@@ -104,79 +98,43 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident,
})
g.emit("}\n\n")
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit])\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
- g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
- // Fast serialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyOutBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
})
g.emit("}\n\n")
g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
- g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyInBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyInBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
})
g.emit("}\n\n")
g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("len, err := w.Write(buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the Write.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return int64(len), err\n")
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
})
g.emit("}\n\n")
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
index 159397825..ef9bb903d 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
@@ -150,80 +150,133 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident)
})
g.emit("}\n\n")
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit])\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
- g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
- // Fast serialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyOutBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
})
g.emit("}\n\n")
g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
- g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyInBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyInBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
})
g.emit("}\n\n")
g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("len, err := w.Write(buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the Write.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return int64(len), err\n")
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
+
+ })
+ g.emit("}\n\n")
+}
+
+func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Ident, slice *sliceAPI) {
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+
+ eltType := g.typeName()
+ if slice.inner {
+ eltType = nt.Name
+ }
+
+ g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType)
+ g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType)
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitCastSliceToByteSlice("&dst", "buf", "size * count")
+
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType)
+ g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType)
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitCastSliceToByteSlice("&src", "buf", "size * count")
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitNoEscapeSliceDataPointer("&src", "val")
+
+ g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitNoEscapeSliceDataPointer("&dst", "val")
+ g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
})
g.emit("}\n\n")
}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
index e66a38b2e..4236e978e 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
@@ -62,8 +62,8 @@ func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType
// No validation to perform on selector fields. However this
// callback must still be provided.
},
- array: func(n, _ *ast.Ident, len int) {
- g.validateArrayNewtype(n, f.Type.(*ast.ArrayType))
+ array: func(n *ast.Ident, a *ast.ArrayType, _ *ast.Ident) {
+ g.validateArrayNewtype(n, a)
},
unhandled: func(_ *ast.Ident) {
g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
@@ -72,20 +72,24 @@ func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType
})
}
-func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
- // Is g.t a packed struct without consideing field types?
- thisPacked := true
+func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool {
+ packed := true
forEachStructField(st, func(f *ast.Field) {
if f.Tag != nil {
if f.Tag.Value == "`marshal:\"unaligned\"`" {
- if thisPacked {
+ if packed {
debugfAt(g.f.Position(g.t.Pos()),
fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
- thisPacked = false
+ packed = false
}
}
}
})
+ return packed
+}
+
+func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
+ thisPacked := g.isStructPacked(st)
g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
@@ -108,16 +112,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.recordUsedMarshallable(tName)
dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
},
- array: func(n, t *ast.Ident, len int) {
- if len < 1 {
- // Zero-length arrays should've been rejected by validate().
- panic("unreachable")
- }
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size * len
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr))
} else {
g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr))
}
},
}.dispatch)
@@ -148,7 +149,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.shift("dst", len)
} else {
// We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can referece here
+ // an instance of the dynamic type we can reference here
// (since the version in this struct is anonymous). Use
// a typed nil pointer to call SizeBytes() instead.
g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
@@ -158,24 +159,30 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
},
selector: func(n, tX, tSel *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name)
+ g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
+ return
+ }
g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
},
- array: func(n, t *ast.Ident, size int) {
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len*size)
+ g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("dst = dst[%d*(%s):]\n", size, lenExpr)
} else {
// We can't use shiftDynamic here because we don't have
// an instance of the dynamic type we can reference here
// (since the version in this struct is anonymous). Use
// a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
}
return
}
- g.emit("for idx := 0; idx < %d; idx++ {\n", size)
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
g.inIndent(func() {
g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
})
@@ -195,11 +202,11 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
if len, dynamic := g.scalarSize(t); !dynamic {
g.shift("src", len)
} else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name)
+ // We don't have an instance of the dynamic type we can
+ // reference here (since the version in this struct is
+ // anonymous). Use a typed nil pointer to call
+ // SizeBytes() instead.
+ g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name))
g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
}
return
@@ -207,24 +214,31 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
},
selector: func(n, tX, tSel *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name)
+ g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name))
+ return
+ }
g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
},
- array: func(n, t *ast.Ident, size int) {
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
if n.Name == "_" {
- g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len*size)
+ g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("src = src[%d*(%s):]\n", size, lenExpr)
} else {
// We can't use shiftDynamic here because we don't have
// an instance of the dynamic type we can referece here
// (since the version in this struct is anonymous). Use
// a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
}
return
}
- g.emit("for idx := 0; idx < %d; idx++ {\n", size)
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
g.inIndent(func() {
g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
})
@@ -302,17 +316,16 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
})
g.emit("}\n\n")
- g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
g.recordUsedImport("marshal")
g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
fallback := func() {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
g.emit("%s.MarshalBytes(buf)\n", g.r)
- g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
- g.emit("return err\n")
+ g.emit("return task.CopyOutBytes(addr, buf[:limit])\n")
}
if thisPacked {
g.recordUsedImport("reflect")
@@ -324,48 +337,39 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.emit("}\n\n")
}
// Fast serialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyOutBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit])\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
} else {
fallback()
}
})
g.emit("}\n\n")
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ })
+ g.emit("}\n\n")
+
g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
g.recordUsedImport("marshal")
g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
fallback := func() {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
- g.emit("_, err := task.CopyInBytes(addr, buf)\n")
- g.emit("if err != nil {\n")
- g.inIndent(func() {
- g.emit("return err\n")
- })
- g.emit("}\n")
-
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n")
+ g.emit("// partially unmarshalled struct.\n")
g.emit("%s.UnmarshalBytes(buf)\n", g.r)
- g.emit("return nil\n")
+ g.emit("return length, err\n")
}
if thisPacked {
g.recordUsedImport("reflect")
@@ -377,25 +381,11 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.emit("}\n\n")
}
// Fast deserialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyInBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyInBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
} else {
fallback()
}
@@ -410,8 +400,8 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
g.emit("%s.MarshalBytes(buf)\n", g.r)
- g.emit("n, err := w.Write(buf)\n")
- g.emit("return int64(n), err\n")
+ g.emit("length, err := w.Write(buf)\n")
+ g.emit("return int64(length), err\n")
}
if thisPacked {
g.recordUsedImport("reflect")
@@ -423,25 +413,199 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.emit("}\n\n")
}
// Fast serialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("len, err := w.Write(buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the Write.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return int64(len), err\n")
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+}
+
+func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) {
+ thisPacked := g.isStructPacked(st)
+
+ if slice.inner {
+ abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident))
+ }
+
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+
+ g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName())
+ g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(size * count)\n")
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n\n")
+
+ g.emit("// Unmarshal as much as possible, even on error. First handle full objects.\n")
+ g.emit("limit := length/size\n")
+ g.emit("for idx := 0; idx < limit; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Handle any final partial object.\n")
+ g.emit("if length < size*count && length%size != 0 {\n")
+ g.inIndent(func() {
+ g.emit("idx := limit\n")
+ g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("return length, err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !dst[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast deserialization.
+ g.emitCastSliceToByteSlice("&dst", "buf", "size * count")
+
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName())
+ g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(size * count)\n")
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return task.CopyOutBytes(addr, buf)\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !src[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emitCastSliceToByteSlice("&src", "buf", "size * count")
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("src[idx].MarshalBytes(dst[size*idx:(size)*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return size * count, nil\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !src[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ g.emitNoEscapeSliceDataPointer("&src", "val")
+
+ g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("dst[idx].UnmarshalBytes(src[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return size * count, nil\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !dst[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ g.emitNoEscapeSliceDataPointer("&dst", "val")
+
+ g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
} else {
fallback()
}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index fd992e44a..631295373 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -30,6 +30,11 @@ var standardImports = []string{
"gvisor.dev/gvisor/tools/go_marshal/analysis",
}
+var sliceAPIImports = []string{
+ "encoding/binary",
+ "gvisor.dev/gvisor/pkg/usermem",
+}
+
type testGenerator struct {
sourceBuffer
@@ -58,6 +63,11 @@ func newTestGenerator(t *ast.TypeSpec) *testGenerator {
for _, i := range standardImports {
g.imports.add(i).markUsed()
}
+ // These imports are used if a type requests the slice API. Don't
+ // mark them as used by default.
+ for _, i := range sliceAPIImports {
+ g.imports.add(i)
+ }
return g
}
@@ -132,6 +142,42 @@ func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() {
})
}
+func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceAPI) {
+ for _, name := range []string{"binary", "usermem"} {
+ if !g.imports.markUsed(name) {
+ panic(fmt.Sprintf("Generated test for '%s' referenced a non-existent import with local name '%s'", g.typeName(), name))
+ }
+ }
+
+ g.inTestFunction("TestSafeMarshalUnmarshalSlicePreservesData", func() {
+ g.emit("var x, y, yUnsafe [8]%s\n", g.typeName())
+ g.emit("analysis.RandomizeValue(&x)\n\n")
+ g.emit("size := (*%s)(nil).SizeBytes() * len(x)\n", g.typeName())
+ g.emit("buf := bytes.NewBuffer(make([]byte, size))\n")
+ g.emit("buf.Reset()\n")
+ g.emit("if err := binary.Write(buf, usermem.ByteOrder, x[:]); err != nil {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"binary.Write failed: %v\", err))\n")
+ })
+ g.emit("}\n")
+ g.emit("bufUnsafe := make([]byte, size)\n")
+ g.emit("MarshalUnsafe%s(x[:], bufUnsafe)\n\n", slice.ident)
+
+ g.emit("UnmarshalUnsafe%s(y[:], buf.Bytes())\n", slice.ident)
+ g.emit("if !reflect.DeepEqual(x, y) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across binary.Write/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
+ })
+ g.emit("}\n")
+ g.emit("UnmarshalUnsafe%s(yUnsafe[:], bufUnsafe)\n", slice.ident)
+ g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafeSlice/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
+ })
+ g.emit("}\n\n")
+ })
+}
+
func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() {
g.inTestFunction("TestWriteToUnmarshalPreservesData", func() {
g.emit("var x, y, yUnsafe %s\n", g.typeName())
@@ -170,12 +216,16 @@ func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() {
})
}
-func (g *testGenerator) emitTests() {
+func (g *testGenerator) emitTests(slice *sliceAPI) {
g.emitTestNonZeroSize()
g.emitTestSuspectAlignment()
g.emitTestMarshalUnmarshalPreservesData()
g.emitTestWriteToUnmarshalPreservesData()
g.emitTestSizeBytesOnTypedNilPtr()
+
+ if slice != nil {
+ g.emitTestMarshalUnmarshalSlicePreservesData(slice)
+ }
}
func (g *testGenerator) write(out io.Writer) error {
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
index a0936e013..d94314302 100644
--- a/tools/go_marshal/gomarshal/util.go
+++ b/tools/go_marshal/gomarshal/util.go
@@ -25,7 +25,6 @@ import (
"path"
"reflect"
"sort"
- "strconv"
"strings"
)
@@ -75,29 +74,10 @@ func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) {
type fieldDispatcher struct {
primitive func(n, t *ast.Ident)
selector func(n, tX, tSel *ast.Ident)
- array func(n, t *ast.Ident, size int)
+ array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident)
unhandled func(n *ast.Ident)
}
-// Precondition: a must have a literal for the array length. Consts and
-// expressions are not allowed as array lengths, and should be rejected by the
-// caller.
-func arrayLen(a *ast.ArrayType) int {
- if a.Len == nil {
- // Probably a slice? Must be handled by caller.
- panic("Nil array length in array type")
- }
- lenLit, ok := a.Len.(*ast.BasicLit)
- if !ok {
- panic("Array has non-literal for length")
- }
- len, err := strconv.Atoi(lenLit.Value)
- if err != nil {
- panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err))
- }
- return len
-}
-
// Precondition: All dispatch callbacks that will be invoked must be
// provided. Embedded fields are not allowed, len(f.Names) >= 1.
func (fd fieldDispatcher) dispatch(f *ast.Field) {
@@ -123,7 +103,7 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) {
case *ast.ArrayType:
switch t := v.Elt.(type) {
case *ast.Ident:
- fd.array(name, t, arrayLen(v))
+ fd.array(name, v, t)
default:
// Should be handled with a better error message during validate.
panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t))
@@ -285,6 +265,11 @@ type importStmt struct {
aliased bool
// Indicates whether this import was referenced by generated code.
used bool
+ // AST node and file set representing the import statement, if any. These
+ // are only non-nil if the import statement originates from an input source
+ // file.
+ spec *ast.ImportSpec
+ fset *token.FileSet
}
func newImport(p string) *importStmt {
@@ -310,14 +295,27 @@ func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
name: name,
path: p,
aliased: spec.Name != nil,
+ spec: spec,
+ fset: f,
}
}
+// String implements fmt.Stringer.String. This generates a string for the import
+// statement appropriate for writing directly to generated code.
func (i *importStmt) String() string {
if i.aliased {
- return fmt.Sprintf("%s \"%s\"", i.name, i.path)
+ return fmt.Sprintf("%s %q", i.name, i.path)
+ }
+ return fmt.Sprintf("%q", i.path)
+}
+
+// debugString returns a debug string representing an import statement. This
+// representation is not valid golang code and is used for debugging output.
+func (i *importStmt) debugString() string {
+ if i.spec != nil && i.fset != nil {
+ return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i)
}
- return fmt.Sprintf("\"%s\"", i.path)
+ return fmt.Sprintf("(go-marshal import): %s", i)
}
func (i *importStmt) markUsed() {
@@ -329,40 +327,78 @@ func (i *importStmt) equivalent(other *importStmt) bool {
}
// importTable represents a collection of importStmts.
+//
+// An importTable may contain multiple import statements referencing the same
+// local name. All import statements aliasing to the same local name are
+// technically ambiguous, as if such an import name is used in the generated
+// code, it's not clear which import statement it refers to. We ignore any
+// potential collisions until actually writing the import table to the generated
+// source file. See importTable.write.
+//
+// Given the following import statements across all the files comprising a
+// package marshalled:
+//
+// "sync"
+// "pkg/sync"
+// "pkg/sentry/kernel"
+// ktime "pkg/sentry/kernel/time"
+//
+// An importTable representing them would look like this:
+//
+// importTable {
+// is: map[string][]*importStmt {
+// "sync": []*importStmt{
+// importStmt{name:"sync", path:"sync", aliased:false}
+// importStmt{name:"sync", path:"pkg/sync", aliased:false}
+// },
+// "kernel": []*importStmt{importStmt{
+// name: "kernel",
+// path: "pkg/sentry/kernel",
+// aliased: false
+// }},
+// "ktime": []*importStmt{importStmt{
+// name: "ktime",
+// path: "pkg/sentry/kernel/time",
+// aliased: true,
+// }},
+// }
+// }
+//
+// Note that the local name "sync" is assigned to two different import
+// statements. This is possible if the import statements are from different
+// source files in the same package.
+//
+// Since go-marshal generates a single output file per package regardless of the
+// number of input files, if "sync" is referenced by any generated code, it's
+// unclear which import statement "sync" refers to. While it's theoretically
+// possible to resolve this by assigning a unique local alias to each instance
+// of the sync package, go-marshal currently aborts when it encounters such an
+// ambiguity.
+//
+// TODO(b/151478251): importTable considers the final component of an import
+// path to be the package name, but this is only a convention. The actual
+// package name is determined by the package statement in the source files for
+// the package.
type importTable struct {
// Map of imports and whether they should be copied to the output.
- is map[string]*importStmt
+ is map[string][]*importStmt
}
func newImportTable() *importTable {
return &importTable{
- is: make(map[string]*importStmt),
+ is: make(map[string][]*importStmt),
}
}
-// Merges import statements from other into i. Collisions in import statements
-// result in a panic.
+// Merges import statements from other into i.
func (i *importTable) merge(other *importTable) {
- for name, im := range other.is {
- if dup, ok := i.is[name]; ok && !dup.equivalent(im) {
- panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im))
- }
-
- i.is[name] = im
+ for name, ims := range other.is {
+ i.is[name] = append(i.is[name], ims...)
}
}
func (i *importTable) addStmt(s *importStmt) *importStmt {
- if old, ok := i.is[s.name]; ok && !old.equivalent(s) {
- // A collision should always be between an import inserted by the
- // go-marshal tool and an import from the original source file (assuming
- // the original source file was valid). We could theoretically handle
- // the collision by assigning a local name to our import. However, this
- // would need to be plumbed throughout the generator. Given that
- // collisions should be rare, simply panic on collision.
- panic(fmt.Sprintf("Import collision: old: %s as %v; new: %v as %v", old.path, old.name, s.path, s.name))
- }
- i.is[s.name] = s
+ i.is[s.name] = append(i.is[s.name], s)
return s
}
@@ -378,16 +414,20 @@ func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *impor
// Marks the import named n as used. If no such import is in the table, returns
// false.
func (i *importTable) markUsed(n string) bool {
- if n, ok := i.is[n]; ok {
- n.markUsed()
+ if ns, ok := i.is[n]; ok {
+ for _, n := range ns {
+ n.markUsed()
+ }
return true
}
return false
}
func (i *importTable) clear() {
- for _, i := range i.is {
- i.used = false
+ for _, is := range i.is {
+ for _, i := range is {
+ i.used = false
+ }
}
}
@@ -398,9 +438,42 @@ func (i *importTable) write(out io.Writer) error {
}
imports := make([]string, 0, len(i.is))
- for _, i := range i.is {
- if i.used {
- imports = append(imports, i.String())
+ for name, is := range i.is {
+ var lastUsed *importStmt
+ var ambiguous bool
+
+ for _, i := range is {
+ if i.used {
+ if lastUsed != nil {
+ if !i.equivalent(lastUsed) {
+ ambiguous = true
+ }
+ }
+ lastUsed = i
+ }
+ }
+
+ if ambiguous {
+ // We have two or more import statements across the different source
+ // files that share a local name, and at least one of these imports
+ // are used by the generated code. This ambiguity can't be resolved
+ // by go-marshal and requires the user intervention. Dump a list of
+ // the colliding import statements and let the user modify the input
+ // files as appropriate.
+ var b strings.Builder
+ fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name)
+ fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name)
+ // Note: len(is) is guaranteed to be 1 or greater or ambiguous can't
+ // be true. Therefore the slicing below is safe.
+ for _, i := range is[:len(is)-1] {
+ fmt.Fprintf(&b, " %v\n", i.debugString())
+ }
+ fmt.Fprintf(&b, " %v", is[len(is)-1].debugString())
+ panic(b.String())
+ }
+
+ if lastUsed != nil {
+ imports = append(imports, lastUsed.String())
}
}
sort.Strings(imports)
diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go
index f129788e0..cb2166252 100644
--- a/tools/go_marshal/marshal/marshal.go
+++ b/tools/go_marshal/marshal/marshal.go
@@ -42,7 +42,11 @@ type Task interface {
CopyInBytes(addr usermem.Addr, b []byte) (int, error)
}
-// Marshallable represents a type that can be marshalled to and from memory.
+// Marshallable represents operations on a type that can be marshalled to and
+// from memory.
+//
+// go-marshal automatically generates implementations for this interface for
+// types marked as '+marshal'.
type Marshallable interface {
io.WriterTo
@@ -54,12 +58,18 @@ type Marshallable interface {
// likely make use of the type of these fields).
SizeBytes() int
- // MarshalBytes serializes a copy of a type to dst. dst must be at least
- // SizeBytes() long.
+ // MarshalBytes serializes a copy of a type to dst. dst may be smaller than
+ // SizeBytes(), which results in a part of the struct being marshalled. Note
+ // that this may have unexpected results for non-packed types, as implicit
+ // padding needs to be taken into account when reasoning about how much of
+ // the type is serialized.
MarshalBytes(dst []byte)
- // UnmarshalBytes deserializes a type from src. src must be at least
- // SizeBytes() long.
+ // UnmarshalBytes deserializes a type from src. src may be smaller than
+ // SizeBytes(), which results in a partially deserialized struct. Note that
+ // this may have unexpected results for non-packed types, as implicit
+ // padding needs to be taken into account when reasoning about how much of
+ // the type is deserialized.
UnmarshalBytes(src []byte)
// Packed returns true if the marshalled size of the type is the same as the
@@ -67,13 +77,20 @@ type Marshallable interface {
// starting at unaligned addresses (should always be true by default for ABI
// structs, verified by automatically generated tests when using
// go_marshal), and has no fields marked `marshal:"unaligned"`.
+ //
+ // Packed must return the same result for all possible values of the type
+ // implementing it. Violating this constraint implies the type doesn't have
+ // a static memory layout, and will lead to memory corruption.
+ // Go-marshal-generated code reuses the result of Packed for multiple values
+ // of the same type.
Packed() bool
// MarshalUnsafe serializes a type by bulk copying its in-memory
// representation to the dst buffer. This is only safe to do when the type
// has no implicit padding, see Marshallable.Packed. When Packed would
// return false, MarshalUnsafe should fall back to the safer but slower
- // MarshalBytes.
+ // MarshalBytes. dst may be smaller than SizeBytes(), see comment for
+ // MarshalBytes for implications.
MarshalUnsafe(dst []byte)
// UnmarshalUnsafe deserializes a type by directly copying to the underlying
@@ -82,7 +99,8 @@ type Marshallable interface {
// This allows much faster unmarshalling of types which have no implicit
// padding, see Marshallable.Packed. When Packed would return false,
// UnmarshalUnsafe should fall back to the safer but slower unmarshal
- // mechanism implemented in UnmarshalBytes.
+ // mechanism implemented in UnmarshalBytes. src may be smaller than
+ // SizeBytes(), see comment for UnmarshalBytes for implications.
UnmarshalUnsafe(src []byte)
// CopyIn deserializes a Marshallable type from a task's memory. This may
@@ -91,12 +109,79 @@ type Marshallable interface {
// marshalled does not escape. The implementation should avoid creating
// extra copies in memory by directly deserializing to the object's
// underlying memory.
- CopyIn(task Task, addr usermem.Addr) error
+ //
+ // If the copy-in from the task memory is only partially successful, CopyIn
+ // should still attempt to deserialize as much data as possible. See comment
+ // for UnmarshalBytes.
+ CopyIn(task Task, addr usermem.Addr) (int, error)
// CopyOut serializes a Marshallable type to a task's memory. This may only
// be called from a task goroutine. This is more efficient than calling
// MarshalUnsafe on Marshallable.Packed types, as the type being serialized
// does not escape. The implementation should avoid creating extra copies in
// memory by directly serializing from the object's underlying memory.
- CopyOut(task Task, addr usermem.Addr) error
+ //
+ // The copy-out to the task memory may be partially successful, in which
+ // case CopyOut returns how much data was serialized. See comment for
+ // MarshalBytes for implications.
+ CopyOut(task Task, addr usermem.Addr) (int, error)
+
+ // CopyOutN is like CopyOut, but explicitly requests a partial
+ // copy-out. Note that this may yield unexpected results for non-packed
+ // types and the caller may only want to allow this for packed types. See
+ // comment on MarshalBytes.
+ //
+ // The limit must be less than or equal to SizeBytes().
+ CopyOutN(task Task, addr usermem.Addr, limit int) (int, error)
}
+
+// go-marshal generates additional functions for a type based on additional
+// clauses to the +marshal directive. They are documented below.
+//
+// Slice API
+// =========
+//
+// Adding a "slice" clause to the +marshal directive for structs or newtypes on
+// primitives like this:
+//
+// // +marshal slice:FooSlice
+// type Foo struct { ... }
+//
+// Generates four additional functions for marshalling slices of Foos like this:
+//
+// // MarshalUnsafeFooSlice is like Foo.MarshalUnsafe, buf for a []Foo. It's
+// // more efficient that repeatedly calling calling Foo.MarshalUnsafe over a
+// // []Foo in a loop.
+// func MarshalUnsafeFooSlice(src []Foo, dst []byte) (int, error) { ... }
+//
+// // UnmarshalUnsafeFooSlice is like Foo.UnmarshalUnsafe, buf for a []Foo. It's
+// // more efficient that repeatedly calling calling Foo.UnmarshalUnsafe over a
+// // []Foo in a loop.
+// func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... }
+//
+// // CopyFooSliceIn copies in a slice of Foo objects from the task's memory.
+// func CopyFooSliceIn(task marshal.Task, addr usermem.Addr, dst []Foo) (int, error) { ... }
+//
+// // CopyFooSliceIn copies out a slice of Foo objects to the task's memory.
+// func CopyFooSliceOut(task marshal.Task, addr usermem.Addr, src []Foo) (int, error) { ... }
+//
+// The name of the functions are of the format "Copy%sIn" and "Copy%sOut", where
+// %s is the first argument to the slice clause. This directive is not supported
+// for newtypes on arrays.
+//
+// The slice clause also takes an optional second argument, which must be the
+// value "inner":
+//
+// // +marshal slice:Int32Slice:inner
+// type Int32 int32
+//
+// This is only valid on newtypes on primitives, and causes the generated
+// functions to accept slices of the inner type instead:
+//
+// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []int32) (int, error) { ... }
+//
+// Without "inner", they would instead be:
+//
+// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []Int32) (int, error) { ... }
+//
+// This may help avoid a cast depending on how the generated functions are used.
diff --git a/tools/go_marshal/primitive/BUILD b/tools/go_marshal/primitive/BUILD
new file mode 100644
index 000000000..cc08ba63a
--- /dev/null
+++ b/tools/go_marshal/primitive/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "primitive",
+ srcs = [
+ "primitive.go",
+ ],
+ marshal = True,
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+ ],
+)
diff --git a/tools/go_marshal/primitive/primitive.go b/tools/go_marshal/primitive/primitive.go
new file mode 100644
index 000000000..ebcf130ae
--- /dev/null
+++ b/tools/go_marshal/primitive/primitive.go
@@ -0,0 +1,175 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package primitive defines marshal.Marshallable implementations for primitive
+// types.
+package primitive
+
+import (
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+)
+
+// Int16 is a marshal.Marshallable implementation for int16.
+//
+// +marshal slice:Int16Slice:inner
+type Int16 int16
+
+// Uint16 is a marshal.Marshallable implementation for uint16.
+//
+// +marshal slice:Uint16Slice:inner
+type Uint16 uint16
+
+// Int32 is a marshal.Marshallable implementation for int32.
+//
+// +marshal slice:Int32Slice:inner
+type Int32 int32
+
+// Uint32 is a marshal.Marshallable implementation for uint32.
+//
+// +marshal slice:Uint32Slice:inner
+type Uint32 uint32
+
+// Int64 is a marshal.Marshallable implementation for int64.
+//
+// +marshal slice:Int64Slice:inner
+type Int64 int64
+
+// Uint64 is a marshal.Marshallable implementation for uint64.
+//
+// +marshal slice:Uint64Slice:inner
+type Uint64 uint64
+
+// Below, we define some convenience functions for marshalling primitive types
+// using the newtypes above, without requiring superfluous casts.
+
+// 16-bit integers
+
+// CopyInt16In is a convenient wrapper for copying in an int16 from the task's
+// memory.
+func CopyInt16In(task marshal.Task, addr usermem.Addr, dst *int16) (int, error) {
+ var buf Int16
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int16(buf)
+ return n, nil
+}
+
+// CopyInt16Out is a convenient wrapper for copying out an int16 to the task's
+// memory.
+func CopyInt16Out(task marshal.Task, addr usermem.Addr, src int16) (int, error) {
+ srcP := Int16(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// CopyUint16In is a convenient wrapper for copying in a uint16 from the task's
+// memory.
+func CopyUint16In(task marshal.Task, addr usermem.Addr, dst *uint16) (int, error) {
+ var buf Uint16
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint16(buf)
+ return n, nil
+}
+
+// CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's
+// memory.
+func CopyUint16Out(task marshal.Task, addr usermem.Addr, src uint16) (int, error) {
+ srcP := Uint16(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// 32-bit integers
+
+// CopyInt32In is a convenient wrapper for copying in an int32 from the task's
+// memory.
+func CopyInt32In(task marshal.Task, addr usermem.Addr, dst *int32) (int, error) {
+ var buf Int32
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int32(buf)
+ return n, nil
+}
+
+// CopyInt32Out is a convenient wrapper for copying out an int32 to the task's
+// memory.
+func CopyInt32Out(task marshal.Task, addr usermem.Addr, src int32) (int, error) {
+ srcP := Int32(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// CopyUint32In is a convenient wrapper for copying in a uint32 from the task's
+// memory.
+func CopyUint32In(task marshal.Task, addr usermem.Addr, dst *uint32) (int, error) {
+ var buf Uint32
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint32(buf)
+ return n, nil
+}
+
+// CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's
+// memory.
+func CopyUint32Out(task marshal.Task, addr usermem.Addr, src uint32) (int, error) {
+ srcP := Uint32(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// 64-bit integers
+
+// CopyInt64In is a convenient wrapper for copying in an int64 from the task's
+// memory.
+func CopyInt64In(task marshal.Task, addr usermem.Addr, dst *int64) (int, error) {
+ var buf Int64
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int64(buf)
+ return n, nil
+}
+
+// CopyInt64Out is a convenient wrapper for copying out an int64 to the task's
+// memory.
+func CopyInt64Out(task marshal.Task, addr usermem.Addr, src int64) (int, error) {
+ srcP := Int64(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// CopyUint64In is a convenient wrapper for copying in a uint64 from the task's
+// memory.
+func CopyUint64In(task marshal.Task, addr usermem.Addr, dst *uint64) (int, error) {
+ var buf Uint64
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint64(buf)
+ return n, nil
+}
+
+// CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's
+// memory.
+func CopyUint64Out(task marshal.Task, addr usermem.Addr, src uint64) (int, error) {
+ srcP := Uint64(src)
+ return srcP.CopyOut(task, addr)
+}
diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD
index f27c5ce52..3b839799d 100644
--- a/tools/go_marshal/test/BUILD
+++ b/tools/go_marshal/test/BUILD
@@ -39,3 +39,17 @@ go_binary(
"//tools/go_marshal/marshal",
],
)
+
+go_test(
+ name = "marshal_test",
+ size = "small",
+ srcs = ["marshal_test.go"],
+ deps = [
+ ":test",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//tools/go_marshal/analysis",
+ "//tools/go_marshal/marshal",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
+)
diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go
index c79defe9e..224d308c7 100644
--- a/tools/go_marshal/test/benchmark_test.go
+++ b/tools/go_marshal/test/benchmark_test.go
@@ -176,3 +176,45 @@ func BenchmarkGoMarshalUnsafe(b *testing.B) {
panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
}
}
+
+func BenchmarkBinarySlice(b *testing.B) {
+ var s1, s2 [64]test.Stat
+ analysis.RandomizeValue(&s1)
+
+ size := binary.Size(s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, 0, size)
+ buf = binary.Marshal(buf, usermem.ByteOrder, &s1)
+ binary.Unmarshal(buf, usermem.ByteOrder, &s2)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+func BenchmarkGoMarshalUnsafeSlice(b *testing.B) {
+ var s1, s2 [64]test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, (*test.Stat)(nil).SizeBytes()*len(s1))
+ test.MarshalUnsafeStatSlice(s1[:], buf)
+ test.UnmarshalUnsafeStatSlice(s2[:], buf)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
diff --git a/tools/go_marshal/test/external/external.go b/tools/go_marshal/test/external/external.go
index 4be3722f3..26fe8e0c8 100644
--- a/tools/go_marshal/test/external/external.go
+++ b/tools/go_marshal/test/external/external.go
@@ -21,3 +21,11 @@ package external
type External struct {
j int64
}
+
+// NotPacked is an unaligned Marshallable type for use in testing.
+//
+// +marshal
+type NotPacked struct {
+ a int32
+ b byte `marshal:"unaligned"`
+}
diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go
new file mode 100644
index 000000000..16829ee45
--- /dev/null
+++ b/tools/go_marshal/test/marshal_test.go
@@ -0,0 +1,515 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package marshal_test contains manual tests for the marshal interface. These
+// are intended to test behaviour not covered by the automatically generated
+// tests.
+package marshal_test
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "reflect"
+ "runtime"
+ "testing"
+ "unsafe"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/analysis"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/test"
+)
+
+var simulatedErr error = syserror.EFAULT
+
+// mockTask implements marshal.Task.
+type mockTask struct {
+ taskMem usermem.BytesIO
+}
+
+// populate fills the task memory with the contents of val.
+func (t *mockTask) populate(val interface{}) {
+ var buf bytes.Buffer
+ // Use binary.Write so we aren't testing go-marshal against its own
+ // potentially buggy implementation.
+ if err := binary.Write(&buf, usermem.ByteOrder, val); err != nil {
+ panic(err)
+ }
+ t.taskMem.Bytes = buf.Bytes()
+}
+
+func (t *mockTask) setLimit(n int) {
+ if len(t.taskMem.Bytes) < n {
+ grown := make([]byte, n)
+ copy(grown, t.taskMem.Bytes)
+ t.taskMem.Bytes = grown
+ return
+ }
+ t.taskMem.Bytes = t.taskMem.Bytes[:n]
+}
+
+// CopyScratchBuffer implements marshal.Task.CopyScratchBuffer.
+func (t *mockTask) CopyScratchBuffer(size int) []byte {
+ return make([]byte, size)
+}
+
+// CopyOutBytes implements marshal.Task.CopyOutBytes. The implementation
+// completely ignores the target address and stores a copy of b in its
+// internally buffer, overriding any previous contents.
+func (t *mockTask) CopyOutBytes(_ usermem.Addr, b []byte) (int, error) {
+ return t.taskMem.CopyOut(nil, 0, b, usermem.IOOpts{})
+}
+
+// CopyInBytes implements marshal.Task.CopyInBytes. The implementation
+// completely ignores the source address and always fills b from the begining of
+// its internal buffer.
+func (t *mockTask) CopyInBytes(_ usermem.Addr, b []byte) (int, error) {
+ return t.taskMem.CopyIn(nil, 0, b, usermem.IOOpts{})
+}
+
+// unsafeMemory returns the underlying memory for m. The returned slice is only
+// valid for the lifetime for m. The garbage collector isn't aware that the
+// returned slice is related to m, the caller must ensure m lives long enough.
+func unsafeMemory(m marshal.Marshallable) []byte {
+ if !m.Packed() {
+ // We can't return a slice pointing to the underlying memory
+ // since the layout isn't packed. Allocate a temporary buffer
+ // and marshal instead.
+ var buf bytes.Buffer
+ if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil {
+ panic(err)
+ }
+ return buf.Bytes()
+ }
+
+ // reflect.ValueOf(m)
+ // .Elem() // Unwrap interface to inner concrete object
+ // .Addr() // Pointer value to object
+ // .Pointer() // Actual address from the pointer value
+ ptr := reflect.ValueOf(m).Elem().Addr().Pointer()
+
+ size := m.SizeBytes()
+
+ var mem []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&mem))
+ hdr.Data = ptr
+ hdr.Len = size
+ hdr.Cap = size
+
+ return mem
+}
+
+// unsafeMemorySlice returns the underlying memory for m. The returned slice is
+// only valid for the lifetime for m. The garbage collector isn't aware that the
+// returned slice is related to m, the caller must ensure m lives long enough.
+//
+// Precondition: m must be a slice.
+func unsafeMemorySlice(m interface{}, elt marshal.Marshallable) []byte {
+ kind := reflect.TypeOf(m).Kind()
+ if kind != reflect.Slice {
+ panic("unsafeMemorySlice called on non-slice")
+ }
+
+ if !elt.Packed() {
+ // We can't return a slice pointing to the underlying memory
+ // since the layout isn't packed. Allocate a temporary buffer
+ // and marshal instead.
+ var buf bytes.Buffer
+ if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil {
+ panic(err)
+ }
+ return buf.Bytes()
+ }
+
+ v := reflect.ValueOf(m)
+ length := v.Len() * elt.SizeBytes()
+
+ var mem []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&mem))
+ hdr.Data = v.Pointer() // This is a pointer to the first elem for slices.
+ hdr.Len = length
+ hdr.Cap = length
+
+ return mem
+}
+
+func isZeroes(buf []byte) bool {
+ for _, b := range buf {
+ if b != 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// compareMemory compares the first n bytes of two chuncks of memory represented
+// by expected and actual.
+func compareMemory(t *testing.T, expected, actual []byte, n int) {
+ t.Logf("Expected (%d): %v (%d) + (%d) %v\n", len(expected), expected[:n], n, len(expected)-n, expected[n:])
+ t.Logf("Actual (%d): %v (%d) + (%d) %v\n", len(actual), actual[:n], n, len(actual)-n, actual[n:])
+
+ if diff := cmp.Diff(expected[:n], actual[:n]); diff != "" {
+ t.Errorf("Memory buffers don't match:\n--- expected only\n+++ actual only\n%v", diff)
+ }
+}
+
+// limitedCopyIn populates task memory with src, then unmarshals task memory to
+// dst. The task signals an error at limit bytes during copy-in, which should
+// result in a truncated unmarshalling.
+func limitedCopyIn(t *testing.T, src, dst marshal.Marshallable, limit int) {
+ var task mockTask
+ task.populate(src)
+ task.setLimit(limit)
+
+ n, err := dst.CopyIn(&task, usermem.Addr(0))
+ if n != limit {
+ t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if err != simulatedErr {
+ t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := unsafeMemory(src)
+ defer runtime.KeepAlive(src)
+ actualMem := unsafeMemory(dst)
+ defer runtime.KeepAlive(dst)
+
+ compareMemory(t, expectedMem, actualMem, n)
+
+ // The last n bytes should be zero for actual, since actual was
+ // zero-initialized, and CopyIn shouldn't have touched those bytes. However
+ // we can only guarantee we didn't touch anything in the last n bytes if the
+ // layout is packed.
+ if dst.Packed() && !isZeroes(actualMem[n:]) {
+ t.Errorf("Expected the last %d bytes of copied in object to be zeroes, got %v\n", dst.SizeBytes()-n, actualMem)
+ }
+}
+
+// limitedCopyOut marshals src to task memory. The task signals an error at
+// limit bytes during copy-out, which should result in a truncated marshalling.
+func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) {
+ var task mockTask
+ task.setLimit(limit)
+
+ n, err := src.CopyOut(&task, usermem.Addr(0))
+ if n != limit {
+ t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if err != simulatedErr {
+ t.Errorf("CopyOut returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := unsafeMemory(src)
+ defer runtime.KeepAlive(src)
+ actualMem := task.taskMem.Bytes
+
+ compareMemory(t, expectedMem, actualMem, n)
+}
+
+// copyOutN marshals src to task memory, requesting the marshalling to be
+// limited to limit bytes.
+func copyOutN(t *testing.T, src marshal.Marshallable, limit int) {
+ var task mockTask
+ task.setLimit(limit)
+
+ n, err := src.CopyOutN(&task, usermem.Addr(0), limit)
+ if err != nil {
+ t.Errorf("CopyOut returned unexpected error: %v", err)
+ }
+ if n != limit {
+ t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+
+ expectedMem := unsafeMemory(src)
+ defer runtime.KeepAlive(src)
+ actualMem := task.taskMem.Bytes
+
+ t.Logf("Expected: %v + %v\n", expectedMem[:n], expectedMem[n:])
+ t.Logf("Actual : %v + %v\n", actualMem[:n], actualMem[n:])
+
+ compareMemory(t, expectedMem, actualMem, n)
+}
+
+// TestLimitedMarshalling verifies marshalling/unmarshalling succeeds when the
+// underyling copy in/out operations partially succeed.
+func TestLimitedMarshalling(t *testing.T) {
+ types := []reflect.Type{
+ // Packed types.
+ reflect.TypeOf((*test.Type2)(nil)),
+ reflect.TypeOf((*test.Type3)(nil)),
+ reflect.TypeOf((*test.Timespec)(nil)),
+ reflect.TypeOf((*test.Stat)(nil)),
+ reflect.TypeOf((*test.InetAddr)(nil)),
+ reflect.TypeOf((*test.SignalSet)(nil)),
+ reflect.TypeOf((*test.SignalSetAlias)(nil)),
+ // Non-packed types.
+ reflect.TypeOf((*test.Type1)(nil)),
+ reflect.TypeOf((*test.Type4)(nil)),
+ reflect.TypeOf((*test.Type5)(nil)),
+ reflect.TypeOf((*test.Type6)(nil)),
+ reflect.TypeOf((*test.Type7)(nil)),
+ reflect.TypeOf((*test.Type8)(nil)),
+ }
+
+ for _, tyPtr := range types {
+ // Remove one level of pointer-indirection from the type. We get this
+ // back when we pass the type to reflect.New.
+ ty := tyPtr.Elem()
+
+ // Partial copy-in.
+ t.Run(fmt.Sprintf("PartialCopyIn_%v", ty), func(t *testing.T) {
+ expected := reflect.New(ty).Interface().(marshal.Marshallable)
+ actual := reflect.New(ty).Interface().(marshal.Marshallable)
+ analysis.RandomizeValue(expected)
+
+ limitedCopyIn(t, expected, actual, expected.SizeBytes()/2)
+ })
+
+ // Partial copy-out.
+ t.Run(fmt.Sprintf("PartialCopyOut_%v", ty), func(t *testing.T) {
+ expected := reflect.New(ty).Interface().(marshal.Marshallable)
+ analysis.RandomizeValue(expected)
+
+ limitedCopyOut(t, expected, expected.SizeBytes()/2)
+ })
+
+ // Explicitly request partial copy-out.
+ t.Run(fmt.Sprintf("PartialCopyOutN_%v", ty), func(t *testing.T) {
+ expected := reflect.New(ty).Interface().(marshal.Marshallable)
+ analysis.RandomizeValue(expected)
+
+ copyOutN(t, expected, expected.SizeBytes()/2)
+ })
+ }
+}
+
+// TestLimitedMarshalling verifies marshalling/unmarshalling of slices of
+// marshallable types succeed when the underyling copy in/out operations
+// partially succeed.
+func TestLimitedSliceMarshalling(t *testing.T) {
+ types := []struct {
+ arrayPtrType reflect.Type
+ copySliceIn func(task marshal.Task, addr usermem.Addr, dstSlice interface{}) (int, error)
+ copySliceOut func(task marshal.Task, addr usermem.Addr, srcSlice interface{}) (int, error)
+ unsafeMemory func(arrPtr interface{}) []byte
+ }{
+ // Packed types.
+ {
+ reflect.TypeOf((*[20]test.Stat)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[20]test.Stat)[:]
+ return test.CopyStatSliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[20]test.Stat)[:]
+ return test.CopyStatSliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[20]test.Stat)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[1]test.Stat)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[1]test.Stat)[:]
+ return test.CopyStatSliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[1]test.Stat)[:]
+ return test.CopyStatSliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[1]test.Stat)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[5]test.SignalSetAlias)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[5]test.SignalSetAlias)[:]
+ return test.CopySignalSetAliasSliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[5]test.SignalSetAlias)[:]
+ return test.CopySignalSetAliasSliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[5]test.SignalSetAlias)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ // Non-packed types.
+ {
+ reflect.TypeOf((*[20]test.Type1)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[20]test.Type1)[:]
+ return test.CopyType1SliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[20]test.Type1)[:]
+ return test.CopyType1SliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[20]test.Type1)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[1]test.Type1)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[1]test.Type1)[:]
+ return test.CopyType1SliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[1]test.Type1)[:]
+ return test.CopyType1SliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[1]test.Type1)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[7]test.Type8)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[7]test.Type8)[:]
+ return test.CopyType8SliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[7]test.Type8)[:]
+ return test.CopyType8SliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[7]test.Type8)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ }
+
+ for _, tt := range types {
+ // The body of this loop is generic over the type tt.arrayPtrType, with
+ // the help of reflection. To aid in readability, comments below show
+ // the equivalent go code assuming
+ // tt.arrayPtrType = typeof(*[20]test.Stat).
+
+ // Equivalent:
+ // var x *[20]test.Stat
+ // arrayTy := reflect.TypeOf(*x)
+ arrayTy := tt.arrayPtrType.Elem()
+
+ // Partial copy-in of slices.
+ t.Run(fmt.Sprintf("PartialCopySliceIn_%v", arrayTy), func(t *testing.T) {
+ // Equivalent:
+ // var x [20]test.Stat
+ // length := len(x)
+ length := arrayTy.Len()
+ if length < 1 {
+ panic("Test type can't be zero-length array")
+ }
+ // Equivalent:
+ // elem := new(test.Stat).(marshal.Marshallable)
+ elem := reflect.New(arrayTy.Elem()).Interface().(marshal.Marshallable)
+
+ // Equivalent:
+ // var expected, actual interface{}
+ // expected = new([20]test.Stat)
+ // actual = new([20]test.Stat)
+ expected := reflect.New(arrayTy).Interface()
+ actual := reflect.New(arrayTy).Interface()
+
+ analysis.RandomizeValue(expected)
+
+ limit := (length * elem.SizeBytes()) / 2
+ // Also make sure the limit is partially inside one of the elements.
+ limit += elem.SizeBytes() / 2
+ analysis.RandomizeValue(expected)
+
+ var task mockTask
+ task.populate(expected)
+ task.setLimit(limit)
+
+ n, err := tt.copySliceIn(&task, usermem.Addr(0), actual)
+ if n != limit {
+ t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if n < length*elem.SizeBytes() && err != simulatedErr {
+ t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := tt.unsafeMemory(expected)
+ defer runtime.KeepAlive(expected)
+ actualMem := tt.unsafeMemory(actual)
+ defer runtime.KeepAlive(actual)
+
+ compareMemory(t, expectedMem, actualMem, n)
+
+ // The last n bytes should be zero for actual, since actual was
+ // zero-initialized, and CopyIn shouldn't have touched those bytes. However
+ // we can only guarantee we didn't touch anything in the last n bytes if the
+ // layout is packed.
+ if elem.Packed() && !isZeroes(actualMem[n:]) {
+ t.Errorf("Expected the last %d bytes of copied in object to be zeroes, got %v\n", (elem.SizeBytes()*length)-n, actualMem)
+ }
+ })
+
+ // Partial copy-out of slices.
+ t.Run(fmt.Sprintf("PartialCopySliceOut_%v", arrayTy), func(t *testing.T) {
+ // Equivalent:
+ // var x [20]test.Stat
+ // length := len(x)
+ length := arrayTy.Len()
+ if length < 1 {
+ panic("Test type can't be zero-length array")
+ }
+ // Equivalent:
+ // elem := new(test.Stat).(marshal.Marshallable)
+ elem := reflect.New(arrayTy.Elem()).Interface().(marshal.Marshallable)
+
+ // Equivalent:
+ // var expected, actual interface{}
+ // expected = new([20]test.Stat)
+ // actual = new([20]test.Stat)
+ expected := reflect.New(arrayTy).Interface()
+
+ analysis.RandomizeValue(expected)
+
+ limit := (length * elem.SizeBytes()) / 2
+ // Also make sure the limit is partially inside one of the elements.
+ limit += elem.SizeBytes() / 2
+ analysis.RandomizeValue(expected)
+
+ var task mockTask
+ task.populate(expected)
+ task.setLimit(limit)
+
+ n, err := tt.copySliceOut(&task, usermem.Addr(0), expected)
+ if n != limit {
+ t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if n < length*elem.SizeBytes() && err != simulatedErr {
+ t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := tt.unsafeMemory(expected)
+ defer runtime.KeepAlive(expected)
+ actualMem := task.taskMem.Bytes
+
+ compareMemory(t, expectedMem, actualMem, n)
+ })
+ }
+}
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
index c829db6da..f75ca1b7f 100644
--- a/tools/go_marshal/test/test.go
+++ b/tools/go_marshal/test/test.go
@@ -23,7 +23,7 @@ import (
// Type1 is a test data type.
//
-// +marshal
+// +marshal slice:Type1Slice
type Type1 struct {
a Type2
x, y int64 // Multiple field names.
@@ -75,6 +75,34 @@ type Type5 struct {
m int64
}
+// Type6 is a test data type ends mid-word.
+//
+// +marshal
+type Type6 struct {
+ a int64
+ b int64
+ // If c isn't marked unaligned, analysis fails (as it should, since
+ // the unsafe API corrupts Type7).
+ c byte `marshal:"unaligned"`
+}
+
+// Type7 is a test data type that contains a child struct that ends
+// mid-word.
+// +marshal
+type Type7 struct {
+ x Type6
+ y int64
+}
+
+// Type8 is a test data type which contains an external non-packed field.
+//
+// +marshal slice:Type8Slice
+type Type8 struct {
+ a int64
+ np ex.NotPacked
+ b int64
+}
+
// Timespec represents struct timespec in <time.h>.
//
// +marshal
@@ -85,7 +113,7 @@ type Timespec struct {
// Stat represents struct stat.
//
-// +marshal
+// +marshal slice:StatSlice
type Stat struct {
Dev uint64
Ino uint64
@@ -111,10 +139,38 @@ type InetAddr [4]byte
// SignalSet is an example marshallable newtype on a primitive.
//
-// +marshal
+// +marshal slice:SignalSetSlice:inner
type SignalSet uint64
// SignalSetAlias is an example newtype on another marshallable type.
//
-// +marshal
+// +marshal slice:SignalSetAliasSlice
type SignalSetAlias SignalSet
+
+const sizeA = 64
+const sizeB = 8
+
+// TestArray is a test data structure on an array with a constant length.
+//
+// +marshal
+type TestArray [sizeA]int32
+
+// TestArray2 is a newtype on an array with a simple arithmetic expression of
+// constants for the array length.
+//
+// +marshal
+type TestArray2 [sizeA * sizeB]int32
+
+// TestArray2 is a newtype on an array with a simple arithmetic expression of
+// mixed constants and literals for the array length.
+//
+// +marshal
+type TestArray3 [sizeA*sizeB + 12]int32
+
+// Type9 is a test data type containing an array with a non-literal length.
+//
+// +marshal
+type Type9 struct {
+ x int64
+ y [sizeA]int32
+}
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
index 3437aa476..309ee9c21 100644
--- a/tools/go_stateify/main.go
+++ b/tools/go_stateify/main.go
@@ -206,7 +206,7 @@ func main() {
initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name))
}
emitZeroCheck := func(name string) {
- fmt.Fprintf(outputFile, " if !%sIsZeroValue(x.%s) { m.Failf(\"%s is %%v, expected zero\", x.%s) }\n", statePrefix, name, name, name)
+ fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { m.Failf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, name, name)
}
emitLoadValue := func(name, typName string) {
fmt.Fprintf(outputFile, " m.LoadValue(\"%s\", new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", name, typName, camelCased(name), typName)
diff --git a/tools/image_build.sh b/tools/image_build.sh
deleted file mode 100755
index 9b20a740d..000000000
--- a/tools/image_build.sh
+++ /dev/null
@@ -1,98 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# This script is responsible for building a new GCP image that: 1) has nested
-# virtualization enabled, and 2) has been completely set up with the
-# image_setup.sh script. This script should be idempotent, as we memoize the
-# setup script with a hash and check for that name.
-#
-# The GCP project name should be defined via a gcloud config.
-
-set -xeo pipefail
-
-# Parameters.
-declare -r ZONE=${ZONE:-us-central1-f}
-declare -r USERNAME=${USERNAME:-test}
-declare -r IMAGE_PROJECT=${IMAGE_PROJECT:-ubuntu-os-cloud}
-declare -r IMAGE_FAMILY=${IMAGE_FAMILY:-ubuntu-1604-lts}
-
-# Random names.
-declare -r DISK_NAME=$(mktemp -u disk-XXXXXX | tr A-Z a-z)
-declare -r SNAPSHOT_NAME=$(mktemp -u snapshot-XXXXXX | tr A-Z a-z)
-declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z)
-
-# Hashes inputs.
-declare -r SETUP_BLOB=$(echo ${ZONE} ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && sha256sum "$@")
-declare -r SETUP_HASH=$(echo ${SETUP_BLOB} | sha256sum - | cut -d' ' -f1 | cut -c 1-16)
-declare -r IMAGE_NAME=${IMAGE_NAME:-image-}${SETUP_HASH}
-
-# Does the image already exist? Skip the build.
-declare -r existing=$(gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)")
-if ! [[ -z "${existing}" ]]; then
- echo "${existing}"
- exit 0
-fi
-
-# Set the zone for all actions.
-gcloud config set compute/zone "${ZONE}"
-
-# Start a unique instance. Note that this instance will have a unique persistent
-# disk as it's boot disk with the same name as the instance.
-gcloud compute instances create \
- --quiet \
- --image-project "${IMAGE_PROJECT}" \
- --image-family "${IMAGE_FAMILY}" \
- --boot-disk-size "200GB" \
- "${INSTANCE_NAME}"
-function cleanup {
- gcloud compute instances delete --quiet "${INSTANCE_NAME}"
-}
-trap cleanup EXIT
-
-# Wait for the instance to become available.
-declare attempts=0
-while [[ "${attempts}" -lt 30 ]]; do
- attempts=$((${attempts}+1))
- if gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- true; then
- break
- fi
-done
-if [[ "${attempts}" -ge 30 ]]; then
- echo "too many attempts: failed"
- exit 1
-fi
-
-# Run the install scripts provided.
-for arg; do
- gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- sudo bash - <"${arg}"
-done
-
-# Stop the instance; required before creating an image.
-gcloud compute instances stop --quiet "${INSTANCE_NAME}"
-
-# Create a snapshot of the instance disk.
-gcloud compute disks snapshot \
- --quiet \
- --zone="${ZONE}" \
- --snapshot-names="${SNAPSHOT_NAME}" \
- "${INSTANCE_NAME}"
-
-# Create the disk image.
-gcloud compute images create \
- --quiet \
- --source-snapshot="${SNAPSHOT_NAME}" \
- --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \
- "${IMAGE_NAME}"
diff --git a/tools/images/BUILD b/tools/images/BUILD
index 66ffd02aa..8d319e3e4 100644
--- a/tools/images/BUILD
+++ b/tools/images/BUILD
@@ -6,14 +6,9 @@ package(
licenses = ["notice"],
)
-genrule(
+sh_binary(
name = "zone",
- outs = ["zone.txt"],
- cmd = "gcloud config get-value compute/zone > \"$@\"",
- tags = [
- "local",
- "manual",
- ],
+ srcs = ["zone.sh"],
)
sh_binary(
diff --git a/tools/images/README.md b/tools/images/README.md
new file mode 100644
index 000000000..26c0f84f2
--- /dev/null
+++ b/tools/images/README.md
@@ -0,0 +1,42 @@
+# Images
+
+All commands in this directory require the `gcloud` project to be set.
+
+For example: `gcloud config set project gvisor-kokoro-testing`.
+
+Images can be generated by using the `vm_image` rule. This rule will generate a
+binary target that builds an image in an idempotent way, and can be referenced
+from other rules.
+
+For example:
+
+```
+vm_image(
+ name = "ubuntu",
+ project = "ubuntu-1604-lts",
+ family = "ubuntu-os-cloud",
+ scripts = [
+ "script.sh",
+ "other.sh",
+ ],
+)
+```
+
+These images can be built manually by executing the target. The output on
+`stdout` will be the image id (in the current project).
+
+Images are always named per the hash of all the hermetic input scripts. This
+allows images to be memoized quickly and easily.
+
+The `vm_test` rule can be used to execute a command remotely. This is still
+under development however, and will likely change over time.
+
+For example:
+
+```
+vm_test(
+ name = "mycommand",
+ image = ":ubuntu",
+ targets = [":test"],
+)
+```
diff --git a/tools/images/build.sh b/tools/images/build.sh
index f89f39cbd..f39f723b8 100755
--- a/tools/images/build.sh
+++ b/tools/images/build.sh
@@ -19,7 +19,7 @@
# image_setup.sh script. This script should be idempotent, as we memoize the
# setup script with a hash and check for that name.
-set -xeou pipefail
+set -eou pipefail
# Parameters.
declare -r USERNAME=${USERNAME:-test}
@@ -34,10 +34,10 @@ declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z)
# Hash inputs in order to memoize the produced image.
declare -r SETUP_HASH=$( (echo ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && cat "$@") | sha256sum - | cut -d' ' -f1 | cut -c 1-16)
-declare -r IMAGE_NAME=${IMAGE_FAMILY:-image-}${SETUP_HASH}
+declare -r IMAGE_NAME=${IMAGE_FAMILY:-image}-${SETUP_HASH}
# Does the image already exist? Skip the build.
-declare -r existing=$(gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)")
+declare -r existing=$(set -x; gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)")
if ! [[ -z "${existing}" ]]; then
echo "${existing}"
exit 0
@@ -48,28 +48,30 @@ export PATH=${PATH:-/bin:/usr/bin:/usr/local/bin}
# Start a unique instance. Note that this instance will have a unique persistent
# disk as it's boot disk with the same name as the instance.
-gcloud compute instances create \
+(set -x; gcloud compute instances create \
--quiet \
--image-project "${IMAGE_PROJECT}" \
--image-family "${IMAGE_FAMILY}" \
--boot-disk-size "200GB" \
--zone "${ZONE}" \
- "${INSTANCE_NAME}" >/dev/null
+ "${INSTANCE_NAME}" >/dev/null)
function cleanup {
- gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}"
+ (set -x; gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}")
}
trap cleanup EXIT
# Wait for the instance to become available (up to 5 minutes).
+echo -n "Waiting for ${INSTANCE_NAME}"
declare timeout=300
declare success=0
declare internal=""
declare -r start=$(date +%s)
declare -r end=$((${start}+${timeout}))
while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do
- if gcloud compute ssh --zone "${internal}" "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- env - true 2>/dev/null; then
+ echo -n "."
+ if gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- env - true 2>/dev/null; then
success=$((${success}+1))
- elif gcloud compute ssh --zone --internal-ip "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- env - true 2>/dev/null; then
+ elif gcloud compute ssh --internal-ip --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- env - true 2>/dev/null; then
success=$((${success}+1))
internal="--internal-ip"
fi
@@ -78,29 +80,34 @@ done
if [[ "${success}" -eq "0" ]]; then
echo "connect timed out after ${timeout} seconds."
exit 1
+else
+ echo "done."
fi
# Run the install scripts provided.
for arg; do
- gcloud compute ssh --zone "${internal}" "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- sudo bash - <"${arg}" >/dev/null
+ (set -x; gcloud compute ssh ${internal} \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}" -- \
+ sudo bash - <"${arg}" >/dev/null)
done
# Stop the instance; required before creating an image.
-gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null
+(set -x; gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null)
# Create a snapshot of the instance disk.
-gcloud compute disks snapshot \
+(set -x; gcloud compute disks snapshot \
--quiet \
--zone "${ZONE}" \
--snapshot-names="${SNAPSHOT_NAME}" \
- "${INSTANCE_NAME}" >/dev/null
+ "${INSTANCE_NAME}" >/dev/null)
# Create the disk image.
-gcloud compute images create \
+(set -x; gcloud compute images create \
--quiet \
--source-snapshot="${SNAPSHOT_NAME}" \
--licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \
- "${IMAGE_NAME}" >/dev/null
+ "${IMAGE_NAME}" >/dev/null)
# Finish up.
echo "${IMAGE_NAME}"
diff --git a/tools/images/defs.bzl b/tools/images/defs.bzl
index de365d153..2847e1847 100644
--- a/tools/images/defs.bzl
+++ b/tools/images/defs.bzl
@@ -1,76 +1,49 @@
-"""Image configuration.
-
-Images can be generated by using the vm_image rule. For example,
-
- vm_image(
- name = "ubuntu",
- project = "...",
- family = "...",
- scripts = [
- "script.sh",
- "other.sh",
- ],
- )
-
-This will always create an vm_image in the current default gcloud project. The
-rule has a text file as its output containing the image name. This will enforce
-serialization for all dependent rules.
-
-Images are always named per the hash of all the hermetic input scripts. This
-allows images to be memoized quickly and easily.
-
-The vm_test rule can be used to execute a command remotely. For example,
-
- vm_test(
- name = "mycommand",
- image = ":myimage",
- targets = [":test"],
- )
-"""
+"""Image configuration. See README.md."""
load("//tools:defs.bzl", "default_installer")
-def _vm_image_impl(ctx):
+# vm_image_builder is a rule that will construct a shell script that actually
+# generates a given VM image. Note that this does not _run_ the shell script
+# (although it can be run manually). It will be run manually during generation
+# of the vm_image target itself. This level of indirection is used so that the
+# build system itself only runs the builder once when multiple targets depend
+# on it, avoiding a set of races and conflicts.
+def _vm_image_builder_impl(ctx):
+ # Generate a binary that actually builds the image.
+ builder = ctx.actions.declare_file(ctx.label.name)
script_paths = []
for script in ctx.files.scripts:
script_paths.append(script.short_path)
+ builder_content = "\n".join([
+ "#!/bin/bash",
+ "export ZONE=$(%s)" % ctx.files.zone[0].short_path,
+ "export USERNAME=%s" % ctx.attr.username,
+ "export IMAGE_PROJECT=%s" % ctx.attr.project,
+ "export IMAGE_FAMILY=%s" % ctx.attr.family,
+ "%s %s" % (ctx.files._builder[0].short_path, " ".join(script_paths)),
+ "",
+ ])
+ ctx.actions.write(builder, builder_content, is_executable = True)
- resolved_inputs, argv, runfiles_manifests = ctx.resolve_command(
- command = "USERNAME=%s ZONE=$(cat %s) IMAGE_PROJECT=%s IMAGE_FAMILY=%s %s %s > %s" %
- (
- ctx.attr.username,
- ctx.files.zone[0].path,
- ctx.attr.project,
- ctx.attr.family,
- ctx.executable.builder.path,
- " ".join(script_paths),
- ctx.outputs.out.path,
- ),
- tools = [ctx.attr.builder] + ctx.attr.scripts,
- )
-
- ctx.actions.run_shell(
- tools = resolved_inputs,
- outputs = [ctx.outputs.out],
- progress_message = "Building image...",
- execution_requirements = {"local": "true"},
- command = argv,
- input_manifests = runfiles_manifests,
- )
+ # Note that the scripts should only be files, and should not include any
+ # indirect transitive dependencies. The build script wouldn't work.
return [DefaultInfo(
- files = depset([ctx.outputs.out]),
- runfiles = ctx.runfiles(files = [ctx.outputs.out]),
+ executable = builder,
+ runfiles = ctx.runfiles(
+ files = ctx.files.scripts + ctx.files._builder + ctx.files.zone,
+ ),
)]
-_vm_image = rule(
+vm_image_builder = rule(
attrs = {
- "builder": attr.label(
+ "_builder": attr.label(
executable = True,
default = "//tools/images:builder",
cfg = "host",
),
"username": attr.string(default = "$(whoami)"),
"zone": attr.label(
+ executable = True,
default = "//tools/images:zone",
cfg = "host",
),
@@ -78,20 +51,55 @@ _vm_image = rule(
"project": attr.string(mandatory = True),
"scripts": attr.label_list(allow_files = True),
},
- outputs = {
- "out": "%{name}.txt",
+ executable = True,
+ implementation = _vm_image_builder_impl,
+)
+
+# See vm_image_builder above.
+def _vm_image_impl(ctx):
+ # Run the builder to generate our output.
+ echo = ctx.actions.declare_file(ctx.label.name)
+ resolved_inputs, argv, runfiles_manifests = ctx.resolve_command(
+ command = "echo -ne \"#!/bin/bash\\necho $(%s)\\n\" > %s && chmod 0755 %s" % (
+ ctx.files.builder[0].path,
+ echo.path,
+ echo.path,
+ ),
+ tools = [ctx.attr.builder],
+ )
+ ctx.actions.run_shell(
+ tools = resolved_inputs,
+ outputs = [echo],
+ progress_message = "Building image...",
+ execution_requirements = {"local": "true"},
+ command = argv,
+ input_manifests = runfiles_manifests,
+ )
+
+ # Return just the echo command. All of the builder runfiles have been
+ # resolved and consumed in the generation of the trivial echo script.
+ return [DefaultInfo(executable = echo)]
+
+_vm_image = rule(
+ attrs = {
+ "builder": attr.label(
+ executable = True,
+ cfg = "host",
+ ),
},
+ executable = True,
implementation = _vm_image_impl,
)
-def vm_image(**kwargs):
- _vm_image(
- tags = [
- "local",
- "manual",
- ],
+def vm_image(name, **kwargs):
+ vm_image_builder(
+ name = name + "_builder",
**kwargs
)
+ _vm_image(
+ name = name,
+ builder = ":" + name + "_builder",
+ )
def _vm_test_impl(ctx):
runner = ctx.actions.declare_file("%s-executer" % ctx.label.name)
diff --git a/tools/images/zone.sh b/tools/images/zone.sh
new file mode 100755
index 000000000..79569fb19
--- /dev/null
+++ b/tools/images/zone.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+exec gcloud config get-value compute/zone
diff --git a/tools/nogo.json b/tools/nogo.json
index 83cb76b93..ae969409e 100644
--- a/tools/nogo.json
+++ b/tools/nogo.json
@@ -9,27 +9,6 @@
"/external/": "allowed: not subject to unsafe naming rules"
}
},
- "copylocks": {
- "exclude_files": {
- ".*_state_autogen.go": "fix: m.Failf copies by value",
- "/pkg/log/json.go": "fix: Emit passes lock by value: gvisor.dev/gvisor/pkg/log.JSONEmitter contains gvisor.dev/gvisor/pkg/log.Writer contains gvisor.dev/gvisor/pkg/sync.Mutex",
- "/pkg/log/log_test.go": "fix: call of fmt.Printf copies lock value: gvisor.dev/gvisor/pkg/log.Writer contains gvisor.dev/gvisor/pkg/sync.Mutex",
- "/pkg/sentry/fs/host/socket_test.go": "fix: call of t.Errorf copies lock value: gvisor.dev/gvisor/pkg/sentry/fs/host.ConnectedEndpoint contains gvisor.dev/gvisor/pkg/refs.AtomicRefCount contains gvisor.dev/gvisor/pkg/sync.Mutex",
- "/pkg/sentry/fs/proc/sys_net.go": "fix: Truncate passes lock by value: gvisor.dev/gvisor/pkg/sentry/fs/proc.tcpMemInode contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.SimpleFileInode contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.InodeSimpleAttributes contains gvisor.dev/gvisor/pkg/sync.RWMutex",
- "/pkg/sentry/fs/proc/sys_net.go": "fix: Truncate passes lock by value: gvisor.dev/gvisor/pkg/sentry/fs/proc.tcpSack contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.SimpleFileInode contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.InodeSimpleAttributes contains gvisor.dev/gvisor/pkg/sync.RWMutex",
- "/pkg/sentry/fs/tty/slave.go": "fix: Truncate passes lock by value: gvisor.dev/gvisor/pkg/sentry/fs/tty.slaveInodeOperations contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.SimpleFileInode contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.InodeSimpleAttributes contains gvisor.dev/gvisor/pkg/sync.RWMutex",
- "/pkg/sentry/kernel/time/time.go": "fix: Readiness passes lock by value: gvisor.dev/gvisor/pkg/sentry/kernel/time.ClockEventsQueue contains gvisor.dev/gvisor/pkg/waiter.Queue contains gvisor.dev/gvisor/pkg/sync.RWMutex",
- "/pkg/sentry/kernel/syscalls_state.go": "fix: assignment copies lock value to *s: gvisor.dev/gvisor/pkg/sentry/kernel.SyscallTable contains gvisor.dev/gvisor/pkg/sentry/kernel.SyscallFlagsTable contains gvisor.dev/gvisor/pkg/sync.Mutex"
- }
- },
- "lostcancel": {
- "exclude_files": {
- "/pkg/tcpip/network/arp/arp_test.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak",
- "/pkg/tcpip/stack/ndp_test.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak",
- "/pkg/tcpip/transport/udp/udp_test.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak",
- "/pkg/tcpip/transport/tcp/testing/context/context.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak"
- }
- },
"nilness": {
"exclude_files": {
"/com_github_vishvananda_netlink/route_linux.go": "allowed: false positive",
@@ -40,37 +19,6 @@
"/external/io_opencensus_go/tag/map_codec.go": "allowed: false positive"
}
},
- "printf": {
- "exclude_files": {
- ".*_abi_autogen_test.go": "fix: Sprintf format has insufficient args",
- "/pkg/segment/test/segment_test.go": "fix: Errorf format %d arg seg.Start is a func value, not called",
- "/pkg/tcpip/tcpip_test.go": "fix: Error call has possible formatting directive %q",
- "/pkg/tcpip/header/eth_test.go": "fix: Fatalf format %s reads arg #3, but call has 2 args",
- "/pkg/tcpip/header/ndp_test.go": "fix: Errorf format %d reads arg #1, but call has 0 args",
- "/pkg/eventchannel/event_test.go": "fix: Fatal call has possible formatting directive %v",
- "/pkg/tcpip/stack/ndp.go": "fix: Fatalf format %s has arg protocolAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress",
- "/pkg/sentry/fs/fdpipe/pipe_test.go": "fix: Errorf format %s has arg flags of wrong type gvisor.dev/gvisor/pkg/sentry/fs.FileFlags",
- "/pkg/sentry/fs/fdpipe/pipe_test.go": "fix: Errorf format %d arg f.FD is a func value, not called",
- "/pkg/tcpip/link/fdbased/endpoint.go": "fix: Sprintf format %v with arg p causes recursive String method call",
- "/pkg/tcpip/transport/udp/udp_test.go": "fix: Fatalf format %s has arg h.srcAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.FullAddress",
- "/pkg/tcpip/transport/tcp/tcp_test.go": "fix: Fatalf format %s has arg tcpTW of wrong type gvisor.dev/gvisor/pkg/tcpip.TCPTimeWaitTimeoutOption",
- "/pkg/tcpip/transport/tcp/tcp_test.go": "fix: Errorf call needs 1 arg but has 2 args",
- "/pkg/tcpip/stack/ndp_test.go": "fix: Errorf format %s reads arg #3, but call has 2 args",
- "/pkg/tcpip/stack/ndp_test.go": "fix: Fatalf format %s reads arg #5, but call has 4 args",
- "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf format %s has arg protoAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress",
- "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf format %s has arg nic1ProtoAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress",
- "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf format %s has arg nic2ProtoAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress",
- "/pkg/tcpip/stack/stack_test.go": "fix: Fatal call has possible formatting directive %t",
- "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf call has arguments but no formatting directives",
- "/pkg/tcpip/link/fdbased/endpoint.go": "fix: Sprintf format %v with arg p causes recursive String method call",
- "/pkg/sentry/fsimpl/tmpfs/stat_test.go": "fix: Errorf format %v reads arg #1, but call has 0 args",
- "/runsc/container/test_app/test_app.go": "fix: Fatal call has possible formatting directive %q",
- "/test/root/cgroup_test.go": "fix: Errorf format %s has arg gots of wrong type []int",
- "/test/root/cgroup_test.go": "fix: Fatalf format %v reads arg #3, but call has 2 args",
- "/test/runtimes/runner.go": "fix: Skip call has possible formatting directive %q",
- "/test/runtimes/blacklist_test.go": "fix: Errorf format %q has arg blacklistFile of wrong type *string"
- }
- },
"structtag": {
"exclude_files": {
"/external/": "allowed: may use arbitrary tags"
@@ -83,16 +31,9 @@
"/pkg/gohacks/gohacks_unsafe.go": "allowed: special case",
"/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go": "allowed: special case",
"/pkg/sentry/platform/kvm/(bluepill|machine)_unsafe.go": "allowed: special case",
- "/pkg/sentry/platform/kvm/machine_arm64_unsafe.go": "fix: gvisor.dev/issue/22464",
"/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go": "allowed: special case",
"/pkg/sentry/platform/safecopy/safecopy_unsafe.go": "allowed: special case",
"/pkg/sentry/vfs/mount_unsafe.go": "allowed: special case"
}
- },
- "unusedresult": {
- "exclude_files": {
- "/pkg/sentry/fsimpl/proc/task_net.go": "fix: result of fmt.Sprintf call not used",
- "/pkg/sentry/fsimpl/proc/tasks_net.go": "fix: result of fmt.Sprintf call not used"
- }
}
}